├── .gitignore ├── .vscode └── launch.json ├── LICENSE ├── MANIFEST.in ├── Makefile ├── README.md ├── __init__.py ├── app ├── __init__.py ├── api │ ├── __init__.py │ └── v1 │ │ ├── __init__.py │ │ ├── buia.py │ │ ├── geojson.py │ │ ├── job.py │ │ ├── predict.py │ │ ├── predict_buildings.py │ │ ├── task.py │ │ ├── task_admin.py │ │ ├── test.py │ │ ├── tools.py │ │ ├── train.py │ │ └── wmts.py ├── app.py ├── config │ ├── __init__.py │ ├── secure.py │ └── setting.py ├── libs │ ├── enums.py │ ├── error.py │ ├── error_code.py │ ├── redprint.py │ ├── scope.py │ ├── token_auth.py │ ├── utils.py │ └── utils_geom.py └── models │ ├── __init__.py │ ├── base.py │ ├── buia.py │ ├── predict_buildings.py │ ├── task.py │ └── task_admin.py ├── arcpy_geoc ├── .vscode │ ├── launch.json │ └── settings.json ├── regular_build.py ├── regular_command.py └── setting.py ├── batch_cover.py ├── config.toml ├── data ├── config.toml └── dataset-parking.toml ├── docs ├── 101.md ├── BUIA.sql ├── Linux安装指南.md ├── config.md ├── extensibility_by_design.md ├── from_opendata_to_opendataset.md ├── img │ ├── from_opendata_to_opendataset │ │ ├── compare.png │ │ ├── compare_clean.png │ │ ├── compare_side.png │ │ ├── compare_side_clean.png │ │ ├── compare_zoom_out.png │ │ ├── images.png │ │ ├── labels.png │ │ └── masks.png │ ├── quality_analysis │ │ ├── compare.png │ │ ├── compare_side.png │ │ ├── images.png │ │ ├── labels.png │ │ ├── masks.png │ │ ├── osm.png │ │ ├── predict_compare.png │ │ ├── predict_compare_side.png │ │ ├── predict_images.png │ │ └── predict_mask.png │ └── readme │ │ ├── data_preparation.png │ │ ├── draw_me_robosat_pink.png │ │ ├── minimal.png │ │ ├── stacks.png │ │ ├── top_example.jpeg │ │ └── 模型优化.png ├── makefile.md ├── predict_buildings.sql ├── task.sql ├── task_admin.sql └── tools.md ├── gunicorn_config.py ├── main.py ├── requirements.txt ├── robosat_pink ├── __init__.py ├── core.py ├── da │ ├── __init__.py │ ├── core.py │ └── strong.py ├── geoc │ ├── RSPcover.py │ ├── RSPpredict.py │ ├── RSPreturn_predict.py │ ├── RSPtrain.py │ ├── __init__.py │ ├── config.py │ ├── params.py │ ├── pg生成乡为单位的中心点geojson.sql │ └── utils.py ├── geojson.py ├── graph │ ├── __init__.py │ └── core.py ├── loaders │ ├── __init__.py │ └── semsegtiles.py ├── losses │ ├── __init__.py │ └── lovasz.py ├── metrics │ ├── __init__.py │ ├── core.py │ ├── iou.py │ ├── mcc.py │ └── qod.py ├── models │ ├── __init__.py │ └── albunet.py ├── osm │ ├── __init__.py │ ├── building.py │ └── road.py ├── spatial │ ├── __init__.py │ └── core.py ├── tiles.py ├── tools │ ├── __init__.py │ ├── __main__.py │ ├── compare.py │ ├── cover.py │ ├── download.py │ ├── export.py │ ├── extract.py │ ├── features.py │ ├── info.py │ ├── merge.py │ ├── predict.py │ ├── rasterize.py │ ├── subset.py │ ├── tile.py │ ├── train.py │ └── vectorize.py └── web_ui │ ├── compare.html │ └── leaflet.html ├── setup.py ├── test.py ├── tests ├── __init__.py ├── fixtures │ ├── images │ │ └── 18 │ │ │ ├── 69105 │ │ │ └── 105093.jpg │ │ │ └── 69108 │ │ │ ├── 105091.jpg │ │ │ └── 105092.jpg │ ├── labels │ │ └── 18 │ │ │ ├── 69105 │ │ │ └── 105093.png │ │ │ └── 69108 │ │ │ ├── 105091.png │ │ │ └── 105092.png │ ├── osm │ │ └── 18 │ │ │ ├── 69105 │ │ │ └── 105093.png │ │ │ └── 69108 │ │ │ ├── 105091.png │ │ │ └── 105092.png │ ├── parking │ │ ├── features.geojson │ │ ├── images │ │ │ └── 18 │ │ │ │ ├── 69623 │ │ │ │ └── 104946.webp │ │ │ │ ├── 70761 │ │ │ │ └── 104120.webp │ │ │ │ ├── 70762 │ │ │ │ └── 104119.webp │ │ │ │ └── 70763 │ │ │ │ └── 104119.webp │ │ ├── labels │ │ │ └── 18 │ │ │ │ ├── 69623 │ │ │ │ └── 104946.png │ │ │ │ ├── 70761 │ │ │ │ └── 104120.png │ │ │ │ ├── 70762 │ │ │ │ └── 104119.png │ │ │ │ └── 70763 │ │ │ │ └── 104119.png │ │ └── tiles.csv │ └── tiles.csv ├── loaders │ └── test_semsegtiles.py ├── test_tiles.py └── tools │ └── test_rasterize.py ├── webmap ├── .gitignore ├── README.md ├── babel.config.js ├── dist │ ├── config.js │ ├── css │ │ ├── app.823ee787.css │ │ └── chunk-vendors.a84ffaf8.css │ ├── favicon.ico │ ├── index.html │ ├── js │ │ ├── app.8bf85688.js │ │ ├── app.8bf85688.js.map │ │ ├── chunk-vendors.5109e7b2.js │ │ └── chunk-vendors.5109e7b2.js.map │ ├── style.json │ └── test.json ├── package-lock.json ├── package.json ├── public │ ├── config.js │ ├── favicon.ico │ ├── index.html │ ├── style.json │ └── test.json └── src │ ├── App.vue │ ├── assets │ └── logo.png │ ├── components │ └── HomeMap.vue │ └── main.js └── xyz_proxy.py /.gitignore: -------------------------------------------------------------------------------- 1 | .python-version 2 | __pycache__ 3 | 4 | build 5 | dist 6 | RoboSat_geoc.egg-info 7 | 8 | test 9 | .idea 10 | ds 11 | .DS_Store 12 | .vscode/settings.json 13 | dp 14 | data/model 15 | dataset 16 | out.log 17 | .vscode/launch.json 18 | .env 19 | log -------------------------------------------------------------------------------- /.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | // 使用 IntelliSense 了解相关属性。 3 | // 悬停以查看现有属性的描述。 4 | // 欲了解更多信息,请访问: https://go.microsoft.com/fwlink/?linkid=830387 5 | "version": "0.2.0", 6 | "configurations": [ 7 | { 8 | "name": "Python: 当前文件", 9 | "type": "python", 10 | "request": "launch", 11 | "program": "${file}", 12 | "console": "integratedTerminal" 13 | } 14 | ] 15 | } 16 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) <2018> MapBox 4 | Copyright (c) <2018-2019> DataPink 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include LICENSE 2 | include requirements.txt 3 | include robosat_pink/web_ui/compare.html 4 | include robosat_pink/web_ui/leaflet.html 5 | recursive-include data * 6 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | help: 2 | @echo "This Makefile rules are designed for RoboSat.pink devs and power-users." 3 | @echo "For plain user installation follow README.md instructions, instead." 4 | @echo "" 5 | @echo "" 6 | @echo " make install To install, few Python dev tools and RoboSat.pink in editable mode." 7 | @echo " So any further RoboSat.pink Python code modification will be usable at once," 8 | @echo " throught either rsp tools commands or robosat_pink.* modules." 9 | @echo "" 10 | @echo " make check Launchs code tests, and tools doc updating." 11 | @echo " Do it, at least, before sending a Pull Request." 12 | @echo "" 13 | @echo " make check_tuto Launchs rsp commands embeded in tutorials, to be sure everything still up to date." 14 | @echo " Do it, at least, on each CLI modifications, and before a release." 15 | @echo " NOTA: It takes a while." 16 | @echo "" 17 | @echo " make pink Python code beautifier," 18 | @echo " as Pink is the new Black ^^" 19 | 20 | 21 | 22 | # Dev install 23 | install: 24 | pip3 install pytest black flake8 twine 25 | pip3 install -e . 26 | 27 | 28 | # Lauch all tests 29 | check: ut it doc 30 | @echo "===================================================================================" 31 | @echo "All tests passed !" 32 | @echo "===================================================================================" 33 | 34 | 35 | # Python code beautifier 36 | pink: 37 | black -l 125 *.py robosat_pink/*.py robosat_pink/*/*.py tests/*py tests/*/*.py 38 | 39 | 40 | # Perform units tests, and linter checks 41 | ut: 42 | @echo "===================================================================================" 43 | black -l 125 --check *.py robosat_pink/*.py robosat_pink/*/*.py 44 | @echo "===================================================================================" 45 | flake8 --max-line-length 125 --ignore=E203,E241,E226,E272,E261,E221,W503,E722 46 | @echo "===================================================================================" 47 | pytest tests -W ignore::UserWarning 48 | 49 | 50 | # Launch Integration Tests 51 | it: it_pre it_train it_post 52 | 53 | 54 | # Integration Tests: Data Preparation 55 | it_pre: 56 | @echo "===================================================================================" 57 | rm -rf it 58 | rsp cover --zoom 18 --bbox 4.8,45.7,4.82,45.72 it/cover 59 | rsp download --rate 20 --type WMS 'https://download.data.grandlyon.com/wms/grandlyon?SERVICE=WMS&REQUEST=GetMap&VERSION=1.3.0&LAYERS=Ortho2015_vue_ensemble_16cm_CC46&WIDTH=512&HEIGHT=512&CRS=EPSG:3857&BBOX={xmin},{ymin},{xmax},{ymax}&FORMAT=image/jpeg' it/cover it/images 60 | echo "Download GeoJSON" && wget --show-progress -q -nc -O it/lyon_roofprint.json 'https://download.data.grandlyon.com/wfs/grandlyon?SERVICE=WFS&REQUEST=GetFeature&TYPENAME=ms:fpc_fond_plan_communaut.fpctoit&VERSION=1.1.0&srsName=EPSG:4326&BBOX=4.8,45.7,4.82,45.72&outputFormat=application/json; subtype=geojson' | true 61 | rsp rasterize --type Building --geojson it/lyon_roofprint.json --config config.toml --cover it/cover it/labels 62 | echo "Download PBF" && wget --show-progress -q -O it/lyon.pbf http://datapink.tools/rsp/it/lyon.pbf 63 | rsp extract --type Building it/lyon.pbf it/osm_lyon_footprint.json 64 | rsp rasterize --type Building --geojson it/lyon_roofprint.json --config config.toml --cover it/cover it/labels_osm 65 | rsp cover --dir it/images --splits 80/20 it/training/cover it/validation/cover 66 | rsp subset --dir it/images --cover it/training/cover it/training/images 67 | rsp subset --dir it/labels --cover it/training/cover it/training/labels 68 | rsp subset --dir it/images --cover it/validation/cover it/validation/images 69 | rsp subset --dir it/labels --cover it/validation/cover it/validation/labels 70 | wget -nc -O it/tanzania.tif http://datapink.tools/rsp/it/tanzania.tif 71 | rsp tile --zoom 19 it/tanzania.tif it/prediction/images 72 | rsp cover --zoom 19 --dir it/prediction/images it/prediction/cover 73 | wget -nc -O it/tanzania.geojson http://datapink.tools/rsp/it/tanzania.geojson 74 | rsp rasterize --type Building --geojson it/tanzania.geojson --config config.toml --cover it/prediction/cover it/prediction/labels 75 | 76 | 77 | 78 | # Integration Tests: Training 79 | it_train: 80 | @echo "===================================================================================" 81 | rsp train --config config.toml --workers 2 --bs 2 --lr 0.00025 --epochs 2 it it/pth 82 | rsp train --config config.toml --workers 2 --bs 2 --lr 0.00025 --epochs 3 --resume --checkpoint it/pth/checkpoint-00002.pth it it/pth 83 | 84 | 85 | # Integration Tests: Post Training 86 | it_post: 87 | @echo "===================================================================================" 88 | rsp export --checkpoint it/pth/checkpoint-00003.pth --type jit it/pth/export.jit 89 | rsp export --checkpoint it/pth/checkpoint-00003.pth --type onnx it/pth/export.onnx 90 | rsp predict --config config.toml --bs 4 --checkpoint it/pth/checkpoint-00003.pth it/prediction it/prediction/masks 91 | rsp compare --images it/prediction/images it/prediction/labels it/prediction/masks --mode stack --labels it/prediction/labels --masks it/prediction/masks it/prediction/compare 92 | rsp compare --images it/prediction/images it/prediction/compare --mode side it/prediction/compare_side 93 | rsp compare --mode list --labels it/prediction/labels --maximum_qod 75 --minimum_fg 5 --masks it/prediction/masks --geojson it/prediction/compare/tiles.json 94 | cp it/prediction/compare/tiles.json it/prediction/compare_side/tiles.json 95 | rsp vectorize --type Building --config config.toml it/prediction/masks it/prediction/vector.json 96 | 97 | 98 | # Documentation generation (tools and config file) 99 | doc: 100 | @echo "===================================================================================" 101 | @echo "# RoboSat.pink tools documentation" > docs/tools.md 102 | @for tool in `ls robosat_pink/tools/[^_]*py | sed -e 's#.*/##g' -e 's#.py##'`; do \ 103 | echo "Doc generation: $$tool"; \ 104 | echo "## rsp $$tool" >> docs/tools.md; \ 105 | echo '```' >> docs/tools.md; \ 106 | rsp $$tool -h >> docs/tools.md; \ 107 | echo '```' >> docs/tools.md; \ 108 | done 109 | @echo "Doc generation: config.toml" 110 | @echo "## config.toml" > docs/config.md; \ 111 | echo '```' >> docs/config.md; \ 112 | cat config.toml >> docs/config.md; \ 113 | echo '```' >> docs/config.md; 114 | @echo "Doc generation: Makefile" 115 | @echo "## Makefile" > docs/makefile.md; \ 116 | echo '```' >> docs/makefile.md; \ 117 | make --no-print-directory >> docs/makefile.md; \ 118 | echo '```' >> docs/makefile.md; 119 | 120 | 121 | # Check rsp commands embeded in Documentation 122 | check_doc: 123 | @echo "===================================================================================" 124 | @echo "Checking README:" 125 | @echo "===================================================================================" 126 | @rm -rf ds && sed -n -e '/```bash/,/```/ p' README.md | sed -e '/```/d' > .CHECK && sh .CHECK 127 | @echo "===================================================================================" 128 | 129 | 130 | # Check rsp commands embeded in Tutorials 131 | check_tuto: 132 | @mkdir tuto 133 | @echo "===================================================================================" 134 | @echo "Checking 101" 135 | @sudo su postgres -c 'dropdb tanzania' || : 136 | @cd tuto && mkdir 101 && sed -n -e '/```bash/,/```/ p' ../docs/101.md | sed -e '/```/d' > 101/.CHECK && cd 101 && sh .CHECK && cd .. 137 | @echo "===================================================================================" 138 | @echo "Checking Tutorial OpenData to OpenDataset:" 139 | @cd tuto && mkdir gl && sed -n -e '/```bash/,/```/ p' ../docs/from_opendata_to_opendataset.md | sed -e '/```/d' > gl/.CHECK && cd gl && sh .CHECK && cd .. 140 | @echo "===================================================================================" 141 | 142 | 143 | # Send a release on PyPI 144 | pypi: 145 | rm -rf dist RoboSat.pink.egg-info 146 | python3 setup.py sdist 147 | twine upload dist/* -r pypi 148 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 |

RoboSat_geoc

4 |

从标准WMTS影像中提取建筑物的深度学习框架

5 |

forked by mapbox/robosat & Robosat.pink

6 |

7 | RoboSat_Geoc buildings segmentation from Imagery 8 |

9 | 10 | ## 简介: 11 | 12 | `RoboSat.geoc` 由 [mapbox/robosat](https://github.com/mapbox/robosat) 及 [Robosat.pink](https://github.com/datapink/robosat.pink) fork 而来。 13 | 14 | 利用深度学习工具,可以很方便的使用标准 WMTS 影像对建筑物轮廓提取进行训练和预测。 15 | 16 | ## 目的: 17 | 18 | - `Mapbox/Robosat` 是非常不错的建筑物提取工具,`Robosat.pink` 对其做了重构和改造,使其易用性得到了提升。 19 | - `Robosat.geoc` 在 `Robosat.pink` 的基础上,做了自动化和工程化改造,并可以结合 [rs_buildings_extraction](https://github.com/geocompass/rs_buildings_extraction) ,使用可视化界面和接口的方式进行训练和预测,很方便的用于生产环境。 20 | 21 | ## 主要功能: 22 | 23 | - 继承了`RoboSat.pink` 的所有功能: 24 | - 提供了命令行工具,可以很方便的进行批处理 25 | - 遵循了 WMTS 服务标准,方便遥感影像数据的准备 26 | - 内置了最先进的计算机视觉模型,并可以自行拓展 27 | - 支持 RGB 和多波段影像,并允许数据融合 28 | - 提供了 Web 界面工具,可以轻松的显示、对比、选择训练结果 29 | - 高性能 30 | - 很轻松的能够拓展 31 | - 等等 32 | - 将深度学习训练标注(`label`) 数据以 PostGIS 的方式存储,对 GISer 极其友好 33 | - 提供了 WMTS 瓦片服务代理工具,可将天地图、谷歌影像等作为影像数据源(Robosat 不支持类似 `http://this_is_host?x={x}&y={y}&z={z}` 形式的 URL,仅支持类似 `http://this_is_host/z/x/y` 34 | - 对 `RoboSat.pink` 做了自动化改造,无需手动逐个输入命令行,一键式训练或预测 35 | - 简化调试方式,仅需提供待训练或预测的范围(`extent`) 36 | - 自动化训练限定为 `PostgreSQL + PostGIS` 数据源作为深度学习标注 37 | 38 | ## 说明文档: 39 | 40 | ### 训练数据准备: 41 | 42 | - 安装 `PostgreSQL + PostGIS`,创建数据库,添加 `PostGIS` 扩展 `create extension postgis;` 43 | - 使用 `shp2pgsql` 等工具将已有的建筑物轮廓数据导入 `PostGIS` 作为深度学习标注数据,或者使用 `QGIS` 等工具连接 `PostGIS` 并加载遥感影像底图进行绘制建筑物轮廓 44 | 45 | ### 如何安装: 46 | 47 | - 对于 MacOS 或 Linux: 48 | 见[docs/linux安装指南](https://github.com/geocompass/robosat_geoc/blob/master/docs/Linux安装指南.md) 49 | - 对于 Windows: 50 | - 在 Windows 安装依赖时会报 `GLAL` 相关错误,目前没有比较好的解决办法 51 | - 建议使用 WSL,[在 Windows 中安装 Ubuntu SubLinux](https://docs.microsoft.com/zh-cn/windows/wsl/install-win10) 52 | - 配合 [Windows Terminal](https://www.microsoft.com/zh-cn/p/windows-terminal-preview/9n0dx20hk701) ,使用 Ubuntu 命令行工具 53 | - 使用上述 MacOS 或 Linux 安装方式进行部署 54 | 55 | ### 如何运行: 56 | 57 | - 设置已有的建筑物轮廓标注数据 58 | - 设置 PostGIS 连接: `robosat_pink/geoc/config.py` 中的 `POSTGRESQL` 59 | - 设置已有建筑物轮廓数据表:`robosat_pink/geoc/config.py` 中的 `BUILDING_TABLE` 60 | - 后台运行 WMTS 代理工具:`python xyz_proxy.py &` 61 | - 设置训练或预测范围:`./test.py` 中的 `extent` 62 | - 开始训练或预测:`python test.py` 63 | 64 | ### Windows 中如何开发: 65 | 66 | - 使用 VSCode: 67 | - 使用 [Remote-WSL](https://marketplace.visualstudio.com/items?itemName=ms-vscode-remote.remote-wsl) 拓展连接 WSL 的 Ubuntu,连接该项目文件夹进行开发 68 | - 使用 PyCharm: 69 | - 在 PyCharm 的`Settings` 中[配置 Project Interpreter](https://www.jetbrains.com/help/pycharm/using-wsl-as-a-remote-interpreter.html) 的 WSL 参数。 70 | 71 | ### 如何作为 Packages: 72 | 73 | - 构建:`python setup.py build` 74 | - 安装:`python setup.py install` 75 | - 在工程中调用:`from robosat_pink.geoc import RSPtrain` & `from robosat_pink.geoc import RSPpredict` 76 | - `Shapely`安装:`pip install shapely==1.7.0 -U` `pip install shapely==1.6.4.post2 -U` 77 | 78 | ## RoboSat.pink 使用教程 79 | 80 | ### 必读教程: 81 | 82 | - RoboSat.pink 101 83 | - How to use plain OpenData to create a decent training OpenDataSet 84 | - How to extend RoboSat.pink features, and use it as a Framework 85 | 86 | ### Tools: 87 | 88 | - `rsp cover` Generate a tiles covering, in csv format: X,Y,Z 89 | - `rsp download` Downloads tiles from a remote server (XYZ, WMS, or TMS) 90 | - `rsp extract` Extracts GeoJSON features from OpenStreetMap .pbf 91 | - `rsp rasterize` Rasterize vector features (GeoJSON or PostGIS), to raster tiles 92 | - `rsp subset` Filter images in a slippy map dir using a csv tiles cover 93 | - `rsp tile` Tile raster coverage 94 | - `rsp train` Trains a model on a dataset 95 | - `rsp export` Export a model to ONNX or Torch JIT 96 | - `rsp predict` Predict masks, from given inputs and an already trained model 97 | - `rsp compare` Compute composite images and/or metrics to compare several XYZ dirs 98 | - `rsp vectorize` Extract simplified GeoJSON features from segmentation masks 99 | - `rsp info` Print RoboSat.pink version informations 100 | 101 | ### 模型优化 102 | 103 | - 利用 robosat.merge 和 features 对预测结果进行规范化,参数调整包括: 104 | - merge: 105 | - threshold=1(融合阈值,单位:像素) 106 | - features: 107 | - denoise=10(除噪,对要素预处理,单位:像素) 108 | - grow=20(填坑,作用类似除噪,单位:像素) 109 | - simplify=0.01(新要素与原要素的简化比) 110 | - 优化效果: 111 |

112 | RoboSat_Geoc buildings segmentation from Imagery 113 |

114 | 115 | ## 本项目作者: 116 | 117 | - 吴灿 [https://github.com/wucangeo](https://github.com/wucangeo) 118 | - Liii18 [https://github.com/liii18](https://github.com/liii18) 119 | 120 | ## 欢迎 Issues 121 | 122 | 欢迎提一个 [Issue](https://github.com/geocompass/robosat_geoc/issues) 123 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geocompass/robosat_geoc/44d95183322ba68c0728de44a66e50f510bfd919/__init__.py -------------------------------------------------------------------------------- /app/__init__.py: -------------------------------------------------------------------------------- 1 | from flask.json import JSONEncoder 2 | from flask_apscheduler import APScheduler 3 | from datetime import date 4 | from flask_cors import CORS 5 | from .app import Flask 6 | 7 | 8 | def register_blueprints(app): 9 | from app.api.v1 import create_blueprint_v1 10 | app.register_blueprint(create_blueprint_v1(), url_prefix='/v1') 11 | 12 | 13 | def register_plugin(app): 14 | from app.models.base import db 15 | from app.api.v1.job import scheduler 16 | db.app = app 17 | db.init_app(app) 18 | with app.app_context(): 19 | db.create_all(app=app) 20 | scheduler = APScheduler(app=app) 21 | 22 | 23 | # class CustomJSONEncoder(JSONEncoder): 24 | # def default(self, obj): 25 | # try: 26 | # if isinstance(obj, date): 27 | # return obj.isoformat() 28 | # iterable = iter(obj) 29 | # except TypeError: 30 | # pass 31 | # else: 32 | # return list(iterable) 33 | # return JSONEncoder.default(self, obj) 34 | 35 | 36 | def create_app(): 37 | app = Flask(__name__, static_folder='../webmap/dist', 38 | static_url_path='') 39 | CORS(app, supports_credentials=True) 40 | app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False 41 | app.config.from_object('app.config.setting') 42 | app.config.from_object('app.config.secure') 43 | # app.json_encoder = CustomJSONEncoder 44 | 45 | register_blueprints(app) 46 | register_plugin(app) 47 | 48 | return app 49 | -------------------------------------------------------------------------------- /app/api/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geocompass/robosat_geoc/44d95183322ba68c0728de44a66e50f510bfd919/app/api/__init__.py -------------------------------------------------------------------------------- /app/api/v1/__init__.py: -------------------------------------------------------------------------------- 1 | from flask import Blueprint 2 | from app.api.v1 import test, predict_buildings, buia, train, predict, wmts, geojson, tools, task, task_admin, job 3 | 4 | 5 | def create_blueprint_v1(): 6 | bp_v1 = Blueprint('v1', __name__) 7 | 8 | buia.api.register(bp_v1) 9 | geojson.api.register(bp_v1) 10 | job.api.register(bp_v1) 11 | predict.api.register(bp_v1) 12 | predict_buildings.api.register(bp_v1) 13 | task.api.register(bp_v1) 14 | task_admin.api.register(bp_v1) 15 | test.api.register(bp_v1) 16 | train.api.register(bp_v1) 17 | tools.api.register(bp_v1) 18 | wmts.api.register(bp_v1) 19 | 20 | return bp_v1 21 | -------------------------------------------------------------------------------- /app/api/v1/buia.py: -------------------------------------------------------------------------------- 1 | from sqlalchemy import or_, text 2 | from app.libs.redprint import Redprint 3 | from app.models.buia import BUIA 4 | from app.models.base import queryBySQL 5 | from flask import jsonify, request 6 | import json 7 | 8 | api = Redprint('buia') 9 | 10 | 11 | @api.route("", methods=['GET']) 12 | def geojson_by_extent(): 13 | result = { 14 | "code": 1, 15 | "data": None, 16 | "msg": "ok" 17 | } 18 | extent = request.args.get("extent") 19 | if not extent: 20 | result["code"] = 0 21 | result["msg"] = "参数有误" 22 | return jsonify(result) 23 | # coords = extent.split(',') 24 | sql = '''SELECT 25 | jsonb_build_object ( 'type', 'FeatureCollection', 'features', jsonb_agg ( features.feature ) ) 26 | FROM 27 | ( 28 | SELECT 29 | jsonb_build_object ( 'type', 'Feature', 'id', gid, 'geometry', ST_AsGeoJSON ( geom ) :: jsonb, 'properties', to_jsonb ( inputs ) - 'geom' ) AS feature 30 | FROM 31 | ( 32 | SELECT gid,"CNAME",geom AS geom 33 | FROM "BUIA" WHERE 34 | geom @ 35 | ST_MakeEnvelope ( {extent}, {srid} )) inputs 36 | ) features; ''' 37 | queryData = queryBySQL(sql.format(extent=extent, srid=4326)) 38 | if not queryData: 39 | result["code"] = 0 40 | result["msg"] = "查询语句有问题" 41 | return jsonify(result) 42 | row = queryData.fetchone() 43 | result["data"] = row 44 | 45 | return jsonify(result) 46 | 47 | 48 | @api.route("/", methods=['GET']) 49 | def get(gid): 50 | result = { 51 | "code": 1, 52 | "data": None, 53 | "msg": "ok" 54 | } 55 | sql = '''select st_asgeojson(geom) as geojson from "BUIA" where gid ={gid}''' 56 | queryData = queryBySQL(sql.format(gid=gid)) 57 | if not queryData: 58 | result["code"] = 0 59 | result["msg"] = "查询语句有问题" 60 | return jsonify(result) 61 | if queryData.rowcount == 0: 62 | result["code"] = 0 63 | result["msg"] = "未查询到内容" 64 | return jsonify(result) 65 | row = queryData.fetchone() 66 | result["data"] = json.loads(row["geojson"]) 67 | return jsonify(result) 68 | -------------------------------------------------------------------------------- /app/api/v1/geojson.py: -------------------------------------------------------------------------------- 1 | from flask import request, Response, jsonify 2 | from app.models.base import queryBySQL 3 | from app.libs.redprint import Redprint 4 | from app.config import setting 5 | api = Redprint('geojson') 6 | 7 | 8 | @api.route('', methods=['GET']) 9 | def geojson(): 10 | extent = request.args.get("extent") 11 | if not extent: 12 | return jsonify("") 13 | extentArr = extent.split(',') 14 | if len(extentArr) != 4: 15 | return jsonify("") 16 | if float(extentArr[2]) - float(extentArr[0]) > 0.05 or float(extentArr[3]) - float(extentArr[1]) > 0.04: 17 | return jsonify("") 18 | 19 | sql = '''SELECT 20 | jsonb_build_object ( 'type', 'FeatureCollection', 'features', jsonb_agg ( features.feature ) ) 21 | FROM 22 | ( 23 | SELECT 24 | jsonb_build_object ( 'type', 'Feature', 'id', gid, 'geometry', ST_AsGeoJSON ( geom ) :: jsonb, 'properties', to_jsonb ( inputs ) - 'geom' ) AS feature 25 | FROM 26 | ( 27 | SELECT gid,geom AS geom 28 | FROM "{buildings_table}" WHERE 29 | geom @ 30 | ST_MakeEnvelope ( {extent}, {srid} )) inputs 31 | ) features; ''' 32 | queryData = queryBySQL(sql.format( 33 | buildings_table=setting.BUILDINGS_TABLE, extent=extent, srid=4326)) 34 | row = queryData.fetchone() 35 | return jsonify(row["jsonb_build_object"]) 36 | -------------------------------------------------------------------------------- /app/api/v1/job.py: -------------------------------------------------------------------------------- 1 | from flask import jsonify, request 2 | from flask_apscheduler import APScheduler 3 | import json 4 | from app.models.base import queryBySQL, db as DB 5 | from app.api.v1 import task as TASK, predict as PREDICT 6 | from app.libs.redprint import Redprint 7 | 8 | api = Redprint('job') 9 | 10 | scheduler = APScheduler() 11 | 12 | 13 | @scheduler.task(trigger='interval', id='predict_job', seconds=5) 14 | def task_job(): 15 | # check doing failed job. 16 | TASK.job_listen() 17 | 18 | # check doing job by this server 19 | isDoingJob = TASK.doing_job() 20 | if isDoingJob: 21 | return 22 | 23 | # get one new job 24 | newTask = TASK.get_one_job() 25 | if not newTask: 26 | return 27 | 28 | # start one job 29 | print("start one job.") 30 | TASK.do_job(newTask.task_id, 2) # update task state 31 | # do the predict by robosat 32 | result = PREDICT.predict_job(newTask) 33 | 34 | if result['code'] == 0: 35 | TASK.do_job(newTask.task_id, 4) # 任务失败 36 | print('job faild!') 37 | else: 38 | TASK.do_job(newTask.task_id, 3) # 任务完成并修改完成时间 39 | print('job success!') 40 | 41 | 42 | @api.route('/pause', methods=['GET']) 43 | def pause_job(id): # 暂停 44 | job_id = request.args.get('id') or id 45 | scheduler.pause_job(str(job_id)) 46 | return "pause success!" 47 | 48 | 49 | @api.route('/resume', methods=['GET']) 50 | def resume_job(id): # 恢复 51 | job_id = request.args.get('id') or id 52 | scheduler.resume_job(str(job_id)) 53 | return "resume success!" 54 | 55 | 56 | @api.route('/get_jobs', methods=['GET']) 57 | def get_task(): # 获取 58 | # job_id = request.args.get('id') 59 | jobs = scheduler.get_jobs() 60 | print(jobs) 61 | return 'jobs:'+str(jobs) 62 | 63 | 64 | @api.route('/remove_job', methods=['GET']) 65 | def remove_job(): # 移除 66 | job_id = request.args.get('id') 67 | scheduler.remove_job(str(job_id)) 68 | return 'remove success' 69 | 70 | # /add_job?id=2 71 | @api.route('/add_job', methods=['GET']) 72 | def add_task(): 73 | data = request.args.get('id') 74 | if data == '1': 75 | # trigger='cron' 表示是一个定时任务 76 | scheduler.add_job(func=task_job, id='1', args=(1, 1), trigger='cron', day_of_week='0-6', hour=18, minute=24, 77 | second=10, replace_existing=True) 78 | return 'add job success' 79 | -------------------------------------------------------------------------------- /app/api/v1/predict.py: -------------------------------------------------------------------------------- 1 | import time 2 | import requests 3 | import json 4 | import shutil 5 | from flask import jsonify, request 6 | from app.libs import redprint, utils, utils_geom 7 | from app.config import setting as SETTING 8 | from robosat_pink.geoc import RSPpredict, RSPreturn_predict 9 | from app.api.v1 import tools as TOOLS, task as TASK, predict_buildings as BUILDINGS, job as JOB 10 | 11 | api = redprint.Redprint('predict') 12 | 13 | 14 | @api.route('', methods=['GET']) 15 | def predict(): 16 | # check extent 17 | extent = request.args.get("extent") 18 | result = TOOLS.check_extent(extent, "predict", True) 19 | # result = TOOLS.check_extent(extent, "predict") 20 | if result["code"] == 0: 21 | return jsonify(result) 22 | 23 | # 使用robosat_geoc开始预测 24 | dataPath = SETTING.ROBOSAT_DATA_PATH 25 | datasetPath = SETTING.ROBOSAT_DATASET_PATH 26 | ts = time.time() 27 | 28 | dsPredictPath = datasetPath+"/predict_"+str(ts) 29 | geojson = RSPpredict.main( 30 | # extent, dataPath, dsPredictPath, map="tdt") 31 | extent, dataPath, dsPredictPath, map="google") 32 | 33 | if not geojson: 34 | result["code"] = 0 35 | result["msg"] = "预测失败" 36 | return jsonify(result) 37 | 38 | # 给geojson添加properties 39 | for feature in geojson["features"]: 40 | feature["properties"] = {} 41 | 42 | result["data"] = geojson 43 | return jsonify(result) 44 | 45 | 46 | def predict_job(task): 47 | result = { 48 | "code": 1, 49 | "data": None, 50 | "msg": "do job success" 51 | } 52 | extent = task.extent 53 | if task.user_id != "ADMIN": 54 | result = TOOLS.check_extent(extent, "predict", True) 55 | if result["code"] == 0: 56 | return result 57 | 58 | # 使用robosat_geoc开始预测 59 | dataPath = SETTING.ROBOSAT_DATA_PATH 60 | datasetPath = SETTING.ROBOSAT_DATASET_PATH 61 | ts = time.time() 62 | dsPredictPath = datasetPath+"/predict_"+str(ts) 63 | geojson_predcit = RSPpredict.main( 64 | extent, dataPath, dsPredictPath, map="google") 65 | 66 | if not geojson_predcit or not isinstance(geojson_predcit, dict) or 'features' not in geojson_predcit: 67 | result["code"] = 0 68 | result["msg"] = "预测失败" 69 | return result 70 | 71 | # 转换为3857坐标系 72 | # geojson3857 = utils_geom.geojson_project( 73 | # geojson_predcit, "epsg:4326", "epsg:3857") 74 | 75 | # geojson 转 shapefile 76 | building_predcit_path = dsPredictPath+"/building1_predict.shp" 77 | utils_geom.geojson2shp(geojson_predcit, building_predcit_path) 78 | 79 | # regularize-building-footprint 80 | # site:https://pro.arcgis.com/zh-cn/pro-app/tool-reference/3d-analyst/regularize-building-footprint.htm 81 | shp_regularized = dsPredictPath + "/building5_4326.shp" 82 | arcpy_requests = requests.get( 83 | SETTING.ARCPY_HOST.format(path="predict_" + str(ts))) 84 | arcpy_result = arcpy_requests.json() 85 | if arcpy_result['code'] == 0: 86 | result["code"] = 0 87 | result["msg"] = "arcpy regularize faild." 88 | return result 89 | 90 | # shp to geojson 91 | geojson4326 = utils_geom.shp2geojson(shp_regularized) 92 | 93 | # project from 3857 to 4326 94 | # geojson4326 = utils_geom.geojson_project( 95 | # geojson3857, "epsg:3857", "epsg:4326") 96 | 97 | # 给geojson添加properties 98 | handler = SETTING.IPADDR 99 | # for feature in geojson_predcit["features"]: 100 | for feature in geojson4326["features"]: 101 | feature["properties"] = { 102 | "task_id": task.task_id, 103 | "extent": task.extent, 104 | "user_id": task.user_id, 105 | "area_code": task.area_code, 106 | "handler": handler 107 | } 108 | 109 | # delete temp dir after done job. 110 | if SETTING.DEBUG_MODE: 111 | shutil.rmtree(dsPredictPath) 112 | 113 | # 插入数据库 114 | result_create = BUILDINGS.insert_buildings(geojson4326) 115 | if not result_create: 116 | result["code"] = 0 117 | result["msg"] = "预测失败" 118 | return result 119 | 120 | return result 121 | -------------------------------------------------------------------------------- /app/api/v1/task_admin.py: -------------------------------------------------------------------------------- 1 | # import json 2 | from flask import jsonify, request 3 | import json 4 | import collections 5 | from rasterio.warp import transform_bounds 6 | from mercantile import tiles, xy_bounds 7 | 8 | from robosat_pink.geojson import geojson_parse_feature 9 | from app.models.base import queryBySQL, db as DB 10 | from app.libs.redprint import Redprint 11 | from app.config import setting as SETTING 12 | from app.models.task_admin import task_admin as TASK_ADMIN 13 | 14 | 15 | api = Redprint('task_admin') 16 | 17 | 18 | @api.route('', methods=['GET']) 19 | def create_task_by_areacode(): 20 | result = { 21 | "code": 1, 22 | "data": None, 23 | "msg": "create bat task success!" 24 | } 25 | areacode = request.args.get('areacode') 26 | zoom = request.args.get('zoom') or '14' # 将区域范围分割成zoom级别瓦片大小的任务 27 | zoom = eval(zoom) 28 | 29 | if not areacode: 30 | result['code'] = 0 31 | result['msg'] = "no areacode params" 32 | return jsonify(result) 33 | quhuaTable = '' 34 | if len(areacode) == 9:#FIXME:bug when null areacode 35 | quhuaTable = SETTING.QUHUA_XIANG 36 | elif len(areacode) == 6: 37 | quhuaTable = SETTING.QUHUA_XIAN 38 | elif len(areacode) == 4: 39 | quhuaTable = SETTING.QUHUA_SHI 40 | elif len(areacode) == 2: 41 | quhuaTable = SETTING.QUHUA_SHENG 42 | else: 43 | result['code'] = 0 44 | result['msg'] = "areacode not support" 45 | return jsonify(result) 46 | 47 | areacode = areacode.ljust(12, '0') 48 | sql = """ 49 | SELECT 50 | '{{"type": "Feature", "geometry": ' 51 | || ST_AsGeoJSON(st_simplify(geom,0.001)) 52 | || '}}' AS features 53 | FROM {quhuaTable} WHERE code = '{areacode}' 54 | """ 55 | queryData = queryBySQL(sql.format( 56 | areacode=areacode, quhuaTable=quhuaTable)) 57 | if not queryData: 58 | result["code"] = 0 59 | result["msg"] = "not found this area,areacode:"+areacode 60 | return jsonify(result) 61 | area_json = queryData.fetchone() 62 | 63 | feature_map = collections.defaultdict(list) 64 | 65 | # FIXME: fetchall will not always fit in memory... 66 | for feature in area_json: 67 | feature_map = geojson_parse_feature( 68 | zoom, 4326, feature_map, json.loads(feature)) 69 | 70 | cover = feature_map.keys() 71 | 72 | extents = [] 73 | for tile in cover: 74 | w, s, n, e = transform_bounds( 75 | "epsg:3857", "epsg:4326", *xy_bounds(tile)) 76 | extent = [w, s, n, e] 77 | extents.append(','.join([str(elem) for elem in extent])) 78 | 79 | for extent in extents: 80 | # originalExtent = extent 81 | user_id = "ADMIN" 82 | area_code = areacode 83 | with DB.auto_commit(): 84 | task = TASK_ADMIN() 85 | task.extent = extent 86 | # task.originalextent = originalExtent 87 | task.user_id = user_id 88 | task.area_code = area_code 89 | DB.session.add(task) 90 | 91 | result['data'] = { 92 | "count": len(extents), 93 | "extent": extents 94 | } 95 | return jsonify(result) 96 | -------------------------------------------------------------------------------- /app/api/v1/test.py: -------------------------------------------------------------------------------- 1 | from app.libs.redprint import Redprint 2 | api = Redprint('test') 3 | 4 | 5 | @api.route('', methods=['GET']) 6 | def get_test(): 7 | return "testttt" 8 | 9 | 10 | @api.route('/1', methods=['GET']) 11 | def get_tt(): 12 | return "ddddd" 13 | -------------------------------------------------------------------------------- /app/api/v1/tools.py: -------------------------------------------------------------------------------- 1 | import os 2 | from flask import jsonify 3 | from app.libs.redprint import Redprint 4 | from app.config import setting as SETTING 5 | api = Redprint('tools') 6 | 7 | 8 | @api.route('/log', methods=['GET']) 9 | def get_log(): 10 | dataPath = SETTING.ROBOSAT_DATA_PATH 11 | logPath = dataPath+"/model/log" 12 | if not os.path.isfile(logPath): 13 | return "未找到日志文件!,路径为:"+logPath 14 | with open(logPath) as f: 15 | f = f.readlines() 16 | logContent = ["这个是日志文件,路径为:"+logPath, "", ""] 17 | for line in f: 18 | logContent.append("

"+line+"

") 19 | logStr = " ".join(logContent) 20 | return logStr 21 | 22 | 23 | @api.route('/log/clear', methods=['GET']) 24 | def clear_log(): 25 | dataPath = SETTING.ROBOSAT_DATA_PATH 26 | logPath = dataPath+"/model/log" 27 | if not os.path.isfile(logPath): 28 | return "未找到日志文件!,路径为:"+logPath 29 | open(logPath, "w").close() 30 | result = { 31 | "code": 1, 32 | "msg": "log is clean now." 33 | } 34 | return jsonify(result) 35 | 36 | 37 | def check_extent(extent, train_or_predict, set_maximum=False): 38 | result = { 39 | "code": 1, 40 | "data": None, 41 | "msg": "ok" 42 | } 43 | if not extent: 44 | result["code"] = 0 45 | result["msg"] = "参数有误" 46 | return result 47 | coords = extent.split(',') 48 | if len(coords) != 4: 49 | result["code"] = 0 50 | result["msg"] = "参数有误" 51 | return result 52 | if "train" in train_or_predict: 53 | if float(coords[2]) - float(coords[0]) < SETTING.MIN_T_EXTENT or float(coords[3]) - float(coords[1]) < SETTING.MIN_T_EXTENT: 54 | result["code"] = 0 55 | result["msg"] = "Extent for training is too small. Training stopped." 56 | elif "predict" in train_or_predict: 57 | if float(coords[2]) - float(coords[0]) < SETTING.MIN_P_EXTENT or float(coords[3]) - float(coords[1]) < SETTING.MIN_P_EXTENT: 58 | result["code"] = 0 59 | result["msg"] = "Extent for prediction is too small. Predicting stopped." 60 | elif float(coords[2]) - float(coords[0]) > SETTING.MAX_P_EXTENT or float(coords[3]) - float(coords[1]) > SETTING.MAX_P_EXTENT: 61 | result["code"] = 0 62 | result["msg"] = "Extent for prediction is too small. Predicting stopped." 63 | elif set_maximum and float(coords[2]) - float(coords[0]) > 0.02 or float(coords[3]) - float(coords[1]) > 0.02: 64 | result["code"] = 0 65 | result["msg"] = "Extent for prediction is too big. Predicting stopped." 66 | else: 67 | result["code"] = 0 68 | result["msg"] = "got wrong params." 69 | return result 70 | -------------------------------------------------------------------------------- /app/api/v1/train.py: -------------------------------------------------------------------------------- 1 | import time 2 | from app.libs.redprint import Redprint 3 | from flask import jsonify, request 4 | from robosat_pink.geoc import RSPtrain 5 | from app.api.v1 import tools 6 | from app.config import setting as SETTING 7 | api = Redprint('train') 8 | 9 | 10 | @api.route('', methods=['GET']) 11 | def train(): 12 | # check extent 13 | extent = request.args.get("extent") 14 | result = tools.check_extent(extent, "train") 15 | print(result) 16 | if result["code"] == 0: 17 | return jsonify(result) 18 | 19 | # 通过robosat_pink训练 20 | dataPath = SETTING.ROBOSAT_DATA_PATH 21 | datasetPath = SETTING.ROBOSAT_DATASET_PATH 22 | ts = time.time() 23 | dsTrainPath = datasetPath+"/train_"+str(ts) 24 | trainResult = RSPtrain.main(extent, dataPath, dsTrainPath, 1, map="tdt") 25 | result["data"] = trainResult 26 | return jsonify(result) 27 | -------------------------------------------------------------------------------- /app/api/v1/wmts.py: -------------------------------------------------------------------------------- 1 | import requests 2 | from flask import request, Response 3 | from app.libs.redprint import Redprint 4 | from app.config import setting as SETTING 5 | api = Redprint('wmts') 6 | 7 | 8 | @api.route('///', methods=['GET']) 9 | def wmts(x, y, z): 10 | map = request.args.get("type") 11 | if not x or not y or not z: 12 | return None 13 | if not map and map != "tdt" and map != "google": 14 | return "faild to set map type, neither tianditu nor google" 15 | url = SETTING.URL_TDT 16 | url_google = SETTING.URL_GOOGLE 17 | if map == 'google': 18 | url = url_google 19 | image = requests.get(url.format(x=x, y=y, z=z)) 20 | 21 | print(url.format(x=x, y=y, z=z)) 22 | return Response(image, mimetype='image/jpeg') 23 | -------------------------------------------------------------------------------- /app/app.py: -------------------------------------------------------------------------------- 1 | from flask import Flask as _Flask 2 | from flask.json import JSONEncoder as _JSONEncoder 3 | 4 | from app.libs.error_code import ServerError 5 | from datetime import date 6 | 7 | 8 | class JSONEncoder(_JSONEncoder): 9 | def default(self, o): 10 | if hasattr(o, 'keys') and hasattr(o, '__getitem__'): 11 | return dict(o) 12 | if isinstance(o, date): 13 | return o.strftime('%Y-%m-%d %H:%M:%S') 14 | raise ServerError() 15 | 16 | 17 | class Flask(_Flask): 18 | json_encoder = JSONEncoder 19 | -------------------------------------------------------------------------------- /app/config/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geocompass/robosat_geoc/44d95183322ba68c0728de44a66e50f510bfd919/app/config/__init__.py -------------------------------------------------------------------------------- /app/config/secure.py: -------------------------------------------------------------------------------- 1 | SQLALCHEMY_DATABASE_URI = \ 2 | 'postgres+psycopg2://postgres:postgres@172.16.100.143/china_census7' 3 | # 'postgres+psycopg2://postgres:postgres@172.16.100.140/tdt2018' 4 | # 'postgres+psycopg2://postgres:postgres@localhost/tdt2018' 5 | 6 | SECRET_KEY = '\x88D\xf09\x91\x07\x98\x89\x87\x96\xa0A\xc68\xf9\xecJ:U\x17\xc5V\xbe\x8b\xef\xd7\xd8\xd3\xe6\x98*4' 7 | -------------------------------------------------------------------------------- /app/config/setting.py: -------------------------------------------------------------------------------- 1 | from app.libs import utils as UTILS 2 | 3 | # building outline PostGIS data table using by training label 4 | BUILDINGS_TABLE = "BUIA" 5 | 6 | # USER OR ADMIN MODE 7 | USER_OR_ADMIN = "USER" 8 | # USER_OR_ADMIN = "admin" 9 | 10 | QUHUA_SHENG = "data_1741" 11 | QUHUA_SHI = "data_1745" 12 | QUHUA_XIAN = "data_1746" 13 | QUHUA_XIANG = "data_1744" 14 | 15 | ARCPY_HOST = "http://172.16.105.70:5001/regularize?path={path}" 16 | # ARCPY_HOST = "http://localhost:5001/regularize?path={path}" 17 | 18 | # config.toml and checkpoint.pth files path 19 | ROBOSAT_DATA_PATH = "/data/datamodel" 20 | # ROBOSAT_DATA_PATH = "data" 21 | 22 | # dataset to training or predicting 23 | ROBOSAT_DATASET_PATH = "/data/dataset" 24 | # ROBOSAT_DATASET_PATH = "dataset" 25 | # ROBOSAT_DATASET_PATH = "/mnt/c/Users/WUCAN/Documents/dataset" 26 | 27 | 28 | # tianditu and google map remote sensing wmts url 29 | URL_TDT = '''https://t1.tianditu.gov.cn/DataServer?T=img_w&x={x}&y={y}&l={z}&tk=4830425f5d789b48b967b1062deb8c71''' 30 | URL_GOOGLE = '''http://ditu.google.cn/maps/vt/lyrs=s&x={x}&y={y}&z={z}''' 31 | # URL_TDT = '''http://yingxiang2019.geo-compass.com/api/wmts?layer=s%3Azjw&style=time%3D1576222648262&tilematrixset=w&Service=WMTS&Request=GetTile&Version=1.0.0&Format=image%2Fjpeg&TileMatrix={z}&TileCol={x}&TileRow={y}&threshold=100''' 32 | 33 | TOKEN_EXPIRATION = 30 * 24 * 3600 34 | 35 | # ip address 36 | IPADDR = UTILS.get_host_ip() 37 | 38 | # extent 39 | MIN_T_EXTENT = 0.0042 40 | MIN_P_EXTENT = 0.0014 41 | MAX_P_EXTENT = 0.0098 42 | 43 | # minimum building area 44 | MIN_BUILDING_AREA = 50 45 | 46 | # if open debug mode 47 | # DEBUG_MODE = True 48 | DEBUG_MODE = False 49 | -------------------------------------------------------------------------------- /app/libs/enums.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class ClientTypeEnum(Enum): 5 | USER_EMAIL = 100 6 | USER_MOBILE = 101 7 | 8 | # 微信小程序 9 | USER_MINA = 200 10 | # 微信公众号 11 | USER_WX = 201 12 | -------------------------------------------------------------------------------- /app/libs/error.py: -------------------------------------------------------------------------------- 1 | 2 | from flask import request, json 3 | from werkzeug.exceptions import HTTPException 4 | 5 | 6 | class APIException(HTTPException): 7 | code = 500 8 | msg = 'sorry, we made a mistake (* ̄︶ ̄)!' 9 | error_code = 999 10 | 11 | def __init__(self, msg=None, code=None, error_code=None, 12 | headers=None): 13 | if code: 14 | self.code = code 15 | if error_code: 16 | self.error_code = error_code 17 | if msg: 18 | self.msg = msg 19 | super(APIException, self).__init__(msg, None) 20 | 21 | def get_body(self, environ=None): 22 | body = dict( 23 | msg=self.msg, 24 | error_code=self.error_code, 25 | request=request.method + ' ' + self.get_url_no_param() 26 | ) 27 | text = json.dumps(body) 28 | return text 29 | 30 | def get_headers(self, environ=None): 31 | """Get a list of headers.""" 32 | return [('Content-Type', 'application/json')] 33 | 34 | @staticmethod 35 | def get_url_no_param(): 36 | full_path = str(request.full_path) 37 | main_path = full_path.split('?') 38 | return main_path[0] 39 | -------------------------------------------------------------------------------- /app/libs/error_code.py: -------------------------------------------------------------------------------- 1 | from werkzeug.exceptions import HTTPException 2 | 3 | from app.libs.error import APIException 4 | 5 | 6 | class Success(APIException): 7 | code = 201 8 | msg = 'ok' 9 | error_code = 0 10 | 11 | 12 | class DeleteSuccess(Success): 13 | code = 202 14 | error_code = 1 15 | 16 | 17 | class ServerError(APIException): 18 | code = 500 19 | msg = 'sorry, we made a mistake (* ̄︶ ̄)!' 20 | error_code = 999 21 | 22 | 23 | class ClientTypeError(APIException): 24 | # 400 401 403 404 25 | # 500 26 | # 200 201 204 27 | # 301 302 28 | code = 400 29 | msg = 'client is invalid' 30 | error_code = 1006 31 | 32 | 33 | class ParameterException(APIException): 34 | code = 400 35 | msg = 'invalid parameter' 36 | error_code = 1000 37 | 38 | 39 | class NotFound(APIException): 40 | code = 404 41 | msg = 'the resource are not found O__O...' 42 | error_code = 1001 43 | 44 | 45 | class AuthFailed(APIException): 46 | code = 401 47 | error_code = 1005 48 | msg = 'authorization failed' 49 | 50 | 51 | class Forbidden(APIException): 52 | code = 403 53 | error_code = 1004 54 | msg = 'forbidden, not in scope' 55 | 56 | 57 | class DuplicateGift(APIException): 58 | code = 400 59 | error_code = 2001 60 | msg = 'the current book has already in gift' 61 | -------------------------------------------------------------------------------- /app/libs/redprint.py: -------------------------------------------------------------------------------- 1 | 2 | class Redprint: 3 | def __init__(self, name): 4 | self.name = name 5 | self.mound = [] 6 | 7 | def route(self, rule, **options): 8 | def decorator(f): 9 | self.mound.append((f, rule, options)) 10 | return f 11 | 12 | return decorator 13 | 14 | def register(self, bp, url_prefix=None): 15 | if url_prefix is None: 16 | url_prefix = '/' + self.name 17 | for f, rule, options in self.mound: 18 | endpoint = self.name + '+' + \ 19 | options.pop("endpoint", f.__name__) 20 | bp.add_url_rule(url_prefix + rule, endpoint, f, **options) 21 | -------------------------------------------------------------------------------- /app/libs/scope.py: -------------------------------------------------------------------------------- 1 | class Scope: 2 | allow_api = [] 3 | allow_module = [] 4 | forbidden = [] 5 | 6 | def __add__(self, other): 7 | self.allow_api = self.allow_api + other.allow_api 8 | self.allow_api = list(set(self.allow_api)) 9 | # 运算符重载 10 | 11 | self.allow_module = self.allow_module + \ 12 | other.allow_module 13 | self.allow_module = list(set(self.allow_module)) 14 | 15 | self.forbidden = self.forbidden + other.forbidden 16 | self.forbidden = list(set(self.forbidden)) 17 | 18 | return self 19 | 20 | 21 | class AdminScope(Scope): 22 | # allow_api = ['v1.user+super_get_user', 23 | # 'v1.user+super_delete_user'] 24 | allow_module = ['v1.user'] 25 | 26 | def __init__(self): 27 | # 排除 28 | pass 29 | # self + UserScope() 30 | 31 | 32 | class UserScope(Scope): 33 | allow_module = ['v1.gift'] 34 | forbidden = ['v1.user+super_get_user', 35 | 'v1.user+super_delete_user'] 36 | 37 | def __init__(self): 38 | self + AdminScope() 39 | # allow_api = ['v1.user+get_user', 'v1.user+delete_user'] 40 | 41 | 42 | def is_in_scope(scope, endpoint): 43 | # scope() 44 | # 反射 45 | # globals 46 | # v1.view_func v1.module_name+view_func 47 | # v1.red_name+view_func 48 | scope = globals()[scope]() 49 | splits = endpoint.split('+') 50 | red_name = splits[0] 51 | if endpoint in scope.forbidden: 52 | return False 53 | if endpoint in scope.allow_api: 54 | return True 55 | if red_name in scope.allow_module: 56 | return True 57 | else: 58 | return False 59 | -------------------------------------------------------------------------------- /app/libs/token_auth.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | 3 | from flask import current_app, g, request 4 | from flask_httpauth import HTTPBasicAuth 5 | from itsdangerous import TimedJSONWebSignatureSerializer \ 6 | as Serializer, BadSignature, SignatureExpired 7 | 8 | from app.libs.error_code import AuthFailed, Forbidden 9 | from app.libs.scope import is_in_scope 10 | 11 | 12 | auth = HTTPBasicAuth() 13 | User = namedtuple('User', ['uid', 'ac_type', 'scope']) 14 | 15 | 16 | @auth.verify_password 17 | def verify_password(token, password): 18 | # token 19 | # HTTP 账号密码 20 | # header key:value 21 | # account qiyue 22 | # 123456 23 | # key=Authorization 24 | # value =basic base64(qiyue:123456) 25 | user_info = verify_auth_token(token) 26 | if not user_info: 27 | return False 28 | else: 29 | # request 30 | g.user = user_info 31 | return True 32 | 33 | 34 | def verify_auth_token(token): 35 | s = Serializer(current_app.config['SECRET_KEY']) 36 | try: 37 | data = s.loads(token) 38 | except BadSignature: 39 | raise AuthFailed(msg='token is invalid', 40 | error_code=1002) 41 | except SignatureExpired: 42 | raise AuthFailed(msg='token is expired', 43 | error_code=1003) 44 | uid = data['uid'] 45 | ac_type = data['type'] 46 | scope = data['scope'] 47 | # request 视图函数 48 | allow = is_in_scope(scope, request.endpoint) 49 | if not allow: 50 | raise Forbidden() 51 | return User(uid, ac_type, scope) 52 | -------------------------------------------------------------------------------- /app/libs/utils.py: -------------------------------------------------------------------------------- 1 | import socket 2 | 3 | 4 | def get_host_ip(): 5 | try: 6 | s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) 7 | s.connect(('8.8.8.8', 80)) 8 | ip = s.getsockname()[0] 9 | finally: 10 | s.close() 11 | 12 | return ip 13 | 14 | 15 | # if __name__ == "__main__": 16 | # pass 17 | -------------------------------------------------------------------------------- /app/libs/utils_geom.py: -------------------------------------------------------------------------------- 1 | import shapefile 2 | import geojson 3 | import shapely 4 | from shapely import geometry 5 | import pyproj 6 | import functools 7 | import json 8 | import fiona 9 | 10 | 11 | def project(shape, source, target): 12 | """Projects a geometry from one coordinate system into another. 13 | 14 | Args: 15 | shape: the geometry to project. 16 | source: the source EPSG spatial reference system identifier. 17 | target: the target EPSG spatial reference system identifier. 18 | 19 | Returns: 20 | The projected geometry in the target coordinate system. 21 | """ 22 | 23 | project = functools.partial(pyproj.transform, pyproj.Proj( 24 | init=source), pyproj.Proj(init=target)) 25 | 26 | return shapely.ops.transform(project, shape) 27 | 28 | 29 | # feature to shape,projection,shape to feature.yul 30 | def geojson_project(collection, source, target): 31 | # with open(geojson_path) as fp: 32 | # collection = geojson.load(fp) 33 | 34 | shapes = [shapely.geometry.shape(feature["geometry"]) 35 | for feature in collection["features"]] 36 | features = [] 37 | for shape in shapes: 38 | if not shape.is_simple or not shape.is_valid: 39 | continue 40 | projected = project(shape, source, target) 41 | feature = geojson.Feature(geometry=shapely.geometry.mapping( 42 | projected)) 43 | features.append(feature) 44 | collection_projected = geojson.FeatureCollection(features) 45 | return collection_projected 46 | 47 | 48 | def shp2geojson(shp_path): 49 | # read the shapefile 50 | reader = shapefile.Reader(shp_path) 51 | fields = reader.fields[1:] 52 | field_names = [field[0] for field in fields] 53 | buffer = [] 54 | # shapeRecords = reader.shapeRecords() 55 | for sr in reader.shapeRecords(): 56 | atr = dict(zip(field_names, sr.record)) 57 | try: 58 | geom = sr.shape.__geo_interface__ 59 | buffer.append(dict(type="Feature", 60 | geometry=geom, properties=atr)) 61 | except: 62 | print('要素不可用。要素信息:%s' % str(sr.record)) 63 | jsonstr = json.dumps({"type": "FeatureCollection", 64 | "features": buffer}, indent=2) 65 | return json.loads(jsonstr) 66 | 67 | 68 | def geojson2shp(collection, shp_path): 69 | shapes = [geometry.shape(feature["geometry"]) 70 | for feature in collection["features"]] 71 | schema = { 72 | 'geometry': 'Polygon', 73 | 'properties': {}, 74 | } 75 | 76 | # Write a new Shapefile 77 | with fiona.open(shp_path, 'w', 'ESRI Shapefile', schema) as c: 78 | for shape in shapes: 79 | # delete area smaller than 0.00**1 80 | if shape.area < 0.00000001: 81 | continue 82 | c.write({ 83 | 'geometry': shapely.geometry.mapping(shape), 84 | 'properties': {}, 85 | }) 86 | # Write a prj file 87 | prj_str = '''GEOGCS["GCS_WGS_1984",DATUM["D_WGS_1984",SPHEROID["WGS_1984",6378137,298.257223563]],PRIMEM["Greenwich",0],UNIT["Degree",0.017453292519943295]]''' 88 | prj_path = shp_path.replace('.shp', '.prj') 89 | with open(prj_path, 'w') as f: 90 | f.write(prj_str) 91 | f.close() 92 | 93 | # if __name__ == "__main__": 94 | # DATA_PATH = SETTING.ROBOSAT_DATASET_PATH + "./dataset/predict_1583054722.123778" 95 | # DATA_PATH = "." 96 | # shp_path = DATA_PATH + "/regularized_footprints.shp" 97 | # shp4326_path = DATA_PATH + "/building4326.shp" 98 | # regularized_json_path = DATA_PATH+"/regularized.json" 99 | # json4326 = DATA_PATH+"/regular_4326.json" 100 | # shp_to_geojson(shp_path,regularized_json_path) 101 | # projected_json = geojson_project( 102 | # regularized_json_path, "epsg:3857", "epsg:4326") 103 | # with open(json4326, 'r') as rg: 104 | # geojson4326 = json.load(rg) 105 | # geojson2shp(geojson4326, shp4326_path) 106 | -------------------------------------------------------------------------------- /app/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geocompass/robosat_geoc/44d95183322ba68c0728de44a66e50f510bfd919/app/models/__init__.py -------------------------------------------------------------------------------- /app/models/base.py: -------------------------------------------------------------------------------- 1 | from app.libs.error_code import NotFound 2 | 3 | from datetime import datetime 4 | 5 | from flask_sqlalchemy import SQLAlchemy as _SQLAlchemy, BaseQuery 6 | from sqlalchemy import inspect, Column, Integer, SmallInteger, orm, DateTime 7 | from contextlib import contextmanager 8 | 9 | 10 | class SQLAlchemy(_SQLAlchemy): 11 | @contextmanager 12 | def auto_commit(self): 13 | try: 14 | yield 15 | self.session.commit() 16 | except Exception as e: 17 | db.session.rollback() 18 | raise e 19 | 20 | 21 | class Query(BaseQuery): 22 | def filter_by(self, **kwargs): 23 | return super(Query, self).filter_by(**kwargs) 24 | 25 | def get_or_404(self, ident): 26 | rv = self.get(ident) 27 | # if not rv: 28 | # raise NotFound() 29 | return rv 30 | 31 | def first_or_404(self): 32 | rv = self.first() 33 | # if not rv: 34 | # raise NotFound() 35 | return rv 36 | 37 | 38 | db = SQLAlchemy(query_class=Query) 39 | 40 | # 通过SQL语句查询数据库 41 | 42 | 43 | def queryBySQL(sql): 44 | return db.session.execute(sql) 45 | 46 | 47 | class Base(db.Model): 48 | __abstract__ = True 49 | # created_at = Column(DateTime) 50 | status = Column(SmallInteger, default=1) 51 | 52 | def __init__(self): 53 | self.created_at = datetime.now() 54 | pass 55 | 56 | def __getitem__(self, item): 57 | return getattr(self, item) 58 | 59 | # @property 60 | # def created_datetime(self): 61 | # if self.created_at: 62 | # return datetime.fromtimestamp(self.created_at) 63 | # else: 64 | # return None 65 | 66 | def set_attrs(self, attrs_dict): 67 | for key, value in attrs_dict.items(): 68 | if hasattr(self, key) and key != 'id': 69 | setattr(self, key, value) 70 | 71 | def delete(self): 72 | self.status = 0 73 | 74 | def keys(self): 75 | return self.fields 76 | 77 | def hide(self, *keys): 78 | for key in keys: 79 | self.fields.remove(key) 80 | return self 81 | 82 | def append(self, *keys): 83 | for key in keys: 84 | self.fields.append(key) 85 | return self 86 | 87 | 88 | class MixinJSONSerializer: 89 | @orm.reconstructor 90 | def init_on_load(self): 91 | self._fields = [] 92 | # self._include = [] 93 | self._exclude = [] 94 | 95 | self._set_fields() 96 | self.__prune_fields() 97 | 98 | def _set_fields(self): 99 | pass 100 | 101 | def __prune_fields(self): 102 | columns = inspect(self.__class__).columns 103 | if not self._fields: 104 | all_columns = set(columns.keys()) 105 | self._fields = list(all_columns - set(self._exclude)) 106 | 107 | def hide(self, *args): 108 | for key in args: 109 | self._fields.remove(key) 110 | return self 111 | 112 | def keys(self): 113 | return self._fields 114 | 115 | def __getitem__(self, key): 116 | return getattr(self, key) 117 | -------------------------------------------------------------------------------- /app/models/buia.py: -------------------------------------------------------------------------------- 1 | from sqlalchemy import Column, String, Boolean, Integer, ForeignKey, Text 2 | from sqlalchemy.orm import relationship 3 | 4 | from app.models.base import Base 5 | 6 | 7 | class BUIA(Base): 8 | gid = Column(Integer, primary_key=True) 9 | CNAME = Column(String(48)) 10 | LEVEL = Column(String(16)) 11 | # geom = Column(Text) 12 | -------------------------------------------------------------------------------- /app/models/predict_buildings.py: -------------------------------------------------------------------------------- 1 | from sqlalchemy import Column, String, Boolean, Integer, ForeignKey, Text 2 | from sqlalchemy.orm import relationship 3 | 4 | from app.models.base import Base 5 | 6 | 7 | class PredictBuildings(Base): 8 | gid = Column(Integer, primary_key=True, autoincrement=True) 9 | geom = Column(Text) 10 | task_id = Column(Integer, primary_key=True) 11 | extent = Column(String(256)) 12 | user_id = Column(String(50)) 13 | area_code = Column(String(50)) 14 | state = Column(Integer, default=1) 15 | status = Column(Integer, default=1) 16 | -------------------------------------------------------------------------------- /app/models/task.py: -------------------------------------------------------------------------------- 1 | from sqlalchemy import Column, String, Boolean, Integer, ForeignKey, Text, DateTime 2 | from sqlalchemy.orm import relationship 3 | 4 | from app.models.base import Base 5 | 6 | 7 | class task(Base): 8 | task_id = Column(Integer, primary_key=True) 9 | extent = Column(String(256)) 10 | user_id = Column(String(50)) 11 | area_code = Column(String(50)) 12 | state = Column(Integer, default=1) 13 | status = Column(Integer, default=1) 14 | end_at = Column(DateTime) 15 | handler = Column(Integer, default=1) 16 | originalextent = Column(String(256)) 17 | -------------------------------------------------------------------------------- /app/models/task_admin.py: -------------------------------------------------------------------------------- 1 | from sqlalchemy import Column, String, Boolean, Integer, ForeignKey, Text, DateTime 2 | from sqlalchemy.orm import relationship 3 | 4 | from app.models.base import Base 5 | 6 | 7 | class task_admin(Base): 8 | task_id = Column(Integer, primary_key=True) 9 | extent = Column(String(256)) 10 | user_id = Column(String(50)) 11 | area_code = Column(String(50)) 12 | state = Column(Integer, default=1) 13 | status = Column(Integer, default=1) 14 | end_at = Column(DateTime) 15 | handler = Column(String(50), default=1) 16 | originalextent = Column(String(256)) 17 | updated_at = Column(DateTime) 18 | -------------------------------------------------------------------------------- /arcpy_geoc/.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | // 使用 IntelliSense 了解相关属性。 3 | // 悬停以查看现有属性的描述。 4 | // 欲了解更多信息,请访问: https://go.microsoft.com/fwlink/?linkid=830387 5 | "version": "0.2.0", 6 | "configurations": [ 7 | { 8 | "name": "Python: 当前文件", 9 | "type": "python", 10 | "request": "launch", 11 | "program": "${file}", 12 | "console": "integratedTerminal" 13 | } 14 | ] 15 | } 16 | -------------------------------------------------------------------------------- /arcpy_geoc/.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "python.pythonPath": "C:\\Python27\\ArcGIS10.7\\python.exe" 3 | } -------------------------------------------------------------------------------- /arcpy_geoc/regular_build.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | from flask import Flask, request 4 | import os 5 | import subprocess 6 | import setting as SETTING 7 | 8 | 9 | app = Flask(__name__) 10 | 11 | current_path = os.path.abspath(os.getcwd()) 12 | 13 | command_path = current_path + "/" + SETTING.COMMAND_FILE 14 | config_path = SETTING.CONFIG_PATH_TXT 15 | 16 | 17 | @app.route('/') 18 | def hello_world(): 19 | return 'Hello flask!' 20 | 21 | 22 | @app.route('/regularize', methods=['GET']) 23 | def wmts(): 24 | print("start flask") 25 | result = { 26 | "code": 1, 27 | "data": None, 28 | "msg": "ok" 29 | } 30 | path = request.args.get("path") 31 | with open(config_path, 'w') as f: 32 | f.write(path) 33 | f.close() 34 | # FNULL = open(os.devnull, 'w') 35 | try: 36 | proc = subprocess.check_output( 37 | [SETTING.CONFIG_ARCPY, command_path], stderr=subprocess.STDOUT) 38 | except subprocess.CalledProcessError: 39 | print(subprocess.STDOUT) 40 | print(proc) 41 | if not proc or "Failed" in proc: 42 | result['code'] = 0 43 | result['msg'] = 'execute arcpy command failed.' 44 | if "regularized.shp already exists" in proc: 45 | result['code'] = 1 46 | result['msg'] = 'regularized.shp already exists.' 47 | return result 48 | 49 | 50 | if __name__ == '__main__': 51 | app.run(host='0.0.0.0', port=5001, debug=True) 52 | -------------------------------------------------------------------------------- /arcpy_geoc/regular_command.py: -------------------------------------------------------------------------------- 1 | import os 2 | import arcpy 3 | import time 4 | import setting as SETTING 5 | 6 | print("start load arcpy") 7 | startTime = time.clock() 8 | 9 | 10 | arcpy.env.workspace = SETTING.DATA_PATH 11 | 12 | envTime = time.clock() 13 | print("envTime:" + str(envTime-startTime)) 14 | 15 | 16 | config_path = SETTING.DATA_PATH + "config.txt" 17 | 18 | with open(config_path, 'r') as f: 19 | lines = f.readlines() 20 | path = lines[0].strip() 21 | 22 | 23 | def reguar(): 24 | # print("start regularize") 25 | try: 26 | DIR_PATH = SETTING.DATA_PATH + path 27 | building1_path = os.path.join(DIR_PATH, 'building1_predict.shp') 28 | building2_path = os.path.join(DIR_PATH, 'building2_3857.shp') 29 | building3_path = os.path.join(DIR_PATH, 'building3_merged.shp') 30 | building4_path = os.path.join(DIR_PATH, 'building4_regularized.shp') 31 | building5_path = os.path.join(DIR_PATH, 'building5_4326.shp') 32 | 33 | # project 34 | WKT4326 = 'GEOGCS["GCS_WGS_1984",DATUM["D_WGS_1984",SPHEROID["WGS_1984",6378137,298.257223563]],PRIMEM["Greenwich",0],UNIT["Degree",0.017453292519943295]]' 35 | WKT3857 = 'PROJCS["WGS_1984_Web_Mercator_Auxiliary_Sphere",GEOGCS["GCS_WGS_1984",DATUM["D_WGS_1984",SPHEROID["WGS_1984",6378137.0,298.257223563]],PRIMEM["Greenwich",0.0],UNIT["Degree",0.0174532925199433]],PROJECTION["Mercator_Auxiliary_Sphere"],PARAMETER["False_Easting",0.0],PARAMETER["False_Northing",0.0],PARAMETER["Central_Meridian",0.0],PARAMETER["Standard_Parallel_1",0.0],PARAMETER["Auxiliary_Sphere_Type",0.0],UNIT["Meter",1.0]]' 36 | CS4326 = arcpy.SpatialReference() 37 | CS3857 = arcpy.SpatialReference() 38 | CS4326.loadFromString(WKT4326) 39 | CS3857.loadFromString(WKT3857) 40 | 41 | # project 42 | print("1.start project") 43 | arcpy.Project_management( 44 | building1_path, building2_path, CS3857, "", CS4326) 45 | 46 | print("2.start merge") 47 | # merge 48 | arcpy.Dissolve_management(building2_path, building3_path, 49 | "", "", "", 50 | "DISSOLVE_LINES") 51 | print("3.start regularize") 52 | # regularize 53 | arcpy.ddd.RegularizeBuildingFootprint(building3_path, 54 | building4_path, 55 | method='RIGHT_ANGLES', 56 | tolerance=10, 57 | precision=0.25, 58 | min_radius=0.1, 59 | max_radius=1000000) 60 | # unpreject 61 | print("4.start unproject") 62 | arcpy.Project_management( 63 | building4_path, building5_path, CS4326, "", CS3857) 64 | except arcpy.ExecuteError: 65 | print(arcpy.GetMessages()) 66 | 67 | endTime = time.clock() 68 | print("end regular") 69 | print("spendssss:" + str(endTime-startTime)) 70 | print("okka") 71 | return "okkkb" 72 | 73 | 74 | if __name__ == "__main__": 75 | reguar() 76 | -------------------------------------------------------------------------------- /arcpy_geoc/setting.py: -------------------------------------------------------------------------------- 1 | # dataset path 2 | DATA_PATH = "C:\\Users\\WUCAN\\Documents\\dataset\\" 3 | 4 | # arcpy path path 5 | CONFIG_ARCPY = 'C:\\Python27\\ArcGIS10.6\\python.exe' 6 | 7 | CONFIG_PATH_TXT = DATA_PATH + "config.txt" 8 | 9 | COMMAND_FILE = "regular_command.py" 10 | -------------------------------------------------------------------------------- /batch_cover.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | import time 3 | import json 4 | from robosat_pink.geoc import RSPcover, utils 5 | from app.libs import utils_geom 6 | 7 | def cover(dsPath,geojson,out): 8 | return RSPcover.main(dsPath,geojson,out) 9 | 10 | if __name__ == "__main__": 11 | # # if cover by dir 12 | # dsPath = "/data/dataset/train/train_2/tdttianjin/training/labels" 13 | # geojson = None 14 | # out = None 15 | 16 | # if cover by geojson 17 | dsPath = None 18 | dir = '/data/dataset/train/train_3_0527' 19 | # jsonFile = open(dir + "/centroid_buffer_union.json", 'r') 20 | geojson = dir + "/centroid_buffer_union.json" 21 | out = [dir+'/cover'] 22 | 23 | # training dataset directory 24 | startTime = datetime.now() 25 | ts = time.time() 26 | 27 | result = cover(dsPath,geojson,out) 28 | 29 | endTime = datetime.now() 30 | timeSpend = (endTime-startTime).seconds 31 | print("Cover DONE!All spends:", timeSpend, "seconds!") 32 | -------------------------------------------------------------------------------- /config.toml: -------------------------------------------------------------------------------- 1 | # RoboSat.pink Configuration 2 | 3 | 4 | # Input channels configuration 5 | # You can, add several channels blocks to compose your input Tensor. Order is meaningful. 6 | # 7 | # name: dataset subdirectory name 8 | # bands: bands to keep from sub source. Order is meaningful 9 | # mean: bands mean value [default, if model pretrained: [0.485, 0.456, 0.406] ] 10 | # std: bands std value [default, if model pretrained: [0.229, 0.224, 0.225] ] 11 | 12 | [[channels]] 13 | name = "images" 14 | bands = [1, 2, 3] 15 | 16 | 17 | # Output Classes configuration 18 | # Nota: available colors are either CSS3 colors names or #RRGGBB hexadecimal representation. 19 | # Nota: only support binary classification for now. 20 | 21 | [[classes]] 22 | title = "Building" 23 | color = "deeppink" 24 | 25 | 26 | 27 | [model] 28 | # Neurals Network name 29 | nn = "Albunet" 30 | 31 | # Pretrained on ImageNet 32 | #pretrained = true 33 | 34 | # Loss function name 35 | loss = "Lovasz" 36 | 37 | # Batch size for training 38 | #bs = 4 39 | 40 | # Learning rate for the optimizer 41 | #lr = 0.000025 42 | 43 | # Model internal input tile size [W, H] 44 | #ts = [512, 512] 45 | 46 | # Dataset loader name 47 | loader = "SemSegTiles" 48 | 49 | # Kind of data augmentation to apply while training 50 | da = "Strong" 51 | 52 | # Data Augmentation probability 53 | #dap = 1.0 54 | 55 | # Metrics 56 | metrics = ["iou", "mcc"] 57 | -------------------------------------------------------------------------------- /data/config.toml: -------------------------------------------------------------------------------- 1 | # Dataset specific common attributes. 2 | [common] 3 | 4 | # Human representation for classes. 5 | classes = ['background', 'parking'] 6 | 7 | 8 | # RoboSat.pink Configuration 9 | 10 | 11 | # Input channels configuration 12 | # You can, add several channels blocks to compose your input Tensor. Order is meaningful. 13 | # 14 | # name: dataset subdirectory name 15 | # bands: bands to keep from sub source. Order is meaningful 16 | # mean: bands mean value [default, if model pretrained: [0.485, 0.456, 0.406] ] 17 | # std: bands std value [default, if model pretrained: [0.229, 0.224, 0.225] ] 18 | 19 | [[channels]] 20 | name = "images" 21 | bands = [1, 2, 3] 22 | 23 | 24 | # Output Classes configuration 25 | # Nota: available colors are either CSS3 colors names or #RRGGBB hexadecimal representation. 26 | # Nota: only support binary classification for now. 27 | 28 | [[classes]] 29 | title = "Building" 30 | color = "deeppink" 31 | 32 | 33 | 34 | [model] 35 | # Neurals Network name 36 | nn = "Albunet" 37 | 38 | # Pretrained on ImageNet 39 | pretrained = true 40 | 41 | # Loss function name 42 | loss = "Lovasz" 43 | 44 | # Batch size for training 45 | bs = 4 46 | 47 | # Learning rate for the optimizer 48 | lr = 0.000025 49 | 50 | # Model internal input tile size [W, H] 51 | # ts = [256, 256] 52 | 53 | # Dataset loader name 54 | loader = "SemSegTiles" 55 | 56 | # Kind of data augmentation to apply while training 57 | da = "Strong" 58 | 59 | # Data Augmentation probability 60 | dap = 1.0 61 | 62 | # Metrics 63 | metrics = ["iou", "mcc"] 64 | -------------------------------------------------------------------------------- /data/dataset-parking.toml: -------------------------------------------------------------------------------- 1 | # Configuration related to a specific dataset. 2 | # For syntax see: https://github.com/toml-lang/toml#table-of-contents 3 | 4 | 5 | # Dataset specific common attributes. 6 | [common] 7 | 8 | # The slippy map dataset's base directory. 9 | dataset = '/tmp/slippy-map-dir/' 10 | 11 | # Human representation for classes. 12 | classes = ['background', 'parking'] 13 | 14 | # Color map for visualization and representing classes in masks. 15 | # Note: available colors can be found in `robosat/colors.py` 16 | colors = ['denim', 'orange'] 17 | 18 | 19 | # Dataset specific class weights computes on the training data. 20 | # Needed by 'mIoU' and 'CrossEntropy' losses to deal with unbalanced classes. 21 | # Note: use `./rs weights -h` to compute these for new datasets. 22 | [weights] 23 | values = [1.6248, 5.762827] 24 | -------------------------------------------------------------------------------- /docs/BUIA.sql: -------------------------------------------------------------------------------- 1 | -- add pgis extension 2 | CREATE extension postgis; 3 | 4 | -- create table 5 | CREATE TABLE "public"."BUIA" ( 6 | "gid" serial4, 7 | "CNAME" varchar(255) COLLATE "pg_catalog"."default", 8 | "LEVEL" varchar(20) COLLATE "pg_catalog"."default", 9 | "geom" "public"."geometry", 10 | CONSTRAINT "BUIA_pkey" PRIMARY KEY ("gid") 11 | ) 12 | ; 13 | 14 | ALTER TABLE "public"."BUIA" 15 | OWNER TO "postgres"; 16 | 17 | -- create index 18 | CREATE INDEX "BUIA_geom_idx" ON "public"."BUIA" USING gist ( 19 | "geom" "public"."gist_geometry_ops_2d" 20 | ); 21 | 22 | CREATE INDEX "BUIA_gid_idx" ON "public"."BUIA" USING btree ( 23 | "gid" "pg_catalog"."int4_ops" ASC NULLS LAST 24 | ) 25 | -------------------------------------------------------------------------------- /docs/Linux安装指南.md: -------------------------------------------------------------------------------- 1 | - linux环境准备 2 | - $ yum update 3 | - $ yum install git 4 | 5 | - 下代码 6 | - git clone https://github.com/geocompass/robosat_geoc.git 7 | 8 | - 安装anaconda 9 | - 下载 Anaconda.sh 64-Bit (x86) Installer (506 MB) 10 | - $ bash Anaconda3-4.4.0-Linux-x86_64.sh 11 | - 默认安装路径:/root/anaconda3 12 | - initialize conda? $ yes 13 | - 增加环境变量 plan A 14 | - $ source /bin/activate(官网推荐) 15 | - 增加环境变量 plan B 16 | - $ vim /root/.bashrc 17 | - added by Anaconda3 4.4.0 installer 18 | - export PATH="/root/anaconda3/bin: - $PATH" 19 | - 保存退出 20 | - $ source /root/.bashrc 21 | - 检查是否安装成功 - $ conda customized 22 | - 更新 - $ conda update -n base -c defaults conda 23 | 24 | - $ conda init或者退出xshell重连 25 | 26 | - 创建虚拟环境 27 | - $ conda create -n robosat 28 | - $ conda activate robosat 29 | 30 | - 安装pip 31 | - $ yum -y install epel-release 32 | - $ yum -y install python-pip 33 | 34 | - 安装rtree 35 | 36 | - 安装rtree依赖 37 | 38 | - 安装libspatialindex 39 | - $ conda install -c conda-forge libspatialindex=1.9.3 40 | - 若libspatial安装成功跳过此步,失败安装cmake 41 | - 下载cmake并移动到linux根目录:https://github.com/Kitware/CMake/releases/download/v3.13.2/cmake-3.13.2.tar.gz 42 | - $ tar -zxvf cmake-3.13.2.tar.gz 43 | - $ cd cmake-3.13.2 44 | - $ ./bootstrap && make && make install 45 | - $ cmake version 3.10.2(失败) 46 | 47 | - 安装rtree 48 | - $ conda install rtree 49 | 50 | - 安装torch 51 | - $ pip install torch 52 | - $ pip install torchvision 53 | 54 | - 安装robosat_geoc依赖 55 | - $ pip install --upgrade pip 56 | - $ pip install -r requirements.txt (因为torch包700m下载过慢,放在安装依赖步骤最后,避免重复耗时) -------------------------------------------------------------------------------- /docs/config.md: -------------------------------------------------------------------------------- 1 | ## config.toml 2 | ``` 3 | # RoboSat.pink Configuration 4 | 5 | 6 | # Input channels configuration 7 | # You can, add several channels blocks to compose your input Tensor. Order is meaningful. 8 | # 9 | # name: dataset subdirectory name 10 | # bands: bands to keep from sub source. Order is meaningful 11 | # mean: bands mean value [default, if model pretrained: [0.485, 0.456, 0.406] ] 12 | # std: bands std value [default, if model pretrained: [0.229, 0.224, 0.225] ] 13 | 14 | [[channels]] 15 | name = "images" 16 | bands = [1, 2, 3] 17 | 18 | 19 | # Output Classes configuration 20 | # Nota: available colors are either CSS3 colors names or #RRGGBB hexadecimal representation. 21 | # Nota: only support binary classification for now. 22 | 23 | [[classes]] 24 | title = "Building" 25 | color = "deeppink" 26 | 27 | 28 | 29 | [model] 30 | # Neurals Network name 31 | nn = "Albunet" 32 | 33 | # Pretrained on ImageNet 34 | #pretrained = true 35 | 36 | # Loss function name 37 | loss = "Lovasz" 38 | 39 | # Batch size for training 40 | #bs = 4 41 | 42 | # Learning rate for the optimizer 43 | #lr = 0.000025 44 | 45 | # Model internal input tile size [W, H] 46 | #ts = [512, 512] 47 | 48 | # Dataset loader name 49 | loader = "SemSegTiles" 50 | 51 | # Kind of data augmentation to apply while training 52 | da = "Strong" 53 | 54 | # Data Augmentation probability 55 | #dap = 1.0 56 | 57 | # Metrics 58 | metrics = ["iou", "mcc"] 59 | ``` 60 | -------------------------------------------------------------------------------- /docs/img/from_opendata_to_opendataset/compare.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geocompass/robosat_geoc/44d95183322ba68c0728de44a66e50f510bfd919/docs/img/from_opendata_to_opendataset/compare.png -------------------------------------------------------------------------------- /docs/img/from_opendata_to_opendataset/compare_clean.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geocompass/robosat_geoc/44d95183322ba68c0728de44a66e50f510bfd919/docs/img/from_opendata_to_opendataset/compare_clean.png -------------------------------------------------------------------------------- /docs/img/from_opendata_to_opendataset/compare_side.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geocompass/robosat_geoc/44d95183322ba68c0728de44a66e50f510bfd919/docs/img/from_opendata_to_opendataset/compare_side.png -------------------------------------------------------------------------------- /docs/img/from_opendata_to_opendataset/compare_side_clean.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geocompass/robosat_geoc/44d95183322ba68c0728de44a66e50f510bfd919/docs/img/from_opendata_to_opendataset/compare_side_clean.png -------------------------------------------------------------------------------- /docs/img/from_opendata_to_opendataset/compare_zoom_out.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geocompass/robosat_geoc/44d95183322ba68c0728de44a66e50f510bfd919/docs/img/from_opendata_to_opendataset/compare_zoom_out.png -------------------------------------------------------------------------------- /docs/img/from_opendata_to_opendataset/images.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geocompass/robosat_geoc/44d95183322ba68c0728de44a66e50f510bfd919/docs/img/from_opendata_to_opendataset/images.png -------------------------------------------------------------------------------- /docs/img/from_opendata_to_opendataset/labels.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geocompass/robosat_geoc/44d95183322ba68c0728de44a66e50f510bfd919/docs/img/from_opendata_to_opendataset/labels.png -------------------------------------------------------------------------------- /docs/img/from_opendata_to_opendataset/masks.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geocompass/robosat_geoc/44d95183322ba68c0728de44a66e50f510bfd919/docs/img/from_opendata_to_opendataset/masks.png -------------------------------------------------------------------------------- /docs/img/quality_analysis/compare.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geocompass/robosat_geoc/44d95183322ba68c0728de44a66e50f510bfd919/docs/img/quality_analysis/compare.png -------------------------------------------------------------------------------- /docs/img/quality_analysis/compare_side.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geocompass/robosat_geoc/44d95183322ba68c0728de44a66e50f510bfd919/docs/img/quality_analysis/compare_side.png -------------------------------------------------------------------------------- /docs/img/quality_analysis/images.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geocompass/robosat_geoc/44d95183322ba68c0728de44a66e50f510bfd919/docs/img/quality_analysis/images.png -------------------------------------------------------------------------------- /docs/img/quality_analysis/labels.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geocompass/robosat_geoc/44d95183322ba68c0728de44a66e50f510bfd919/docs/img/quality_analysis/labels.png -------------------------------------------------------------------------------- /docs/img/quality_analysis/masks.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geocompass/robosat_geoc/44d95183322ba68c0728de44a66e50f510bfd919/docs/img/quality_analysis/masks.png -------------------------------------------------------------------------------- /docs/img/quality_analysis/osm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geocompass/robosat_geoc/44d95183322ba68c0728de44a66e50f510bfd919/docs/img/quality_analysis/osm.png -------------------------------------------------------------------------------- /docs/img/quality_analysis/predict_compare.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geocompass/robosat_geoc/44d95183322ba68c0728de44a66e50f510bfd919/docs/img/quality_analysis/predict_compare.png -------------------------------------------------------------------------------- /docs/img/quality_analysis/predict_compare_side.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geocompass/robosat_geoc/44d95183322ba68c0728de44a66e50f510bfd919/docs/img/quality_analysis/predict_compare_side.png -------------------------------------------------------------------------------- /docs/img/quality_analysis/predict_images.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geocompass/robosat_geoc/44d95183322ba68c0728de44a66e50f510bfd919/docs/img/quality_analysis/predict_images.png -------------------------------------------------------------------------------- /docs/img/quality_analysis/predict_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geocompass/robosat_geoc/44d95183322ba68c0728de44a66e50f510bfd919/docs/img/quality_analysis/predict_mask.png -------------------------------------------------------------------------------- /docs/img/readme/data_preparation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geocompass/robosat_geoc/44d95183322ba68c0728de44a66e50f510bfd919/docs/img/readme/data_preparation.png -------------------------------------------------------------------------------- /docs/img/readme/draw_me_robosat_pink.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geocompass/robosat_geoc/44d95183322ba68c0728de44a66e50f510bfd919/docs/img/readme/draw_me_robosat_pink.png -------------------------------------------------------------------------------- /docs/img/readme/minimal.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geocompass/robosat_geoc/44d95183322ba68c0728de44a66e50f510bfd919/docs/img/readme/minimal.png -------------------------------------------------------------------------------- /docs/img/readme/stacks.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geocompass/robosat_geoc/44d95183322ba68c0728de44a66e50f510bfd919/docs/img/readme/stacks.png -------------------------------------------------------------------------------- /docs/img/readme/top_example.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geocompass/robosat_geoc/44d95183322ba68c0728de44a66e50f510bfd919/docs/img/readme/top_example.jpeg -------------------------------------------------------------------------------- /docs/img/readme/模型优化.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geocompass/robosat_geoc/44d95183322ba68c0728de44a66e50f510bfd919/docs/img/readme/模型优化.png -------------------------------------------------------------------------------- /docs/makefile.md: -------------------------------------------------------------------------------- 1 | ## Makefile 2 | ``` 3 | This Makefile rules are designed for RoboSat.pink devs and power-users. 4 | For plain user installation follow README.md instructions, instead. 5 | 6 | 7 | make install To install, few Python dev tools and RoboSat.pink in editable mode. 8 | So any further RoboSat.pink Python code modification will be usable at once, 9 | throught either rsp tools commands or robosat_pink.* modules. 10 | 11 | make check Launchs code tests, and tools doc updating. 12 | Do it, at least, before sending a Pull Request. 13 | 14 | make check_tuto Launchs rsp commands embeded in tutorials, to be sure everything still up to date. 15 | Do it, at least, on each CLI modifications, and before a release. 16 | NOTA: It takes a while. 17 | 18 | make pink Python code beautifier, 19 | as Pink is the new Black ^^ 20 | ``` 21 | -------------------------------------------------------------------------------- /docs/predict_buildings.sql: -------------------------------------------------------------------------------- 1 | CREATE TABLE "public"."predict_buildings" ( 2 | "gid" serial4, 3 | "task_id" serial4, 4 | "extent" varchar(255) COLLATE "pg_catalog"."default", 5 | "user_id" int4, 6 | "area_code" varchar(20) COLLATE "pg_catalog"."default", 7 | "handler" varchar(50) COLLATE "pg_catalog"."default", 8 | "state" int2 DEFAULT 1, 9 | "status" int2 DEFAULT 1, 10 | "created_at" timestamp(6) DEFAULT CURRENT_TIMESTAMP, 11 | "updated_at" timestamp(6) DEFAULT CURRENT_TIMESTAMP, 12 | "geom" "public"."geometry", 13 | CONSTRAINT "predict_buildings_pkey" PRIMARY KEY ("gid") 14 | ) 15 | ; 16 | 17 | ALTER TABLE "public"."predict_buildings" 18 | OWNER TO "postgres"; 19 | 20 | 21 | -- add COMMENT 22 | COMMENT ON COLUMN "public"."predict_buildings"."task_id" IS '自增序列'; 23 | COMMENT ON COLUMN "public"."predict_buildings"."extent" IS '预测范围'; 24 | COMMENT ON COLUMN "public"."predict_buildings"."user_id" IS '用户编号'; 25 | COMMENT ON COLUMN "public"."predict_buildings"."state" IS '当前状态'; 26 | COMMENT ON COLUMN "public"."predict_buildings"."status" IS '是否删除'; 27 | COMMENT ON COLUMN "public"."predict_buildings"."created_at" IS '创建时间'; 28 | COMMENT ON COLUMN "public"."predict_buildings"."updated_at" IS '更新时间'; 29 | 30 | 31 | CREATE INDEX "predict_buildings_geom_idx" ON "public"."predict_buildings" USING gist ( 32 | "geom" "public"."gist_geometry_ops_2d" 33 | ); 34 | 35 | CREATE INDEX "predict_buildings_gid_idx" ON "public"."predict_buildings" USING btree ( 36 | "gid" "pg_catalog"."int4_ops" ASC NULLS LAST 37 | ); 38 | 39 | -- create update function 40 | CREATE OR REPLACE FUNCTION predict_buildings_update_timestamp () RETURNS TRIGGER AS $$ BEGIN 41 | NEW.updated_at = CURRENT_TIMESTAMP; 42 | RETURN NEW; 43 | END $$ LANGUAGE plpgsql; 44 | 45 | -- create trigger on task 46 | CREATE TRIGGER "predict_buildings_upd" BEFORE UPDATE ON "public"."predict_buildings" FOR EACH ROW 47 | EXECUTE PROCEDURE "public"."predict_buildings_update_timestamp"(); -------------------------------------------------------------------------------- /docs/task.sql: -------------------------------------------------------------------------------- 1 | -- create table 2 | CREATE TABLE "public"."task" ( 3 | "task_id" serial, 4 | "extent" varchar(255) COLLATE "pg_catalog"."default", 5 | "originalextent" varchar(255) COLLATE "pg_catalog"."default", 6 | "user_id" varchar(50), 7 | "area_code" varchar(50), 8 | "handler" varchar(50) COLLATE "pg_catalog"."default", 9 | "state" int2 DEFAULT 1, 10 | "status" int2 DEFAULT 1, 11 | "created_at" timestamp(6) DEFAULT CURRENT_TIMESTAMP, 12 | "updated_at" timestamp(6) DEFAULT CURRENT_TIMESTAMP, 13 | "end_at" timestamp(6) DEFAULT CURRENT_TIMESTAMP, 14 | CONSTRAINT "task_pkey" PRIMARY KEY ("task_id") 15 | ) 16 | ; 17 | ALTER TABLE "public"."task" 18 | OWNER TO "postgres"; 19 | 20 | -- add COMMENT 21 | COMMENT ON COLUMN "public"."task"."task_id" IS '自增序列'; 22 | COMMENT ON COLUMN "public"."task"."extent" IS '预测范围'; 23 | COMMENT ON COLUMN "public"."task"."originalextent" IS '初始范围'; 24 | COMMENT ON COLUMN "public"."task"."user_id" IS '用户编号'; 25 | COMMENT ON COLUMN "public"."task"."area_code" IS '区划代码'; 26 | COMMENT ON COLUMN "public"."task"."state" IS '当前状态'; 27 | COMMENT ON COLUMN "public"."task"."status" IS '是否删除'; 28 | COMMENT ON COLUMN "public"."task"."created_at" IS '创建时间'; 29 | COMMENT ON COLUMN "public"."task"."updated_at" IS '更新时间'; 30 | COMMENT ON COLUMN "public"."task"."end_at" IS '完成时间'; 31 | COMMENT ON COLUMN "public"."task"."handler" IS '预测主机'; 32 | 33 | -- create update function 34 | CREATE OR REPLACE FUNCTION task_update_timestamp () RETURNS TRIGGER AS $$ BEGIN 35 | NEW.updated_at = CURRENT_TIMESTAMP; 36 | RETURN NEW; 37 | END $$ LANGUAGE plpgsql; 38 | 39 | -- create trigger on task 40 | CREATE TRIGGER "task_upd" BEFORE UPDATE ON "public"."task" FOR EACH ROW 41 | EXECUTE PROCEDURE "public"."task_update_timestamp"(); -------------------------------------------------------------------------------- /docs/task_admin.sql: -------------------------------------------------------------------------------- 1 | -- create table 2 | CREATE TABLE "public"."task_admin" ( 3 | "task_id" serial, 4 | "extent" varchar(255) COLLATE "pg_catalog"."default", 5 | "originalextent" varchar(255) COLLATE "pg_catalog"."default", 6 | "user_id" varchar(50), 7 | "area_code" varchar(50), 8 | "handler" varchar(50) COLLATE "pg_catalog"."default", 9 | "state" int2 DEFAULT 1, 10 | "status" int2 DEFAULT 1, 11 | "created_at" timestamp(6) DEFAULT CURRENT_TIMESTAMP, 12 | "updated_at" timestamp(6) DEFAULT CURRENT_TIMESTAMP, 13 | "end_at" timestamp(6) DEFAULT CURRENT_TIMESTAMP, 14 | CONSTRAINT "task_admin_pkey" PRIMARY KEY ("task_id") 15 | ) 16 | ; 17 | ALTER TABLE "public"."task_admin" 18 | OWNER TO "postgres"; 19 | 20 | -- add COMMENT 21 | COMMENT ON COLUMN "public"."task_admin"."task_id" IS '自增序列'; 22 | COMMENT ON COLUMN "public"."task_admin"."extent" IS '预测范围'; 23 | COMMENT ON COLUMN "public"."task_admin"."originalextent" IS '初始范围'; 24 | COMMENT ON COLUMN "public"."task_admin"."user_id" IS '用户编号'; 25 | COMMENT ON COLUMN "public"."task_admin"."area_code" IS '区划代码'; 26 | COMMENT ON COLUMN "public"."task_admin"."state" IS '当前状态'; 27 | COMMENT ON COLUMN "public"."task_admin"."status" IS '是否删除'; 28 | COMMENT ON COLUMN "public"."task_admin"."created_at" IS '创建时间'; 29 | COMMENT ON COLUMN "public"."task_admin"."updated_at" IS '更新时间'; 30 | COMMENT ON COLUMN "public"."task_admin"."end_at" IS '完成时间'; 31 | COMMENT ON COLUMN "public"."task_admin"."handler" IS '预测主机'; 32 | 33 | -- create update function 34 | CREATE OR REPLACE FUNCTION task_update_timestamp () RETURNS TRIGGER AS $$ BEGIN 35 | NEW.updated_at = CURRENT_TIMESTAMP; 36 | RETURN NEW; 37 | END $$ LANGUAGE plpgsql; 38 | 39 | -- create trigger on task 40 | CREATE TRIGGER "task_upd" BEFORE UPDATE ON "public"."task_admin" FOR EACH ROW 41 | EXECUTE PROCEDURE "public"."task_update_timestamp"(); -------------------------------------------------------------------------------- /gunicorn_config.py: -------------------------------------------------------------------------------- 1 | # config.py 教程:https://www.jianshu.com/p/fecf15ad0c9a 2 | import os 3 | import gevent.monkey 4 | gevent.monkey.patch_all() 5 | 6 | import multiprocessing 7 | 8 | # debug = True 9 | loglevel = 'debug' 10 | bind = "0.0.0.0:5000" 11 | pidfile = "log/gunicorn.pid" 12 | accesslog = "log/access.log" 13 | errorlog = "log/debug.log" 14 | daemon = True 15 | 16 | # 启动的进程数 17 | workers = multiprocessing.cpu_count()*2+1 18 | worker_class = 'gevent' 19 | x_forwarded_for_header = 'X-FORWARDED-FOR' -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | """ 2 | Created by wucan on 2019-10-15. 3 | """ 4 | from flask_apscheduler import APScheduler 5 | from werkzeug.exceptions import HTTPException 6 | from app import create_app 7 | from app.libs.error import APIException 8 | from app.libs.error_code import ServerError 9 | from app.api.v1.job import scheduler 10 | 11 | app = create_app() 12 | 13 | 14 | @app.errorhandler(Exception) 15 | def framework_error(e): 16 | if isinstance(e, APIException): 17 | return e 18 | if isinstance(e, HTTPException): 19 | code = e.code 20 | msg = e.description 21 | error_code = 1007 22 | return APIException(msg, code, error_code) 23 | else: 24 | # 调试模式 25 | # log 26 | if not app.config['DEBUG']: 27 | return ServerError() 28 | else: 29 | raise e 30 | 31 | 32 | if __name__ == '__main__': 33 | scheduler.start() 34 | app.run(host='0.0.0.0', port=5000, debug=False) 35 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | pillow 3 | albumentations>=0.2.2 4 | opencv-python-headless 5 | tqdm>=4.29.0 6 | geojson 7 | mercantile>=1.0.4 8 | osmium>=2.15.0 9 | rasterio>=1.0.26 10 | supermercado>=0.0.5 11 | shapely>=1.6.4.post2 12 | pyproj>=1.9.6 13 | toml 14 | webcolors 15 | psycopg2-binary 16 | flask>=1.0 17 | flask-sqlalchemy>=2.3.2 18 | flask-wtf>=0.14.2 19 | cymysql>=0.9.1 20 | flask_cors>=2.1.0 21 | flask-httpauth>=2.7.0 22 | Flask-APScheduler>=1.11.0 23 | requests>=2.18.4 24 | # psycopg2>=2.8.3 25 | wtforms>=2.2.1 26 | sqlalchemy 27 | # rtree 28 | # libspatialindex 29 | # torch>=1.2.0 30 | # torchvision>=0.4.0 31 | geomet>=0.2.1.post1 32 | fiona>=1.8.11 33 | pyshp>=2.1.0 34 | # gdal -------------------------------------------------------------------------------- /robosat_pink/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.5.2" 2 | -------------------------------------------------------------------------------- /robosat_pink/core.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import glob 4 | import toml 5 | from importlib import import_module 6 | 7 | import re 8 | import math 9 | import colorsys 10 | import webcolors 11 | from pathlib import Path 12 | 13 | from robosat_pink.tiles import tile_pixel_to_location, tiles_to_geojson 14 | 15 | 16 | # 17 | # Import module 18 | # 19 | def load_module(module): 20 | module = import_module(module) 21 | assert module, "Unable to import module {}".format(module) 22 | return module 23 | 24 | 25 | # 26 | # Config 27 | # 28 | def load_config(path): 29 | """Loads a dictionary from configuration file.""" 30 | 31 | if not path: 32 | path = os.environ["RSP_CONFIG"] if "RSP_CONFIG" in os.environ else None 33 | if not path: 34 | path = os.path.expanduser("~/.rsp_config") if os.path.isfile(os.path.expanduser("~/.rsp_config")) else None 35 | if not path: 36 | sys.exit("CONFIG ERROR: Either ~/.rsp_config or RSP_CONFIG env var or --config parameter, is required.") 37 | 38 | config = toml.load(os.path.expanduser(path)) 39 | assert config, "Unable to parse config file" 40 | config["classes"].insert(0, {"title": "Background", "color": "white"}) # Insert white Background 41 | 42 | # Set default values 43 | if "model" not in config.keys(): 44 | config["model"] = {} 45 | 46 | if "ts" not in config["model"].keys(): 47 | config["model"]["ts"] = (512, 512) 48 | 49 | if "pretrained" not in config["model"].keys(): 50 | config["model"]["pretrained"] = True 51 | 52 | return config 53 | 54 | 55 | def check_channels(config): 56 | assert "channels" in config.keys(), "At least one Channel is mandatory" 57 | 58 | # TODO Add name check 59 | 60 | # for channel in config["channels"]: 61 | # if not (len(channel["bands"]) == len(channel["mean"]) == len(channel["std"])): 62 | # sys.exit("CONFIG ERROR: Inconsistent channel bands, mean or std lenght in config file") 63 | 64 | 65 | def check_classes(config): 66 | """Check if config file classes subpart is consistent. Exit on error if not.""" 67 | 68 | assert "classes" in config.keys(), "At least one class is mandatory" 69 | 70 | for classe in config["classes"]: 71 | assert "title" in classe.keys() and len(classe["title"]), "Missing or Empty classes.title.value" 72 | assert "color" in classe.keys() and check_color(classe["color"]), "Missing or Invalid classes.color value" 73 | 74 | 75 | def check_model(config): 76 | 77 | hps = {"nn": "str", "pretrained": "bool", "loss": "str", "da": "str"} 78 | for hp in hps: 79 | assert hp in config["model"].keys() and type(config["model"][hp]).__name__ == hps[hp], "Missing or Invalid model" 80 | 81 | 82 | # 83 | # Logs 84 | # 85 | class Logs: 86 | def __init__(self, path, out=sys.stderr): 87 | """Create a logs instance on a logs file.""" 88 | 89 | self.fp = None 90 | self.out = out 91 | if path: 92 | if not os.path.isdir(os.path.dirname(path)): 93 | os.makedirs(os.path.dirname(path), exist_ok=True) 94 | self.fp = open(path, mode="a") 95 | 96 | def log(self, msg): 97 | """Log a new message to the opened logs file, and optionnaly on stdout or stderr too.""" 98 | if self.fp: 99 | self.fp.write(msg + os.linesep) 100 | self.fp.flush() 101 | 102 | if self.out: 103 | print(msg, file=self.out) 104 | 105 | 106 | # 107 | # Colors 108 | # 109 | def make_palette(colors, complementary=False): 110 | """Builds a One Hot PIL color palette from Classes CSS3 color names, or hex values patterns as #RRGGBB.""" 111 | 112 | assert 0 < len(colors) < 8 # 8bits One Hot encoding 113 | 114 | hex_colors = [webcolors.CSS3_NAMES_TO_HEX[color.lower()] if color[0] != "#" else color for color in colors] 115 | rgb_colors = [(int(h[1:3], 16), int(h[3:5], 16), int(h[5:7], 16)) for h in hex_colors] 116 | 117 | one_hot_colors = [(0, 0, 0) for i in range(256)] 118 | one_hot_colors[0] = rgb_colors[0] 119 | for c, color in enumerate(rgb_colors[1:]): 120 | one_hot_colors[int(math.pow(2, c))] = color 121 | 122 | for i in range(3, int(math.pow(2, len(colors) - 1))): 123 | if i not in (4, 8, 16, 32, 64, 128): 124 | one_hot_colors[i] = (0, 0, 0) # TODO compute colors for overlapping classes 125 | 126 | palette = list(sum(one_hot_colors, ())) # flatten 127 | 128 | return palette if not complementary else complementary_palette(palette) 129 | 130 | 131 | def complementary_palette(palette): 132 | """Creates a PIL complementary colors palette based on an initial PIL palette.""" 133 | 134 | comp_palette = [] 135 | colors = [palette[i : i + 3] for i in range(0, len(palette), 3)] 136 | 137 | for color in colors: 138 | r, g, b = [v for v in color] 139 | h, s, v = colorsys.rgb_to_hsv(r, g, b) 140 | comp_palette.extend(map(int, colorsys.hsv_to_rgb((h + 0.5) % 1, s, v))) 141 | 142 | return comp_palette 143 | 144 | 145 | def check_color(color): 146 | """Check if an input color is or not valid (i.e CSS3 color name or #RRGGBB).""" 147 | 148 | hex_color = webcolors.CSS3_NAMES_TO_HEX[color.lower()] if color[0] != "#" else color 149 | return bool(re.match(r"^#([0-9a-fA-F]){6}$", hex_color)) 150 | 151 | 152 | # 153 | # Web UI 154 | # 155 | def web_ui(out, base_url, coverage_tiles, selected_tiles, ext, template, union_tiles=True): 156 | 157 | out = os.path.expanduser(out) 158 | template = os.path.expanduser(template) 159 | 160 | templates = glob.glob(os.path.join(Path(__file__).parent, "web_ui", "*")) 161 | if os.path.isfile(template): 162 | templates.append(template) 163 | if os.path.lexists(os.path.join(out, "index.html")): 164 | os.remove(os.path.join(out, "index.html")) # if already existing output dir, as symlink can't be overwriten 165 | os.symlink(os.path.basename(template), os.path.join(out, "index.html")) 166 | 167 | def process_template(template): 168 | web_ui = open(template, "r").read() 169 | web_ui = re.sub("{{base_url}}", base_url, web_ui) 170 | web_ui = re.sub("{{ext}}", ext, web_ui) 171 | web_ui = re.sub("{{tiles}}", "tiles.json" if selected_tiles else "''", web_ui) 172 | 173 | if coverage_tiles: 174 | tile = list(coverage_tiles)[0] # Could surely be improved, but for now, took the first tile to center on 175 | x, y, z = map(int, [tile.x, tile.y, tile.z]) 176 | web_ui = re.sub("{{zoom}}", str(z), web_ui) 177 | web_ui = re.sub("{{center}}", str(list(tile_pixel_to_location(tile, 0.5, 0.5))[::-1]), web_ui) 178 | 179 | with open(os.path.join(out, os.path.basename(template)), "w", encoding="utf-8") as fp: 180 | fp.write(web_ui) 181 | 182 | for template in templates: 183 | process_template(template) 184 | 185 | if selected_tiles: 186 | with open(os.path.join(out, "tiles.json"), "w", encoding="utf-8") as fp: 187 | fp.write(tiles_to_geojson(selected_tiles, union_tiles)) 188 | -------------------------------------------------------------------------------- /robosat_pink/da/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geocompass/robosat_geoc/44d95183322ba68c0728de44a66e50f510bfd919/robosat_pink/da/__init__.py -------------------------------------------------------------------------------- /robosat_pink/da/core.py: -------------------------------------------------------------------------------- 1 | """PyTorch-compatible Data Augmentation.""" 2 | 3 | import sys 4 | import cv2 5 | import torch 6 | import numpy as np 7 | from importlib import import_module 8 | 9 | 10 | def to_normalized_tensor(config, ts, mode, image, mask=None): 11 | 12 | assert mode == "train" or mode == "predict" 13 | assert len(ts) == 2 14 | assert image is not None 15 | 16 | # Resize, ToTensor and Data Augmentation 17 | if mode == "train": 18 | try: 19 | module = import_module("robosat_pink.da.{}".format(config["model"]["da"].lower())) 20 | except: 21 | sys.exit("Unable to load data augmentation module") 22 | 23 | transform = module.transform(config, image, mask) 24 | image = cv2.resize(image, ts, interpolation=cv2.INTER_LINEAR) 25 | image = torch.from_numpy(np.moveaxis(transform["image"], 2, 0)).float() 26 | mask = cv2.resize(mask, ts, interpolation=cv2.INTER_NEAREST) 27 | mask = torch.from_numpy(transform["mask"]).long() 28 | 29 | elif mode == "predict": 30 | image = cv2.resize(image, ts, interpolation=cv2.INTER_LINEAR) 31 | image = torch.from_numpy(np.moveaxis(image, 2, 0)).float() 32 | 33 | # Normalize 34 | std = [] 35 | mean = [] 36 | 37 | try: 38 | for channel in config["channels"]: 39 | std.extend(channel["std"]) 40 | mean.extend(channel["mean"]) 41 | except: 42 | if config["model"]["pretrained"] and not len(std) and not len(mean): 43 | mean = [0.485, 0.456, 0.406] # RGB ImageNet default 44 | std = [0.229, 0.224, 0.225] # RGB ImageNet default 45 | 46 | assert len(std) and len(mean) 47 | image.sub_(torch.as_tensor(mean, device=image.device)[:, None, None]) 48 | image.div_(torch.as_tensor(std, device=image.device)[:, None, None]) 49 | 50 | if mode == "train": 51 | assert image is not None and mask is not None 52 | return image, mask 53 | 54 | elif mode == "predict": 55 | assert image is not None 56 | return image 57 | -------------------------------------------------------------------------------- /robosat_pink/da/strong.py: -------------------------------------------------------------------------------- 1 | from albumentations import ( 2 | Compose, 3 | IAAAdditiveGaussianNoise, 4 | GaussNoise, 5 | OneOf, 6 | Flip, 7 | Transpose, 8 | MotionBlur, 9 | Blur, 10 | ShiftScaleRotate, 11 | IAASharpen, 12 | IAAEmboss, 13 | RandomBrightnessContrast, 14 | MedianBlur, 15 | HueSaturationValue, 16 | ) 17 | 18 | 19 | def transform(config, image, mask): 20 | 21 | try: 22 | p = config["model"]["dap"] 23 | except: 24 | p = 1 25 | 26 | assert 0 <= p <= 1 27 | 28 | # Inspire by: https://albumentations.readthedocs.io/en/latest/examples.html 29 | return Compose( 30 | [ 31 | Flip(), 32 | Transpose(), 33 | OneOf([IAAAdditiveGaussianNoise(), GaussNoise()], p=0.2), 34 | OneOf([MotionBlur(p=0.2), MedianBlur(blur_limit=3, p=0.1), Blur(blur_limit=3, p=0.1)], p=0.2), 35 | ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.2, rotate_limit=45, p=0.2), 36 | OneOf([IAASharpen(), IAAEmboss(), RandomBrightnessContrast()], p=0.3), 37 | HueSaturationValue(p=0.3), 38 | ] 39 | )(image=image, mask=mask, p=p) 40 | -------------------------------------------------------------------------------- /robosat_pink/geoc/RSPcover.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import shutil 4 | import multiprocessing 5 | 6 | from robosat_pink.tools import cover 7 | from robosat_pink.geoc import config as CONFIG, params, utils 8 | 9 | multiprocessing.set_start_method('spawn', True) 10 | 11 | 12 | def main(dsPath,geojson,out): 13 | params_cover = params.Cover( 14 | dir=dsPath, 15 | bbox=None, 16 | geojson=geojson, 17 | cover=None, 18 | raster=None, 19 | sql=None, 20 | pg=None, 21 | no_xyz=None, 22 | zoom=18, 23 | extent=None, 24 | splits=None, 25 | out=out) 26 | cover.main(params_cover) 27 | 28 | return True 29 | # 2 mins 30 | -------------------------------------------------------------------------------- /robosat_pink/geoc/RSPpredict.py: -------------------------------------------------------------------------------- 1 | import os 2 | from robosat_pink.tools import cover, download, rasterize, predict, vectorize, merge # , features 3 | # from robosat.tools import feature, merge 4 | 5 | import time 6 | import shutil 7 | import json 8 | import multiprocessing 9 | 10 | from robosat_pink.geoc import config as CONFIG, params, utils 11 | 12 | multiprocessing.set_start_method('spawn', True) 13 | 14 | 15 | def main(extent, dataPath, dsPath, map="google", auto_delete=False): 16 | # training or predict checkpoint.pth number 17 | pthNum = utils.getLastPth(dataPath) 18 | if pthNum == 0: 19 | return 'No model was found in directory for prediction' 20 | 21 | params_cover = params.Cover( 22 | bbox=extent, 23 | zoom=18, out=[dsPath + "/cover"]) 24 | cover.main(params_cover) 25 | 26 | params_download = params.Download( 27 | type="XYZ", 28 | url=CONFIG.WMTS_HOST+"/{z}/{x}/{y}?type="+map, 29 | # url='https://server.arcgisonline.com/ArcGIS/rest/services/World_Imagery/MapServer/tile/{z}/{y}/{x}', 30 | # url='https://b.tiles.mapbox.com/v4/mapbox.satellite/{z}/{x}/{y}.png?access_token=pk.eyJ1IjoibWFwYm94IiwiYSI6ImNpejY4NXFhYTA2bTMyeW44ZG0ybXBkMHkifQ.gUGbDOPUN1v1fTs5SeOR4A' 31 | cover=dsPath + "/cover", 32 | out=dsPath + "/images", 33 | timeout=20) 34 | download.main(params_download) 35 | 36 | pthPath = dataPath + "/model/checkpoint-" + \ 37 | str(pthNum).zfill(5)+".pth" 38 | 39 | params_predict = params.Predict( 40 | dataset=dsPath, 41 | checkpoint=pthPath, 42 | config=dataPath+"/config.toml", 43 | out=dsPath + "/masks" 44 | ) 45 | predict.main(params_predict) 46 | 47 | params_vectorize = params.Vectorize( 48 | masks=dsPath + "/masks", 49 | type="Building", 50 | config=dataPath+"/config.toml", 51 | out=dsPath + "/vectors.json" 52 | ) 53 | vectorize.main(params_vectorize) 54 | 55 | jsonFile = open(dsPath + "/vectors.json", 'r') 56 | jsonObj = json.load(jsonFile) 57 | if jsonObj["features"] == []: 58 | return jsonObj 59 | 60 | # # # 解析预测结果并返回 61 | jsonFile = open(dsPath + "/vectors.json", 'r') 62 | jsonObj = json.load(jsonFile) 63 | 64 | # params_features = params.Features( 65 | # masks=dsPath + "/masks", 66 | # type="parking", 67 | # dataset=dataPath+"/config.toml", 68 | # out=dsPath + "/features.json" 69 | # ) 70 | # features.main(params_features) 71 | 72 | # # 解析预测结果并返回 73 | # jsonFile = open(dsPath + "/features.json", 'r') 74 | # jsonObj = json.load(jsonFile) 75 | 76 | # params_merge = params.Merge( 77 | # features=dsPath + "/vectors.json", 78 | # threshold=2, 79 | # out=dsPath + "/merged_features.json" 80 | # ) 81 | # merge.main(params_merge) 82 | 83 | # 解析预测结果并返回 84 | # jsonFile = open(dsPath + "/merged_features.json", 'r') 85 | # jsonObj = json.load(jsonFile) 86 | 87 | # if auto_delete: 88 | # shutil.rmtree(dsPath) 89 | 90 | return jsonObj 91 | -------------------------------------------------------------------------------- /robosat_pink/geoc/RSPreturn_predict.py: -------------------------------------------------------------------------------- 1 | import os 2 | from robosat_pink.tools import cover, download, rasterize, predict, vectorize 3 | # from robosat.tools import feature, merge 4 | import time 5 | import shutil 6 | import json 7 | import multiprocessing 8 | 9 | from robosat_pink.geoc import config as CONFIG, params, utils 10 | 11 | multiprocessing.set_start_method('spawn', True) 12 | 13 | 14 | def main(extent, dataPath, dsPath, map="google", auto_delete=False): 15 | # training or predict checkpoint.pth number 16 | # pthNum = utils.getLastPth(dataPath) 17 | # if pthNum == 0: 18 | # return 'No model was found in directory for prediction' 19 | 20 | params_cover = params.Cover( 21 | bbox=extent, 22 | zoom=18, out=[dsPath + "/cover"]) 23 | cover.main(params_cover) 24 | 25 | # params_download = params.Download( 26 | # type="XYZ", 27 | # url=CONFIG.WMTS_HOST+"/{z}/{x}/{y}?type="+map, 28 | # # url='https://server.arcgisonline.com/ArcGIS/rest/services/World_Imagery/MapServer/tile/{z}/{y}/{x}', 29 | # # url='https://b.tiles.mapbox.com/v4/mapbox.satellite/{z}/{x}/{y}.png?access_token=pk.eyJ1IjoibWFwYm94IiwiYSI6ImNpejY4NXFhYTA2bTMyeW44ZG0ybXBkMHkifQ.gUGbDOPUN1v1fTs5SeOR4A' 30 | # cover=dsPath + "/cover", 31 | # out=dsPath + "/images", 32 | # timeout=20) 33 | # download.main(params_download) 34 | 35 | # pthPath = dataPath + "/model/checkpoint-" + \ 36 | # str(pthNum).zfill(5)+".pth" 37 | 38 | # params_predict = params.Predict( 39 | # dataset=dsPath, 40 | # checkpoint=pthPath, 41 | # config=dataPath+"/config.toml", 42 | # out=dsPath + "/masks" 43 | # ) 44 | # predict.main(params_predict) 45 | 46 | # params_vectorize = params.Vectorize( 47 | # masks=dsPath + "/masks", 48 | # type="Building", 49 | # config=dataPath+"/config.toml", 50 | # out=dsPath + "/vectors.json" 51 | # ) 52 | # vectorize.main(params_vectorize) 53 | 54 | # params_features = params.Features( 55 | # masks=dsPath + "/masks", 56 | # type="parking", 57 | # dataset=dataPath+"/config.toml", 58 | # out=dsPath + "/features.json" 59 | # ) 60 | # feature.main(params_features) 61 | 62 | # # 解析预测结果并返回 63 | # jsonFile = open(dsPath + "/features.json", 'r') 64 | # jsonObj = json.load(jsonFile) 65 | 66 | # params_merge = params.Merge( 67 | # features=dsPath + "/features.json", 68 | # threshold=1, 69 | # out=dsPath + "/merged_features.json" 70 | # ) 71 | # merge.main(params_merge) 72 | 73 | # params_subset_masks = params.Subset( 74 | # dir=dsPath+'/masks', 75 | # cover=dsPath+'/cover', 76 | # out=dsPath+'/masks' 77 | # ) 78 | # subset.main(params_subset_masks) 79 | 80 | # # img to json 81 | # out = subset_features.json 82 | 83 | # 解析预测结果并返回 84 | jsonFile = open(dsPath + "/merged_features.json", 'r') 85 | jsonObj = json.load(jsonFile) 86 | 87 | # if auto_delete: 88 | # shutil.rmtree(dsPath) 89 | 90 | return jsonObj 91 | -------------------------------------------------------------------------------- /robosat_pink/geoc/RSPtrain.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import shutil 4 | import multiprocessing 5 | 6 | from robosat_pink.tools import cover, download, rasterize, subset, train 7 | from robosat_pink.geoc import config as CONFIG, params, utils 8 | 9 | multiprocessing.set_start_method('spawn', True) 10 | 11 | 12 | def main(extent, dataPath, dsPath, epochs=10, map="tdt", auto_delete=False): 13 | # training or predict checkpoint.pth number 14 | pthNum = utils.getLastPth(dataPath) 15 | 16 | params_cover = params.Cover( 17 | bbox=extent, 18 | zoom=18, out=[dsPath + "/cover"]) 19 | cover.main(params_cover) 20 | 21 | params_download = params.Download( 22 | url=CONFIG.WMTS_HOST+"/{z}/{x}/{y}?type="+map, 23 | cover=dsPath + "/cover", 24 | out=dsPath + "/images") 25 | download.main(params_download) 26 | 27 | params_rasterize = params.Rasterize( 28 | config=dataPath+"/config.toml", 29 | type="Building", 30 | ts="256,256", 31 | pg=CONFIG.POSTGRESQL, 32 | sql='SELECT geom FROM "'+CONFIG.BUILDING_TABLE + 33 | '" WHERE ST_Intersects(TILE_GEOM, geom)', 34 | cover=dsPath + "/cover", 35 | out=dsPath + "/labels" 36 | ) 37 | rasterize.main(params_rasterize) 38 | 39 | params_cover2 = params.Cover( 40 | dir=dsPath+'/images', 41 | splits='70/30', 42 | out=[dsPath+'/training/cover', dsPath+'/validation/cover'] 43 | ) 44 | cover.main(params_cover2) 45 | 46 | params_subset_train_images = params.Subset( 47 | dir=dsPath+'/images', 48 | cover=dsPath+'/training/cover', 49 | out=dsPath+'/training/images' 50 | ) 51 | subset.main(params_subset_train_images) 52 | 53 | params_subset_train_labels = params.Subset( 54 | dir=dsPath+'/labels', 55 | cover=dsPath+'/training/cover', 56 | out=dsPath+'/training/labels' 57 | ) 58 | subset.main(params_subset_train_labels) 59 | 60 | params_subset_validation_images = params.Subset( 61 | dir=dsPath+'/images', 62 | cover=dsPath+'/validation/cover', 63 | out=dsPath+'/validation/images' 64 | ) 65 | subset.main(params_subset_validation_images) 66 | 67 | params_subset_validation_labels = params.Subset( 68 | dir=dsPath+'/labels', 69 | cover=dsPath+'/validation/cover', 70 | out=dsPath+'/validation/labels' 71 | ) 72 | subset.main(params_subset_validation_labels) 73 | 74 | params_train = params.Train( 75 | config=dataPath+'/config.toml', 76 | epochs=epochs, 77 | ts="256, 256", 78 | dataset=dsPath, 79 | out=dataPath+'/model' 80 | ) 81 | if pthNum: 82 | pthPath = dataPath + "/model/checkpoint-" + \ 83 | str(pthNum).zfill(5)+".pth" 84 | params_train.checkpoint = pthPath 85 | params_train.resume = True 86 | params_train.epochs = pthNum+epochs 87 | train.main(params_train) 88 | 89 | if auto_delete: 90 | shutil.rmtree(dsPath) 91 | 92 | return True 93 | # 2 mins 94 | -------------------------------------------------------------------------------- /robosat_pink/geoc/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geocompass/robosat_geoc/44d95183322ba68c0728de44a66e50f510bfd919/robosat_pink/geoc/__init__.py -------------------------------------------------------------------------------- /robosat_pink/geoc/config.py: -------------------------------------------------------------------------------- 1 | # PostgreSQL Connection String 2 | POSTGRESQL = "host='localhost' dbname='tdt2018' user='postgres' password='postgres'" 3 | # POSTGRESQL = "host='172.16.100.140' dbname='tdt2018' user='postgres' password='postgres'" 4 | 5 | # building outline PostGIS data table using by training label 6 | BUILDING_TABLE = "BUIA" 7 | 8 | # remote sensing image tiles download host url 9 | WMTS_HOST = "http://127.0.0.1:5000/v1/wmts" 10 | 11 | # tianditu and google map remote sensing wmts url 12 | URL_TDT = '''https://t1.tianditu.gov.cn/DataServer?T=img_w&x={x}&y={y}&l={z}&tk=4830425f5d789b48b967b1062deb8c71''' 13 | URL_GOOGLE = '''http://ditu.google.cn/maps/vt/lyrs=s&x={x}&y={y}&z={z}''' 14 | 15 | # wmts_xyz_proxy port 16 | FLASK_PORT = 5000 17 | -------------------------------------------------------------------------------- /robosat_pink/geoc/pg生成乡为单位的中心点geojson.sql: -------------------------------------------------------------------------------- 1 | --从BUIA整理一个乡为单位的中心点extent 2 | ALTER table data_xiang add COLUMN centroid "public"."geometry"; 3 | ALTER table data_xiang add COLUMN centroid_buff "public"."geometry"; 4 | update data_xiang set centroid = st_centroid(geom) where st_geometrytype(geom)='ST_GeometryCollection' 5 | update data_xiang set centroid_buff = st_buffer(centroid,0.01) where centroid is not NULL 6 | --多个geometry变单个geometry,导出geojson 7 | select st_asgeojson(st_union(centroid_buff)) from data_xiang -------------------------------------------------------------------------------- /robosat_pink/geoc/utils.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | 4 | 5 | def getLastPth(path): 6 | if not os.path.isdir(path): 7 | return 0 8 | files = glob.glob(path + "/model/*.pth") 9 | maxNum = 0 10 | for file in files: 11 | fileName = os.path.basename(file) 12 | name_and_ext = os.path.splitext(fileName) 13 | names = name_and_ext[0].split('-') 14 | if maxNum < int(names[1]): 15 | maxNum = int(names[1]) 16 | return maxNum 17 | -------------------------------------------------------------------------------- /robosat_pink/geojson.py: -------------------------------------------------------------------------------- 1 | from rasterio.crs import CRS 2 | from rasterio.warp import transform 3 | from rasterio.features import rasterize 4 | from rasterio.transform import from_bounds 5 | 6 | import mercantile 7 | from supermercado import burntiles 8 | 9 | from robosat_pink.tiles import tile_bbox 10 | 11 | 12 | def geojson_reproject(feature, srid_in, srid_out): 13 | """Reproject GeoJSON Polygon feature coords 14 | Inspired by: https://gist.github.com/dnomadb/5cbc116aacc352c7126e779c29ab7abe 15 | """ 16 | 17 | if feature["geometry"]["type"] == "Polygon": 18 | xys = (zip(*ring) for ring in feature["geometry"]["coordinates"]) 19 | xys = (list(zip(*transform(CRS.from_epsg(srid_in), CRS.from_epsg(srid_out), *xy))) for xy in xys) 20 | 21 | yield {"coordinates": list(xys), "type": "Polygon"} 22 | 23 | 24 | def geojson_parse_feature(zoom, srid, feature_map, feature): 25 | def geojson_parse_polygon(zoom, srid, feature_map, polygon): 26 | 27 | if srid != 4326: 28 | polygon = [xy for xy in geojson_reproject({"type": "feature", "geometry": polygon}, srid, 4326)][0] 29 | 30 | for i, ring in enumerate(polygon["coordinates"]): # GeoJSON coordinates could be N dimensionals 31 | polygon["coordinates"][i] = [[x, y] for point in ring for x, y in zip([point[0]], [point[1]])] 32 | 33 | if polygon["coordinates"]: 34 | for tile in burntiles.burn([{"type": "feature", "geometry": polygon}], zoom=zoom): 35 | feature_map[mercantile.Tile(*tile)].append({"type": "feature", "geometry": polygon}) 36 | 37 | return feature_map 38 | 39 | def geojson_parse_geometry(zoom, srid, feature_map, geometry): 40 | 41 | if geometry["type"] == "Polygon": 42 | feature_map = geojson_parse_polygon(zoom, srid, feature_map, geometry) 43 | 44 | elif geometry["type"] == "MultiPolygon": 45 | for polygon in geometry["coordinates"]: 46 | feature_map = geojson_parse_polygon(zoom, srid, feature_map, {"type": "Polygon", "coordinates": polygon}) 47 | 48 | return feature_map 49 | 50 | if feature["geometry"]["type"] == "GeometryCollection": 51 | for geometry in feature["geometry"]["geometries"]: 52 | feature_map = geojson_parse_geometry(zoom, srid, feature_map, geometry) 53 | else: 54 | feature_map = geojson_parse_geometry(zoom, srid, feature_map, feature["geometry"]) 55 | 56 | return feature_map 57 | 58 | 59 | def geojson_srid(feature_collection): 60 | 61 | try: 62 | crs_mapping = {"CRS84": "4326", "900913": "3857"} 63 | srid = feature_collection["crs"]["properties"]["name"].split(":")[-1] 64 | srid = int(srid) if srid not in crs_mapping else int(crs_mapping[srid]) 65 | except: 66 | srid = int(4326) 67 | 68 | return srid 69 | 70 | 71 | def geojson_tile_burn(tile, features, srid, ts, burn_value=1): 72 | """Burn tile with GeoJSON features.""" 73 | 74 | shapes = ((geometry, burn_value) for feature in features for geometry in geojson_reproject(feature, srid, 3857)) 75 | 76 | bounds = tile_bbox(tile, mercator=True) 77 | transform = from_bounds(*bounds, *ts) 78 | 79 | try: 80 | return rasterize(shapes, out_shape=ts, transform=transform) 81 | except: 82 | return None 83 | -------------------------------------------------------------------------------- /robosat_pink/graph/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geocompass/robosat_geoc/44d95183322ba68c0728de44a66e50f510bfd919/robosat_pink/graph/__init__.py -------------------------------------------------------------------------------- /robosat_pink/graph/core.py: -------------------------------------------------------------------------------- 1 | import collections 2 | 3 | 4 | class UndirectedGraph: 5 | """Simple undirected graph. 6 | 7 | Note: stores edges; can not store vertices without edges. 8 | """ 9 | 10 | def __init__(self): 11 | """Creates an empty `UndirectedGraph` instance. 12 | """ 13 | 14 | # Todo: We might need a compressed sparse row graph (i.e. adjacency array) 15 | # to make this scale. Let's circle back when we run into this limitation. 16 | self.edges = collections.defaultdict(set) 17 | 18 | def add_edge(self, s, t): 19 | """Adds an edge to the graph. 20 | 21 | Args: 22 | s: the source vertex. 23 | t: the target vertex. 24 | 25 | Note: because this is an undirected graph for every edge `s, t` an edge `t, s` is added. 26 | """ 27 | 28 | self.edges[s].add(t) 29 | self.edges[t].add(s) 30 | 31 | def targets(self, v): 32 | """Returns all outgoing targets for a vertex. 33 | 34 | Args: 35 | v: the vertex to return targets for. 36 | 37 | Returns: 38 | A list of all outgoing targets for the vertex. 39 | """ 40 | 41 | return self.edges[v] 42 | 43 | def vertices(self): 44 | """Returns all vertices in the graph. 45 | 46 | Returns: 47 | A set of all vertices in the graph. 48 | """ 49 | 50 | return self.edges.keys() 51 | 52 | def empty(self): 53 | """Returns true if the graph is empty, false otherwise. 54 | 55 | Returns: 56 | True if the graph has no edges or vertices, false otherwise. 57 | """ 58 | return len(self.edges) == 0 59 | 60 | def dfs(self, v): 61 | """Applies a depth-first search to the graph. 62 | 63 | Args: 64 | v: the vertex to start the depth-first search at. 65 | 66 | Yields: 67 | The visited graph vertices in depth-first search order. 68 | 69 | Note: does not include the start vertex `v` (except if an edge targets it). 70 | """ 71 | 72 | stack = [] 73 | stack.append(v) 74 | 75 | seen = set() 76 | 77 | while stack: 78 | s = stack.pop() 79 | 80 | if s not in seen: 81 | seen.add(s) 82 | 83 | for t in self.targets(s): 84 | stack.append(t) 85 | 86 | yield s 87 | 88 | def components(self): 89 | """Computes connected components for the graph. 90 | 91 | Yields: 92 | The connected component sub-graphs consisting of vertices; in no particular order. 93 | """ 94 | 95 | seen = set() 96 | 97 | for v in self.vertices(): 98 | if v not in seen: 99 | component = set(self.dfs(v)) 100 | component.add(v) 101 | 102 | seen.update(component) 103 | 104 | yield component 105 | -------------------------------------------------------------------------------- /robosat_pink/loaders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geocompass/robosat_geoc/44d95183322ba68c0728de44a66e50f510bfd919/robosat_pink/loaders/__init__.py -------------------------------------------------------------------------------- /robosat_pink/loaders/semsegtiles.py: -------------------------------------------------------------------------------- 1 | """PyTorch-compatible datasets. Cf: https://pytorch.org/docs/stable/data.html """ 2 | 3 | import os 4 | import numpy as np 5 | import torch.utils.data 6 | 7 | from robosat_pink.tiles import tiles_from_dir, tile_image_from_file, tile_label_from_file 8 | from robosat_pink.da.core import to_normalized_tensor 9 | 10 | 11 | class SemSegTiles(torch.utils.data.Dataset): 12 | def __init__(self, config, ts, root, mode): 13 | super().__init__() 14 | 15 | self.root = os.path.expanduser(root) 16 | self.config = config 17 | self.mode = mode 18 | 19 | assert mode == "train" or mode == "predict" 20 | 21 | num_channels = 0 22 | self.tiles = {} 23 | for channel in config["channels"]: 24 | path = os.path.join(self.root, channel["name"]) 25 | self.tiles[channel["name"]] = [(tile, path) for tile, path in tiles_from_dir(path, xyz_path=True)] 26 | self.tiles[channel["name"]].sort(key=lambda tile: tile[0]) 27 | num_channels += len(channel["bands"]) 28 | 29 | self.shape_in = (num_channels,) + ts # C,W,H 30 | self.shape_out = (len(config["classes"]),) + ts # C,W,H 31 | 32 | if self.mode == "train": 33 | path = os.path.join(self.root, "labels") 34 | self.tiles["labels"] = [(tile, path) for tile, path in tiles_from_dir(path, xyz_path=True)] 35 | self.tiles["labels"].sort(key=lambda tile: tile[0]) 36 | 37 | def __len__(self): 38 | return len(self.tiles[self.config["channels"][0]["name"]]) 39 | 40 | def __getitem__(self, i): 41 | 42 | tile = None 43 | mask = None 44 | image = None 45 | 46 | for channel in self.config["channels"]: 47 | 48 | image_channel = None 49 | bands = None if not channel["bands"] else channel["bands"] 50 | 51 | if tile is None: 52 | tile, path = self.tiles[channel["name"]][i] 53 | else: 54 | assert tile == self.tiles[channel["name"]][i][0], "Dataset channel inconsistency" 55 | tile, path = self.tiles[channel["name"]][i] 56 | 57 | image_channel = tile_image_from_file(path, bands) 58 | 59 | assert image_channel is not None, "Dataset channel {} not retrieved: {}".format(channel["name"], path) 60 | image = np.concatenate((image, image_channel), axis=2) if image is not None else image_channel 61 | 62 | if self.mode == "train": 63 | assert tile == self.tiles["labels"][i][0], "Dataset mask inconsistency" 64 | mask = tile_label_from_file(self.tiles["labels"][i][1]) 65 | assert mask is not None, "Dataset mask not retrieved" 66 | 67 | image, mask = to_normalized_tensor(self.config, self.shape_in[1:3], self.mode, image, mask) 68 | return image, mask, tile 69 | 70 | if self.mode == "predict": 71 | image = to_normalized_tensor(self.config, self.shape_in[1:3], self.mode, image) 72 | return image, torch.IntTensor([tile.x, tile.y, tile.z]) 73 | -------------------------------------------------------------------------------- /robosat_pink/losses/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geocompass/robosat_geoc/44d95183322ba68c0728de44a66e50f510bfd919/robosat_pink/losses/__init__.py -------------------------------------------------------------------------------- /robosat_pink/losses/lovasz.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class Lovasz(nn.Module): 6 | """Lovasz Loss. Cf: https://arxiv.org/abs/1705.08790 """ 7 | 8 | def __init__(self): 9 | super().__init__() 10 | 11 | def forward(self, inputs, targets, config): 12 | 13 | N, C, H, W = inputs.size() 14 | masks = torch.zeros(N, C, H, W).to(targets.device).scatter_(1, targets.view(N, 1, H, W), 1) 15 | 16 | loss = 0.0 17 | 18 | for mask, input in zip(masks.view(N, -1), inputs.view(N, -1)): 19 | 20 | max_margin_errors = 1.0 - ((mask * 2 - 1) * input) 21 | errors_sorted, indices = torch.sort(max_margin_errors, descending=True) 22 | labels_sorted = mask[indices.data] 23 | 24 | inter = labels_sorted.sum() - labels_sorted.cumsum(0) 25 | union = labels_sorted.sum() + (1.0 - labels_sorted).cumsum(0) 26 | iou = 1.0 - inter / union 27 | 28 | p = len(labels_sorted) 29 | if p > 1: 30 | iou[1:p] = iou[1:p] - iou[0:-1] 31 | 32 | loss += torch.dot(nn.functional.relu(errors_sorted), iou) 33 | 34 | return loss / N 35 | -------------------------------------------------------------------------------- /robosat_pink/metrics/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geocompass/robosat_geoc/44d95183322ba68c0728de44a66e50f510bfd919/robosat_pink/metrics/__init__.py -------------------------------------------------------------------------------- /robosat_pink/metrics/core.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from robosat_pink.core import load_module 3 | 4 | 5 | class Metrics: 6 | def __init__(self, metrics, config=None): 7 | self.config = config 8 | self.metrics = {metric: 0.0 for metric in metrics} 9 | self.modules = {metric: load_module("robosat_pink.metrics." + metric) for metric in metrics} 10 | self.n = 0 11 | 12 | def add(self, mask, output): 13 | assert self.modules 14 | assert self.metrics 15 | self.n += 1 16 | for metric, module in self.modules.items(): 17 | m = module.get(mask, output, self.config) 18 | self.metrics[metric] += m 19 | 20 | def get(self): 21 | assert self.metrics 22 | assert self.n 23 | return {metric: value / self.n for metric, value in self.metrics.items()} 24 | 25 | 26 | def confusion(predicted, label): 27 | confusion = predicted.view(-1).float() / label.view(-1).float() 28 | 29 | tn = torch.sum(torch.isnan(confusion)).item() 30 | fn = torch.sum(confusion == float("inf")).item() 31 | fp = torch.sum(confusion == 0).item() 32 | tp = torch.sum(confusion == 1).item() 33 | 34 | return tn, fn, fp, tp 35 | -------------------------------------------------------------------------------- /robosat_pink/metrics/iou.py: -------------------------------------------------------------------------------- 1 | from robosat_pink.metrics.core import confusion 2 | 3 | 4 | def get(label, predicted, config=None): 5 | 6 | tn, fn, fp, tp = confusion(label, predicted) 7 | 8 | try: 9 | iou = tp / (tp + fn + fp) 10 | except ZeroDivisionError: 11 | iou = 1.0 12 | 13 | return iou 14 | -------------------------------------------------------------------------------- /robosat_pink/metrics/mcc.py: -------------------------------------------------------------------------------- 1 | import math 2 | from robosat_pink.metrics.core import confusion 3 | 4 | 5 | def get(label, predicted, config=None): 6 | 7 | tn, fn, fp, tp = confusion(label, predicted) 8 | 9 | try: 10 | mcc = (tp * tn - fp * fn) / math.sqrt((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn)) 11 | except ZeroDivisionError: 12 | mcc = 0.0 13 | 14 | return mcc 15 | -------------------------------------------------------------------------------- /robosat_pink/metrics/qod.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | 4 | from robosat_pink.metrics.core import confusion 5 | 6 | 7 | def get(label, mask, config=None): 8 | 9 | tn, fn, fp, tp = confusion(label, mask) 10 | 11 | try: 12 | iou = tp / (tp + fn + fp) 13 | except ZeroDivisionError: 14 | iou = float("NaN") 15 | 16 | W, H = mask.size() 17 | ratio = float(100 * torch.max(torch.sum(mask != 0), torch.sum(label != 0)) / (W * H)) 18 | dist = 0.0 if iou != iou else 1.0 - iou 19 | 20 | qod = 100 - (dist * (math.log(ratio + 1.0) + 1e-7) * (100 / math.log(100))) 21 | qod = 0.0 if qod < 0.0 else qod # Corner case prophilaxy 22 | 23 | return (dist, ratio, qod) 24 | -------------------------------------------------------------------------------- /robosat_pink/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geocompass/robosat_geoc/44d95183322ba68c0728de44a66e50f510bfd919/robosat_pink/models/__init__.py -------------------------------------------------------------------------------- /robosat_pink/models/albunet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torchvision.models import resnet50 4 | 5 | 6 | class ConvRelu(nn.Module): 7 | """3x3 convolution followed by ReLU activation building block.""" 8 | 9 | def __init__(self, num_in, num_out): 10 | super().__init__() 11 | self.block = nn.Conv2d(num_in, num_out, kernel_size=3, padding=1, bias=False) 12 | 13 | def forward(self, x): 14 | return nn.functional.relu(self.block(x), inplace=True) 15 | 16 | 17 | class DecoderBlock(nn.Module): 18 | """Decoder building block upsampling resolution by a factor of two.""" 19 | 20 | def __init__(self, num_in, num_out): 21 | super().__init__() 22 | self.block = ConvRelu(num_in, num_out) 23 | 24 | def forward(self, x): 25 | return self.block(nn.functional.interpolate(x, scale_factor=2, mode="nearest")) 26 | 27 | 28 | class Albunet(nn.Module): 29 | def __init__(self, shape_in, shape_out, train_config=None): 30 | self.doc = """ 31 | U-Net inspired encoder-decoder architecture with a ResNet encoder as proposed by Alexander Buslaev. 32 | 33 | - https://arxiv.org/abs/1505.04597 - U-Net: Convolutional Networks for Biomedical Image Segmentation 34 | - https://arxiv.org/pdf/1804.08024 - Angiodysplasia Detection and Localization Using DCNN 35 | - https://arxiv.org/abs/1806.00844 - TernausNetV2: Fully Convolutional Network for Instance Segmentation 36 | """ 37 | self.version = 1 38 | 39 | num_filters = 32 40 | num_channels = shape_in[0] 41 | num_classes = shape_out[0] 42 | 43 | super().__init__() 44 | 45 | try: 46 | pretrained = train_config["model"]["pretrained"] 47 | except: 48 | pretrained = False 49 | 50 | self.resnet = resnet50(pretrained=pretrained) 51 | 52 | assert num_channels 53 | if num_channels != 3: 54 | weights = nn.init.xavier_uniform_(torch.zeros((64, num_channels, 7, 7))) 55 | if pretrained: 56 | for c in range(min(num_channels, 3)): 57 | weights.data[:, c, :, :] = self.resnet.conv1.weight.data[:, c, :, :] 58 | self.resnet.conv1 = nn.Conv2d(num_channels, 64, kernel_size=7, stride=2, padding=3, bias=False) 59 | self.resnet.conv1.weight = nn.Parameter(weights) 60 | 61 | # No encoder reference, cf: https://github.com/pytorch/pytorch/issues/8392 62 | 63 | self.center = DecoderBlock(2048, num_filters * 8) 64 | 65 | self.dec0 = DecoderBlock(2048 + num_filters * 8, num_filters * 8) 66 | self.dec1 = DecoderBlock(1024 + num_filters * 8, num_filters * 8) 67 | self.dec2 = DecoderBlock(512 + num_filters * 8, num_filters * 2) 68 | self.dec3 = DecoderBlock(256 + num_filters * 2, num_filters * 2 * 2) 69 | self.dec4 = DecoderBlock(num_filters * 2 * 2, num_filters) 70 | self.dec5 = ConvRelu(num_filters, num_filters) 71 | 72 | self.final = nn.Conv2d(num_filters, num_classes, kernel_size=1) 73 | 74 | def forward(self, x): 75 | 76 | enc0 = self.resnet.conv1(x) 77 | enc0 = self.resnet.bn1(enc0) 78 | enc0 = self.resnet.relu(enc0) 79 | enc0 = self.resnet.maxpool(enc0) 80 | 81 | enc1 = self.resnet.layer1(enc0) 82 | enc2 = self.resnet.layer2(enc1) 83 | enc3 = self.resnet.layer3(enc2) 84 | enc4 = self.resnet.layer4(enc3) 85 | 86 | center = self.center(nn.functional.max_pool2d(enc4, kernel_size=2, stride=2)) 87 | 88 | dec0 = self.dec0(torch.cat([enc4, center], dim=1)) 89 | dec1 = self.dec1(torch.cat([enc3, dec0], dim=1)) 90 | dec2 = self.dec2(torch.cat([enc2, dec1], dim=1)) 91 | dec3 = self.dec3(torch.cat([enc1, dec2], dim=1)) 92 | dec4 = self.dec4(dec3) 93 | dec5 = self.dec5(dec4) 94 | 95 | return self.final(dec5) 96 | -------------------------------------------------------------------------------- /robosat_pink/osm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geocompass/robosat_geoc/44d95183322ba68c0728de44a66e50f510bfd919/robosat_pink/osm/__init__.py -------------------------------------------------------------------------------- /robosat_pink/osm/building.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import osmium 4 | import geojson 5 | import shapely.geometry 6 | 7 | 8 | class BuildingHandler(osmium.SimpleHandler): 9 | """Extracts building polygon features (visible in satellite imagery) from the map.""" 10 | 11 | # building=* to discard because these features are not vislible in satellite imagery 12 | building_filter = set( 13 | ["construction", "houseboat", "static_caravan", "stadium", "conservatory", "digester", "greenhouse", "ruins"] 14 | ) 15 | 16 | # location=* to discard because these features are not vislible in satellite imagery 17 | location_filter = set(["underground", "underwater"]) 18 | 19 | def __init__(self): 20 | super().__init__() 21 | self.features = [] 22 | 23 | def way(self, w): 24 | if not w.is_closed() or len(w.nodes) < 4: 25 | return 26 | 27 | if "building" not in w.tags: 28 | return 29 | 30 | if w.tags["building"] in self.building_filter: 31 | return 32 | 33 | if "location" in w.tags and w.tags["location"] in self.location_filter: 34 | return 35 | 36 | geometry = geojson.Polygon([[(n.lon, n.lat) for n in w.nodes]]) 37 | shape = shapely.geometry.shape(geometry) 38 | 39 | if shape.is_valid: 40 | feature = geojson.Feature(geometry=geometry) 41 | self.features.append(feature) 42 | else: 43 | print("Warning: invalid feature: https://www.openstreetmap.org/way/{}".format(w.id), file=sys.stderr) 44 | 45 | def save(self, out): 46 | collection = geojson.FeatureCollection(self.features) 47 | 48 | with open(out, "w") as fp: 49 | geojson.dump(collection, fp) 50 | -------------------------------------------------------------------------------- /robosat_pink/osm/road.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import math 4 | import osmium 5 | import geojson 6 | import shapely.geometry 7 | 8 | 9 | class RoadHandler(osmium.SimpleHandler): 10 | """Extracts road polygon features (visible in satellite imagery) from the map. 11 | """ 12 | 13 | highway_attributes = { 14 | "motorway": {"lanes": 4, "lane_width": 3.75, "left_hard_shoulder_width": 0.75, "right_hard_shoulder_width": 3.0}, 15 | "trunk": {"lanes": 3, "lane_width": 3.75, "left_hard_shoulder_width": 0.75, "right_hard_shoulder_width": 3.0}, 16 | "primary": {"lanes": 2, "lane_width": 3.75, "left_hard_shoulder_width": 0.50, "right_hard_shoulder_width": 1.50}, 17 | "secondary": {"lanes": 1, "lane_width": 3.50, "left_hard_shoulder_width": 0.00, "right_hard_shoulder_width": 0.75}, 18 | "tertiary": {"lanes": 1, "lane_width": 3.50, "left_hard_shoulder_width": 0.00, "right_hard_shoulder_width": 0.75}, 19 | "unclassified": { 20 | "lanes": 1, 21 | "lane_width": 3.50, 22 | "left_hard_shoulder_width": 0.00, 23 | "right_hard_shoulder_width": 0.00, 24 | }, 25 | "residential": {"lanes": 1, "lane_width": 3.50, "left_hard_shoulder_width": 0.00, "right_hard_shoulder_width": 0.75}, 26 | "service": {"lanes": 1, "lane_width": 3.00, "left_hard_shoulder_width": 0.00, "right_hard_shoulder_width": 0.00}, 27 | "motorway_link": { 28 | "lanes": 2, 29 | "lane_width": 3.75, 30 | "left_hard_shoulder_width": 0.75, 31 | "right_hard_shoulder_width": 3.00, 32 | }, 33 | "trunk_link": {"lanes": 2, "lane_width": 3.75, "left_hard_shoulder_width": 0.50, "right_hard_shoulder_width": 1.50}, 34 | "primary_link": { 35 | "lanes": 1, 36 | "lane_width": 3.50, 37 | "left_hard_shoulder_width": 0.00, 38 | "right_hard_shoulder_width": 0.75, 39 | }, 40 | "secondary_link": { 41 | "lanes": 1, 42 | "lane_width": 3.50, 43 | "left_hard_shoulder_width": 0.00, 44 | "right_hard_shoulder_width": 0.75, 45 | }, 46 | "tertiary_link": { 47 | "lanes": 1, 48 | "lane_width": 3.50, 49 | "left_hard_shoulder_width": 0.00, 50 | "right_hard_shoulder_width": 0.00, 51 | }, 52 | } 53 | 54 | road_filter = set(highway_attributes.keys()) 55 | 56 | EARTH_MEAN_RADIUS = 6371004.0 57 | 58 | def __init__(self): 59 | super().__init__() 60 | self.features = [] 61 | 62 | def way(self, w): 63 | if "highway" not in w.tags: 64 | return 65 | 66 | if w.tags["highway"] not in self.road_filter: 67 | return 68 | 69 | left_hard_shoulder_width = self.highway_attributes[w.tags["highway"]]["left_hard_shoulder_width"] 70 | lane_width = self.highway_attributes[w.tags["highway"]]["lane_width"] 71 | lanes = self.highway_attributes[w.tags["highway"]]["lanes"] 72 | right_hard_shoulder_width = self.highway_attributes[w.tags["highway"]]["right_hard_shoulder_width"] 73 | 74 | if "oneway" not in w.tags: 75 | lanes = lanes * 2 76 | elif w.tags["oneway"] == "no": 77 | lanes = lanes * 2 78 | 79 | if "lanes" in w.tags: 80 | try: 81 | # Roads have at least one lane; guard against data issues. 82 | lanes = max(int(w.tags["lanes"]), 1) 83 | 84 | # Todo: take into account related lane tags 85 | # https://wiki.openstreetmap.org/wiki/Tag:busway%3Dlane 86 | # https://wiki.openstreetmap.org/wiki/Tag:cycleway%3Dlane 87 | # https://wiki.openstreetmap.org/wiki/Key:parking:lane 88 | except ValueError: 89 | print("Warning: invalid feature: https://www.openstreetmap.org/way/{}".format(w.id), file=sys.stderr) 90 | 91 | road_width = left_hard_shoulder_width + lane_width * lanes + right_hard_shoulder_width 92 | 93 | if "width" in w.tags: 94 | try: 95 | # At least one meter wide, for road classes specified above 96 | road_width = max(float(w.tags["width"]), 1.0) 97 | 98 | # Todo: handle optional units such as "2 m" 99 | # https://wiki.openstreetmap.org/wiki/Key:width 100 | except ValueError: 101 | print("Warning: invalid feature: https://www.openstreetmap.org/way/{}".format(w.id), file=sys.stderr) 102 | 103 | geometry = geojson.LineString([(n.lon, n.lat) for n in w.nodes]) 104 | shape = shapely.geometry.shape(geometry) 105 | geometry_buffer = shape.buffer(math.degrees(road_width / 2.0 / self.EARTH_MEAN_RADIUS)) 106 | 107 | if shape.is_valid: 108 | feature = geojson.Feature(geometry=shapely.geometry.mapping(geometry_buffer)) 109 | self.features.append(feature) 110 | else: 111 | print("Warning: invalid feature: https://www.openstreetmap.org/way/{}".format(w.id), file=sys.stderr) 112 | 113 | def save(self, out): 114 | collection = geojson.FeatureCollection(self.features) 115 | 116 | with open(out, "w") as fp: 117 | geojson.dump(collection, fp) 118 | -------------------------------------------------------------------------------- /robosat_pink/spatial/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geocompass/robosat_geoc/44d95183322ba68c0728de44a66e50f510bfd919/robosat_pink/spatial/__init__.py -------------------------------------------------------------------------------- /robosat_pink/spatial/core.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | import pyproj 4 | import shapely.ops 5 | 6 | from rtree.index import Index, Property 7 | 8 | 9 | def project(shape, source, target): 10 | """Projects a geometry from one coordinate system into another. 11 | 12 | Args: 13 | shape: the geometry to project. 14 | source: the source EPSG spatial reference system identifier. 15 | target: the target EPSG spatial reference system identifier. 16 | 17 | Returns: 18 | The projected geometry in the target coordinate system. 19 | """ 20 | 21 | project = functools.partial(pyproj.transform, pyproj.Proj(init=source), pyproj.Proj(init=target)) 22 | 23 | return shapely.ops.transform(project, shape) 24 | 25 | 26 | def union(shapes): 27 | """Returns the union of all shapes. 28 | 29 | Args: 30 | shapes: the geometries to merge into one. 31 | 32 | Returns: 33 | The union of all shapes as one shape. 34 | """ 35 | 36 | assert shapes 37 | 38 | def fn(lhs, rhs): 39 | return lhs.union(rhs) 40 | 41 | return functools.reduce(fn, shapes) 42 | 43 | 44 | def iou(lhs, rhs): 45 | """Calculates intersection over union metric between two shapes.. 46 | 47 | Args: 48 | lhs: first shape for IoU calculation. 49 | rhs: second shape for IoU calculation. 50 | 51 | Returns: 52 | IoU metric in range [0, 1] 53 | """ 54 | 55 | # equal-area projection for comparing shape areas 56 | lhs = project(lhs, "epsg:4326", "esri:54009") 57 | rhs = project(rhs, "epsg:4326", "esri:54009") 58 | 59 | intersection = lhs.intersection(rhs) 60 | union = lhs.union(rhs) 61 | 62 | rv = intersection.area / union.area 63 | assert 0 <= rv <= 1 64 | 65 | return rv 66 | 67 | 68 | def make_index(shapes): 69 | """Creates an index for fast and efficient spatial queries. 70 | 71 | Args: 72 | shapes: shapely shapes to bulk-insert bounding boxes for into the spatial index. 73 | 74 | Returns: 75 | The spatial index created from the shape's bounding boxes. 76 | """ 77 | 78 | # Todo: benchmark these for our use-cases 79 | prop = Property() 80 | prop.dimension = 2 81 | prop.leaf_capacity = 1000 82 | prop.fill_factor = 0.9 83 | 84 | def bounded(): 85 | for i, shape in enumerate(shapes): 86 | yield (i, shape.bounds, None) 87 | 88 | return Index(bounded(), properties=prop) 89 | -------------------------------------------------------------------------------- /robosat_pink/tools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geocompass/robosat_geoc/44d95183322ba68c0728de44a66e50f510bfd919/robosat_pink/tools/__init__.py -------------------------------------------------------------------------------- /robosat_pink/tools/__main__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import argparse 4 | 5 | import glob 6 | import shutil 7 | from importlib import import_module 8 | 9 | 10 | def main(): 11 | 12 | if not sys.version_info >= (3, 6): 13 | sys.exit("ERROR: rsp needs Python 3.6 or later.") 14 | 15 | if not len(sys.argv) > 1: 16 | print("rsp: RoboSat.pink command line tools") 17 | print("") 18 | print("Usages:") 19 | print("rsp -h, --help show tools availables") 20 | print("rsp -h, --help show options availables for a tool") 21 | print("rsp [...] launch an rsp tool command") 22 | sys.exit() 23 | 24 | tools = [os.path.basename(tool)[:-3] for tool in glob.glob(os.path.join(os.path.dirname(__file__), "[a-z]*.py"))] 25 | tools = [sys.argv[1]] if sys.argv[1] in tools else tools 26 | 27 | os.environ["COLUMNS"] = str(shutil.get_terminal_size().columns) # cf https://bugs.python.org/issue13041 28 | fc = lambda prog: argparse.RawTextHelpFormatter(prog, max_help_position=40, indent_increment=1) # noqa: E731 29 | for i, arg in enumerate(sys.argv): # handle negative values cf #64 30 | if (arg[0] == "-") and arg[1].isdigit(): 31 | sys.argv[i] = " " + arg 32 | parser = argparse.ArgumentParser(prog="rsp", formatter_class=fc) 33 | subparser = parser.add_subparsers(title="RoboSat.pink tools", metavar="") 34 | 35 | for tool in tools: 36 | module = import_module("robosat_pink.tools.{}".format(tool)) 37 | module.add_parser(subparser, formatter_class=fc) 38 | 39 | args = parser.parse_args() 40 | 41 | if "RSP_DEBUG" in os.environ and os.environ["RSP_DEBUG"] == "1": 42 | args.func(args) 43 | 44 | else: 45 | 46 | try: 47 | args.func(args) 48 | except (Exception) as error: 49 | sys.exit("{}ERROR: {}".format(os.linesep, error)) 50 | -------------------------------------------------------------------------------- /robosat_pink/tools/download.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import math 5 | import concurrent.futures as futures 6 | 7 | import requests 8 | from tqdm import tqdm 9 | from mercantile import xy_bounds 10 | 11 | from robosat_pink.core import web_ui, Logs 12 | from robosat_pink.tiles import tiles_from_csv, tile_image_from_url, tile_image_to_file 13 | 14 | 15 | def add_parser(subparser, formatter_class): 16 | parser = subparser.add_parser( 17 | "download", help="Downloads tiles from a remote server (XYZ or WMS)", formatter_class=formatter_class 18 | ) 19 | 20 | ws = parser.add_argument_group("Web Server") 21 | ws.add_argument("url", type=str, help="URL server endpoint, with: {z}/{x}/{y} or {xmin},{ymin},{xmax},{ymax} [required]") 22 | ws.add_argument("--type", type=str, default="XYZ", choices=["XYZ", "WMS"], help="service type [default: XYZ]") 23 | ws.add_argument("--rate", type=int, default=10, help="download rate limit in max requests/seconds [default: 10]") 24 | ws.add_argument("--timeout", type=int, default=10, help="download request timeout (in seconds) [default: 10]") 25 | ws.add_argument("--workers", type=int, help="number of workers [default: CPU / 2]") 26 | 27 | cover = parser.add_argument_group("Coverage to download") 28 | cover.add_argument("cover", type=str, help="path to .csv tiles list [required]") 29 | 30 | out = parser.add_argument_group("Output") 31 | out.add_argument("--format", type=str, default="webp", help="file format to save images in [default: webp]") 32 | out.add_argument("out", type=str, help="output directory path [required]") 33 | 34 | ui = parser.add_argument_group("Web UI") 35 | ui.add_argument("--web_ui_base_url", type=str, help="alternate Web UI base URL") 36 | ui.add_argument("--web_ui_template", type=str, help="alternate Web UI template path") 37 | ui.add_argument("--no_web_ui", action="store_true", help="desactivate Web UI output") 38 | 39 | parser.set_defaults(func=main) 40 | 41 | 42 | def main(args): 43 | 44 | tiles = list(tiles_from_csv(args.cover)) 45 | os.makedirs(os.path.expanduser(args.out), exist_ok=True) 46 | 47 | if not args.workers: 48 | args.workers = max(1, math.floor(os.cpu_count() * 0.5)) 49 | 50 | log = Logs(os.path.join(args.out, "log"), out=sys.stderr) 51 | log.log("RoboSat.pink - download with {} workers, at max {} req/s, from: {}".format(args.workers, args.rate, args.url)) 52 | 53 | already_dl = 0 54 | dl = 0 55 | 56 | with requests.Session() as session: 57 | 58 | progress = tqdm(total=len(tiles), ascii=True, unit="image") 59 | with futures.ThreadPoolExecutor(args.workers) as executor: 60 | 61 | def worker(tile): 62 | tick = time.monotonic() 63 | progress.update() 64 | 65 | try: 66 | x, y, z = map(str, [tile.x, tile.y, tile.z]) 67 | os.makedirs(os.path.join(args.out, z, x), exist_ok=True) 68 | except: 69 | return tile, None, False 70 | 71 | path = os.path.join(args.out, z, x, "{}.{}".format(y, args.format)) 72 | if os.path.isfile(path): # already downloaded 73 | return tile, None, True 74 | 75 | if args.type == "XYZ": 76 | url = args.url.format(x=tile.x, y=tile.y, z=tile.z) 77 | elif args.type == "WMS": 78 | xmin, ymin, xmax, ymax = xy_bounds(tile) 79 | url = args.url.format(xmin=xmin, ymin=ymin, xmax=xmax, ymax=ymax) 80 | 81 | res = tile_image_from_url(session, url, args.timeout) 82 | if res is None: # let's retry once 83 | res = tile_image_from_url(session, url, args.timeout) 84 | if res is None: 85 | return tile, url, False 86 | 87 | try: 88 | tile_image_to_file(args.out, tile, res) 89 | except OSError: 90 | return tile, url, False 91 | 92 | tock = time.monotonic() 93 | 94 | time_for_req = tock - tick 95 | time_per_worker = args.workers / args.rate 96 | 97 | if time_for_req < time_per_worker: 98 | time.sleep(time_per_worker - time_for_req) 99 | 100 | return tile, url, True 101 | 102 | for tile, url, ok in executor.map(worker, tiles): 103 | if url and ok: 104 | dl += 1 105 | elif not url and ok: 106 | already_dl += 1 107 | else: 108 | log.log("Warning:\n {} failed, skipping.\n {}\n".format(tile, url)) 109 | 110 | if already_dl: 111 | log.log("Notice: {} tiles were already downloaded previously, and so skipped now.".format(already_dl)) 112 | if already_dl + dl == len(tiles): 113 | log.log("Notice: Coverage is fully downloaded.") 114 | 115 | if not args.no_web_ui: 116 | template = "leaflet.html" if not args.web_ui_template else args.web_ui_template 117 | base_url = args.web_ui_base_url if args.web_ui_base_url else "." 118 | web_ui(args.out, base_url, tiles, tiles, args.format, template) 119 | -------------------------------------------------------------------------------- /robosat_pink/tools/export.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import uuid 3 | import torch 4 | import torch.onnx 5 | import torch.autograd 6 | 7 | import robosat_pink as rsp 8 | from robosat_pink.core import load_module 9 | 10 | 11 | def add_parser(subparser, formatter_class): 12 | parser = subparser.add_parser("export", help="Export a model to ONNX or Torch JIT", formatter_class=formatter_class) 13 | 14 | inp = parser.add_argument_group("Inputs") 15 | inp.add_argument("--checkpoint", type=str, required=True, help="model checkpoint to load [required]") 16 | inp.add_argument("--type", type=str, choices=["onnx", "jit", "pth"], default="onnx", help="output type [default: onnx]") 17 | 18 | pth = parser.add_argument_group("To set or override metadata pth parameters:") 19 | pth.add_argument("--nn", type=str, help="nn name") 20 | pth.add_argument("--loader", type=str, help="nn loader") 21 | pth.add_argument("--doc_string", type=str, help="nn documentation abstract") 22 | pth.add_argument("--shape_in", type=str, help="nn shape in (e.g 3,512,512)") 23 | pth.add_argument("--shape_out", type=str, help="nn shape_out (e.g 2,512,512)") 24 | 25 | out = parser.add_argument_group("Output") 26 | out.add_argument("out", type=str, help="path to save export model to [required]") 27 | 28 | parser.set_defaults(func=main) 29 | 30 | 31 | def main(args): 32 | 33 | chkpt = torch.load(args.checkpoint, map_location=torch.device("cpu")) 34 | UUID = chkpt["uuid"] if "uuid" in chkpt else uuid.uuid1() 35 | 36 | try: 37 | nn_name = chkpt["nn"] 38 | except: 39 | assert args.nn, "--nn mandatory as not already in input .pth" 40 | nn_name = args.nn 41 | 42 | try: 43 | loader = chkpt["loader"] 44 | except: 45 | assert args.loader, "--loader mandatory as not already in input .pth" 46 | doc_string = args.doc_string 47 | 48 | try: 49 | doc_string = chkpt["doc_string"] 50 | except: 51 | assert args.doc_string, "--doc_string mandatory as not already in input .pth" 52 | doc_string = args.doc_string 53 | 54 | try: 55 | shape_in = chkpt["shape_in"] 56 | except: 57 | assert args.shape_in, "--shape_in mandatory as not already in input .pth" 58 | shape_in = tuple(map(int, args.shape_in.split(","))) 59 | 60 | try: 61 | shape_out = chkpt["shape_out"] 62 | except: 63 | assert args.shape_out, "--shape_out mandatory as not already in input .pth" 64 | shape_out = tuple(map(int, args.shape_out.split(","))) 65 | 66 | model_module = load_module("robosat_pink.models.{}".format(nn_name.lower())) 67 | nn = getattr(model_module, nn_name)(shape_in, shape_out).to("cpu") 68 | 69 | print("RoboSat.pink - export model to {}".format(args.type), file=sys.stderr) 70 | print("Model: {}".format(nn_name, file=sys.stderr)) 71 | print("UUID: {}".format(UUID, file=sys.stderr)) 72 | 73 | if args.type == "pth": 74 | 75 | states = { 76 | "uuid": UUID, 77 | "model_version": None, 78 | "producer_name": "RoboSat.pink", 79 | "producer_version": rsp.__version__, 80 | "model_licence": "MIT", 81 | "domain": "pink.RoboSat", # reverse-DNS 82 | "doc_string": doc_string, 83 | "shape_in": shape_in, 84 | "shape_out": shape_out, 85 | "state_dict": nn.state_dict(), 86 | "epoch": 0, 87 | "nn": nn_name, 88 | "optimizer": None, 89 | "loader": loader, 90 | } 91 | 92 | torch.save(states, args.out) 93 | 94 | else: 95 | 96 | try: # https://github.com/pytorch/pytorch/issues/9176 97 | nn.module.state_dict(chkpt["state_dict"]) 98 | except AttributeError: 99 | nn.state_dict(chkpt["state_dict"]) 100 | 101 | nn.eval() 102 | 103 | batch = torch.rand(1, *shape_in) 104 | 105 | if args.type == "onnx": 106 | torch.onnx.export( 107 | nn, 108 | torch.autograd.Variable(batch), 109 | args.out, 110 | input_names=["input", "shape_in", "shape_out"], 111 | output_names=["output"], 112 | dynamic_axes={"input": {0: "num_batch"}, "output": {0: "num_batch"}}, 113 | ) 114 | 115 | if args.type == "jit": 116 | torch.jit.trace(nn, batch).save(args.out) 117 | -------------------------------------------------------------------------------- /robosat_pink/tools/extract.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | from robosat_pink.core import load_module 5 | 6 | 7 | def add_parser(subparser, formatter_class): 8 | parser = subparser.add_parser("extract", help="Extracts GeoJSON features from OSM .pbf", formatter_class=formatter_class) 9 | 10 | inp = parser.add_argument_group("Inputs") 11 | inp.add_argument("--type", type=str, required=True, help="type of feature to extract (e.g Building, Road) [required]") 12 | inp.add_argument("pbf", type=str, help="path to .osm.pbf file [required]") 13 | 14 | out = parser.add_argument_group("Output") 15 | out.add_argument("out", type=str, help="GeoJSON output file path [required]") 16 | 17 | parser.set_defaults(func=main) 18 | 19 | 20 | def main(args): 21 | 22 | print("RoboSat.pink - extract {} from {}. Could take time.".format(args.type, args.pbf), file=sys.stderr, flush=True) 23 | 24 | module = load_module("robosat_pink.osm.{}".format(args.type.lower())) 25 | osmium_handler = getattr(module, "{}Handler".format(args.type))() 26 | osmium_handler.apply_file(filename=os.path.expanduser(args.pbf), locations=True) 27 | osmium_handler.save(os.path.expanduser(args.out)) 28 | -------------------------------------------------------------------------------- /robosat_pink/tools/features.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import numpy as np 4 | 5 | from PIL import Image 6 | from tqdm import tqdm 7 | 8 | from robosat_pink.tiles import tiles_from_slippy_map 9 | # from robosat.config import load_config 10 | from robosat_pink.core import load_config 11 | 12 | from robosat.features.parking import ParkingHandler 13 | 14 | 15 | # Register post-processing handlers here; they need to support a `apply(tile, mask)` function 16 | # for handling one mask and a `save(path)` function for GeoJSON serialization to a file. 17 | handlers = {"parking": ParkingHandler} 18 | 19 | 20 | def add_parser(subparser): 21 | parser = subparser.add_parser( 22 | "features", 23 | help="extracts simplified GeoJSON features from segmentation masks", 24 | formatter_class=argparse.ArgumentDefaultsHelpFormatter, 25 | ) 26 | 27 | parser.add_argument("masks", type=str, 28 | help="slippy map directory with segmentation masks") 29 | parser.add_argument("--type", type=str, required=True, 30 | choices=handlers.keys(), help="type of feature to extract") 31 | parser.add_argument("--dataset", type=str, required=True, 32 | help="path to dataset configuration file") 33 | parser.add_argument( 34 | "out", type=str, help="path to GeoJSON file to store features in") 35 | 36 | parser.set_defaults(func=main) 37 | 38 | 39 | def main(args): 40 | dataset = load_config(args.dataset) 41 | 42 | labels = dataset["common"]["classes"] 43 | assert set(labels).issuperset( 44 | set(handlers.keys())), "handlers have a class label" 45 | index = labels.index(args.type) 46 | 47 | handler = handlers[args.type]() 48 | 49 | tiles = list(tiles_from_slippy_map(args.masks)) 50 | 51 | for tile, path in tqdm(tiles, ascii=True, unit="mask"): 52 | image = np.array(Image.open(path).convert("P"), dtype=np.uint8) 53 | mask = (image == index).astype(np.uint8) 54 | 55 | handler.apply(tile, mask) 56 | 57 | handler.save(args.out) 58 | -------------------------------------------------------------------------------- /robosat_pink/tools/info.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import sys 4 | import cv2 5 | import torch 6 | import rasterio 7 | import robosat_pink as rsp 8 | 9 | 10 | def add_parser(subparser, formatter_class): 11 | epilog = "Usages:\nTo kill GPU processes: rsp info --processes | xargs sudo kill -9" 12 | parser = subparser.add_parser("info", help="Provide informations", formatter_class=formatter_class, epilog=epilog) 13 | parser.add_argument("--processes", action="store_true", help="if set, output GPU processes list") 14 | parser.set_defaults(func=main) 15 | 16 | 17 | def main(args): 18 | 19 | if args.processes: 20 | devices = os.getenv("CUDA_VISIBLE_DEVICES") 21 | assert devices, "CUDA_VISIBLE_DEVICES not set." 22 | pids = set() 23 | for i in devices.split(","): 24 | lsof = os.popen("lsof /dev/nvidia{}".format(i)).read() 25 | for row in re.sub("( )+", "|", lsof).split("\n"): 26 | try: 27 | pid = row.split("|")[1] 28 | pids.add(int(pid)) 29 | except: 30 | continue 31 | 32 | for pid in sorted(pids): 33 | print("{} ".format(pid), end="") 34 | 35 | sys.exit() 36 | 37 | print("========================================") 38 | print("RoboSat.pink: " + rsp.__version__) 39 | print("========================================") 40 | print("Python " + sys.version[:5]) 41 | print("Torch " + torch.__version__) 42 | print("OpenCV " + cv2.__version__) 43 | print("GDAL " + rasterio._base.gdal_version()) 44 | print("Cuda " + torch.version.cuda) 45 | print("Cudnn " + str(torch.backends.cudnn.version())) 46 | print("========================================") 47 | print("CPUs " + str(os.cpu_count())) 48 | print("GPUs " + str(torch.cuda.device_count())) 49 | for i in range(torch.cuda.device_count()): 50 | print(" - " + torch.cuda.get_device_name(i)) 51 | print("========================================") 52 | -------------------------------------------------------------------------------- /robosat_pink/tools/merge.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import argparse 3 | 4 | import geojson 5 | 6 | from tqdm import tqdm 7 | import shapely.geometry 8 | 9 | from robosat_pink.spatial.core import make_index, project, union 10 | from robosat_pink.graph.core import UndirectedGraph 11 | from app.config import setting as SETTING 12 | 13 | 14 | def add_parser(subparser): 15 | parser = subparser.add_parser( 16 | "merge", help="merged adjacent GeoJSON features", formatter_class=argparse.ArgumentDefaultsHelpFormatter 17 | ) 18 | 19 | parser.add_argument("features", type=str, 20 | help="GeoJSON file to read features from") 21 | parser.add_argument("--threshold", type=int, required=True, 22 | help="minimum distance to adjacent features, in m") 23 | parser.add_argument( 24 | "out", type=str, help="path to GeoJSON to save merged features to") 25 | 26 | parser.set_defaults(func=main) 27 | 28 | 29 | def main(args): 30 | with open(args.features) as fp: 31 | collection = geojson.load(fp) 32 | 33 | shapes = [shapely.geometry.shape(feature["geometry"]) 34 | for feature in collection["features"]] 35 | del collection 36 | 37 | graph = UndirectedGraph() 38 | idx = make_index(shapes) 39 | 40 | def buffered(shape): 41 | projected = project(shape, "epsg:4326", "epsg:3857") 42 | buffered = projected.buffer(args.threshold) 43 | unprojected = project(buffered, "epsg:3857", "epsg:4326") 44 | return unprojected 45 | 46 | def unbuffered(shape): 47 | projected = project(shape, "epsg:4326", "epsg:3857") 48 | # if int(round(projected.area)) < 100: 49 | # return None 50 | unbuffered = projected.buffer(-1 * args.threshold) 51 | unprojected = project(unbuffered, "epsg:3857", "epsg:4326") 52 | return unprojected 53 | 54 | for i, shape in enumerate(tqdm(shapes, desc="Building graph", unit="shapes", ascii=True)): 55 | embiggened = buffered(shape) 56 | 57 | graph.add_edge(i, i) 58 | 59 | nearest = [j for j in idx.intersection( 60 | embiggened.bounds, objects=False) if i != j] 61 | 62 | for t in nearest: 63 | if embiggened.intersects(shapes[t]): 64 | graph.add_edge(i, t) 65 | 66 | components = list(graph.components()) 67 | assert sum([len(v) for v in components]) == len( 68 | shapes), "components capture all shape indices" 69 | 70 | features = [] 71 | 72 | for component in tqdm(components, desc="Merging components", unit="component", ascii=True): 73 | embiggened = [buffered(shapes[v]) for v in component] 74 | merged = unbuffered(union(embiggened)) 75 | if not merged: 76 | continue 77 | if merged.is_valid: 78 | # Orient exterior ring of the polygon in counter-clockwise direction. 79 | if isinstance(merged, shapely.geometry.polygon.Polygon): 80 | merged = shapely.geometry.polygon.orient(merged, sign=1.0) 81 | elif isinstance(merged, shapely.geometry.multipolygon.MultiPolygon): 82 | merged = [shapely.geometry.polygon.orient( 83 | geom, sign=1.0) for geom in merged.geoms] 84 | merged = shapely.geometry.MultiPolygon(merged) 85 | else: 86 | print( 87 | "Warning: merged feature is neither Polygon nor MultiPoylgon, skipping", file=sys.stderr) 88 | continue 89 | 90 | # equal-area projection; round to full m^2, we're not that precise anyway 91 | area = int(round(project(merged, "epsg:4326", "epsg:3857").area)) 92 | if area < SETTING.MIN_BUILDING_AREA: 93 | continue 94 | 95 | feature = geojson.Feature(geometry=shapely.geometry.mapping( 96 | merged)) # , properties={"area": area} 97 | features.append(feature) 98 | else: 99 | print("Warning: merged feature is not valid, skipping", 100 | file=sys.stderr) 101 | 102 | collection = geojson.FeatureCollection(features) 103 | 104 | with open(args.out, "w") as fp: 105 | geojson.dump(collection, fp) 106 | -------------------------------------------------------------------------------- /robosat_pink/tools/predict.py: -------------------------------------------------------------------------------- 1 | import os 2 | from tqdm import tqdm 3 | 4 | import numpy as np 5 | import mercantile 6 | 7 | import torch 8 | import torch.backends.cudnn 9 | from torch.utils.data import DataLoader 10 | 11 | from robosat_pink.core import load_config, load_module, check_classes, check_channels, make_palette, web_ui, Logs 12 | from robosat_pink.tiles import tiles_from_dir, tile_label_to_file 13 | 14 | 15 | def add_parser(subparser, formatter_class): 16 | parser = subparser.add_parser( 17 | "predict", help="Predict masks, from given inputs and an already trained model", formatter_class=formatter_class 18 | ) 19 | 20 | inp = parser.add_argument_group("Inputs") 21 | inp.add_argument("dataset", type=str, help="predict dataset directory path [required]") 22 | inp.add_argument("--checkpoint", type=str, required=True, help="path to the trained model to use [required]") 23 | inp.add_argument("--config", type=str, help="path to config file [required]") 24 | 25 | out = parser.add_argument_group("Outputs") 26 | out.add_argument("out", type=str, help="output directory path [required]") 27 | 28 | perf = parser.add_argument_group("Data Loaders") 29 | perf.add_argument("--workers", type=int, help="number of workers to load images [default: GPU x 2]") 30 | perf.add_argument("--bs", type=int, default=4, help="batch size value for data loader [default: 4]") 31 | 32 | ui = parser.add_argument_group("Web UI") 33 | ui.add_argument("--web_ui_base_url", type=str, help="alternate Web UI base URL") 34 | ui.add_argument("--web_ui_template", type=str, help="alternate Web UI template path") 35 | ui.add_argument("--no_web_ui", action="store_true", help="desactivate Web UI output") 36 | 37 | parser.set_defaults(func=main) 38 | 39 | 40 | def main(args): 41 | config = load_config(args.config) 42 | check_channels(config) 43 | check_classes(config) 44 | palette = make_palette([classe["color"] for classe in config["classes"]]) 45 | args.workers = torch.cuda.device_count() * 2 if torch.device("cuda") and not args.workers else args.workers 46 | 47 | log = Logs(os.path.join(args.out, "log")) 48 | 49 | if torch.cuda.is_available(): 50 | log.log("RoboSat.pink - predict on {} GPUs, with {} workers".format(torch.cuda.device_count(), args.workers)) 51 | log.log("(Torch:{} Cuda:{} CudNN:{})".format(torch.__version__, torch.version.cuda, torch.backends.cudnn.version())) 52 | device = torch.device("cuda") 53 | torch.backends.cudnn.enabled = True 54 | torch.backends.cudnn.benchmark = True 55 | else: 56 | log.log("RoboSat.pink - predict on CPU, with {} workers".format(args.workers)) 57 | device = torch.device("cpu") 58 | 59 | chkpt = torch.load(args.checkpoint, map_location=device) 60 | model_module = load_module("robosat_pink.models.{}".format(chkpt["nn"].lower())) 61 | nn = getattr(model_module, chkpt["nn"])(chkpt["shape_in"], chkpt["shape_out"]).to(device) 62 | nn = torch.nn.DataParallel(nn) 63 | nn.load_state_dict(chkpt["state_dict"]) 64 | nn.eval() 65 | 66 | log.log("Model {} - UUID: {}".format(chkpt["nn"], chkpt["uuid"])) 67 | 68 | loader_module = load_module("robosat_pink.loaders.{}".format(chkpt["loader"].lower())) 69 | loader_predict = getattr(loader_module, chkpt["loader"])(config, chkpt["shape_in"][1:3], args.dataset, mode="predict") 70 | 71 | loader = DataLoader(loader_predict, batch_size=args.bs, num_workers=args.workers) 72 | assert len(loader), "Empty predict dataset directory. Check your path." 73 | 74 | with torch.no_grad(): # don't track tensors with autograd during prediction 75 | 76 | for images, tiles in tqdm(loader, desc="Eval", unit="batch", ascii=True): 77 | 78 | images = images.to(device) 79 | 80 | outputs = nn(images) 81 | probs = torch.nn.functional.softmax(outputs, dim=1).data.cpu().numpy() 82 | 83 | for tile, prob in zip(tiles, probs): 84 | x, y, z = list(map(int, tile)) 85 | mask = np.around(prob[1:, :, :]).astype(np.uint8).squeeze() 86 | tile_label_to_file(args.out, mercantile.Tile(x, y, z), palette, mask) 87 | 88 | if not args.no_web_ui: 89 | template = "leaflet.html" if not args.web_ui_template else args.web_ui_template 90 | base_url = args.web_ui_base_url if args.web_ui_base_url else "." 91 | tiles = [tile for tile in tiles_from_dir(args.out)] 92 | web_ui(args.out, base_url, tiles, tiles, "png", template) 93 | -------------------------------------------------------------------------------- /robosat_pink/tools/rasterize.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import sys 4 | import math 5 | import json 6 | import collections 7 | 8 | import numpy as np 9 | from tqdm import tqdm 10 | 11 | import psycopg2 12 | 13 | from robosat_pink.core import load_config, check_classes, make_palette, web_ui, Logs 14 | from robosat_pink.tiles import tiles_from_csv, tile_label_to_file, tile_bbox 15 | from robosat_pink.geojson import geojson_srid, geojson_tile_burn, geojson_parse_feature 16 | 17 | 18 | def add_parser(subparser, formatter_class): 19 | parser = subparser.add_parser( 20 | "rasterize", help="Rasterize GeoJSON or PostGIS features to tiles", formatter_class=formatter_class 21 | ) 22 | 23 | inp = parser.add_argument_group("Inputs [either --postgis or --geojson is required]") 24 | inp.add_argument("--cover", type=str, help="path to csv tiles cover file [required]") 25 | inp.add_argument("--config", type=str, help="path to config file [required]") 26 | inp.add_argument("--type", type=str, required=True, help="type of feature to rasterize (e.g Building, Road) [required]") 27 | inp.add_argument("--pg", type=str, help="PostgreSQL dsn using psycopg2 syntax (e.g 'dbname=db user=postgres')") 28 | help = "SQL to retrieve geometry features [e.g SELECT geom FROM a_table WHERE ST_Intersects(TILE_GEOM, geom)]" 29 | inp.add_argument("--sql", type=str, help=help) 30 | inp.add_argument("--geojson", type=str, nargs="+", help="path to GeoJSON features files") 31 | 32 | out = parser.add_argument_group("Outputs") 33 | out.add_argument("out", type=str, help="output directory path [required]") 34 | out.add_argument("--append", action="store_true", help="Append to existing tile if any, useful to multiclass labels") 35 | out.add_argument("--ts", type=str, default="512,512", help="output tile size [default: 512,512]") 36 | 37 | ui = parser.add_argument_group("Web UI") 38 | ui.add_argument("--web_ui_base_url", type=str, help="alternate Web UI base URL") 39 | ui.add_argument("--web_ui_template", type=str, help="alternate Web UI template path") 40 | ui.add_argument("--no_web_ui", action="store_true", help="desactivate Web UI output") 41 | 42 | parser.set_defaults(func=main) 43 | 44 | 45 | def main(args): 46 | 47 | assert not (args.sql and args.geojson), "You can only use at once --pg OR --geojson." 48 | assert not (args.pg and not args.sql), "With PostgreSQL --pg, --sql must also be provided" 49 | assert len(args.ts.split(",")) == 2, "--ts expect width,height value (e.g 512,512)" 50 | 51 | config = load_config(args.config) 52 | check_classes(config) 53 | 54 | palette = make_palette([classe["color"] for classe in config["classes"]], complementary=True) 55 | index = [config["classes"].index(classe) for classe in config["classes"] if classe["title"] == args.type] 56 | assert index, "Requested type is not contains in your config file classes." 57 | burn_value = int(math.pow(2, index[0] - 1)) # 8bits One Hot Encoding 58 | assert 0 <= burn_value <= 128 59 | 60 | args.out = os.path.expanduser(args.out) 61 | os.makedirs(args.out, exist_ok=True) 62 | log = Logs(os.path.join(args.out, "log"), out=sys.stderr) 63 | 64 | if args.geojson: 65 | 66 | tiles = [tile for tile in tiles_from_csv(os.path.expanduser(args.cover))] 67 | assert tiles, "Empty cover" 68 | 69 | zoom = tiles[0].z 70 | assert not [tile for tile in tiles if tile.z != zoom], "Unsupported zoom mixed cover. Use PostGIS instead" 71 | 72 | feature_map = collections.defaultdict(list) 73 | 74 | log.log("RoboSat.pink - rasterize - Compute spatial index") 75 | for geojson_file in args.geojson: 76 | 77 | with open(os.path.expanduser(geojson_file)) as geojson: 78 | feature_collection = json.load(geojson) 79 | srid = geojson_srid(feature_collection) 80 | 81 | feature_map = collections.defaultdict(list) 82 | 83 | for i, feature in enumerate(tqdm(feature_collection["features"], ascii=True, unit="feature")): 84 | feature_map = geojson_parse_feature(zoom, srid, feature_map, feature) 85 | 86 | features = args.geojson 87 | 88 | if args.pg: 89 | 90 | conn = psycopg2.connect(args.pg) 91 | db = conn.cursor() 92 | 93 | assert "limit" not in args.sql.lower(), "LIMIT is not supported" 94 | assert "TILE_GEOM" in args.sql, "TILE_GEOM filter not found in your SQL" 95 | sql = re.sub(r"ST_Intersects( )*\((.*)?TILE_GEOM(.*)?\)", "1=1", args.sql, re.I) 96 | assert sql and sql != args.sql 97 | 98 | db.execute("""SELECT ST_Srid("1") AS srid FROM ({} LIMIT 1) AS t("1")""".format(sql)) 99 | srid = db.fetchone()[0] 100 | assert srid and int(srid) > 0, "Unable to retrieve geometry SRID." 101 | 102 | features = args.sql 103 | 104 | log.log("RoboSat.pink - rasterize - rasterizing {} from {} on cover {}".format(args.type, features, args.cover)) 105 | with open(os.path.join(os.path.expanduser(args.out), "instances_" + args.type.lower() + ".cover"), mode="w") as cover: 106 | 107 | for tile in tqdm(list(tiles_from_csv(os.path.expanduser(args.cover))), ascii=True, unit="tile"): 108 | 109 | geojson = None 110 | 111 | if args.pg: 112 | 113 | w, s, e, n = tile_bbox(tile) 114 | tile_geom = "ST_Transform(ST_MakeEnvelope({},{},{},{}, 4326), {})".format(w, s, e, n, srid) 115 | 116 | query = """ 117 | WITH 118 | sql AS ({}), 119 | geom AS (SELECT "1" AS geom FROM sql AS t("1")), 120 | json AS (SELECT '{{"type": "Feature", "geometry": ' 121 | || ST_AsGeoJSON((ST_Dump(ST_Transform(ST_Force2D(geom.geom), 4326))).geom, 6) 122 | || '}}' AS features 123 | FROM geom) 124 | SELECT '{{"type": "FeatureCollection", "features": [' || Array_To_String(array_agg(features), ',') || ']}}' 125 | FROM json 126 | """.format( 127 | args.sql.replace("TILE_GEOM", tile_geom) 128 | ) 129 | 130 | db.execute(query) 131 | row = db.fetchone() 132 | try: 133 | geojson = json.loads(row[0])["features"] if row and row[0] else None 134 | except Exception: 135 | log.log("Warning: Invalid geometries, skipping {}".format(tile)) 136 | conn = psycopg2.connect(args.pg) 137 | db = conn.cursor() 138 | 139 | if args.geojson: 140 | geojson = feature_map[tile] if tile in feature_map else None 141 | 142 | if geojson: 143 | num = len(geojson) 144 | out = geojson_tile_burn(tile, geojson, 4326, list(map(int, args.ts.split(","))), burn_value) 145 | 146 | if not geojson or out is None: 147 | num = 0 148 | out = np.zeros(shape=list(map(int, args.ts.split(","))), dtype=np.uint8) 149 | 150 | tile_label_to_file(args.out, tile, palette, out, append=args.append) 151 | cover.write("{},{},{} {}{}".format(tile.x, tile.y, tile.z, num, os.linesep)) 152 | 153 | if not args.no_web_ui: 154 | template = "leaflet.html" if not args.web_ui_template else args.web_ui_template 155 | base_url = args.web_ui_base_url if args.web_ui_base_url else "." 156 | tiles = [tile for tile in tiles_from_csv(args.cover)] 157 | web_ui(args.out, base_url, tiles, tiles, "png", template) 158 | -------------------------------------------------------------------------------- /robosat_pink/tools/subset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import shutil 4 | 5 | import mercantile 6 | from tqdm import tqdm 7 | 8 | from robosat_pink.tiles import tiles_from_csv, tile_from_xyz 9 | from robosat_pink.core import web_ui 10 | 11 | 12 | def add_parser(subparser, formatter_class): 13 | parser = subparser.add_parser( 14 | "subset", help="Filter images in a slippy map dir using a csv tiles cover", formatter_class=formatter_class 15 | ) 16 | inp = parser.add_argument_group("Inputs") 17 | inp.add_argument("--dir", type=str, required=True, help="to XYZ tiles input dir path [required]") 18 | inp.add_argument("--cover", type=str, required=True, help="path to csv cover file to filter dir by [required]") 19 | 20 | mode = parser.add_argument_group("Alternate modes, as default is to create relative symlinks.") 21 | mode.add_argument("--copy", action="store_true", help="copy tiles from input to output") 22 | mode.add_argument("--delete", action="store_true", help="delete tiles listed in cover") 23 | 24 | out = parser.add_argument_group("Output") 25 | out.add_argument("out", type=str, nargs="?", default=os.getcwd(), help="output dir path [required for copy or move]") 26 | 27 | ui = parser.add_argument_group("Web UI") 28 | ui.add_argument("--web_ui_base_url", type=str, help="alternate Web UI base URL") 29 | ui.add_argument("--web_ui_template", type=str, help="alternate Web UI template path") 30 | ui.add_argument("--no_web_ui", action="store_true", help="desactivate Web UI output") 31 | 32 | parser.set_defaults(func=main) 33 | 34 | 35 | def main(args): 36 | assert args.out or args.delete, "out parameter is required" 37 | args.out = os.path.expanduser(args.out) 38 | 39 | print("RoboSat.pink - subset {} with cover {}, on CPU".format(args.dir, args.cover), file=sys.stderr, flush=True) 40 | 41 | ext = set() 42 | tiles = set(tiles_from_csv(os.path.expanduser(args.cover))) 43 | 44 | for tile in tqdm(tiles, ascii=True, unit="tiles"): 45 | 46 | if isinstance(tile, mercantile.Tile): 47 | src_tile = tile_from_xyz(args.dir, tile.x, tile.y, tile.z) 48 | if not src_tile: 49 | print("WARNING: skipping tile {}".format(tile), file=sys.stderr, flush=True) 50 | continue 51 | _, src = src_tile 52 | dst_dir = os.path.join(args.out, str(tile.z), str(tile.x)) 53 | else: 54 | src = tile 55 | dst_dir = os.path.join(args.out, os.path.dirname(tile)) 56 | 57 | assert os.path.isfile(src) 58 | dst = os.path.join(dst_dir, os.path.basename(src)) 59 | ext.add(os.path.splitext(src)[1][1:]) 60 | 61 | if not os.path.isdir(dst_dir): 62 | os.makedirs(dst_dir, exist_ok=True) 63 | 64 | if args.delete: 65 | os.remove(src) 66 | assert not os.path.lexists(src) 67 | elif args.copy: 68 | shutil.copyfile(src, dst) 69 | assert os.path.exists(dst) 70 | else: 71 | if os.path.islink(dst): 72 | os.remove(dst) 73 | os.symlink(os.path.relpath(src, os.path.dirname(dst)), dst) 74 | assert os.path.islink(dst) 75 | 76 | if tiles and not args.no_web_ui and not args.delete: 77 | assert len(ext) == 1, "ERROR: Mixed extensions, can't generate Web UI" 78 | template = "leaflet.html" if not args.web_ui_template else args.web_ui_template 79 | base_url = args.web_ui_base_url if args.web_ui_base_url else "." 80 | web_ui(args.out, base_url, tiles, tiles, list(ext)[0], template) 81 | -------------------------------------------------------------------------------- /robosat_pink/tools/vectorize.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from tqdm import tqdm 3 | 4 | import numpy as np 5 | from PIL import Image 6 | 7 | import json 8 | import mercantile 9 | import rasterio.features 10 | import rasterio.transform 11 | 12 | from robosat_pink.core import load_config, check_classes 13 | from robosat_pink.tiles import tiles_from_dir 14 | 15 | 16 | def add_parser(subparser, formatter_class): 17 | parser = subparser.add_parser("vectorize", help="Extract GeoJSON from tiles masks", formatter_class=formatter_class) 18 | 19 | inp = parser.add_argument_group("Inputs") 20 | inp.add_argument("masks", type=str, help="input masks directory path [required]") 21 | inp.add_argument("--type", type=str, required=True, help="type of features to extract (i.e class title) [required]") 22 | inp.add_argument("--config", type=str, help="path to config file [required]") 23 | 24 | out = parser.add_argument_group("Outputs") 25 | out.add_argument("out", type=str, help="path to output file to store features in [required]") 26 | 27 | parser.set_defaults(func=main) 28 | 29 | 30 | def main(args): 31 | config = load_config(args.config) 32 | check_classes(config) 33 | index = [i for i in (list(range(len(config["classes"])))) if config["classes"][i]["title"] == args.type] 34 | assert index, "Requested type {} not found among classes title in the config file.".format(args.type) 35 | print("RoboSat.pink - vectorize {} from {}".format(args.type, args.masks), file=sys.stderr, flush=True) 36 | 37 | out = open(args.out, "w", encoding="utf-8") 38 | assert out, "Unable to write in output file" 39 | 40 | out.write('{"type":"FeatureCollection","features":[') 41 | 42 | first = True 43 | for tile, path in tqdm(list(tiles_from_dir(args.masks, xyz_path=True)), ascii=True, unit="mask"): 44 | mask = (np.array(Image.open(path).convert("P"), dtype=np.uint8) == index).astype(np.uint8) 45 | try: 46 | C, W, H = mask.shape 47 | except: 48 | W, H = mask.shape 49 | transform = rasterio.transform.from_bounds((*mercantile.bounds(tile.x, tile.y, tile.z)), W, H) 50 | 51 | for shape, value in rasterio.features.shapes(mask, transform=transform, mask=mask): 52 | geom = '"geometry":{{"type": "Polygon", "coordinates":{}}}'.format(json.dumps(shape["coordinates"])) 53 | out.write('{}{{"type":"Feature",{}}}'.format("" if first else ",", geom)) 54 | first = False 55 | 56 | out.write("]}") 57 | -------------------------------------------------------------------------------- /robosat_pink/web_ui/compare.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | RoboSat.pink Compare WebUI 5 | 6 | 7 | 8 | 9 | 10 |
11 |
12 |

RoboSat.pink Compare Side

13 |

Shift   select or unselect, the current image.

14 |

Esc     copy selected images list, to clipboard.

15 |

h         hide or display, this help message.

16 |

        previous image to compare, if any.

17 |

        next image to compare, if any.

18 |
19 | 22 |
23 | 65 | 66 | 67 | -------------------------------------------------------------------------------- /robosat_pink/web_ui/leaflet.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | RoboSat.pink Leaflet WebUI 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 |
14 | 30 | 31 | 32 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """ 2 | RoboSat_geoc 3 | ---------- 4 | Computer Vision framework for GeoSpatial Imagery 5 | 6 | """ 7 | 8 | from setuptools import setup, find_packages 9 | from os import path 10 | import re 11 | 12 | here = path.dirname(__file__) 13 | 14 | with open(path.join(here, "robosat_pink", "__init__.py"), encoding="utf-8") as f: 15 | version = re.sub("( )*__version__( )*=( )*", "", f.read()).replace('"', "") 16 | 17 | with open(path.join(here, "README.md"), encoding="utf-8") as f: 18 | long_description = f.read() 19 | 20 | with open(path.join(here, "requirements.txt")) as f: 21 | install_requires = f.read().splitlines() 22 | 23 | setup( 24 | name="RoboSat_geoc", 25 | version=version, 26 | url="https://github.com/geocompass/robosat_geoc", 27 | download_url="https://github.com/geocompass/robosat_geoc/releases", 28 | license="MIT", 29 | maintainer="GEO-COMPASS", 30 | maintainer_email="wucangeo@gmail.com", 31 | description="Computer Vision framework for GeoSpatial Imagery", 32 | long_description=long_description, 33 | long_description_content_type="text/markdown", 34 | packages=find_packages( 35 | exclude=["tests", ".idea", ".vscode", "data", "ds"]), 36 | install_requires=install_requires, 37 | entry_points={"console_scripts": [ 38 | "rsp = robosat_pink.tools.__main__:main"]}, 39 | include_package_data=True, 40 | python_requires=">=3.6", 41 | classifiers=[ 42 | "Development Status :: 4 - Beta", 43 | "Environment :: Console", 44 | "Intended Audience :: Science/Research", 45 | "Intended Audience :: Developers", 46 | "Intended Audience :: Information Technology", 47 | "License :: OSI Approved :: MIT License", 48 | "Natural Language :: English", 49 | "Operating System :: POSIX :: Linux", 50 | "Programming Language :: Python :: 3.6", 51 | "Programming Language :: Python :: 3.7", 52 | "Topic :: Scientific/Engineering :: Image Recognition", 53 | "Topic :: Scientific/Engineering :: GIS", 54 | "Topic :: Software Development :: Libraries :: Application Frameworks", 55 | ], 56 | ) 57 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | import time 3 | import json 4 | from robosat_pink.geoc import RSPtrain, RSPpredict, utils 5 | from app.libs import utils_geom 6 | 7 | 8 | def train(extent, dataPath, dsPath, epochs=10, map="google", auto_delete=False): 9 | return RSPtrain.main(extent, dataPath, dsPath, epochs, map, auto_delete) 10 | 11 | 12 | def predict(extent, dataPath, dsPath, map="google", auto_delete=False): 13 | return RSPpredict.main(extent, dataPath, dsPath, map, auto_delete) 14 | 15 | 16 | if __name__ == "__main__": 17 | # config.toml & checkpoint.pth data directory 18 | # dataPath = "data" 19 | dataPath = "/data/datamodel" 20 | 21 | # training dataset directory 22 | startTime = datetime.now() 23 | ts = time.time() 24 | 25 | # map extent for training or predicting 26 | #extent = "116.286626640306,39.93972566103653,116.29035683687295,39.942521109411445" 27 | #extent = "104.7170 31.5125 104.7834 31.4430"#mianyang 28 | extent = "116.3094,39.9313,116.3114,39.9323" 29 | # extent = "116.2159,39.7963,116.5240,40.0092" 30 | 31 | result = "" 32 | # trainging 33 | # result = train(extent, dataPath, "ds/train_" + str(ts), 1) 34 | 35 | # predicting 36 | result = predict(extent, dataPath, "ds/predict_" + str(ts)) 37 | # print(result) 38 | # geojson 转 shapefile 39 | building_predcit_path = "ds/predict_" + str(ts)+"/building1_predict.shp" 40 | utils_geom.geojson2shp(result, building_predcit_path) 41 | 42 | 43 | endTime = datetime.now() 44 | timeSpend = (endTime-startTime).seconds 45 | print("Training or Predicting DONE!All spends:", timeSpend, "seconds!") 46 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geocompass/robosat_geoc/44d95183322ba68c0728de44a66e50f510bfd919/tests/__init__.py -------------------------------------------------------------------------------- /tests/fixtures/images/18/69105/105093.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geocompass/robosat_geoc/44d95183322ba68c0728de44a66e50f510bfd919/tests/fixtures/images/18/69105/105093.jpg -------------------------------------------------------------------------------- /tests/fixtures/images/18/69108/105091.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geocompass/robosat_geoc/44d95183322ba68c0728de44a66e50f510bfd919/tests/fixtures/images/18/69108/105091.jpg -------------------------------------------------------------------------------- /tests/fixtures/images/18/69108/105092.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geocompass/robosat_geoc/44d95183322ba68c0728de44a66e50f510bfd919/tests/fixtures/images/18/69108/105092.jpg -------------------------------------------------------------------------------- /tests/fixtures/labels/18/69105/105093.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geocompass/robosat_geoc/44d95183322ba68c0728de44a66e50f510bfd919/tests/fixtures/labels/18/69105/105093.png -------------------------------------------------------------------------------- /tests/fixtures/labels/18/69108/105091.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geocompass/robosat_geoc/44d95183322ba68c0728de44a66e50f510bfd919/tests/fixtures/labels/18/69108/105091.png -------------------------------------------------------------------------------- /tests/fixtures/labels/18/69108/105092.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geocompass/robosat_geoc/44d95183322ba68c0728de44a66e50f510bfd919/tests/fixtures/labels/18/69108/105092.png -------------------------------------------------------------------------------- /tests/fixtures/osm/18/69105/105093.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geocompass/robosat_geoc/44d95183322ba68c0728de44a66e50f510bfd919/tests/fixtures/osm/18/69105/105093.png -------------------------------------------------------------------------------- /tests/fixtures/osm/18/69108/105091.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geocompass/robosat_geoc/44d95183322ba68c0728de44a66e50f510bfd919/tests/fixtures/osm/18/69108/105091.png -------------------------------------------------------------------------------- /tests/fixtures/osm/18/69108/105092.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geocompass/robosat_geoc/44d95183322ba68c0728de44a66e50f510bfd919/tests/fixtures/osm/18/69108/105092.png -------------------------------------------------------------------------------- /tests/fixtures/parking/features.geojson: -------------------------------------------------------------------------------- 1 | { 2 | "type": "FeatureCollection", 3 | "features": [ 4 | { 5 | "type": "Feature", 6 | "geometry": { 7 | "type": "Polygon", 8 | "coordinates": [ 9 | [ 10 | [ 11 | -82.8224934, 12 | 34.6787452 13 | ], 14 | [ 15 | -82.8216356, 16 | 34.6787385 17 | ], 18 | [ 19 | -82.8215841, 20 | 34.6778632 21 | ], 22 | [ 23 | -82.8218244, 24 | 34.6775386 25 | ], 26 | [ 27 | -82.8220047, 28 | 34.6773692 29 | ], 30 | [ 31 | -82.8234209, 32 | 34.6773974 33 | ], 34 | [ 35 | -82.8234818, 36 | 34.6774475 37 | ], 38 | [ 39 | -82.8235839, 40 | 34.6775315 41 | ], 42 | [ 43 | -82.8236513, 44 | 34.6781899 45 | ], 46 | [ 47 | -82.8230346, 48 | 34.6784279 49 | ], 50 | [ 51 | -82.8226999, 52 | 34.6785903 53 | ], 54 | [ 55 | -82.8224934, 56 | 34.6787452 57 | ] 58 | ] 59 | ] 60 | }, 61 | "properties": {} 62 | }, 63 | { 64 | "type": "Feature", 65 | "geometry": { 66 | "type": "Polygon", 67 | "coordinates": [ 68 | [ 69 | [ 70 | -106.5503557, 71 | 35.1168049 72 | ], 73 | [ 74 | -106.5503088, 75 | 35.1167621 76 | ], 77 | [ 78 | -106.5501478, 79 | 35.1167522 80 | ], 81 | [ 82 | -106.5500325, 83 | 35.1167511 84 | ], 85 | [ 86 | -106.5500271, 87 | 35.1168959 88 | ], 89 | [ 90 | -106.5500285, 91 | 35.1170813 92 | ], 93 | [ 94 | -106.5500244, 95 | 35.1171098 96 | ], 97 | [ 98 | -106.5499386, 99 | 35.117112 100 | ], 101 | [ 102 | -106.5499476, 103 | 35.117322 104 | ], 105 | [ 106 | -106.5500982, 107 | 35.1173248 108 | ], 109 | [ 110 | -106.5502135, 111 | 35.1174938 112 | ], 113 | [ 114 | -106.5502377, 115 | 35.1175256 116 | ], 117 | [ 118 | -106.5502699, 119 | 35.117541 120 | ], 121 | [ 122 | -106.5504858, 123 | 35.1175453 124 | ], 125 | [ 126 | -106.5506865, 127 | 35.117536 128 | ], 129 | [ 130 | -106.5506741, 131 | 35.1172861 132 | ], 133 | [ 134 | -106.5506729, 135 | 35.1171422 136 | ], 137 | [ 138 | -106.550573, 139 | 35.1171366 140 | ], 141 | [ 142 | -106.5505423, 143 | 35.1170818 144 | ], 145 | [ 146 | -106.5505412, 147 | 35.1170446 148 | ], 149 | [ 150 | -106.5502641, 151 | 35.1170428 152 | ], 153 | [ 154 | -106.55023, 155 | 35.1169657 156 | ], 157 | [ 158 | -106.5502289, 159 | 35.1168654 160 | ], 161 | [ 162 | -106.5503061, 163 | 35.1168412 164 | ], 165 | [ 166 | -106.5503557, 167 | 35.1168049 168 | ] 169 | ] 170 | ] 171 | }, 172 | "properties": {} 173 | } 174 | ] 175 | } 176 | -------------------------------------------------------------------------------- /tests/fixtures/parking/images/18/69623/104946.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geocompass/robosat_geoc/44d95183322ba68c0728de44a66e50f510bfd919/tests/fixtures/parking/images/18/69623/104946.webp -------------------------------------------------------------------------------- /tests/fixtures/parking/images/18/70761/104120.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geocompass/robosat_geoc/44d95183322ba68c0728de44a66e50f510bfd919/tests/fixtures/parking/images/18/70761/104120.webp -------------------------------------------------------------------------------- /tests/fixtures/parking/images/18/70762/104119.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geocompass/robosat_geoc/44d95183322ba68c0728de44a66e50f510bfd919/tests/fixtures/parking/images/18/70762/104119.webp -------------------------------------------------------------------------------- /tests/fixtures/parking/images/18/70763/104119.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geocompass/robosat_geoc/44d95183322ba68c0728de44a66e50f510bfd919/tests/fixtures/parking/images/18/70763/104119.webp -------------------------------------------------------------------------------- /tests/fixtures/parking/labels/18/69623/104946.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geocompass/robosat_geoc/44d95183322ba68c0728de44a66e50f510bfd919/tests/fixtures/parking/labels/18/69623/104946.png -------------------------------------------------------------------------------- /tests/fixtures/parking/labels/18/70761/104120.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geocompass/robosat_geoc/44d95183322ba68c0728de44a66e50f510bfd919/tests/fixtures/parking/labels/18/70761/104120.png -------------------------------------------------------------------------------- /tests/fixtures/parking/labels/18/70762/104119.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geocompass/robosat_geoc/44d95183322ba68c0728de44a66e50f510bfd919/tests/fixtures/parking/labels/18/70762/104119.png -------------------------------------------------------------------------------- /tests/fixtures/parking/labels/18/70763/104119.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geocompass/robosat_geoc/44d95183322ba68c0728de44a66e50f510bfd919/tests/fixtures/parking/labels/18/70763/104119.png -------------------------------------------------------------------------------- /tests/fixtures/parking/tiles.csv: -------------------------------------------------------------------------------- 1 | 70762,104119,18 2 | 69623,104946,18 3 | 70763,104119,18 4 | 70761,104120,18 5 | -------------------------------------------------------------------------------- /tests/fixtures/tiles.csv: -------------------------------------------------------------------------------- 1 | 69623,104945,18 2 | 69622,104945,18 3 | 69623,104946,18 4 | -------------------------------------------------------------------------------- /tests/loaders/test_semsegtiles.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import torch 4 | import mercantile 5 | 6 | from robosat_pink.loaders.semsegtiles import SemSegTiles 7 | 8 | 9 | class TestSemSegTiles(unittest.TestCase): 10 | def test_len(self): 11 | path = "tests/fixtures" 12 | config = { 13 | "channels": [{"name": "images", "bands": [1, 2, 3]}], 14 | "classes": [{"title": "Building", "color": "deeppink"}], 15 | "model": {"pretrained": True, "da": "Strong", "ts": 512}, 16 | } 17 | 18 | # mode train 19 | dataset = SemSegTiles(config, (512, 512), path, "train") 20 | self.assertEqual(len(dataset), 3) 21 | 22 | # mode predict 23 | dataset = SemSegTiles(config, (512, 512), path, "predict") 24 | self.assertEqual(len(dataset), 3) 25 | 26 | def test_getitem(self): 27 | path = "tests/fixtures" 28 | config = { 29 | "channels": [{"name": "images", "bands": [1, 2, 3]}], 30 | "classes": [{"title": "Building", "color": "deeppink"}], 31 | "model": {"pretrained": True, "da": "Strong", "ts": 512}, 32 | } 33 | 34 | # mode train 35 | dataset = SemSegTiles(config, (512, 512), path, "train") 36 | image, mask, tile = dataset[0] 37 | 38 | assert tile == mercantile.Tile(69105, 105093, 18) 39 | self.assertEqual(image.shape, torch.Size([3, 512, 512])) 40 | 41 | # mode predict 42 | dataset = SemSegTiles(config, (512, 512), path, "predict") 43 | images, tiles = dataset[0] 44 | 45 | self.assertEqual(type(images), torch.Tensor) 46 | x, y, z = tiles.numpy() 47 | self.assertEqual(mercantile.Tile(x, y, z), mercantile.Tile(69105, 105093, 18)) 48 | -------------------------------------------------------------------------------- /tests/test_tiles.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import mercantile 4 | 5 | from robosat_pink.tiles import tiles_from_dir, tiles_from_csv 6 | 7 | 8 | class TestSlippyMapTiles(unittest.TestCase): 9 | def test_slippy_map_directory(self): 10 | root = "tests/fixtures/images" 11 | tiles = [(tile, path) for tile, path in tiles_from_dir(root, xyz_path=True)] 12 | tiles.sort() 13 | 14 | self.assertEqual(len(tiles), 3) 15 | 16 | tile, path = tiles[0] 17 | self.assertEqual(type(tile), mercantile.Tile) 18 | self.assertEqual(path, "tests/fixtures/images/18/69105/105093.jpg") 19 | 20 | 21 | class TestReadTiles(unittest.TestCase): 22 | def test_read_tiles(self): 23 | filename = "tests/fixtures/tiles.csv" 24 | tiles = [tile for tile in tiles_from_csv(filename)] 25 | tiles.sort() 26 | 27 | self.assertEqual(len(tiles), 3) 28 | self.assertEqual(tiles[1], mercantile.Tile(69623, 104945, 18)) 29 | -------------------------------------------------------------------------------- /tests/tools/test_rasterize.py: -------------------------------------------------------------------------------- 1 | import json 2 | import unittest 3 | 4 | import numpy as np 5 | import mercantile 6 | 7 | from PIL import Image 8 | 9 | from robosat_pink.geojson import geojson_tile_burn, geojson_reproject 10 | 11 | 12 | def get_parking(): 13 | with open("tests/fixtures/parking/features.geojson") as f: 14 | parking_fc = json.load(f) 15 | 16 | assert len(parking_fc["features"]) == 2 17 | return parking_fc 18 | 19 | 20 | class TestBurn(unittest.TestCase): 21 | def test_burn_with_feature(self): 22 | parking_fc = get_parking() 23 | 24 | # The tile below has a parking lot in our fixtures. 25 | tile = mercantile.Tile(70762, 104119, 18) 26 | 27 | rasterized = geojson_tile_burn(tile, parking_fc["features"], 4326, (512, 512)) 28 | rasterized = Image.fromarray(rasterized, mode="P") 29 | 30 | # rasterized.save('rasterized.png') 31 | 32 | self.assertEqual(rasterized.size, (512, 512)) 33 | 34 | # Tile has a parking feature in our fixtures, thus sum should be non-zero. 35 | self.assertNotEqual(np.sum(rasterized), 0) 36 | 37 | def test_burn_without_feature(self): 38 | parking_fc = get_parking() 39 | 40 | # This tile does not have a parking lot in our fixtures. 41 | tile = mercantile.Tile(69623, 104946, 18) 42 | 43 | rasterized = geojson_tile_burn(tile, parking_fc["features"], 4326, (512, 512)) 44 | rasterized = Image.fromarray(rasterized, mode="P") 45 | 46 | self.assertEqual(rasterized.size, (512, 512)) 47 | 48 | # Tile does not have a parking feature in our fixture, the sum of pixels is zero. 49 | self.assertEqual(np.sum(rasterized), 0) 50 | 51 | 52 | class TestFeatureToMercator(unittest.TestCase): 53 | def test_feature_to_mercator(self): 54 | parking_fc = get_parking() 55 | 56 | parking = parking_fc["features"][0] 57 | mercator = next(geojson_reproject(parking, 4326, 3857)) 58 | 59 | self.assertEqual(mercator["type"], "Polygon") 60 | self.assertEqual(int(mercator["coordinates"][0][0][0]), -9219757) 61 | -------------------------------------------------------------------------------- /webmap/.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | node_modules 3 | 4 | 5 | # local env files 6 | .env.local 7 | .env.*.local 8 | 9 | # Log files 10 | npm-debug.log* 11 | yarn-debug.log* 12 | yarn-error.log* 13 | 14 | # Editor directories and files 15 | .idea 16 | .vscode 17 | *.suo 18 | *.ntvs* 19 | *.njsproj 20 | *.sln 21 | *.sw? 22 | -------------------------------------------------------------------------------- /webmap/README.md: -------------------------------------------------------------------------------- 1 | # webmap 2 | 3 | ## Project setup 4 | ``` 5 | npm install 6 | ``` 7 | 8 | ### Compiles and hot-reloads for development 9 | ``` 10 | npm run serve 11 | ``` 12 | 13 | ### Compiles and minifies for production 14 | ``` 15 | npm run build 16 | ``` 17 | 18 | ### Lints and fixes files 19 | ``` 20 | npm run lint 21 | ``` 22 | 23 | ### Customize configuration 24 | See [Configuration Reference](https://cli.vuejs.org/config/). 25 | -------------------------------------------------------------------------------- /webmap/babel.config.js: -------------------------------------------------------------------------------- 1 | module.exports = { 2 | presets: [ 3 | '@vue/cli-plugin-babel/preset' 4 | ] 5 | } 6 | -------------------------------------------------------------------------------- /webmap/dist/config.js: -------------------------------------------------------------------------------- 1 | const CONFIG = { 2 | SERVER: "http://127.0.0.1:5000", 3 | HOST: "http://127.0.0.1:5000", 4 | 5 | // SERVER: "http://127.0.0.1:5000", 6 | // HOST: "http://127.0.0.1:8080", 7 | BUILDING_STYLE: { 8 | /* 为地图添加layer */ 9 | id: "tempbuild" /* layer id是route */, 10 | type: "fill" /* line类型layer*/, 11 | source: "tempbuild" /* 资源引用的是上面定义的source*/, 12 | paint: { 13 | "fill-color": "rgba(194,186,62,0.4)", 14 | "fill-outline-color": "rgba(255,0,0,0.5)" 15 | } 16 | }, 17 | TDT_TILE: { 18 | type: "raster", 19 | tiles: [ 20 | "http://t0.tianditu.gov.cn/DataServer?T=img_w&x={x}&y={y}&l={z}&tk=d47b4c1d699be4ed6a68132338b741b2" 21 | ], 22 | tileSize: 256 23 | }, 24 | GOOGLE_TILE: { 25 | type: "raster", 26 | tiles: ["http://ditu.google.cn/maps/vt/lyrs=s&x={x}&y={y}&z={z}"], 27 | tileSize: 256 28 | } 29 | }; 30 | -------------------------------------------------------------------------------- /webmap/dist/css/app.823ee787.css: -------------------------------------------------------------------------------- 1 | .mapboxgl-map[data-v-693ec376]{height:100%;width:100%}.xunlianIcon[data-v-693ec376]{position:absolute;top:20px;left:400px}.xunlianIcon button[data-v-693ec376]{background-position:0;width:50px;text-align:right;padding:4px}.yuceIcon[data-v-693ec376]{position:absolute;top:20px;left:480px}.yuceIcon button[data-v-693ec376]{background-position:0;width:50px;text-align:right;padding:4px}.qingkongIcon[data-v-693ec376]{position:absolute;top:20px;left:560px}.qingkongIcon button[data-v-693ec376]{background-position:0;width:50px;text-align:right;padding:4px}.msg[data-v-693ec376]{position:absolute;top:23px;left:650px;max-width:500px;text-align:left;background-color:hsla(0,0%,100%,.7);padding:3px}.buildIcon[data-v-693ec376]{position:absolute;top:23px;left:250px}.buildIcon label[data-v-693ec376]{background-color:hsla(0,0%,100%,.7);padding:3px}.tdtIcon[data-v-693ec376]{position:absolute;top:20px;left:150px}.tdtIcon button[data-v-693ec376]{background-position:0;width:60px;text-align:right;padding:4px}.logIcon[data-v-693ec376]{position:absolute;top:20px;right:130px}.logIcon button[data-v-693ec376]{background-position:0;width:50px;text-align:right;padding:4px}.logClearIcon[data-v-693ec376]{position:absolute;top:20px;right:75px}.logClearIcon button[data-v-693ec376]{background-position:0;width:50px;text-align:right;padding:4px}#app{font-family:Avenir,Helvetica,Arial,sans-serif;-webkit-font-smoothing:antialiased;-moz-osx-font-smoothing:grayscale;text-align:center;color:#2c3e50;position:absolute;top:0;bottom:0;width:100%}.homeMap{width:100%;height:100%}body{margin:0;padding:0} -------------------------------------------------------------------------------- /webmap/dist/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geocompass/robosat_geoc/44d95183322ba68c0728de44a66e50f510bfd919/webmap/dist/favicon.ico -------------------------------------------------------------------------------- /webmap/dist/index.html: -------------------------------------------------------------------------------- 1 | 建筑物提取训练与预测工具
-------------------------------------------------------------------------------- /webmap/dist/style.json: -------------------------------------------------------------------------------- 1 | { 2 | "version": 8, 3 | "sources": { 4 | "raster-tiles": { 5 | "type": "raster", 6 | "tiles": [ 7 | "http://t0.tianditu.gov.cn/DataServer?T=img_w&x={x}&y={y}&l={z}&tk=d47b4c1d699be4ed6a68132338b741b2" 8 | ], 9 | "tileSize": 256 10 | } 11 | }, 12 | "layers": [ 13 | { 14 | "id": "tdt-img-tiles", 15 | "type": "raster", 16 | "source": "raster-tiles", 17 | "minzoom": 0, 18 | "maxzoom": 22 19 | } 20 | ] 21 | } 22 | -------------------------------------------------------------------------------- /webmap/dist/test.json: -------------------------------------------------------------------------------- 1 | { 2 | "type": "FeatureCollection", 3 | "features": [ 4 | { 5 | "type": "Feature", 6 | "properties": {}, 7 | "geometry": { 8 | "type": "Polygon", 9 | "coordinates": [ 10 | [ 11 | [116.30204200744629, 39.9391590321541], 12 | [116.30470275878906, 39.9391590321541], 13 | [116.30470275878906, 39.94080422911386], 14 | [116.30204200744629, 39.94080422911386], 15 | [116.30204200744629, 39.9391590321541] 16 | ] 17 | ] 18 | } 19 | }, 20 | { 21 | "type": "Feature", 22 | "properties": {}, 23 | "geometry": { 24 | "type": "Polygon", 25 | "coordinates": [ 26 | [ 27 | [116.29629135131836, 39.93817189499188], 28 | [116.29852294921876, 39.93817189499188], 29 | [116.29852294921876, 39.93988292368957], 30 | [116.29629135131836, 39.93988292368957], 31 | [116.29629135131836, 39.93817189499188] 32 | ] 33 | ] 34 | } 35 | } 36 | ] 37 | } 38 | -------------------------------------------------------------------------------- /webmap/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "webmap", 3 | "version": "0.1.0", 4 | "private": true, 5 | "scripts": { 6 | "serve": "vue-cli-service serve", 7 | "build": "vue-cli-service build", 8 | "lint": "vue-cli-service lint" 9 | }, 10 | "dependencies": { 11 | "@mapbox/mapbox-gl-draw": "^1.1.2", 12 | "@mapbox/mapbox-gl-language": "^0.10.1", 13 | "axios": "^0.19.0", 14 | "core-js": "^3.1.2", 15 | "mapbox-gl": "^1.4.1", 16 | "mapbox-gl-draw-rectangle-mode": "^1.0.4", 17 | "vue": "^2.6.10" 18 | }, 19 | "devDependencies": { 20 | "@vue/cli-plugin-babel": "^4.0.0", 21 | "@vue/cli-plugin-eslint": "^4.0.0", 22 | "@vue/cli-service": "^4.0.0", 23 | "babel-eslint": "^10.0.1", 24 | "eslint": "^5.16.0", 25 | "eslint-plugin-vue": "^5.0.0", 26 | "vue-template-compiler": "^2.6.10" 27 | }, 28 | "eslintConfig": { 29 | "root": true, 30 | "env": { 31 | "node": true 32 | }, 33 | "extends": [ 34 | "plugin:vue/essential", 35 | "eslint:recommended" 36 | ], 37 | "rules": { 38 | "no-console": "off", 39 | "no-debugger": "off", 40 | "no-undef": "off", 41 | "no-implicit-globals": "off" 42 | }, 43 | "parserOptions": { 44 | "parser": "babel-eslint" 45 | }, 46 | "globals": { 47 | "config": true 48 | } 49 | }, 50 | "postcss": { 51 | "plugins": { 52 | "autoprefixer": {} 53 | } 54 | }, 55 | "browserslist": [ 56 | "> 1%", 57 | "last 2 versions" 58 | ] 59 | } 60 | -------------------------------------------------------------------------------- /webmap/public/config.js: -------------------------------------------------------------------------------- 1 | const CONFIG = { 2 | SERVER: "http://127.0.0.1:5000", 3 | HOST: "http://127.0.0.1:5000", 4 | 5 | // SERVER: "http://127.0.0.1:5000", 6 | // HOST: "http://127.0.0.1:8080", 7 | BUILDING_STYLE: { 8 | /* 为地图添加layer */ 9 | id: "tempbuild" /* layer id是route */, 10 | type: "fill" /* line类型layer*/, 11 | source: "tempbuild" /* 资源引用的是上面定义的source*/, 12 | paint: { 13 | "fill-color": "rgba(194,186,62,0.4)", 14 | "fill-outline-color": "rgba(255,0,0,0.5)" 15 | } 16 | }, 17 | TDT_TILE: { 18 | type: "raster", 19 | tiles: [ 20 | "http://t0.tianditu.gov.cn/DataServer?T=img_w&x={x}&y={y}&l={z}&tk=4830425f5d789b48b967b1062deb8c71" 21 | ], 22 | tileSize: 256 23 | }, 24 | GOOGLE_TILE: { 25 | type: "raster", 26 | tiles: ["http://ditu.google.cn/maps/vt/lyrs=s&x={x}&y={y}&z={z}"], 27 | tileSize: 256 28 | } 29 | }; 30 | -------------------------------------------------------------------------------- /webmap/public/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geocompass/robosat_geoc/44d95183322ba68c0728de44a66e50f510bfd919/webmap/public/favicon.ico -------------------------------------------------------------------------------- /webmap/public/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 建筑物提取训练与预测工具 10 | 11 | 12 | 18 |
19 | 20 | 21 | 22 | -------------------------------------------------------------------------------- /webmap/public/style.json: -------------------------------------------------------------------------------- 1 | { 2 | "version": 8, 3 | "sources": { 4 | "raster-tiles": { 5 | "type": "raster", 6 | "tiles": [ 7 | "http://t0.tianditu.gov.cn/DataServer?T=img_w&x={x}&y={y}&l={z}&tk=d47b4c1d699be4ed6a68132338b741b2" 8 | ], 9 | "tileSize": 256 10 | } 11 | }, 12 | "layers": [ 13 | { 14 | "id": "tdt-img-tiles", 15 | "type": "raster", 16 | "source": "raster-tiles", 17 | "minzoom": 0, 18 | "maxzoom": 22 19 | } 20 | ] 21 | } 22 | -------------------------------------------------------------------------------- /webmap/public/test.json: -------------------------------------------------------------------------------- 1 | { 2 | "type": "FeatureCollection", 3 | "features": [ 4 | { 5 | "type": "Feature", 6 | "properties": {}, 7 | "geometry": { 8 | "type": "Polygon", 9 | "coordinates": [ 10 | [ 11 | [116.30204200744629, 39.9391590321541], 12 | [116.30470275878906, 39.9391590321541], 13 | [116.30470275878906, 39.94080422911386], 14 | [116.30204200744629, 39.94080422911386], 15 | [116.30204200744629, 39.9391590321541] 16 | ] 17 | ] 18 | } 19 | }, 20 | { 21 | "type": "Feature", 22 | "properties": {}, 23 | "geometry": { 24 | "type": "Polygon", 25 | "coordinates": [ 26 | [ 27 | [116.29629135131836, 39.93817189499188], 28 | [116.29852294921876, 39.93817189499188], 29 | [116.29852294921876, 39.93988292368957], 30 | [116.29629135131836, 39.93988292368957], 31 | [116.29629135131836, 39.93817189499188] 32 | ] 33 | ] 34 | } 35 | } 36 | ] 37 | } 38 | -------------------------------------------------------------------------------- /webmap/src/App.vue: -------------------------------------------------------------------------------- 1 | 6 | 7 | 17 | 18 | 39 | -------------------------------------------------------------------------------- /webmap/src/assets/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geocompass/robosat_geoc/44d95183322ba68c0728de44a66e50f510bfd919/webmap/src/assets/logo.png -------------------------------------------------------------------------------- /webmap/src/main.js: -------------------------------------------------------------------------------- 1 | import Vue from 'vue' 2 | import App from './App.vue' 3 | 4 | Vue.config.productionTip = false 5 | 6 | new Vue({ 7 | render: h => h(App), 8 | }).$mount('#app') 9 | -------------------------------------------------------------------------------- /xyz_proxy.py: -------------------------------------------------------------------------------- 1 | import requests 2 | from robosat_pink.geoc import config as CONFIG 3 | from flask import Flask, request, Response 4 | app = Flask(__name__) 5 | 6 | 7 | @app.route('/') 8 | def hello_world(): 9 | return 'Hello flask!' 10 | 11 | 12 | @app.route('/v1/wmts///', methods=['GET']) 13 | def wmts(x, y, z): 14 | map = request.args.get("type") 15 | if not x or not y or not z: 16 | return None 17 | if not map and map != "tdt" and map != "google": 18 | return "faild to set map type, neither tianditu nor google" 19 | url = CONFIG.URL_TDT 20 | url_google = CONFIG.URL_GOOGLE 21 | if map == 'google': 22 | url = url_google 23 | image = requests.get(url.format(x=x, y=y, z=z)) 24 | 25 | print(url.format(x=x, y=y, z=z)) 26 | return Response(image, mimetype='image/jpeg') 27 | 28 | 29 | if __name__ == '__main__': 30 | app.run(port=CONFIG.FLASK_PORT) 31 | 32 | # How to run this server backend? 33 | # >: python xyz_proxy.py & 34 | --------------------------------------------------------------------------------