├── .gitignore ├── Building-Footprint-Lite.ipynb ├── Building-Footprint.ipynb ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── Road-Network-Lite.ipynb ├── Road-Network.ipynb ├── THIRD-PARTY ├── conda-requirements.txt ├── configs ├── buildings │ ├── ELEV-only.yml │ ├── RGB+ELEV.yml │ └── RGB-only.yml └── roads │ ├── INTEN-only.yml │ ├── RGB+INTEN.yml │ └── RGB-only.yml ├── download-from-s3.sh ├── libs ├── README.md ├── apls │ ├── apls.py │ ├── apls_utils.py │ ├── create_submission.py │ ├── graphTools.py │ ├── infer_speed.py │ ├── osmnx_funcs.py │ ├── plot_road.py │ ├── skeletonize.py │ ├── utils │ │ ├── __init__.py │ │ ├── rdp.py │ │ ├── sknw.py │ │ └── sknw_int64.py │ └── wkt_to_G.py └── solaris │ ├── LICENSE.txt │ ├── README.md │ ├── __init__.py │ ├── bin │ ├── __init__.py │ ├── geotransform_footprints.py │ ├── make_graphs.py │ ├── make_masks.py │ ├── mask_to_polygons.py │ ├── solaris_run_ml.py │ └── spacenet_eval.py │ ├── data │ ├── __init__.py │ └── coco.py │ ├── eval │ ├── __init__.py │ ├── base.py │ ├── challenges.py │ ├── iou.py │ └── pixel.py │ ├── nets │ ├── __init__.py │ ├── _keras_losses.py │ ├── _torch_losses.py │ ├── callbacks.py │ ├── configs │ │ ├── config_skeleton.yml │ │ ├── selimsef_densenet121unet_spacenet4.yml │ │ ├── selimsef_densenet161unet_spacenet4.yml │ │ ├── selimsef_resnet34unet_spacenet4.yml │ │ ├── selimsef_scse50unet_spacenet4.yml │ │ └── xdxd_spacenet4.yml │ ├── datagen.py │ ├── infer.py │ ├── losses.py │ ├── metrics.py │ ├── model_io.py │ ├── optimizers.py │ ├── torch_callbacks.py │ ├── train.py │ ├── transform.py │ └── zoo │ │ ├── __init__.py │ │ ├── selim_sef_sn4.py │ │ └── xdxd_sn4.py │ ├── raster │ ├── __init__.py │ └── image.py │ ├── tile │ ├── __init__.py │ ├── raster_tile.py │ └── vector_tile.py │ ├── utils │ ├── __init__.py │ ├── cli.py │ ├── config.py │ ├── core.py │ ├── data.py │ ├── geo.py │ ├── io.py │ ├── log.py │ ├── raster.py │ ├── tdigest.py │ └── tile.py │ └── vector │ ├── __init__.py │ ├── graph.py │ ├── mask.py │ └── polygon.py ├── networks ├── resnet_unet.py └── vgg16_unet.py ├── pip-requirements.txt └── setup-env.sh /.gitignore: -------------------------------------------------------------------------------- 1 | # Custom items for this repo 2 | data 3 | results 4 | models 5 | libs/solaris/nets/weights/* 6 | 7 | # Byte-compiled / optimized / DLL files 8 | __pycache__/ 9 | *.py[cod] 10 | *$py.class 11 | 12 | # C extensions 13 | *.so 14 | 15 | # Distribution / packaging 16 | .Python 17 | build/ 18 | develop-eggs/ 19 | dist/ 20 | downloads/ 21 | eggs/ 22 | .eggs/ 23 | lib/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | wheels/ 29 | pip-wheel-metadata/ 30 | share/python-wheels/ 31 | *.egg-info/ 32 | .installed.cfg 33 | *.egg 34 | MANIFEST 35 | 36 | # PyInstaller 37 | # Usually these files are written by a python script from a template 38 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 39 | *.manifest 40 | *.spec 41 | 42 | # Installer logs 43 | pip-log.txt 44 | pip-delete-this-directory.txt 45 | 46 | # Unit test / coverage reports 47 | htmlcov/ 48 | .tox/ 49 | .nox/ 50 | .coverage 51 | .coverage.* 52 | .cache 53 | nosetests.xml 54 | coverage.xml 55 | *.cover 56 | *.py,cover 57 | .hypothesis/ 58 | .pytest_cache/ 59 | 60 | # Translations 61 | *.mo 62 | *.pot 63 | 64 | # Django stuff: 65 | *.log 66 | local_settings.py 67 | db.sqlite3 68 | db.sqlite3-journal 69 | 70 | # Flask stuff: 71 | instance/ 72 | .webassets-cache 73 | 74 | # Scrapy stuff: 75 | .scrapy 76 | 77 | # Sphinx documentation 78 | docs/_build/ 79 | 80 | # PyBuilder 81 | target/ 82 | 83 | # Jupyter Notebook 84 | .ipynb_checkpoints 85 | 86 | # IPython 87 | profile_default/ 88 | ipython_config.py 89 | 90 | # pyenv 91 | .python-version 92 | 93 | # pipenv 94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 97 | # install all needed dependencies. 98 | #Pipfile.lock 99 | 100 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 101 | __pypackages__/ 102 | 103 | # Celery stuff 104 | celerybeat-schedule 105 | celerybeat.pid 106 | 107 | # SageMath parsed files 108 | *.sage.py 109 | 110 | # Environments 111 | .env 112 | .venv 113 | env/ 114 | venv/ 115 | ENV/ 116 | env.bak/ 117 | venv.bak/ 118 | 119 | # Spyder project settings 120 | .spyderproject 121 | .spyproject 122 | 123 | # Rope project settings 124 | .ropeproject 125 | 126 | # mkdocs documentation 127 | /site 128 | 129 | # mypy 130 | .mypy_cache/ 131 | .dmypy.json 132 | dmypy.json 133 | 134 | # Pyre type checker 135 | .pyre/ 136 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | ## Code of Conduct 2 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 3 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 4 | opensource-codeofconduct@amazon.com with any additional questions or comments. 5 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing Guidelines 2 | 3 | Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional 4 | documentation, we greatly value feedback and contributions from our community. 5 | 6 | Please read through this document before submitting any issues or pull requests to ensure we have all the necessary 7 | information to effectively respond to your bug report or contribution. 8 | 9 | 10 | ## Reporting Bugs/Feature Requests 11 | 12 | We welcome you to use the GitHub issue tracker to report bugs or suggest features. 13 | 14 | When filing an issue, please check existing open, or recently closed, issues to make sure somebody else hasn't already 15 | reported the issue. Please try to include as much information as you can. Details like these are incredibly useful: 16 | 17 | * A reproducible test case or series of steps 18 | * The version of our code being used 19 | * Any modifications you've made relevant to the bug 20 | * Anything unusual about your environment or deployment 21 | 22 | 23 | ## Contributing via Pull Requests 24 | Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that: 25 | 26 | 1. You are working against the latest source on the *master* branch. 27 | 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already. 28 | 3. You open an issue to discuss any significant work - we would hate for your time to be wasted. 29 | 30 | To send us a pull request, please: 31 | 32 | 1. Fork the repository. 33 | 2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change. 34 | 3. Ensure local tests pass. 35 | 4. Commit to your fork using clear commit messages. 36 | 5. Send us a pull request, answering any default questions in the pull request interface. 37 | 6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation. 38 | 39 | GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and 40 | [creating a pull request](https://help.github.com/articles/creating-a-pull-request/). 41 | 42 | 43 | ## Finding contributions to work on 44 | Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any 'help wanted' issues is a great place to start. 45 | 46 | 47 | ## Code of Conduct 48 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 49 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 50 | opensource-codeofconduct@amazon.com with any additional questions or comments. 51 | 52 | 53 | ## Security issue notifications 54 | If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue. 55 | 56 | 57 | ## Licensing 58 | 59 | See the [LICENSE](LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution. 60 | 61 | We may ask you to sign a [Contributor License Agreement (CLA)](http://en.wikipedia.org/wiki/Contributor_License_Agreement) for larger changes. 62 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | this software and associated documentation files (the "Software"), to deal in 5 | the Software without restriction, including without limitation the rights to 6 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 7 | the Software, and to permit persons to whom the Software is furnished to do so. 8 | 9 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 10 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 11 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 12 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 13 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 14 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 15 | 16 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deep Learning on AWS Open Data Registry: Automatic Building and Road Extraction from Satellite and LiDAR 2 | 3 | ### For [SpatialAPI 20](https://sites.google.com/ucr.edu/spatialapi20) participants: we recommend registering an AWS account to allow immersive tutorial with hands-on experience. 4 | - [Create a regular AWS account](https://aws.amazon.com/premiumsupport/knowledge-center/create-and-activate-aws-account/) 5 | - [Create an educational AWS account](https://aws.amazon.com/education/awseducate/apply/) 6 | ### All tutorial contents can be reproduced within free tier services at no cost. If you have difficulty registering an AWS account, we offer a limited amount of temporary event account on a first-come, first-served basis. 7 | 8 | This is the repository for OpenData tutorial content by MLSL. 9 | 10 | ## Setup 11 | 12 | ### Create a SageMaker instance 13 | The tutorial can be run with any SageMaker instance type, but we highly recommend instance type with GPU support. For example, `ml.p?.?xlarge` series. The EBS volume size should be more than 60GB in order to store all necessary data. 14 | 15 | Network training/inference is a memory-intensive process. If you run into out of GPU memory or out of RAM error, consider decrease the number of `batch_size` in the `yml` config files in the `configs` folder. 16 | 17 | ### Clone this repository 18 | Once the SageMaker instance is successfully launched, open a terminal and follow the commands below: 19 | ```shell 20 | $ cd ~/SageMaker/ 21 | $ git clone https://github.com/aws-samples/aws-open-data-satellite-lidar-tutorial.git 22 | $ cd aws-open-data-satellite-lidar-tutorial 23 | ``` 24 | This will download the repository and take you to the repository directory. 25 | 26 | ### Create Conda environment 27 | Next, set up a Conda environment by running `setup-env.sh` as shown below. You can change the environment name from `tutorial_env` to any other names. 28 | ```shell 29 | $ ./setup-env.sh tutorial_env 30 | ``` 31 | This may take 10--15 minutes to complete. 32 | 33 | Then check to make sure you have a new Jupyter kernel called `conda_tutorial_env`, or `conda_[name]` if you change the environment name to `[name]`. You may need to wait for a couple of minutes and refresh the Jupyter page. 34 | 35 | ### Download from S3 buckets 36 | Next, download necessary files ([data browser](https://aws-satellite-lidar-tutorial.s3.amazonaws.com/index.html)) from S3 bucket prepared for this tutorial by running `download-from-s3.sh`: 37 | ```shell 38 | $ ./download-from-s3.sh 39 | ``` 40 | This may take 5 minutes to complete, and requires at least 23GB of EBS disk size. 41 | 42 | ## Launch notebook 43 | Finally, you can launch the notebooks `Building-Footprint.ipynb` or `Road-Network.ipynb` and learn to reproduce the tutorial. Note that if the notebook shows "No Kernel", or prompts to "Select Kernel", select the Jupyter kernel created in the previous step. 44 | 45 | ## Security 46 | 47 | See [CONTRIBUTING](CONTRIBUTING.md#security-issue-notifications) for more information. 48 | 49 | ## License 50 | 51 | This library is licensed under the MIT-0 License. See the [LICENSE](LICENSE) file. 52 | The [NOTICE](THIRD-PARTY) includes third-party licenses used in this repository. 53 | 54 | -------------------------------------------------------------------------------- /conda-requirements.txt: -------------------------------------------------------------------------------- 1 | pip>=19.0.3 2 | affine>=2.3.0 3 | fiona>=1.7.13 4 | gdal>=3.0.2 5 | matplotlib>=3.1.2 6 | networkx>=2.4 7 | numpy>=1.17.3 8 | pandas>=0.25.3 9 | pyyaml>=5.4 10 | rasterio>=1.0.23 11 | requests==2.22.0 12 | rtree>=0.9.3 13 | scikit-image>=0.16.2 14 | scipy>=1.3.2 15 | tqdm>=4.40.0 16 | urllib3>=1.25.7 17 | -------------------------------------------------------------------------------- /configs/buildings/ELEV-only.yml: -------------------------------------------------------------------------------- 1 | model_name: xdxd_spacenet4 2 | 3 | model_path: 4 | train: true 5 | infer: true 6 | 7 | pretrained: false 8 | nn_framework: torch 9 | batch_size: 20 10 | 11 | data_specs: 12 | width: 512 13 | height: 512 14 | image_type: 16bit 15 | channels: 1 16 | label_type: mask 17 | is_categorical: false 18 | mask_channels: 1 19 | val_holdout_frac: 0.05 20 | dtype: 21 | 22 | training_data_csv: 'csv_file_path' 23 | validation_data_csv: 'auto_split_from_training_dataset' 24 | inference_data_csv: 'csv_file_path' 25 | 26 | training_augmentation: 27 | augmentations: 28 | DropChannel: 29 | axis: 2 30 | idx: 31 | - 0 32 | - 1 33 | - 2 34 | p: 1.0 35 | HorizontalFlip: 36 | p: 0.5 37 | RandomRotate90: 38 | p: 0.5 39 | RandomCrop: 40 | height: 512 41 | width: 512 42 | p: 1.0 43 | Normalize: 44 | mean: 45 | - 0.0 46 | std: 47 | - 20.0 48 | max_pixel_value: 255 49 | p: 1.0 50 | p: 1.0 51 | shuffle: true 52 | validation_augmentation: 53 | augmentations: 54 | DropChannel: 55 | axis: 2 56 | idx: 57 | - 0 58 | - 1 59 | - 2 60 | p: 1.0 61 | CenterCrop: 62 | height: 512 63 | width: 512 64 | p: 1.0 65 | Normalize: 66 | mean: 67 | - 0.0 68 | std: 69 | - 20.0 70 | max_pixel_value: 255 71 | p: 1.0 72 | p: 1.0 73 | inference_augmentation: 74 | augmentations: 75 | DropChannel: 76 | axis: 2 77 | idx: 78 | - 0 79 | - 1 80 | - 2 81 | p: 1.0 82 | Normalize: 83 | mean: 84 | - 0.0 85 | std: 86 | - 20.0 87 | max_pixel_value: 255 88 | p: 1.0 89 | p: 1.0 90 | 91 | training: 92 | epochs: 100 93 | optimizer: AdamW 94 | lr: 1e-4 95 | opt_args: 96 | loss: 97 | bcewithlogits: 98 | jaccard: 99 | loss_weights: 100 | bcewithlogits: 10 101 | jaccard: 2.5 102 | metrics: 103 | training: 104 | validation: 105 | callbacks: 106 | model_checkpoint: 107 | filepath: './models/buildings/ELEV-only/checkpoint.pth' 108 | monitor: periodic 109 | period: 10 110 | model_dest_path: './models/buildings/ELEV-only.pth' 111 | verbose: true 112 | 113 | inference: 114 | output_dir: './results/buildings/ELEV-only/pred_mask' 115 | window_step_size_x: 116 | window_step_size_y: 117 | 118 | -------------------------------------------------------------------------------- /configs/buildings/RGB+ELEV.yml: -------------------------------------------------------------------------------- 1 | model_name: xdxd_spacenet4 2 | 3 | model_path: 4 | train: true 5 | infer: true 6 | 7 | pretrained: false 8 | nn_framework: torch 9 | batch_size: 20 10 | 11 | data_specs: 12 | width: 512 13 | height: 512 14 | image_type: 16bit 15 | channels: 4 16 | label_type: mask 17 | is_categorical: false 18 | mask_channels: 1 19 | val_holdout_frac: 0.05 20 | dtype: 21 | 22 | training_data_csv: 'csv_file_path' 23 | validation_data_csv: 'auto_split_from_training_dataset' 24 | inference_data_csv: 'csv_file_path' 25 | 26 | training_augmentation: 27 | augmentations: 28 | HorizontalFlip: 29 | p: 0.5 30 | RandomRotate90: 31 | p: 0.5 32 | RandomCrop: 33 | height: 512 34 | width: 512 35 | p: 1.0 36 | Normalize: 37 | mean: 38 | - 0.0 39 | - 0.0 40 | - 0.0 41 | - 0.0 42 | std: 43 | - 1.0 44 | - 1.0 45 | - 1.0 46 | - 20.0 47 | max_pixel_value: 255 48 | p: 1.0 49 | p: 1.0 50 | shuffle: true 51 | validation_augmentation: 52 | augmentations: 53 | CenterCrop: 54 | height: 512 55 | width: 512 56 | p: 1.0 57 | Normalize: 58 | mean: 59 | - 0.0 60 | - 0.0 61 | - 0.0 62 | - 0.0 63 | std: 64 | - 1.0 65 | - 1.0 66 | - 1.0 67 | - 20.0 68 | max_pixel_value: 255 69 | p: 1.0 70 | p: 1.0 71 | inference_augmentation: 72 | augmentations: 73 | Normalize: 74 | mean: 75 | - 0.0 76 | - 0.0 77 | - 0.0 78 | - 0.0 79 | std: 80 | - 1.0 81 | - 1.0 82 | - 1.0 83 | - 20.0 84 | max_pixel_value: 255 85 | p: 1.0 86 | p: 1.0 87 | 88 | training: 89 | epochs: 100 90 | optimizer: AdamW 91 | lr: 1e-4 92 | opt_args: 93 | loss: 94 | bcewithlogits: 95 | jaccard: 96 | loss_weights: 97 | bcewithlogits: 10 98 | jaccard: 2.5 99 | metrics: 100 | training: 101 | validation: 102 | callbacks: 103 | model_checkpoint: 104 | filepath: './models/buildings/RGB+ELEV/checkpoint.pth' 105 | monitor: periodic 106 | period: 10 107 | model_dest_path: './models/buildings/RGB+ELEV.pth' 108 | verbose: true 109 | 110 | inference: 111 | output_dir: './results/buildings/RGB+ELEV/pred_mask' 112 | window_step_size_x: 113 | window_step_size_y: 114 | 115 | -------------------------------------------------------------------------------- /configs/buildings/RGB-only.yml: -------------------------------------------------------------------------------- 1 | model_name: xdxd_spacenet4 2 | 3 | model_path: 4 | train: true 5 | infer: true 6 | 7 | pretrained: false 8 | nn_framework: torch 9 | batch_size: 20 10 | 11 | data_specs: 12 | width: 512 13 | height: 512 14 | image_type: 16bit 15 | channels: 3 16 | label_type: mask 17 | is_categorical: false 18 | mask_channels: 1 19 | val_holdout_frac: 0.05 20 | dtype: 21 | 22 | training_data_csv: 'csv_file_path' 23 | validation_data_csv: 'auto_split_from_training_dataset' 24 | inference_data_csv: 'csv_file_path' 25 | 26 | training_augmentation: 27 | augmentations: 28 | DropChannel: 29 | axis: 2 30 | idx: 3 31 | p: 1.0 32 | HorizontalFlip: 33 | p: 0.5 34 | RandomRotate90: 35 | p: 0.5 36 | RandomCrop: 37 | height: 512 38 | width: 512 39 | p: 1.0 40 | Normalize: 41 | mean: 42 | - 0.0 43 | - 0.0 44 | - 0.0 45 | std: 46 | - 1.0 47 | - 1.0 48 | - 1.0 49 | max_pixel_value: 255 50 | p: 1.0 51 | p: 1.0 52 | shuffle: true 53 | validation_augmentation: 54 | augmentations: 55 | DropChannel: 56 | axis: 2 57 | idx: 3 58 | p: 1.0 59 | CenterCrop: 60 | height: 512 61 | width: 512 62 | p: 1.0 63 | Normalize: 64 | mean: 65 | - 0.0 66 | - 0.0 67 | - 0.0 68 | std: 69 | - 1.0 70 | - 1.0 71 | - 1.0 72 | max_pixel_value: 255 73 | p: 1.0 74 | p: 1.0 75 | inference_augmentation: 76 | augmentations: 77 | DropChannel: 78 | axis: 2 79 | idx: 3 80 | p: 1.0 81 | Normalize: 82 | mean: 83 | - 0.0 84 | - 0.0 85 | - 0.0 86 | std: 87 | - 1.0 88 | - 1.0 89 | - 1.0 90 | max_pixel_value: 255 91 | p: 1.0 92 | p: 1.0 93 | 94 | training: 95 | epochs: 100 96 | optimizer: AdamW 97 | lr: 1e-4 98 | opt_args: 99 | loss: 100 | bcewithlogits: 101 | jaccard: 102 | loss_weights: 103 | bcewithlogits: 10 104 | jaccard: 2.5 105 | metrics: 106 | training: 107 | validation: 108 | callbacks: 109 | model_checkpoint: 110 | filepath: './models/buildings/RGB-only/checkpoint.pth' 111 | monitor: periodic 112 | period: 10 113 | model_dest_path: './models/buildings/RGB-only.pth' 114 | verbose: true 115 | 116 | inference: 117 | output_dir: './results/buildings/RGB-only/pred_mask' 118 | window_step_size_x: 119 | window_step_size_y: 120 | 121 | -------------------------------------------------------------------------------- /configs/roads/INTEN-only.yml: -------------------------------------------------------------------------------- 1 | model_name: cresi 2 | 3 | model_path: 4 | train: true 5 | infer: true 6 | 7 | pretrained: false 8 | nn_framework: torch 9 | batch_size: 12 10 | 11 | data_specs: 12 | width: 1280 13 | height: 1280 14 | image_type: 16bit 15 | channels: 1 16 | label_type: mask 17 | is_categorical: false 18 | mask_channels: 8 19 | val_holdout_frac: 0.05 20 | dtype: 21 | 22 | training_data_csv: 'csv_file_path' 23 | validation_data_csv: 'auto_split_from_training_dataset' 24 | inference_data_csv: 'csv_file_path' 25 | 26 | training_augmentation: 27 | augmentations: 28 | DropChannel: 29 | axis: 2 30 | idx: 31 | - 0 32 | - 1 33 | - 2 34 | p: 1.0 35 | HorizontalFlip: 36 | p: 0.5 37 | RandomRotate90: 38 | p: 0.5 39 | RandomCrop: 40 | height: 1280 41 | width: 1280 42 | p: 1.0 43 | Normalize: 44 | mean: 45 | - 0.0 46 | std: 47 | - 20.0 48 | max_pixel_value: 255 49 | p: 1.0 50 | p: 1.0 51 | shuffle: true 52 | validation_augmentation: 53 | augmentations: 54 | DropChannel: 55 | axis: 2 56 | idx: 57 | - 0 58 | - 1 59 | - 2 60 | p: 1.0 61 | RandomCrop: 62 | height: 1280 63 | width: 1280 64 | p: 1.0 65 | Normalize: 66 | mean: 67 | - 0.0 68 | std: 69 | - 20.0 70 | max_pixel_value: 255 71 | p: 1.0 72 | p: 1.0 73 | inference_augmentation: 74 | augmentations: 75 | DropChannel: 76 | axis: 2 77 | idx: 78 | - 0 79 | - 1 80 | - 2 81 | p: 1.0 82 | Normalize: 83 | mean: 84 | - 0.0 85 | std: 86 | - 20.0 87 | max_pixel_value: 255 88 | p: 1.0 89 | p: 1.0 90 | 91 | training: 92 | epochs: 100 93 | optimizer: AdamW 94 | lr: 1e-4 95 | opt_args: 96 | loss: 97 | bcewithlogits: 98 | jaccard: 99 | loss_weights: 100 | bcewithlogits: 10 101 | jaccard: 2.5 102 | metrics: 103 | training: 104 | validation: 105 | callbacks: 106 | model_checkpoint: 107 | filepath: './models/roads/INTEN-only/checkpoint.pth' 108 | monitor: periodic 109 | period: 10 110 | model_dest_path: './models/roads/INTEN-only.pth' 111 | verbose: true 112 | 113 | inference: 114 | output_dir: './results/roads/INTEN-only/pred_mask' 115 | window_step_size_x: 116 | window_step_size_y: 117 | -------------------------------------------------------------------------------- /configs/roads/RGB+INTEN.yml: -------------------------------------------------------------------------------- 1 | model_name: cresi 2 | 3 | model_path: 4 | train: true 5 | infer: true 6 | 7 | pretrained: false 8 | nn_framework: torch 9 | batch_size: 12 10 | 11 | data_specs: 12 | width: 1280 13 | height: 1280 14 | image_type: 16bit 15 | channels: 4 16 | label_type: mask 17 | is_categorical: false 18 | mask_channels: 8 19 | val_holdout_frac: 0.05 20 | dtype: 21 | 22 | training_data_csv: 'csv_file_path' 23 | validation_data_csv: 'auto_split_from_training_dataset' 24 | inference_data_csv: 'csv_file_path' 25 | 26 | training_augmentation: 27 | augmentations: 28 | HorizontalFlip: 29 | p: 0.5 30 | RandomRotate90: 31 | p: 0.5 32 | RandomCrop: 33 | height: 1280 34 | width: 1280 35 | p: 1.0 36 | Normalize: 37 | mean: 38 | - 0.0 39 | - 0.0 40 | - 0.0 41 | - 0.0 42 | std: 43 | - 1.0 44 | - 1.0 45 | - 1.0 46 | - 20.0 47 | max_pixel_value: 255 48 | p: 1.0 49 | p: 1.0 50 | shuffle: true 51 | validation_augmentation: 52 | augmentations: 53 | RandomCrop: 54 | height: 1280 55 | width: 1280 56 | p: 1.0 57 | Normalize: 58 | mean: 59 | - 0.0 60 | - 0.0 61 | - 0.0 62 | - 0.0 63 | std: 64 | - 1.0 65 | - 1.0 66 | - 1.0 67 | - 20.0 68 | max_pixel_value: 255 69 | p: 1.0 70 | p: 1.0 71 | inference_augmentation: 72 | augmentations: 73 | Normalize: 74 | mean: 75 | - 0.0 76 | - 0.0 77 | - 0.0 78 | - 0.0 79 | std: 80 | - 1.0 81 | - 1.0 82 | - 1.0 83 | - 20.0 84 | max_pixel_value: 255 85 | p: 1.0 86 | p: 1.0 87 | 88 | training: 89 | epochs: 100 90 | optimizer: AdamW 91 | lr: 1e-4 92 | opt_args: 93 | loss: 94 | bcewithlogits: 95 | jaccard: 96 | loss_weights: 97 | bcewithlogits: 10 98 | jaccard: 2.5 99 | metrics: 100 | training: 101 | validation: 102 | callbacks: 103 | model_checkpoint: 104 | filepath: './models/roads/RGB+INTEN/checkpoint.pth' 105 | monitor: periodic 106 | period: 10 107 | model_dest_path: './models/roads/RGB+INTEN.pth' 108 | verbose: true 109 | 110 | inference: 111 | output_dir: './results/roads/RGB+INTEN/pred_mask' 112 | window_step_size_x: 113 | window_step_size_y: 114 | -------------------------------------------------------------------------------- /configs/roads/RGB-only.yml: -------------------------------------------------------------------------------- 1 | model_name: cresi 2 | 3 | model_path: 4 | train: true 5 | infer: true 6 | 7 | pretrained: false 8 | nn_framework: torch 9 | batch_size: 12 10 | 11 | data_specs: 12 | width: 1280 13 | height: 1280 14 | image_type: 16bit 15 | channels: 3 16 | label_type: mask 17 | is_categorical: false 18 | mask_channels: 8 19 | val_holdout_frac: 0.05 20 | dtype: 21 | 22 | training_data_csv: 'csv_file_path' 23 | validation_data_csv: 'auto_split_from_training_dataset' 24 | inference_data_csv: 'csv_file_path' 25 | 26 | training_augmentation: 27 | augmentations: 28 | DropChannel: 29 | axis: 2 30 | idx: 3 31 | p: 1.0 32 | HorizontalFlip: 33 | p: 0.5 34 | RandomRotate90: 35 | p: 0.5 36 | RandomCrop: 37 | height: 1280 38 | width: 1280 39 | p: 1.0 40 | Normalize: 41 | mean: 42 | - 0.0 43 | - 0.0 44 | - 0.0 45 | std: 46 | - 1.0 47 | - 1.0 48 | - 1.0 49 | max_pixel_value: 255 50 | p: 1.0 51 | p: 1.0 52 | shuffle: true 53 | validation_augmentation: 54 | augmentations: 55 | DropChannel: 56 | axis: 2 57 | idx: 3 58 | p: 1.0 59 | RandomCrop: 60 | height: 1280 61 | width: 1280 62 | p: 1.0 63 | Normalize: 64 | mean: 65 | - 0.0 66 | - 0.0 67 | - 0.0 68 | std: 69 | - 1.0 70 | - 1.0 71 | - 1.0 72 | max_pixel_value: 255 73 | p: 1.0 74 | p: 1.0 75 | inference_augmentation: 76 | augmentations: 77 | DropChannel: 78 | axis: 2 79 | idx: 3 80 | p: 1.0 81 | Normalize: 82 | mean: 83 | - 0.0 84 | - 0.0 85 | - 0.0 86 | std: 87 | - 1.0 88 | - 1.0 89 | - 1.0 90 | max_pixel_value: 255 91 | p: 1.0 92 | p: 1.0 93 | 94 | training: 95 | epochs: 100 96 | optimizer: AdamW 97 | lr: 1e-4 98 | opt_args: 99 | loss: 100 | bcewithlogits: 101 | jaccard: 102 | loss_weights: 103 | bcewithlogits: 10 104 | jaccard: 2.5 105 | metrics: 106 | training: 107 | validation: 108 | callbacks: 109 | model_checkpoint: 110 | filepath: './models/roads/RGB-only/checkpoint.pth' 111 | monitor: periodic 112 | period: 10 113 | model_dest_path: './models/roads/RGB-only.pth' 114 | verbose: true 115 | 116 | inference: 117 | output_dir: './results/roads/RGB-only/pred_mask' 118 | window_step_size_x: 119 | window_step_size_y: 120 | -------------------------------------------------------------------------------- /download-from-s3.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 4 | # SPDX-License-Identifier: MIT-0 5 | 6 | # This script downloads data used in this tutorial from S3 buckets. 7 | 8 | # Exit when error occurs 9 | set -e 10 | 11 | echo "===== Downloading RGB+LiDAR merged data (~22GB) ... =====" 12 | aws s3 cp s3://aws-satellite-lidar-tutorial/data/ ./data/ --recursive --no-sign-request 13 | 14 | echo "===== Downloading pretrained model weights (617MB) ... =====" 15 | aws s3 cp s3://aws-satellite-lidar-tutorial/models/ ./models/ --recursive --no-sign-request 16 | 17 | echo "===== Downloading completes. =====" 18 | -------------------------------------------------------------------------------- /libs/README.md: -------------------------------------------------------------------------------- 1 | # External Python Libraries 2 | 3 | This directory includes external Python libraries after some customization. Other dependencies without customization should be set up by `setup-env.sh` via `conda` and `pip`. 4 | 5 | - `solaris`: An open source ML pipeline for overhead imagery. [Link to GitHub repo](https://github.com/CosmiQ/solaris) 6 | - `apls`: APLS (Average Path Length Similarity), a toolkit to extract road graph network from prediction masks and compute the metric scores against ground truth. [Link to GitHub repo](https://github.com/CosmiQ/apls) 7 | 8 | See [NOTICE](../THIRD-PARTY) for third party licenses. 9 | -------------------------------------------------------------------------------- /libs/apls/create_submission.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified on Sun Jul 27 2020 by Yunzhi Shi, DS @ AWS MLSL 3 | 4 | Cleaned up for the tutorial. 5 | 6 | Original author: avanetten 7 | """ 8 | 9 | import os, time 10 | import argparse 11 | 12 | import pandas as pd 13 | import networkx as nx 14 | 15 | 16 | ################################################################################ 17 | def pkl_dir_to_wkt(pkl_dir, output_csv_path='', 18 | weight_keys=['length', 'travel_time_s'], 19 | verbose=False): 20 | """ Create submission wkt from directory full of graph pickles """ 21 | 22 | wkt_list = [] 23 | 24 | pkl_list = sorted([z for z in os.listdir(pkl_dir) if z.endswith('.gpickle')]) 25 | for i, pkl_name in enumerate(pkl_list): 26 | G = nx.read_gpickle(os.path.join(pkl_dir, pkl_name)) 27 | 28 | # ensure an undirected graph 29 | if verbose: 30 | print(i, "/", len(pkl_list), "num G.nodes:", len(G.nodes())) 31 | 32 | name_root = pkl_name.split('.')[0] 33 | if verbose: 34 | print("name_root:", name_root) 35 | 36 | # if empty, still add to submission 37 | if len(G.nodes()) == 0: 38 | wkt_item_root = [name_root, 'LINESTRING EMPTY'] 39 | if len(weight_keys) > 0: 40 | weights = [0 for w in weight_keys] 41 | wkt_list.append(wkt_item_root + weights) 42 | else: 43 | wkt_list.append(wkt_item_root) 44 | 45 | # extract geometry pix wkt, save to list 46 | seen_edges = set([]) 47 | for i, (u, v, attr_dict) in enumerate(G.edges(data=True)): 48 | # make sure we haven't already seen this edge 49 | if (u, v) in seen_edges or (v, u) in seen_edges: 50 | if verbose: 51 | print(u, v, "already catalogued!") 52 | continue 53 | else: 54 | seen_edges.add((u, v)) 55 | seen_edges.add((v, u)) 56 | geom_pix = attr_dict['geometry_pix'] 57 | if type(geom_pix) != str: 58 | geom_pix_wkt = attr_dict['geometry_pix'].wkt 59 | else: 60 | geom_pix_wkt = geom_pix 61 | 62 | # check edge lnegth 63 | if attr_dict['length'] > 5000: 64 | print("Edge too long!, u,v,data:", u,v,attr_dict) 65 | return 66 | 67 | if verbose: 68 | print(i, "/", len(G.edges()), "u, v:", u, v) 69 | print(" attr_dict:", attr_dict) 70 | print(" geom_pix_wkt:", geom_pix_wkt) 71 | 72 | wkt_item_root = [name_root, geom_pix_wkt] 73 | if len(weight_keys) > 0: 74 | weights = [attr_dict[w] for w in weight_keys] 75 | if verbose: 76 | print(" weights:", weights) 77 | wkt_list.append(wkt_item_root + weights) 78 | else: 79 | wkt_list.append(wkt_item_root) 80 | 81 | if verbose: 82 | print("wkt_list:", wkt_list) 83 | 84 | # create dataframe 85 | if len(weight_keys) > 0: 86 | cols = ['ImageId', 'WKT_Pix'] + weight_keys 87 | else: 88 | cols = ['ImageId', 'WKT_Pix'] 89 | 90 | # use 'length_m' and 'travel_time_s' instead? 91 | cols_new = [] 92 | for z in cols: 93 | if z == 'length': 94 | cols_new.append('length_m') 95 | elif z == 'travel_time': 96 | cols_new.append('travel_time_s') 97 | else: 98 | cols_new.append(z) 99 | cols = cols_new 100 | 101 | df = pd.DataFrame(wkt_list, columns=cols) 102 | if len(output_csv_path) > 0: 103 | df.to_csv(output_csv_path, index=False) 104 | 105 | return df 106 | 107 | 108 | ################################################################################ 109 | def main(): 110 | wkt_csv_file = 'wkt_speed.csv' 111 | weight_keys = ['length', 'travel_time_s'] 112 | verbose = False 113 | 114 | parser = argparse.ArgumentParser() 115 | parser.add_argument('--graph_speed_dir', default=None, 116 | help='dir contains speed graph gpickle files') 117 | parser.add_argument('--results_dir', required=True, 118 | help='dir to write output file into') 119 | args = parser.parse_args() 120 | assert os.path.exists(args.results_dir) 121 | if args.graph_speed_dir is None: 122 | args.graph_speed_dir = os.path.join(args.results_dir, 'graph_speed_gpickle') 123 | output_csv_path = os.path.join(args.results_dir, wkt_csv_file) 124 | 125 | t0 = time.time() 126 | df = pkl_dir_to_wkt(args.graph_speed_dir, 127 | output_csv_path=output_csv_path, 128 | weight_keys=weight_keys, 129 | verbose=verbose) 130 | 131 | print("WKT-w/-speed csv file: ", output_csv_path) 132 | t1 = time.time() 133 | print("Time to create speed WKT: {:6.2f} s".format(t1-t0)) 134 | 135 | if __name__ == "__main__": 136 | main() 137 | -------------------------------------------------------------------------------- /libs/apls/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/aws-open-data-satellite-lidar-tutorial/928196f105df202e04d5bcdc64ad449cf65b183d/libs/apls/utils/__init__.py -------------------------------------------------------------------------------- /libs/apls/utils/rdp.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | """ 5 | Code borrowd from: 6 | https://github.com/mitroadmaps/roadtracer/blob/master/lib/discoverlib/rdp.py 7 | 8 | The Ramer-Douglas-Peucker algorithm roughly ported from the pseudo-code provided 9 | by http://en.wikipedia.org/wiki/Ramer-Douglas-Peucker_algorithm 10 | """ 11 | 12 | from math import sqrt 13 | 14 | def distance(a, b): 15 | return sqrt((a[0] - b[0]) ** 2 + (a[1] - b[1]) ** 2) 16 | 17 | def point_line_distance(point, start, end): 18 | if (start == end): 19 | return distance(point, start) 20 | else: 21 | n = abs( 22 | (end[0] - start[0]) * (start[1] - point[1]) - (start[0] - point[0]) * (end[1] - start[1]) 23 | ) 24 | d = sqrt( 25 | (end[0] - start[0]) ** 2 + (end[1] - start[1]) ** 2 26 | ) 27 | return n / d 28 | 29 | def rdp(points, epsilon=1): 30 | """ 31 | Reduces a series of points to a simplified version that loses detail, but 32 | maintains the general shape of the series. 33 | """ 34 | dmax = 0.0 35 | index = 0 36 | for i in range(1, len(points) - 1): 37 | d = point_line_distance(points[i], points[0], points[-1]) 38 | if d > dmax: 39 | index = i 40 | dmax = d 41 | if dmax >= epsilon: 42 | results = rdp(points[:index+1], epsilon)[:-1] + rdp(points[index:], epsilon) 43 | else: 44 | results = [points[0], points[-1]] 45 | return results 46 | 47 | 48 | #def simplify_graph(graph, max_distance=1): 49 | # """ 50 | # https://github.com/anilbatra2185/road_connectivity/blob/master/data_utils/graph_utils.py 51 | # @params graph: MultiGraph object of networkx 52 | # @return: simplified graph after applying RDP algorithm. 53 | # """ 54 | # all_segments = [] 55 | # # Iterate over Graph Edges 56 | # for (s, e) in graph.edges(): 57 | # for _, val in graph[s][e].items(): 58 | # # get all pixel points i.e. (x,y) between the edge 59 | # ps = val["pts"] 60 | # # create a full segment 61 | # full_segments = np.row_stack([graph.node[s]["o"], ps, graph.node[e]["o"]]) 62 | # # simply the graph. 63 | # segments = rdp.rdp(full_segments.tolist(), max_distance) 64 | # all_segments.append(segments) 65 | # 66 | # return all_segments -------------------------------------------------------------------------------- /libs/apls/utils/sknw.py: -------------------------------------------------------------------------------- 1 | # https://github.com/yxdragon/sknw 2 | 3 | import numpy as np 4 | from numba import jit 5 | import networkx as nx 6 | 7 | def neighbors(shape): 8 | dim = len(shape) 9 | block = np.ones([3]*dim) 10 | block[tuple([1]*dim)] = 0 11 | idx = np.where(block>0) 12 | idx = np.array(idx, dtype=np.uint8).T 13 | idx = np.array(idx-[1]*dim) 14 | acc = np.cumprod((1,)+shape[::-1][:-1]) 15 | return np.dot(idx, acc[::-1]) 16 | 17 | @jit(nopython=True) # my mark 18 | def mark(img, nbs): # mark the array use (0, 1, 2) 19 | img = img.ravel() 20 | for p in range(len(img)): 21 | if img[p]==0:continue 22 | s = 0 23 | for dp in nbs: 24 | if img[p+dp]!=0:s+=1 25 | if s==2:img[p]=1 26 | else:img[p]=2 27 | 28 | @jit(nopython=True) # trans index to r, c... 29 | def idx2rc(idx, acc): 30 | rst = np.zeros((len(idx), len(acc)), dtype=np.int16) 31 | for i in range(len(idx)): 32 | for j in range(len(acc)): 33 | rst[i,j] = idx[i]//acc[j] 34 | idx[i] -= rst[i,j]*acc[j] 35 | rst -= 1 36 | return rst 37 | 38 | @jit(nopython=True) # fill a node (may be two or more points) 39 | def fill(img, p, num, nbs, acc, buf): 40 | back = img[p] 41 | img[p] = num 42 | buf[0] = p 43 | cur = 0; s = 1; 44 | 45 | while True: 46 | p = buf[cur] 47 | for dp in nbs: 48 | cp = p+dp 49 | if img[cp]==back: 50 | img[cp] = num 51 | buf[s] = cp 52 | s+=1 53 | cur += 1 54 | if cur==s:break 55 | return idx2rc(buf[:s], acc) 56 | 57 | @jit(nopython=True) # trace the edge and use a buffer, then buf.copy, if use [] numba not works 58 | def trace(img, p, nbs, acc, buf): 59 | c1 = 0; c2 = 0; 60 | newp = 0 61 | cur = 0 62 | 63 | while True: 64 | buf[cur] = p 65 | img[p] = 0 66 | cur += 1 67 | for dp in nbs: 68 | cp = p + dp 69 | if img[cp] >= 10: 70 | if c1==0:c1=img[cp] 71 | else: c2 = img[cp] 72 | if img[cp] == 1: 73 | newp = cp 74 | p = newp 75 | if c2!=0:break 76 | return (c1-10, c2-10, idx2rc(buf[:cur], acc)) 77 | 78 | @jit(nopython=True) # parse the image then get the nodes and edges 79 | def parse_struc(img, pts, nbs, acc): 80 | img = img.ravel() 81 | buf = np.zeros(131072, dtype=np.int64) 82 | num = 10 83 | nodes = [] 84 | for p in pts: 85 | if img[p] == 2: 86 | nds = fill(img, p, num, nbs, acc, buf) 87 | num += 1 88 | nodes.append(nds) 89 | 90 | edges = [] 91 | for p in pts: 92 | for dp in nbs: 93 | if img[p+dp]==1: 94 | edge = trace(img, p+dp, nbs, acc, buf) 95 | edges.append(edge) 96 | return nodes, edges 97 | 98 | # use nodes and edges build a networkx graph 99 | def build_graph(nodes, edges, multi=False): 100 | graph = nx.MultiGraph() if multi else nx.Graph() 101 | for i in range(len(nodes)): 102 | graph.add_node(i, pts=nodes[i], o=nodes[i].mean(axis=0)) 103 | for s,e,pts in edges: 104 | l = np.linalg.norm(pts[1:]-pts[:-1], axis=1).sum() 105 | graph.add_edge(s,e, pts=pts, weight=l) 106 | return graph 107 | 108 | def buffer(ske): 109 | buf = np.zeros(tuple(np.array(ske.shape)+2), dtype=np.uint16) 110 | buf[tuple([slice(1,-1)]*buf.ndim)] = ske 111 | return buf 112 | 113 | def build_sknw(ske, multi=False): 114 | buf = buffer(ske) 115 | nbs = neighbors(buf.shape) 116 | acc = np.cumprod((1,)+buf.shape[::-1][:-1])[::-1] 117 | mark(buf, nbs) 118 | pts = np.array(np.where(buf.ravel()==2))[0] 119 | nodes, edges = parse_struc(buf, pts, nbs, acc) 120 | return build_graph(nodes, edges, multi) 121 | 122 | # draw the graph 123 | def draw_graph(img, graph, cn=255, ce=128): 124 | acc = np.cumprod((1,)+img.shape[::-1][:-1])[::-1] 125 | img = img.ravel() 126 | for idx in graph.nodes(): 127 | pts = graph.node[idx]['pts'] 128 | img[np.dot(pts, acc)] = cn 129 | for (s, e) in graph.edges(): 130 | eds = graph[s][e] 131 | for i in eds: 132 | pts = eds[i]['pts'] 133 | img[np.dot(pts, acc)] = ce 134 | 135 | if __name__ == '__main__': 136 | g = nx.MultiGraph() 137 | g.add_nodes_from([1,2,3,4,5]) 138 | g.add_edges_from([(1,2),(1,3),(2,3),(4,5),(5,4)]) 139 | print(g.nodes()) 140 | print(g.edges()) 141 | a = g.subgraph(1) 142 | print('d') 143 | print(a) 144 | print('d') 145 | -------------------------------------------------------------------------------- /libs/apls/utils/sknw_int64.py: -------------------------------------------------------------------------------- 1 | ''' 2 | https://github.com/yxdragon/sknw 3 | change datatypes to 64 to accomodate very large files with coords over 4 | 32767''' 5 | 6 | import numpy as np 7 | from numba import jit 8 | import networkx as nx 9 | 10 | 11 | 12 | 13 | # get neighbors d index 14 | def neighbors(shape): 15 | dim = len(shape) 16 | block = np.ones([3] * dim) 17 | block[tuple([1] * dim)] = 0 18 | idx = np.where(block > 0) 19 | idx = np.array(idx, dtype=np.uint16).T 20 | idx = np.array(idx - [1] * dim) 21 | acc = np.cumprod((1,) + shape[::-1][:-1]) 22 | return np.dot(idx, acc[::-1]) 23 | 24 | 25 | @jit # my mark 26 | def mark(img): # mark the array use (0, 1, 2) 27 | nbs = neighbors(img.shape) 28 | img = img.ravel() 29 | for p in range(len(img)): 30 | if img[p] == 0: continue 31 | s = 0 32 | for dp in nbs: 33 | if img[p + dp] != 0: s += 1 34 | if s == 2: 35 | img[p] = 1 36 | else: 37 | img[p] = 2 38 | 39 | 40 | @jit # trans index to r, c... 41 | def idx2rc(idx, acc): 42 | rst = np.zeros((len(idx), len(acc)), dtype=np.int32) 43 | for i in range(len(idx)): 44 | for j in range(len(acc)): 45 | rst[i, j] = idx[i] // acc[j] 46 | idx[i] -= rst[i, j] * acc[j] 47 | rst -= 1 48 | return rst 49 | 50 | 51 | @jit # fill a node (may be two or more points) 52 | def fill(img, p, num, nbs, acc, buf): 53 | back = img[p] 54 | img[p] = num 55 | buf[0] = p 56 | cur = 0; 57 | s = 1; 58 | 59 | while True: 60 | p = buf[cur] 61 | for dp in nbs: 62 | cp = p + dp 63 | if img[cp] == back: 64 | img[cp] = num 65 | buf[s] = cp 66 | s += 1 67 | cur += 1 68 | if cur == s: break 69 | return idx2rc(buf[:s], acc) 70 | 71 | 72 | @jit # trace the edge and use a buffer, then buf.copy, if use [] numba not works 73 | def trace(img, p, nbs, acc, buf): 74 | c1 = 0; 75 | c2 = 0; 76 | newp = 0 77 | cur = 0 78 | 79 | while True: 80 | buf[cur] = p 81 | img[p] = 0 82 | cur += 1 83 | for dp in nbs: 84 | cp = p + dp 85 | if img[cp] >= 10: 86 | if c1 == 0: 87 | c1 = img[cp] 88 | else: 89 | c2 = img[cp] 90 | if img[cp] == 1: 91 | newp = cp 92 | p = newp 93 | if c2 != 0: break 94 | return (c1 - 10, c2 - 10, idx2rc(buf[:cur], acc)) 95 | 96 | 97 | @jit # parse the image then get the nodes and edges 98 | def parse_struc(img): 99 | nbs = neighbors(img.shape) 100 | acc = np.cumprod((1,) + img.shape[::-1][:-1])[::-1] 101 | img = img.ravel() 102 | pts = np.array(np.where(img == 2))[0] 103 | buf = np.zeros(131072, dtype=np.int64) 104 | num = 10 105 | nodes = [] 106 | for p in pts: 107 | if img[p] == 2: 108 | nds = fill(img, p, num, nbs, acc, buf) 109 | num += 1 110 | nodes.append(nds) 111 | 112 | edges = [] 113 | for p in pts: 114 | for dp in nbs: 115 | if img[p + dp] == 1: 116 | edge = trace(img, p + dp, nbs, acc, buf) 117 | edges.append(edge) 118 | return nodes, edges 119 | 120 | 121 | # use nodes and edges build a networkx graph 122 | def build_graph(nodes, edges, multi=False): 123 | graph = nx.MultiGraph() if multi else nx.Graph() 124 | for i in range(len(nodes)): 125 | graph.add_node(i, pts=nodes[i], o=np.int64(nodes[i].mean(axis=0))) 126 | for s, e, pts in edges: 127 | l = np.linalg.norm(pts[1:] - pts[:-1], axis=1).sum() 128 | graph.add_edge(s, e, pts=pts, weight=l) 129 | return graph 130 | 131 | 132 | def buffer(ske): 133 | buf = np.zeros(tuple(np.array(ske.shape) + 2), dtype=np.uint32) 134 | buf[tuple([slice(1, -1)] * buf.ndim)] = ske 135 | return buf 136 | 137 | 138 | def build_sknw(ske, multi=False): 139 | buf = buffer(ske) 140 | mark(buf) 141 | nodes, edges = parse_struc(buf) 142 | return build_graph(nodes, edges, multi) 143 | 144 | 145 | # draw the graph 146 | def draw_graph(img, graph, cn=255, ce=128): 147 | acc = np.cumprod((1,) + img.shape[::-1][:-1])[::-1] 148 | img = img.ravel() 149 | for idx in graph.nodes(): 150 | pts = graph.node[idx]['pts'] 151 | img[np.dot(pts, acc)] = cn 152 | for (s, e) in graph.edges(): 153 | eds = graph[s][e] 154 | for i in eds: 155 | pts = eds[i]['pts'] 156 | img[np.dot(pts, acc)] = ce 157 | 158 | 159 | if __name__ == '__main__': 160 | g = nx.MultiGraph() 161 | g.add_nodes_from([1, 2, 3, 4, 5]) 162 | g.add_edges_from([(1, 2), (1, 3), (2, 3), (4, 5), (5, 4)]) 163 | print(g.nodes()) 164 | print(g.edges()) 165 | a = g.subgraph(1) 166 | print('d') 167 | print(a) 168 | print('d') 169 | -------------------------------------------------------------------------------- /libs/solaris/README.md: -------------------------------------------------------------------------------- 1 | # Solaris module 2 | 3 | This is a slightly modified version of CosmiQ Works `solaris` project. Please refer to their [original repository](https://github.com/CosmiQ/solaris) for more information. 4 | -------------------------------------------------------------------------------- /libs/solaris/__init__.py: -------------------------------------------------------------------------------- 1 | from . import bin, data, eval, nets, raster, tile, utils, vector 2 | 3 | __version__ = "0.2.2" 4 | -------------------------------------------------------------------------------- /libs/solaris/bin/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/aws-open-data-satellite-lidar-tutorial/928196f105df202e04d5bcdc64ad449cf65b183d/libs/solaris/bin/__init__.py -------------------------------------------------------------------------------- /libs/solaris/bin/geotransform_footprints.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pandas as pd 3 | from tqdm import tqdm 4 | from multiprocessing import Pool 5 | from ..vector.polygon import geojson_to_px_gdf 6 | from ..vector.polygon import georegister_px_df 7 | from ..utils.cli import _func_wrapper 8 | from itertools import repeat 9 | 10 | 11 | def main(): 12 | 13 | parser = argparse.ArgumentParser( 14 | description='Interconvert footprints between pixel and geographic ' + 15 | 'coordinate systems.', argument_default=None) 16 | 17 | parser.add_argument('--source_file', '-s', type=str, 18 | help='Full path to file to transform') 19 | parser.add_argument('--reference_image', '-r', type=str, 20 | help='Full path to a georegistered image in the same' + 21 | ' coordinate system (for conversion to pixels) or in' + 22 | ' the target coordinate system (for conversion to a' + 23 | ' geographic coordinate reference system).') 24 | parser.add_argument('--output_path', '-o', type=str, 25 | help='Full path to the output file for the converted' + 26 | 'footprints.') 27 | parser.add_argument('--to_pixel', '-p', action='store_true', default=False, 28 | help='Use this argument if you wish to convert' + 29 | ' footprints in --source-file to pixel coordinates.') 30 | parser.add_argument('--to_geo', '-g', action='store_true', default=False, 31 | help='Use this argument if you wish to convert' + 32 | ' footprints in --source-file to a geographic' + 33 | ' coordinate system.') 34 | parser.add_argument('--geometry_column', '-c', type=str, 35 | default='geometry', help='The column containing' + 36 | ' footprint polygons to transform. If not provided,' + 37 | ' defaults to "geometry".') 38 | parser.add_argument('--decimal_precision', '-d', type=int, 39 | help='The number of decimals to round to in the' + 40 | ' final footprint geometries. If not provided, they' + 41 | ' will be rounded to float32 precision.') 42 | parser.add_argument('--batch', '-b', action='store_true', default=False, 43 | help='Use this flag if you wish to operate on' + 44 | ' multiple files in batch. In this case,' + 45 | ' --argument-csv must be provided. See help' + 46 | ' for --argument_csv and the codebase docs at' + 47 | ' https://solaris.readthedocs.io for more info.') 48 | parser.add_argument('--argument_csv', '-a', type=str, 49 | help='The reference file for variable values for' + 50 | ' batch processing. It must contain columns to pass' + 51 | ' the source_file and reference_image arguments, and' + 52 | ' can additionally contain columns providing the' + 53 | ' geometry_column and decimal_precision arguments' + 54 | ' if you wish to define them differently for items' + 55 | ' in the batch. These columns must have the same' + 56 | ' names as the corresponding arguments. See the ' + 57 | ' usage recipes at https://cw-geodata.readthedocs.io' + 58 | ' for examples.') 59 | parser.add_argument('--workers', '-w', type=int, default=1, 60 | help='The number of parallel processing workers to' + 61 | ' use. This should not exceed the number of CPU' + 62 | ' cores available.') 63 | 64 | args = parser.parse_args() 65 | # check that the necessary set of arguments are provided. 66 | if args.batch and args.argument_csv is None: 67 | raise ValueError( 68 | 'To perform batch processing, you must provide both --batch and' + 69 | ' --argument_csv.') 70 | if args.argument_csv is None and args.source_file is None: 71 | raise ValueError( 72 | 'You must provide a source file using either --source_file or' + 73 | ' --argument_csv.') 74 | if args.argument_csv is None and args.reference_image is None: 75 | raise ValueError( 76 | 'You must provide a reference image using either' + 77 | ' --reference_image or --argument_csv.') 78 | if args.to_pixel == args.to_geo: 79 | raise ValueError( 80 | 'One, and only one, of --to_pixel and --to_geo must be specified.') 81 | 82 | if args.argument_csv is not None: 83 | arg_df = pd.read_csv(args.argument_csv) 84 | else: 85 | arg_df = pd.DataFrame({}) 86 | 87 | if args.batch: 88 | # add values from individual arguments to the argument df 89 | if args.source_file is not None: 90 | arg_df['source_file'] = args.source_file 91 | if args.reference_image is not None: 92 | arg_df['reference_image'] = args.reference_image 93 | if args.geometry_column is not None: 94 | arg_df['geometry_column'] = args.geometry_column 95 | if args.decimal_precision is not None: 96 | arg_df['decimal_precision'] = args.decimal_precision 97 | else: 98 | # add values from individual arguments to the argument df 99 | if args.source_file is not None: 100 | arg_df['source_file'] = [args.source_file] 101 | if args.reference_image is not None: 102 | arg_df['reference_image'] = [args.reference_image] 103 | if args.geometry_column is not None: 104 | arg_df['geometry_column'] = [args.geometry_column] 105 | if args.decimal_precision is not None: 106 | arg_df['decimal_precision'] = [args.decimal_precision] 107 | if args.output_path is not None: 108 | arg_df['output_path'] = [args.output_path] 109 | 110 | if args.to_pixel: 111 | # rename argument columns for compatibility with the target func 112 | arg_df = arg_df.rename(columns={'source_file': 'geojson', 113 | 'reference_image': 'im_path', 114 | 'decimal_precision': 'precision', 115 | 'geometry_column': 'geom_col'}) 116 | arg_dict_list = arg_df[ 117 | ['geojson', 'im_path', 'precision', 'geom_col', 'output_path'] 118 | ].to_dict(orient='records') 119 | func_to_call = geojson_to_px_gdf 120 | elif args.to_geo: 121 | # rename argument columns for compatibility with the target func 122 | arg_df = arg_df.rename(columns={'source_file': 'df', 123 | 'reference_image': 'im_path', 124 | 'decimal_precision': 'precision', 125 | 'geometry_column': 'geom_col'}) 126 | arg_dict_list = arg_df[ 127 | ['df', 'im_path', 'precision', 'geom_col', 'output_path'] 128 | ].to_dict(orient='records') 129 | func_to_call = georegister_px_df 130 | 131 | if not args.batch: 132 | result = func_to_call(**arg_dict_list[0]) 133 | if not args.output_path: 134 | return result 135 | else: 136 | with Pool(processes=args.workers) as pool: 137 | result = tqdm(pool.starmap(_func_wrapper, zip(repeat(func_to_call), 138 | arg_dict_list))) 139 | pool.close() 140 | 141 | 142 | if __name__ == '__main__': 143 | main() 144 | -------------------------------------------------------------------------------- /libs/solaris/bin/make_graphs.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pandas as pd 3 | from tqdm import tqdm 4 | from multiprocessing import Pool 5 | from ..vector.graph import geojson_to_graph 6 | from ..utils.cli import _func_wrapper 7 | from itertools import repeat 8 | 9 | 10 | def main(): 11 | 12 | parser = argparse.ArgumentParser( 13 | description='Create training pixel masks from vector data', 14 | argument_default=None) 15 | 16 | parser.add_argument('--source_file', '-s', type=str, 17 | help='Full path to file to create graph from.') 18 | parser.add_argument('--output_path', '-o', type=str, 19 | help='Full path to the output file for the graph' + 20 | ' object.') 21 | parser.add_argument('--road_type_field', '-r', type=str, 22 | help='The name of the column in --source_file that' + 23 | ' defines the road type of each linestring.') 24 | parser.add_argument('--first_edge_idx', '-e', type=int, default=0, 25 | help='The numeric index to use for the first edge in' + 26 | ' the graph. Defaults to 0.') 27 | parser.add_argument('--first_node_idx', '-n', type=int, default=0, 28 | help='The numeric index to use for the first node in' + 29 | ' the graph. Defaults to 0.') 30 | parser.add_argument('--weight_norm_field', '-wn', type=str, 31 | help='The name of a column in --source_file to' + 32 | ' weight edges with. If not provided, edge weights' + 33 | ' are determined only by Euclidean distance. If' + 34 | ' provided, edge weights are distance*weight.') 35 | parser.add_argument('--batch', '-b', action='store_true', default=False, 36 | help='Use this flag if you wish to operate on' + 37 | ' multiple files in batch. In this case,' + 38 | ' --argument-csv must be provided. See help' + 39 | ' for --argument_csv and the codebase docs at' + 40 | ' https://cw-geodata.readthedocs.io for more info.') 41 | parser.add_argument('--argument_csv', '-a', type=str, 42 | help='The reference file for variable values for' + 43 | ' batch processing. It must contain columns to pass' + 44 | ' the source_file and reference_image arguments, and' + 45 | ' can additionally contain columns providing the' + 46 | ' footprint_column and decimal_precision arguments' + 47 | ' if you wish to define them differently for items' + 48 | ' in the batch. These columns must have the same' + 49 | ' names as the corresponding arguments. See the ' + 50 | ' usage recipes at https://cw-geodata.readthedocs.io' + 51 | ' for examples.') 52 | parser.add_argument('--workers', '-w', type=int, default=1, 53 | help='The number of parallel processing workers to' + 54 | ' use. This should not exceed the number of CPU' + 55 | ' cores available.') 56 | 57 | args = parser.parse_args() 58 | 59 | if args.batch and args.argument_csv is None: 60 | raise ValueError( 61 | 'To perform batch processing, you must provide both --batch and' + 62 | ' --argument_csv.') 63 | if args.argument_csv is None and args.source_file is None: 64 | raise ValueError( 65 | 'You must provide a source file using either --source_file or' + 66 | ' --argument_csv.') 67 | 68 | if args.argument_csv is not None: 69 | arg_df = pd.read_csv(args.argument_csv) 70 | else: 71 | arg_df = pd.DataFrame({}) 72 | if args.batch: 73 | if args.source_file is not None: 74 | arg_df['source_file'] = args.source_file 75 | if args.output_path is not None: 76 | arg_df['output_path'] = args.output_path 77 | if args.road_type_field is not None: 78 | arg_df['road_type_field'] = args.road_type_field 79 | arg_df['first_node_idx'] = args.first_node_idx 80 | arg_df['first_edge_idx'] = args.first_edge_idx 81 | if args.weight_norm_field is not None: 82 | arg_df['weight_norm_field'] = args.weight_norm_field 83 | else: 84 | arg_df['source_file'] = [args.source_file] 85 | arg_df['output_path'] = [args.output_path] 86 | arg_df['road_type_field'] = [args.road_type_field] 87 | arg_df['first_node_idx'] = [args.first_node_idx] 88 | arg_df['first_edge_idx'] = [args.first_edge_idx] 89 | arg_df['weight_norm_field'] = [args.weight_norm_field] 90 | 91 | arg_df = arg_df.rename(columns={'source_file': 'geojson', 92 | 'first_edge_idx': 'edge_idx'}) 93 | arg_dict_list = arg_df[['geojson', 'output_path', 'road_type_field', 94 | 'weight_norm_field', 'edge_idx', 'first_node_idx', 95 | 'output_path'] 96 | ].to_dict(orient='records') 97 | if not args.batch: 98 | result = geojson_to_graph(**arg_dict_list[0]) 99 | if not args.output_path: 100 | return result 101 | 102 | else: 103 | with Pool(processes=args.workers) as pool: 104 | result = tqdm(pool.starmap(_func_wrapper, 105 | zip(repeat(geojson_to_graph), 106 | arg_dict_list))) 107 | pool.close() 108 | 109 | 110 | if __name__ == '__main__': 111 | main() 112 | -------------------------------------------------------------------------------- /libs/solaris/bin/make_masks.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pandas as pd 3 | from tqdm import tqdm 4 | from multiprocessing import Pool 5 | from ..vector.mask import df_to_px_mask 6 | from ..utils.cli import _func_wrapper 7 | from itertools import repeat 8 | 9 | 10 | def main(): 11 | 12 | parser = argparse.ArgumentParser( 13 | description='Create training pixel masks from vector data', 14 | argument_default=None) 15 | 16 | parser.add_argument('--source_file', '-s', type=str, 17 | help='Full path to file to create mask from.') 18 | parser.add_argument('--reference_image', '-r', type=str, 19 | help='Full path to a georegistered image in the same' 20 | ' coordinate system (for conversion to pixels) or in' 21 | ' the target coordinate system (for conversion to a' 22 | ' geographic coordinate reference system).') 23 | parser.add_argument('--output_path', '-o', type=str, 24 | help='Full path to the output file for the converted' 25 | 'footprints.') 26 | parser.add_argument('--geometry_column', '-g', type=str, 27 | default='geometry', help='The column containing' 28 | ' footprint polygons to transform. If not provided,' 29 | ' defaults to "geometry".') 30 | parser.add_argument('--transform', '-t', action='store_true', 31 | default=False, help='Use this flag if the geometries' 32 | ' are in a georeferenced coordinate system and' 33 | ' need to be converted to pixel coordinates.') 34 | parser.add_argument('--value', '-v', type=int, default=255, 35 | help='The value to set for labeled pixels in the' 36 | ' mask. Defaults to 255.') 37 | parser.add_argument('--footprint', '-f', action='store_true', 38 | default=False, help='If this flag is set, the mask' 39 | ' will include filled-in building footprints as a' 40 | ' channel.') 41 | parser.add_argument('--edge', '-e', action='store_true', 42 | default=False, help='If this flag is set, the mask' 43 | ' will include the building edges as a channel.') 44 | parser.add_argument('--edge_width', '-ew', type=int, default=3, 45 | help='Pixel thickness of the edges in the edge mask.' 46 | ' Defaults to 3 if not provided.') 47 | parser.add_argument('--edge_type', '-et', type=str, default='inner', 48 | help='Type of edge: either inner or outer. Defaults' 49 | ' to inner if not provided.') 50 | parser.add_argument('--contact', '-c', action='store_true', 51 | default=False, help='If this flag is set, the mask' 52 | ' will include contact points between buildings as a' 53 | ' channel.') 54 | parser.add_argument('--contact_spacing', '-cs', type=int, default=10, 55 | help='Sets the maximum distance between two' 56 | ' buildings, in source file units, that will be' 57 | ' identified as a contact. Defaults to 10.') 58 | parser.add_argument('--metric_widths', '-m', action='store_true', 59 | default=False, help='Use this flag if any widths ' 60 | '(--contact-spacing specifically) should be in metric ' 61 | 'units instead of pixel units.') 62 | parser.add_argument('--batch', '-b', action='store_true', default=False, 63 | help='Use this flag if you wish to operate on' 64 | ' multiple files in batch. In this case,' 65 | ' --argument-csv must be provided. See help' 66 | ' for --argument_csv and the codebase docs at' 67 | ' https://solaris.readthedocs.io for more info.') 68 | parser.add_argument('--argument_csv', '-a', type=str, 69 | help='The reference file for variable values for' 70 | ' batch processing. It must contain columns to pass' 71 | ' the source_file and reference_image arguments, and' 72 | ' can additionally contain columns providing the' 73 | ' footprint_column and decimal_precision arguments' 74 | ' if you wish to define them differently for items' 75 | ' in the batch. These columns must have the same' 76 | ' names as the corresponding arguments. See the ' 77 | ' usage recipes at https://solaris.readthedocs.io' 78 | ' for examples.') 79 | parser.add_argument('--workers', '-w', type=int, default=1, 80 | help='The number of parallel processing workers to' 81 | ' use. This should not exceed the number of CPU' 82 | ' cores available.') 83 | 84 | args = parser.parse_args() 85 | 86 | if args.batch and args.argument_csv is None: 87 | raise ValueError( 88 | 'To perform batch processing, you must provide both --batch and' + 89 | ' --argument_csv.') 90 | if args.argument_csv is None and args.source_file is None: 91 | raise ValueError( 92 | 'You must provide a source file using either --source_file or' + 93 | ' --argument_csv.') 94 | if args.argument_csv is None and args.reference_image is None: 95 | raise ValueError( 96 | 'You must provide a reference image using either' + 97 | ' --reference_image or --argument_csv.') 98 | if not args.footprint and not args.edge and not args.contact: 99 | raise ValueError( 100 | 'You must specify --footprint, --edge, and/or --contact. See' + 101 | ' make_masks --help or the CLI documentation at' + 102 | ' cw-geodata.readthedocs.io.') 103 | 104 | if args.argument_csv is not None: 105 | arg_df = pd.read_csv(args.argument_csv) 106 | else: 107 | arg_df = pd.DataFrame({}) 108 | 109 | # generate the channels argument for df_to_px_mask 110 | channels = [] 111 | if args.footprint: 112 | channels.append('footprint') 113 | if args.edge: 114 | channels.append('boundary') 115 | if args.contact: 116 | channels.append('contact') 117 | if len(arg_df) < 2: 118 | arg_df['channels'] = [channels] 119 | else: 120 | arg_df['channels'] = [channels]*len(arg_df) # all channels in each row 121 | 122 | if args.batch: 123 | if args.source_file is not None: 124 | arg_df['source_file'] = args.source_file 125 | if args.reference_image is not None: 126 | arg_df['reference_image'] = args.reference_image 127 | if args.output_path is not None and not args.batch: 128 | arg_df['output_path'] = args.output_path 129 | if args.geometry_column is not None: 130 | arg_df['geometry_column'] = args.geometry_column 131 | if args.transform: 132 | arg_df['transform'] = True 133 | if 'value' not in arg_df.columns: 134 | arg_df['value'] = args.value 135 | if 'edge_width' not in arg_df.columns: 136 | arg_df['edge_width'] = args.edge_width 137 | if 'edge_type' not in arg_df.columns: 138 | arg_df['edge_type'] = args.edge_type 139 | if 'contact_spacing' not in arg_df.columns: 140 | arg_df['contact_spacing'] = args.contact_spacing 141 | else: 142 | arg_df['source_file'] = [args.source_file] 143 | arg_df['reference_image'] = [args.reference_image] 144 | arg_df['output_path'] = [args.output_path] 145 | arg_df['geometry_column'] = [args.geometry_column] 146 | arg_df['transform'] = [args.transform] 147 | arg_df['metric'] = [args.metric_widths] 148 | arg_df['value'] = [args.value] 149 | arg_df['edge_width'] = [args.edge_width] 150 | arg_df['edge_type'] = [args.edge_type] 151 | arg_df['contact_spacing'] = [args.contact_spacing] 152 | 153 | # rename arguments to match API 154 | arg_df = arg_df.rename(columns={'source_file': 'df', 155 | 'output_path': 'out_file', 156 | 'reference_image': 'reference_im', 157 | 'geometry_column': 'geom_col', 158 | 'transform': 'do_transform', 159 | 'value': 'burn_value', 160 | 'edge_width': 'boundary_width', 161 | 'edge_type': 'boundary_type'}) 162 | 163 | arg_dict_list = arg_df[['df', 'out_file', 'reference_im', 'geom_col', 164 | 'do_transform', 'channels', 'burn_value', 165 | 'boundary_width', 'boundary_type', 166 | 'contact_spacing'] 167 | ].to_dict(orient='records') 168 | if not args.batch: 169 | result = df_to_px_mask(**arg_dict_list[0]) 170 | if not args.output_path: 171 | return result 172 | 173 | else: 174 | with Pool(processes=args.workers) as pool: 175 | result = tqdm(pool.starmap(_func_wrapper, 176 | zip(repeat(df_to_px_mask), 177 | arg_dict_list))) 178 | pool.close() 179 | 180 | 181 | if __name__ == '__main__': 182 | main() 183 | -------------------------------------------------------------------------------- /libs/solaris/bin/mask_to_polygons.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/aws-open-data-satellite-lidar-tutorial/928196f105df202e04d5bcdc64ad449cf65b183d/libs/solaris/bin/mask_to_polygons.py -------------------------------------------------------------------------------- /libs/solaris/bin/solaris_run_ml.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import pandas as pd 4 | from ..utils.config import parse 5 | from ..nets.train import Trainer 6 | from ..nets.infer import Inferer 7 | 8 | 9 | def main(): 10 | 11 | parser = argparse.ArgumentParser( 12 | description='Run a Solaris ML pipeline based on a config YAML', 13 | argument_default=None) 14 | 15 | parser.add_argument('--config', '-c', type=str, required=True, 16 | help="Full path to a YAML-formatted config file " 17 | "specifying parameters for model training and/or " 18 | "inference.") 19 | 20 | args = parser.parse_args() 21 | 22 | if not os.path.exists(args.config): 23 | raise ValueError('The configuration file cannot be found at the path ' 24 | 'specified.') 25 | 26 | config = parse(args.config) 27 | 28 | if config['train']: 29 | trainer = Trainer(config) 30 | trainer.train() 31 | if config['infer']: 32 | inferer = Inferer(config) 33 | inf_df = pd.read_csv(config['inference_data_csv']) 34 | inferer(inf_df) 35 | 36 | 37 | if __name__ == '__main__': 38 | main() 39 | -------------------------------------------------------------------------------- /libs/solaris/bin/spacenet_eval.py: -------------------------------------------------------------------------------- 1 | """Script for executing eval for SpaceNet challenges.""" 2 | from ..eval.challenges import off_nadir_buildings 3 | from ..eval.challenges import spacenet_buildings_2 4 | import argparse 5 | import pandas as pd 6 | supported_challenges = ['off-nadir', 'spacenet-buildings2'] 7 | # , 'spaceNet-buildings1', 'spacenet-roads1', 'buildings', 'roads'] 8 | 9 | 10 | def main(): 11 | parser = argparse.ArgumentParser( 12 | description='Evaluate SpaceNet Competition CSVs') 13 | parser.add_argument('--proposal_csv', '-p', type=str, 14 | help='Proposal CSV') 15 | parser.add_argument('--truth_csv', '-t', type=str, 16 | help='Truth CSV') 17 | parser.add_argument('--challenge', '-c', type=str, 18 | default='off-nadir', 19 | choices=supported_challenges, 20 | help='SpaceNet Challenge eval type') 21 | parser.add_argument('--output_file', '-o', type=str, 22 | default='Off-Nadir', 23 | help='Output file To write results to CSV') 24 | args = parser.parse_args() 25 | 26 | truth_file = args.truth_csv 27 | prop_file = args.proposal_csv 28 | 29 | if args.challenge.lower() == 'off-nadir': 30 | evalSettings = {'miniou': 0.5, 31 | 'min_area': 20} 32 | results_DF, results_DF_Full = off_nadir_buildings( 33 | prop_csv=prop_file, truth_csv=truth_file, **evalSettings) 34 | elif args.challenge.lower() == 'spaceNet-buildings2'.lower(): 35 | evalSettings = {'miniou': 0.5, 36 | 'min_area': 20} 37 | results_DF, results_DF_Full = spacenet_buildings_2( 38 | prop_csv=prop_file, truth_csv=truth_file, **evalSettings) 39 | 40 | with pd.option_context('display.max_rows', None, 41 | 'display.max_columns', None): 42 | print(results_DF) 43 | 44 | if args.output_file: 45 | print("Writing summary results to {}".format( 46 | args.output_file.rstrip('.csv') + '.csv')) 47 | results_DF.to_csv(args.output_file.rstrip('.csv') + '.csv', 48 | index=False) 49 | print("Writing full results to {}".format( 50 | args.output_file.rstrip('.csv')+"_full.csv")) 51 | results_DF_Full.to_csv(args.output_file.rstrip('.csv')+"_full.csv", 52 | index=False) 53 | 54 | 55 | if __name__ == '__main__': 56 | main() 57 | -------------------------------------------------------------------------------- /libs/solaris/data/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | import geopandas as gpd 4 | import gdal 5 | import rasterio 6 | 7 | from . import coco 8 | 9 | # define the current directory as `data_dir` 10 | data_dir = os.path.abspath(os.path.dirname(__file__)) 11 | 12 | 13 | def load_geojson(gj_fname): 14 | """Load a geojson into a gdf using GeoPandas.""" 15 | return gpd.read_file(os.path.join(data_dir, gj_fname)) 16 | 17 | 18 | def gt_gdf(): 19 | """Load in a ground truth GDF example.""" 20 | return load_geojson('gt.geojson') 21 | 22 | 23 | def pred_gdf(): 24 | """Load in an example prediction GDF.""" 25 | return load_geojson('pred.geojson') 26 | 27 | 28 | def sample_load_rasterio(): 29 | return rasterio.open(os.path.join(data_dir, 'sample_geotiff.tif')) 30 | 31 | 32 | def sample_load_gdal(): 33 | return gdal.Open(os.path.join(data_dir, 'sample_geotiff.tif')) 34 | 35 | 36 | def sample_load_geojson(): 37 | return gpd.read_file(os.path.join(data_dir, 'sample.geojson')) 38 | 39 | 40 | def sample_load_csv(): 41 | return pd.read_file(os.path.join(data_dir, 'sample.csv')) 42 | -------------------------------------------------------------------------------- /libs/solaris/eval/__init__.py: -------------------------------------------------------------------------------- 1 | from . import base, iou, challenges, pixel 2 | -------------------------------------------------------------------------------- /libs/solaris/eval/challenges.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from .base import Evaluator 3 | import re 4 | 5 | 6 | def spacenet_buildings_2(prop_csv, truth_csv, miniou=0.5, min_area=20, challenge='spacenet_2'): 7 | """Evaluate a SpaceNet building footprint competition proposal csv. 8 | 9 | Uses :class:`Evaluator` to evaluate SpaceNet challenge proposals. 10 | 11 | Arguments 12 | --------- 13 | prop_csv : str 14 | Path to the proposal polygon CSV file. 15 | truth_csv : str 16 | Path to the ground truth polygon CSV file. 17 | miniou : float, optional 18 | Minimum IoU score between a region proposal and ground truth to define 19 | as a successful identification. Defaults to 0.5. 20 | min_area : float or int, optional 21 | Minimum area of ground truth regions to include in scoring calculation. 22 | Defaults to ``20``. 23 | challenge: str, optional 24 | The challenge id for evaluation. 25 | One of 26 | ``['spacenet_2', 'spacenet_3', 'spacenet_off_nadir', 'spacenet_6']``. 27 | The name of the challenge that `chip_name` came from. Defaults to 28 | ``'spacenet_2'``. 29 | 30 | Returns 31 | ------- 32 | 33 | results_DF, results_DF_Full 34 | 35 | results_DF : :py:class:`pd.DataFrame` 36 | Summary :py:class:`pd.DataFrame` of score outputs grouped by nadir 37 | angle bin, along with the overall score. 38 | 39 | results_DF_Full : :py:class:`pd.DataFrame` 40 | :py:class:`pd.DataFrame` of scores by individual image chip across 41 | the ground truth and proposal datasets. 42 | 43 | """ 44 | 45 | evaluator = Evaluator(ground_truth_vector_file=truth_csv) 46 | evaluator.load_proposal(prop_csv, 47 | conf_field_list=['Confidence'], 48 | proposalCSV=True 49 | ) 50 | results = evaluator.eval_iou_spacenet_csv(miniou=miniou, 51 | iou_field_prefix="iou_score", 52 | imageIDField="ImageId", 53 | min_area=min_area 54 | ) 55 | results_DF_Full = pd.DataFrame(results) 56 | 57 | results_DF_Full['AOI'] = [get_chip_id(imageID, challenge=challenge) 58 | for imageID in results_DF_Full['imageID'].values] 59 | 60 | results_DF = results_DF_Full.groupby(['AOI']).sum() 61 | 62 | # Recalculate Values after Summation of AOIs 63 | for indexVal in results_DF.index: 64 | rowValue = results_DF[results_DF.index == indexVal] 65 | # Precision = TruePos / float(TruePos + FalsePos) 66 | if float(rowValue['TruePos'] + rowValue['FalsePos']) > 0: 67 | Precision = float( 68 | rowValue['TruePos'] / float( 69 | rowValue['TruePos'] + rowValue['FalsePos'])) 70 | else: 71 | Precision = 0 72 | # Recall = TruePos / float(TruePos + FalseNeg) 73 | if float(rowValue['TruePos'] + rowValue['FalseNeg']) > 0: 74 | Recall = float(rowValue['TruePos'] / float( 75 | rowValue['TruePos'] + rowValue['FalseNeg'])) 76 | else: 77 | Recall = 0 78 | if Recall * Precision > 0: 79 | F1Score = 2 * Precision * Recall / (Precision + Recall) 80 | else: 81 | F1Score = 0 82 | results_DF.loc[results_DF.index == indexVal, 'Precision'] = Precision 83 | results_DF.loc[results_DF.index == indexVal, 'Recall'] = Recall 84 | results_DF.loc[results_DF.index == indexVal, 'F1Score'] = F1Score 85 | 86 | return results_DF, results_DF_Full 87 | 88 | 89 | def off_nadir_buildings(prop_csv, truth_csv, image_columns={}, miniou=0.5, 90 | min_area=20, verbose=False): 91 | """Evaluate an off-nadir competition proposal csv. 92 | 93 | Uses :class:`Evaluator` to evaluate off-nadir challenge proposals. See 94 | ``image_columns`` in the source code for how collects are broken into 95 | Nadir, Off-Nadir, and Very-Off-Nadir bins. 96 | 97 | Arguments 98 | --------- 99 | prop_csv : str 100 | Path to the proposal polygon CSV file. 101 | truth_csv : str 102 | Path to the ground truth polygon CSV file. 103 | image_columns : dict, optional 104 | dict of ``(collect: nadir bin)`` pairs used to separate collects into 105 | sets. Nadir bin values must be one of 106 | ``["Nadir", "Off-Nadir", "Very-Off-Nadir"]`` . See source code for 107 | collect name options. 108 | miniou : float, optional 109 | Minimum IoU score between a region proposal and ground truth to define 110 | as a successful identification. Defaults to 0.5. 111 | min_area : float or int, optional 112 | Minimum area of ground truth regions to include in scoring calculation. 113 | Defaults to ``20``. 114 | 115 | Returnss 116 | ------- 117 | 118 | results_DF, results_DF_Full 119 | 120 | results_DF : :py:class:`pd.DataFrame` 121 | Summary :py:class:`pd.DataFrame` of score outputs grouped by nadir 122 | angle bin, along with the overall score. 123 | 124 | results_DF_Full : :py:class:`pd.DataFrame` 125 | :py:class:`pd.DataFrame` of scores by individual image chip across 126 | the ground truth and proposal datasets. 127 | 128 | """ 129 | 130 | evaluator = Evaluator(ground_truth_vector_file=truth_csv) 131 | evaluator.load_proposal(prop_csv, 132 | conf_field_list=['Confidence'], 133 | proposalCSV=True 134 | ) 135 | results = evaluator.eval_iou_spacenet_csv(miniou=miniou, 136 | iou_field_prefix="iou_score", 137 | imageIDField="ImageId", 138 | min_area=min_area 139 | ) 140 | results_DF_Full = pd.DataFrame(results) 141 | 142 | if not image_columns: 143 | image_columns = { 144 | 'Atlanta_nadir7_catid_1030010003D22F00': "Nadir", 145 | 'Atlanta_nadir8_catid_10300100023BC100': "Nadir", 146 | 'Atlanta_nadir10_catid_1030010003993E00': "Nadir", 147 | 'Atlanta_nadir10_catid_1030010003CAF100': "Nadir", 148 | 'Atlanta_nadir13_catid_1030010002B7D800': "Nadir", 149 | 'Atlanta_nadir14_catid_10300100039AB000': "Nadir", 150 | 'Atlanta_nadir16_catid_1030010002649200': "Nadir", 151 | 'Atlanta_nadir19_catid_1030010003C92000': "Nadir", 152 | 'Atlanta_nadir21_catid_1030010003127500': "Nadir", 153 | 'Atlanta_nadir23_catid_103001000352C200': "Nadir", 154 | 'Atlanta_nadir25_catid_103001000307D800': "Nadir", 155 | 'Atlanta_nadir27_catid_1030010003472200': "Off-Nadir", 156 | 'Atlanta_nadir29_catid_1030010003315300': "Off-Nadir", 157 | 'Atlanta_nadir30_catid_10300100036D5200': "Off-Nadir", 158 | 'Atlanta_nadir32_catid_103001000392F600': "Off-Nadir", 159 | 'Atlanta_nadir34_catid_1030010003697400': "Off-Nadir", 160 | 'Atlanta_nadir36_catid_1030010003895500': "Off-Nadir", 161 | 'Atlanta_nadir39_catid_1030010003832800': "Off-Nadir", 162 | 'Atlanta_nadir42_catid_10300100035D1B00': "Very-Off-Nadir", 163 | 'Atlanta_nadir44_catid_1030010003CCD700': "Very-Off-Nadir", 164 | 'Atlanta_nadir46_catid_1030010003713C00': "Very-Off-Nadir", 165 | 'Atlanta_nadir47_catid_10300100033C5200': "Very-Off-Nadir", 166 | 'Atlanta_nadir49_catid_1030010003492700': "Very-Off-Nadir", 167 | 'Atlanta_nadir50_catid_10300100039E6200': "Very-Off-Nadir", 168 | 'Atlanta_nadir52_catid_1030010003BDDC00': "Very-Off-Nadir", 169 | 'Atlanta_nadir53_catid_1030010003193D00': "Very-Off-Nadir", 170 | 'Atlanta_nadir53_catid_1030010003CD4300': "Very-Off-Nadir", 171 | } 172 | 173 | results_DF_Full['nadir-category'] = [ 174 | image_columns[get_chip_id(imageID, challenge='spacenet_off_nadir')] 175 | for imageID in results_DF_Full['imageID'].values] 176 | 177 | results_DF = results_DF_Full.groupby(['nadir-category']).sum() 178 | 179 | # Recalculate Values after Summation of AOIs 180 | for indexVal in results_DF.index: 181 | if verbose: 182 | print(indexVal) 183 | rowValue = results_DF[results_DF.index == indexVal] 184 | # Precision = TruePos / float(TruePos + FalsePos) 185 | if float(rowValue['TruePos'] + rowValue['FalsePos']) > 0: 186 | Precision = float( 187 | rowValue['TruePos'] / float(rowValue['TruePos'] + 188 | rowValue['FalsePos']) 189 | ) 190 | else: 191 | Precision = 0 192 | # Recall = TruePos / float(TruePos + FalseNeg) 193 | if float(rowValue['TruePos'] + rowValue['FalseNeg']) > 0: 194 | Recall = float(rowValue['TruePos'] / float(rowValue['TruePos'] + 195 | rowValue['FalseNeg'])) 196 | else: 197 | Recall = 0 198 | if Recall * Precision > 0: 199 | F1Score = 2 * Precision * Recall / (Precision + Recall) 200 | else: 201 | F1Score = 0 202 | results_DF.loc[results_DF.index == indexVal, 'Precision'] = Precision 203 | results_DF.loc[results_DF.index == indexVal, 'Recall'] = Recall 204 | results_DF.loc[results_DF.index == indexVal, 'F1Score'] = F1Score 205 | 206 | return results_DF, results_DF_Full 207 | 208 | 209 | def get_chip_id(chip_name, challenge="spacenet_2"): 210 | """Get the unique identifier for a chip location from SpaceNet images. 211 | 212 | Arguments 213 | --------- 214 | chip_name: str 215 | The name of the chip to extract the identifier from. 216 | challenge: str, optional 217 | One of 218 | ``['spacenet_2', 'spacenet_3', 'spacenet_off_nadir', 'spacenet_6']``. 219 | The name of the challenge that `chip_name` came from. Defaults to 220 | ``'spacenet_2'``. 221 | 222 | Returns 223 | ------- 224 | chip_id : str 225 | The unique identifier for the chip location. 226 | """ 227 | # AS NEW CHALLENGES ARE ADDED, ADD THE CHIP MATCHING FUNCTIONALITY WITHIN 228 | # THIS FUNCTION. 229 | if challenge in ['spacenet_2', 'spacenet_3']: 230 | chip_id = '_'.join(chip_name.split('_')[:-1]) 231 | elif challenge == 'spacenet_off_nadir': 232 | chip_id = re.findall('Atlanta_nadir[0-9]{1,2}_catid_[0-9A-Z]{16}', 233 | chip_name)[0] 234 | elif challenge == 'spacenet_6': 235 | chip_id = '_'.join(chip_name.split('_')[-4:]).split(".")[0] 236 | 237 | return chip_id 238 | -------------------------------------------------------------------------------- /libs/solaris/eval/iou.py: -------------------------------------------------------------------------------- 1 | import geopandas as gpd 2 | 3 | 4 | def calculate_iou(pred_poly, test_data_GDF): 5 | """Get the best intersection over union for a predicted polygon. 6 | 7 | Arguments 8 | --------- 9 | pred_poly : :py:class:`shapely.Polygon` 10 | Prediction polygon to test. 11 | test_data_GDF : :py:class:`geopandas.GeoDataFrame` 12 | GeoDataFrame of ground truth polygons to test ``pred_poly`` against. 13 | 14 | Returns 15 | ------- 16 | iou_GDF : :py:class:`geopandas.GeoDataFrame` 17 | A subset of ``test_data_GDF`` that overlaps ``pred_poly`` with an added 18 | column ``iou_score`` which indicates the intersection over union value. 19 | 20 | """ 21 | 22 | # Fix bowties and self-intersections 23 | if not pred_poly.is_valid: 24 | pred_poly = pred_poly.buffer(0.0) 25 | 26 | precise_matches = test_data_GDF[test_data_GDF.intersects(pred_poly)] 27 | 28 | iou_row_list = [] 29 | for _, row in precise_matches.iterrows(): 30 | # Load ground truth polygon and check exact iou 31 | test_poly = row.geometry 32 | # Ignore invalid polygons for now 33 | if pred_poly.is_valid and test_poly.is_valid: 34 | intersection = pred_poly.intersection(test_poly).area 35 | union = pred_poly.union(test_poly).area 36 | # Calculate iou 37 | iou_score = intersection / float(union) 38 | else: 39 | iou_score = 0 40 | row['iou_score'] = iou_score 41 | iou_row_list.append(row) 42 | 43 | iou_GDF = gpd.GeoDataFrame(iou_row_list) 44 | return iou_GDF 45 | 46 | 47 | def process_iou(pred_poly, test_data_GDF, remove_matching_element=True): 48 | """Get the maximum intersection over union score for a predicted polygon. 49 | 50 | Arguments 51 | --------- 52 | pred_poly : :py:class:`shapely.geometry.Polygon` 53 | Prediction polygon to test. 54 | test_data_GDF : :py:class:`geopandas.GeoDataFrame` 55 | GeoDataFrame of ground truth polygons to test ``pred_poly`` against. 56 | remove_matching_element : bool, optional 57 | Should the maximum IoU row be dropped from ``test_data_GDF``? Defaults 58 | to ``True``. 59 | 60 | Returns 61 | ------- 62 | *This function doesn't currently return anything.* 63 | 64 | """ 65 | 66 | iou_GDF = calculate_iou(pred_poly, test_data_GDF) 67 | 68 | max_iou_row = iou_GDF.loc[iou_GDF['iou_score'].idxmax(axis=0, skipna=True)] 69 | 70 | if remove_matching_element: 71 | test_data_GDF.drop(max_iou_row.name, axis=0, inplace=True) 72 | 73 | # Prediction poly had no overlap with anything 74 | # if not iou_list: 75 | # return max_iou_row, 0, test_data_DF 76 | # else: 77 | # max_iou_idx, max_iou = max(iou_list, key=lambda x: x[1]) 78 | # # Remove ground truth polygon from tree 79 | # test_tree.delete(max_iou_idx, Polygon(test_data[max_iou_idx]['geometry']['coordinates'][0]).bounds) 80 | # return max_iou_row['iou_score'], iou_GDF, test_data_DF 81 | -------------------------------------------------------------------------------- /libs/solaris/nets/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | weights_dir = os.path.join(os.path.abspath(os.path.dirname(__file__)), 4 | 'weights') 5 | 6 | from . import callbacks, datagen, infer, losses, metrics, model_io 7 | from . import optimizers, train, transform, zoo 8 | 9 | 10 | if not os.path.isdir(weights_dir): 11 | os.mkdir(weights_dir) 12 | -------------------------------------------------------------------------------- /libs/solaris/nets/_keras_losses.py: -------------------------------------------------------------------------------- 1 | from tensorflow.keras import losses 2 | from tensorflow.keras import backend as K 3 | from .metrics import dice_coef_binary 4 | import tensorflow as tf 5 | 6 | 7 | def k_dice_loss(y_true, y_pred): 8 | return 1 - dice_coef_binary(y_true, y_pred) 9 | 10 | 11 | def k_jaccard_loss(y_true, y_pred): 12 | """Jaccard distance for semantic segmentation. 13 | 14 | Modified from the `keras-contrib` package. 15 | 16 | """ 17 | eps = 1e-12 # for stability 18 | y_pred = K.clip(y_pred, eps, 1-eps) 19 | intersection = K.sum(K.abs(y_true*y_pred), axis=-1) 20 | sum_ = K.sum(K.abs(y_true) + K.abs(y_pred), axis=-1) 21 | jac = intersection/(sum_ - intersection) 22 | return 1 - jac 23 | 24 | 25 | def k_focal_loss(gamma=2, alpha=0.75): 26 | # from github.com/atomwh/focalloss_keras 27 | 28 | def focal_loss_fixed(y_true, y_pred): # with tensorflow 29 | 30 | eps = 1e-12 # improve the stability of the focal loss 31 | y_pred = K.clip(y_pred, eps, 1.-eps) 32 | pt_1 = tf.where(tf.equal(y_true, 1), y_pred, tf.ones_like(y_pred)) 33 | pt_0 = tf.where(tf.equal(y_true, 0), y_pred, tf.zeros_like(y_pred)) 34 | return -K.sum( 35 | alpha * K.pow(1. - pt_1, gamma) * K.log(pt_1))-K.sum( 36 | (1-alpha) * K.pow(pt_0, gamma) * K.log(1. - pt_0)) 37 | 38 | return focal_loss_fixed 39 | 40 | 41 | def k_lovasz_hinge(per_image=False): 42 | """Wrapper for the Lovasz Hinge Loss Function, for use in Keras. 43 | 44 | This is a mess. I'm sorry. 45 | """ 46 | 47 | def lovasz_hinge_flat(y_true, y_pred): 48 | # modified from Maxim Berman's GitHub repo tf implementation for Lovasz 49 | eps = 1e-12 # for stability 50 | y_pred = K.clip(y_pred, eps, 1-eps) 51 | logits = K.log(y_pred/(1-y_pred)) 52 | logits = tf.reshape(logits, (-1,)) 53 | y_true = tf.reshape(y_true, (-1,)) 54 | y_true = tf.cast(y_true, logits.dtype) 55 | signs = 2. * y_true - 1. 56 | errors = 1. - logits * tf.stop_gradient(signs) 57 | errors_sorted, perm = tf.nn.top_k(errors, k=tf.shape(errors)[0], 58 | name="descending_sort") 59 | gt_sorted = tf.gather(y_true, perm) 60 | grad = tf_lovasz_grad(gt_sorted) 61 | loss = tf.tensordot(tf.nn.relu(errors_sorted), 62 | tf.stop_gradient(grad), 63 | 1, name="loss_non_void") 64 | return loss 65 | 66 | def lovasz_hinge_per_image(y_true, y_pred): 67 | # modified from Maxim Berman's GitHub repo tf implementation for Lovasz 68 | losses = tf.map_fn(_treat_image, (y_true, y_pred), dtype=tf.float32) 69 | loss = tf.reduce_mean(losses) 70 | return loss 71 | 72 | def _treat_image(ytrue_ypred): 73 | y_true, y_pred = ytrue_ypred 74 | y_true, y_pred = tf.expand_dims(y_true, 0), tf.expand_dims(y_pred, 0) 75 | return lovasz_hinge_flat(y_true, y_pred) 76 | 77 | if per_image: 78 | return lovasz_hinge_per_image 79 | else: 80 | return lovasz_hinge_flat 81 | 82 | 83 | def tf_lovasz_grad(gt_sorted): 84 | """ 85 | Code from Maxim Berman's GitHub repo for Lovasz. 86 | 87 | Computes gradient of the Lovasz extension w.r.t sorted errors 88 | See Alg. 1 in paper 89 | """ 90 | gts = tf.reduce_sum(gt_sorted) 91 | intersection = gts - tf.cumsum(gt_sorted) 92 | union = gts + tf.cumsum(1. - gt_sorted) 93 | jaccard = 1. - intersection / union 94 | jaccard = tf.concat((jaccard[0:1], jaccard[1:] - jaccard[:-1]), 0) 95 | return jaccard 96 | 97 | 98 | # matching dicts to get the right loss function based on the config file 99 | keras_losses = { 100 | 'binary_crossentropy': losses.binary_crossentropy, 101 | 'bce': losses.binary_crossentropy, 102 | 'categorical_crossentropy': losses.categorical_crossentropy, 103 | 'cce': losses.categorical_crossentropy, 104 | 'cosine': losses.cosine_similarity, 105 | 'hinge': losses.hinge, 106 | 'kullback_leibler_divergence': losses.kullback_leibler_divergence, 107 | 'kld': losses.kullback_leibler_divergence, 108 | 'mean_absolute_error': losses.mean_absolute_error, 109 | 'mae': losses.mean_absolute_error, 110 | 'mean_squared_logarithmic_error': losses.mean_squared_logarithmic_error, 111 | 'msle': losses.mean_squared_logarithmic_error, 112 | 'mean_squared_error': losses.mean_squared_error, 113 | 'mse': losses.mean_squared_error, 114 | 'sparse_categorical_crossentropy': losses.sparse_categorical_crossentropy, 115 | 'squared_hinge': losses.squared_hinge, 116 | 'jaccard': k_jaccard_loss, 117 | 'dice': k_dice_loss 118 | } 119 | 120 | 121 | def k_weighted_bce(y_true, y_pred, weight): 122 | """Weighted binary cross-entropy for Keras. 123 | 124 | Arguments: 125 | ---------- 126 | y_true : ``tf.Tensor`` 127 | passed silently by Keras during model training. 128 | y_pred : ``tf.Tensor`` 129 | passed silently by Keras during model training. 130 | weight : :class:`float` or :class:`int` 131 | Weight to assign to mask foreground pixels. Use values 132 | >1 to over-weight foreground or 0 1: 174 | class_two = class_two*(weight-1) 175 | final_mask = weight_mask + class_two # foreground pixels weighted 176 | return K.binary_crossentropy(y_pred, y_true) * final_mask 177 | 178 | 179 | def k_layered_weighted_bce(y_true, y_pred, weights): 180 | """Binary cross-entropy function with different weights for mask channels. 181 | 182 | Arguments: 183 | ---------- 184 | y_true (tensor): passed silently by Keras during model training. 185 | y_pred (tensor): passed silently by Keras during model training. 186 | weights (list-like): Weights to assign to mask foreground pixels for each 187 | channel in the 3rd axis of the mask. 188 | 189 | Returns: 190 | -------- 191 | The binary cross-entropy loss function output multiplied by a weighting 192 | mask. 193 | 194 | Usage: 195 | ------ 196 | See implementation instructions for `weighted_bce`. 197 | 198 | This loss function is intended to allow different weighting of different 199 | segmentation outputs - for example, if a model outputs a 3D image mask, 200 | where the first channel corresponds to foreground objects and the second 201 | channel corresponds to object edges. `weights` must be a list of length 202 | equal to the depth of the output mask. The output mask's "z-axis" 203 | corresponding to the mask channel must be the third axis in the output 204 | array. 205 | 206 | """ 207 | weight_mask = K.ones_like(y_true) 208 | submask_list = [] 209 | for i in range(len(weights)): 210 | class_two = K.equal(y_true[:, :, :, i], weight_mask[:, :, :, i]) 211 | class_two = K.cast(class_two, 'float32') 212 | if weights[i] < 1: 213 | class_two = class_two*(1-weights[i]) 214 | layer_mask = weight_mask[:, :, :, i] - class_two 215 | elif weights[i] > 1: 216 | class_two = class_two*(weights[i]-1) 217 | layer_mask = weight_mask[:, :, :, i] + class_two 218 | else: 219 | layer_mask = weight_mask[:, :, :, i] 220 | submask_list.append(layer_mask) 221 | final_mask = K.stack(submask_list, axis=-1) 222 | return K.binary_crossentropy(y_pred, y_true) * final_mask 223 | -------------------------------------------------------------------------------- /libs/solaris/nets/callbacks.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tensorflow import keras 3 | from tensorflow.keras.callbacks import Callback 4 | from .torch_callbacks import torch_callback_dict 5 | import torch 6 | 7 | 8 | def get_callbacks(framework, config): 9 | """Load callbacks based on a config file for a specific framework. 10 | 11 | Usage 12 | ----- 13 | Note that this function is primarily intended for use with Keras. PyTorch 14 | does not use the same object-oriented training approach as Keras, and 15 | therefore doesn't generally have the same checkpointing objects to pass to 16 | model compilers - instead these are defined in model training code. See 17 | solaris.nets.train for examples of this. The only torch callback 18 | instantiated here is a learning rate scheduler. 19 | 20 | Arguments 21 | --------- 22 | framework : str 23 | Deep learning framework used for the model. Options are 24 | ``['keras', 'torch']`` . 25 | config : dict 26 | Configuration dict generated from the YAML config file. 27 | 28 | Returns 29 | ------- 30 | callbacks : list 31 | A `list` of callbacks to pass to the compiler (Keras) or to wrap the 32 | optimizer (torch learning rate scheduling) for model training. 33 | """ 34 | 35 | callbacks = [] 36 | 37 | if framework == 'keras': 38 | for callback, params in config['training']['callbacks'].items(): 39 | if callback == 'lr_schedule': 40 | callbacks.append(get_lr_schedule(framework, config)) 41 | else: 42 | callbacks.append(keras_callbacks[callback](**params)) 43 | elif framework == 'torch': 44 | for callback, params in config['training']['callbacks'].items(): 45 | if callback == 'lr_schedule': 46 | callbacks.append(get_lr_schedule(framework, config)) 47 | else: 48 | callbacks.append(torch_callback_dict[callback](**params)) 49 | 50 | return callbacks 51 | 52 | 53 | class KerasTerminateOnMetricNaN(Callback): 54 | """Callback to stop training if a metric has value NaN or infinity. 55 | 56 | Notes 57 | ----- 58 | Instantiate as you would any other keras callback. For example, to end 59 | training if a validation metric called `f1_score` reaches value NaN:: 60 | 61 | m = Model(inputs, outputs) 62 | m.compile() 63 | m.fit(X, y, callbacks=[TerminateOnMetricNaN('val_f1_score')]) 64 | 65 | 66 | Attributes 67 | ---------- 68 | metric : str, optional 69 | Name of the metric being monitored. 70 | checkpoint : str, optional 71 | One of ``['epoch', 'batch']``: Should the metric be checked at the end 72 | of every epoch (default) or every batch? 73 | 74 | Methods 75 | ------- 76 | on_epoch_end : operations to complete at the end of each epoch. 77 | on_batch_end : operations to complete at the end of each batch. 78 | """ 79 | 80 | def __init__(self, metric=None, checkpoint='epoch'): 81 | """ 82 | 83 | Parameters 84 | ---------- 85 | metric (str): The name of the metric to be tested for NaN value. 86 | checkpoint (['epoch', 'batch']): Should the metric be checked at the end of 87 | every epoch (default) or every batch? 88 | 89 | """ 90 | super(KerasTerminateOnMetricNaN, self).__init__() 91 | self.metric = metric 92 | self.ckpt = checkpoint 93 | 94 | def on_epoch_end(self, epoch, logs=None): 95 | if self.ckpt == 'epoch': 96 | logs = logs or {} 97 | metric_score = logs.get(self.metric) 98 | if self.metric is not None: 99 | if np.isnan(metric_score) or np.isinf(metric_score): 100 | print('Epoch {}: Invalid score for metric {}, terminating' 101 | ' training'.format(epoch, self.metric)) 102 | self.model.stop_training = True 103 | 104 | def on_batch_end(self, batch, logs=None): 105 | if self.ckpt == 'batch': 106 | logs = logs or {} 107 | metric_score = logs.get(self.metric) 108 | print('metric score: {}'.format(metric_score)) 109 | if np.isnan(metric_score) or np.isinf(metric_score): 110 | print('Batch {}: Invalid score for metric' 111 | '{}, terminating training'.format(batch, self.metric)) 112 | self.model.stop_training = True 113 | 114 | 115 | keras_callbacks = { 116 | 'terminate_on_nan': keras.callbacks.TerminateOnNaN, 117 | 'terminate_on_metric_nan': KerasTerminateOnMetricNaN, 118 | 'model_checkpoint': keras.callbacks.ModelCheckpoint, 119 | 'early_stopping': keras.callbacks.EarlyStopping, 120 | 'reduce_lr_on_plateau': keras.callbacks.ReduceLROnPlateau, 121 | 'csv_logger': keras.callbacks.CSVLogger 122 | } 123 | 124 | 125 | def get_lr_schedule(framework, config): 126 | """Get a LR scheduling function for model training. 127 | 128 | Arguments 129 | --------- 130 | framework : str 131 | Deep learning framework used for the model. Options are 132 | ``['keras', 'torch']`` . 133 | config : dict 134 | Configuration dict generated from the YAML config file. 135 | 136 | Returns 137 | ------- 138 | lr_scheduler : :class:`tensorflow.keras.callbacks.LearningRateScheduler` or 139 | ``torch.optim.lr_schedule`` scheduler class 140 | A scheduler to provide during training. For Keras, this takes the form 141 | of a callback passed to the optimizer; for PyTorch, it's a class object 142 | that wraps the optimizer. Because the torch version must wrap the 143 | optimizer, it's not instantiated here - the class is returned instead. 144 | 145 | """ 146 | 147 | schedule_type = config['training'][ 148 | 'callbacks']['lr_schedule']['schedule_type'] 149 | initial_lr = config['training']['lr'] 150 | update_frequency = config['training']['callbacks']['lr_schedule'].get( 151 | 'update_frequency', 1) 152 | factor = config['training']['callbacks']['lr_schedule'].get( 153 | 'factor', 0) 154 | schedule_dict = config['training']['callbacks']['lr_schedule'].get( 155 | 'schedule_dict', None) 156 | if framework == 'keras': 157 | lr_sched_func = keras_lr_schedule(schedule_type, initial_lr, 158 | update_frequency, factor, 159 | schedule_dict) 160 | lr_scheduler = keras.callbacks.LearningRateScheduler(lr_sched_func) 161 | elif framework == 'torch': 162 | # just get the class itself to use; don't instantiate until the 163 | # optimizer has been created. 164 | if config['training'][ 165 | 'callbacks']['lr_schedule']['schedule_type'] == 'linear': 166 | lr_scheduler = torch.optim.lr_scheduler.StepLR 167 | elif config['training'][ 168 | 'callbacks']['lr_schedule']['schedule_type'] == 'exponential': 169 | lr_scheduler = torch.optim.lr_scheduler.ExponentialLR 170 | elif config['training'][ 171 | 'callbacks']['lr_schedule']['schedule_type'] == 'arbitrary': 172 | lr_scheduler = torch.optim.lr_scheduler.MultiStepLR 173 | 174 | return lr_scheduler 175 | 176 | 177 | def keras_lr_schedule(schedule_type, initial_lr=0.001, update_frequency=1, 178 | factor=0, schedule_dict=None): 179 | """Create a learning rate schedule for Keras from a schedule dict. 180 | 181 | Arguments 182 | --------- 183 | schedule_type : str 184 | Type of learning rate schedule to use. Options are: 185 | ``['arbitrary', 'exponential', 'linear']`` . 186 | initial_lr : float, optional 187 | The initial learning rate to use. Defaults to ``0.001`` . 188 | update_frequency : int, optional 189 | How frequently should learning rate be reduced (or increased)? Defaults 190 | to ``1`` (every epoch). Has no effect if ``schedule_type='arbitrary'``. 191 | factor : float, optional 192 | The magnitude by which learning rate should be changed at each update. 193 | Use a positive number to increase learning rate and a negative number 194 | to decrease learning rate. See Usage for more details. 195 | schedule_dict : dict, optional 196 | A dictionary with ``{epoch: learning rate}`` pairs. The learning rate 197 | defined in each pair will be used beginning at the specified epoch and 198 | continuing until the next highest epoch number is reached during 199 | training. 200 | 201 | Returns 202 | ------- 203 | lr_schedule : func 204 | a function that takes epoch number integers as an argument and returns 205 | a learning rate. 206 | 207 | Usage 208 | ----- 209 | ``schedule_type='arbitrary'`` usage is documented in the arguments above. 210 | For ``schedule_type='exponential'``, the following equation is applied to 211 | determine the learning rate at each epoch: 212 | 213 | .. math:: 214 | 215 | lr = initial_lr*e^{factor\times(floor(epoch/update_frequency))} 216 | 217 | For ``schedule_type='linear'``, the following equation is applied: 218 | 219 | .. math:: 220 | 221 | lr = initial_lr\times(1+factor\times(floor(epoch/update_frequency))) 222 | 223 | """ 224 | if schedule_type == 'arbitrary': 225 | if schedule_dict is None: 226 | raise ValueError('If using an arbitrary schedule, an epoch: lr ' 227 | 'dict must be provided.') 228 | lookup_dict = {} 229 | epoch_vals = np.array(list(schedule_dict.keys())) 230 | for e in range(0, epoch_vals.max() + 1): 231 | if e < epoch_vals.min(): 232 | lookup_dict[e] = schedule_dict[epoch_vals.min()] 233 | else: 234 | # get all the epochs from the dict <= e 235 | lower_epochs = epoch_vals[epoch_vals <= e] 236 | # get the LR for the highest epoch number <= e 237 | lookup_dict[e] = schedule_dict[lower_epochs.max()] 238 | 239 | def lr_schedule(epoch): 240 | if epoch < epoch_vals.min(): 241 | return initial_lr 242 | elif epoch > epoch_vals.max(): 243 | return lookup_dict[epoch_vals.max()] 244 | else: 245 | return lookup_dict[epoch] 246 | 247 | elif schedule_type == 'exponential': 248 | def lr_schedule(epoch): 249 | if not np.floor(epoch/update_frequency): 250 | return initial_lr 251 | else: 252 | return initial_lr*factor/np.floor(epoch/update_frequency) 253 | 254 | elif schedule_type == 'linear': 255 | def lr_schedule(epoch): 256 | return initial_lr*(1+factor*np.floor(epoch/update_frequency)) 257 | 258 | return lr_schedule 259 | -------------------------------------------------------------------------------- /libs/solaris/nets/configs/config_skeleton.yml: -------------------------------------------------------------------------------- 1 | ################################################################################ 2 | ################# SOLARIS MODEL CONFIGURATION SKELETON ######################### 3 | ################################################################################ 4 | 5 | # This skeleton lays out the required instructions for running a model using 6 | # solaris. See the full documentation at [INCLUDE DOC LINK HERE] for details on 7 | # options, required arguments, and sample usage. 8 | 9 | model_name: # include the name of the model to be used here. See the docs 10 | # for options. 11 | model_path: # leave this blank unless you're using a custom model not 12 | # native to solaris. solaris will automatically find your 13 | # model. 14 | train: true # set to false for inference only 15 | infer: true # set to false for training only 16 | 17 | pretrained: true # use pretrained weights associated with the model? 18 | nn_framework: # if not using a model included with the package, use this 19 | # argument to specify if it uses keras, pytorch, tf, etc. 20 | batch_size: # size of each batch fed into nn. 21 | 22 | data_specs: 23 | width: # width of the input images taken in by the neural net. 24 | height: # height of the input images taken in by the neural net. 25 | dtype: # dtype of the inputs ingested by the neural net. 26 | rescale: false # should image pixel values be rescaled before pre-processing? 27 | # If so, the image will be rescaled to the pixel range defined 28 | # by rescale_min and rescale_max below. 29 | rescale_minima: auto # the minimum values to use in rescaling (if 30 | # rescale=true). If 'auto', the minimum pixel intensity 31 | # in each channel of the image will be subtracted. If 32 | # a single value is provided, that value will be set to 33 | # zero for each channel in the input image. 34 | # if a list of values are provided, then minima in the 35 | # separate channels (in that order) will be set to that 36 | # value PRIOR to any augmentation. 37 | rescale_maxima: auto # same as rescale_minima, but for the maximum value for 38 | # each channel in the image. 39 | channels: # number of channels in the input imagery. 40 | label_type: mask # one of ['mask', 'bbox'] (CURRENTLY ONLY MASK IMPLEMENTED) 41 | is_categorical: false # are the labels binary (default) or categorical? 42 | mask_channels: 1 # number of channels in the training mask 43 | val_holdout_frac: # if empty, assumes that separate data ref files define the 44 | # training and validation dataset. If a float between 0 and 45 | # 1, indicates the fraction of training data that's held 46 | # out for validation (and validation_data_csv will be 47 | # ignored) 48 | data_workers: # number of cpu threads to use for loading and preprocessing 49 | # input images. 50 | # other_inputs: # this can provide a list of additional inputs to pass to the 51 | # neural net for training. These inputs should be specified in 52 | # extra columns of the csv files (denoted below), either as 53 | # filepaths to additional data to load or as values to include. 54 | # NOTE: This is not currently implemented. 55 | 56 | training_data_csv: # path to the reference csv that defines training data. 57 | # see the documentation for the specifications of this file. 58 | validation_data_csv: # path to the validation ref csv. See the docs. If 59 | # val_holdout_frac is specified (under data_specs), then 60 | # this argument will be ignored. 61 | inference_data_csv: # path to the reference csv that defines inference data. 62 | # see the documentation for the specs of this file. 63 | 64 | training_augmentation: # augmentations for use with training data 65 | augmentations: 66 | # include augmentations here. See the documentation for options and 67 | # required arguments. 68 | p: 1.0 # probability of applying the entire training augmentation pipeline. 69 | shuffle: true # should the image order be shuffled in each epoch. 70 | validation_augmentation: # augmentations for use with validation data 71 | augmentations: 72 | # include augmentations here 73 | p: # probability of applying the full validation augmentation pipeline. 74 | inference_augmentation: # this is optional. If not provided, 75 | # validation_augmentation will be used instead. 76 | 77 | training: 78 | epochs: # number of epochs. A list can also be provided here indicating 79 | # distinct sets of epochs at different learning rates, etc; if so, 80 | # a list of equal length must be provided in the parameter that varies 81 | # with the values for each set of epochs. 82 | steps_per_epoch: # optional argument defining # steps/epoch. If not provided, 83 | # each epoch will include the number of steps needed to go 84 | # through the entire training dataset. 85 | optimizer: # optimizer function name. see docs for options. 86 | lr: # learning rate. 87 | opt_args: # dictionary of values (e.g. alpha, gamma, momentum) specific to 88 | # the optimizer. 89 | loss: # loss function(s). See docs for options. This should be a list of loss 90 | # names with sublists of loss function hyperparameters (if applicable). 91 | # See the docs for more details and usage examples. 92 | loss_weights: # (optional) weights to use for each loss function if using 93 | # loss: composite. This must be a set of key:value pairs where 94 | # defining the weight for each sub-loss within the composite. 95 | # If using a composite and a value isn't provided here, all 96 | # losses will be weighted equally. 97 | metrics: # metrics to monitor on the training and validation sets. 98 | training: # training set metrics. 99 | validation: # validation set metrics. 100 | checkpoint_frequency: # how frequently should checkpoints be saved? 101 | # this can be an int, in which case every n epochs 102 | # a checkpoint will be made, or a list, in which case 103 | # checkpoints will be saved on those specific epochs. 104 | # if not provided, only the final model is saved. 105 | callbacks: # a list of callbacks to use. 106 | model_dest_path: # path to save the trained model output and checkpoint(s) 107 | # to. Should be a filename ending in .h5, .hdf5 for keras 108 | # or .pth, .pt for torch. Epoch numbers will be appended 109 | # for checkpoints. 110 | verbose: true # verbose text output during training 111 | 112 | inference: 113 | window_step_size_x: # size of each step for the sliding window for inference. 114 | # set to the same size as the input image size for zero 115 | # overlap; to average predictions across multiple images, 116 | # use a smaller step size. 117 | window_step_size_y: # size of each step for the sliding window for inference. 118 | # set to the same size as the input image size for zero 119 | # overlap; to average predictions across multiple images, 120 | # use a smaller step size. 121 | stitching_method: # the method to use to stitch together tiles used during 122 | # inference. defaults to average if not provided. see 123 | # the documentation for sol.raster.image.stitch_images() 124 | # for more. 125 | output_dir: inference_out # the path to save inference outputs to. 126 | -------------------------------------------------------------------------------- /libs/solaris/nets/configs/selimsef_densenet121unet_spacenet4.yml: -------------------------------------------------------------------------------- 1 | model_name: selimsef_spacenet4_densenet121unet 2 | 3 | model_path: 4 | train: true 5 | infer: false 6 | 7 | pretrained: true 8 | nn_framework: torch 9 | batch_size: 32 10 | 11 | data_specs: 12 | width: 384 13 | height: 384 14 | dtype: 15 | image_type: zscore 16 | rescale: false 17 | rescale_minima: auto 18 | rescale_maxima: auto 19 | additional_inputs: 20 | channels: 4 21 | label_type: mask 22 | is_categorical: false 23 | mask_channels: 3 24 | val_holdout_frac: 0.2 25 | data_workers: 26 | 27 | training_data_csv: '/path/to/training_df.csv' 28 | validation_data_csv: 29 | inference_data_csv: '/path/to/test_df.csv' 30 | 31 | training_augmentation: 32 | augmentations: 33 | RandomScale: 34 | scale_limit: 35 | - 0.5 36 | - 1.5 37 | interpolation: nearest 38 | Rotate: 39 | limit: 40 | - 5 41 | - 6 42 | border_mode: constant 43 | p: 0.3 44 | RandomCrop: 45 | height: 416 46 | width: 416 47 | always_apply: true 48 | p: 1.0 49 | Normalize: 50 | mean: 51 | - 0.006479 52 | - 0.009328 53 | - 0.01123 54 | - 0.02082 55 | std: 56 | - 0.004986 57 | - 0.004964 58 | - 0.004950 59 | - 0.004878 60 | max_pixel_value: 65535.0 61 | p: 1.0 62 | p: 1.0 63 | shuffle: true 64 | 65 | validation_augmentation: 66 | augmentations: 67 | CenterCrop: 68 | height: 384 69 | width: 384 70 | p: 1.0 71 | Normalize: 72 | mean: 73 | - 0.006479 74 | - 0.009328 75 | - 0.01123 76 | - 0.02082 77 | std: 78 | - 0.004986 79 | - 0.004964 80 | - 0.004950 81 | - 0.004878 82 | max_pixel_value: 65535.0 83 | p: 1.0 84 | p: 1.0 85 | 86 | inference_augmentation: 87 | augmentations: 88 | Normalize: 89 | mean: 90 | - 0.006479 91 | - 0.009328 92 | - 0.01123 93 | - 0.02082 94 | std: 95 | - 0.004986 96 | - 0.004964 97 | - 0.004950 98 | - 0.004878 99 | max_pixel_value: 65535.0 100 | p: 1.0 101 | p: 1.0 102 | training: 103 | epochs: 70 104 | steps_per_epoch: 105 | optimizer: AdamW 106 | lr: 2e-4 107 | opt_args: 108 | weight_decay: 0.0001 109 | loss: 110 | focal: 111 | dice: 112 | loss_weights: 113 | focal: 1 114 | dice: 1 115 | metrics: 116 | training: 117 | validation: 118 | checkpoint_frequency: 10 119 | callbacks: 120 | lr_schedule: 121 | schedule_type: 'arbitrary' 122 | schedule_dict: 123 | milestones: 124 | - 1 125 | - 5 126 | - 15 127 | - 30 128 | - 50 129 | - 60 130 | gamma: 0.5 131 | model_checkpoint: 132 | filepath: 'selimsef_densenet121_best.pth' 133 | monitor: val_loss 134 | model_dest_path: 'selimsef_densenet121.pth' 135 | verbose: true 136 | 137 | inference: 138 | window_step_size_x: 139 | window_step_size_y: 140 | output_dir: 'inference_out' 141 | -------------------------------------------------------------------------------- /libs/solaris/nets/configs/selimsef_densenet161unet_spacenet4.yml: -------------------------------------------------------------------------------- 1 | model_name: selimsef_spacenet4_densenet161unet 2 | 3 | model_path: 4 | train: true 5 | infer: false 6 | 7 | pretrained: true 8 | nn_framework: torch 9 | batch_size: 20 10 | 11 | data_specs: 12 | width: 384 13 | height: 384 14 | dtype: 15 | image_type: zscore 16 | rescale: false 17 | rescale_minima: auto 18 | rescale_maxima: auto 19 | additional_inputs: 20 | channels: 4 21 | label_type: mask 22 | is_categorical: false 23 | mask_channels: 3 24 | val_holdout_frac: 0.2 25 | data_workers: 26 | 27 | training_data_csv: '/path/to/training_df.csv' 28 | validation_data_csv: 29 | inference_data_csv: '/path/to/test_df.csv' 30 | 31 | training_augmentation: 32 | augmentations: 33 | RandomScale: 34 | scale_limit: 35 | - 0.5 36 | - 1.5 37 | interpolation: nearest 38 | Rotate: 39 | limit: 40 | - 5 41 | - 6 42 | border_mode: constant 43 | p: 0.3 44 | RandomCrop: 45 | height: 416 46 | width: 416 47 | always_apply: true 48 | p: 1.0 49 | Normalize: 50 | mean: 51 | - 0.006479 52 | - 0.009328 53 | - 0.01123 54 | - 0.02082 55 | std: 56 | - 0.004986 57 | - 0.004964 58 | - 0.004950 59 | - 0.004878 60 | max_pixel_value: 65535.0 61 | p: 1.0 62 | p: 1.0 63 | shuffle: true 64 | 65 | validation_augmentation: 66 | augmentations: 67 | CenterCrop: 68 | height: 384 69 | width: 384 70 | p: 1.0 71 | Normalize: 72 | mean: 73 | - 0.006479 74 | - 0.009328 75 | - 0.01123 76 | - 0.02082 77 | std: 78 | - 0.004986 79 | - 0.004964 80 | - 0.004950 81 | - 0.004878 82 | max_pixel_value: 65535.0 83 | p: 1.0 84 | p: 1.0 85 | 86 | inference_augmentation: 87 | augmentations: 88 | Normalize: 89 | mean: 90 | - 0.006479 91 | - 0.009328 92 | - 0.01123 93 | - 0.02082 94 | std: 95 | - 0.004986 96 | - 0.004964 97 | - 0.004950 98 | - 0.004878 99 | max_pixel_value: 65535.0 100 | p: 1.0 101 | p: 1.0 102 | training: 103 | epochs: 60 104 | steps_per_epoch: 105 | optimizer: AdamW 106 | lr: 2e-4 107 | opt_args: 108 | weight_decay: 0.0001 109 | loss: 110 | focal: 111 | dice: 112 | loss_weights: 113 | focal: 1 114 | dice: 1 115 | metrics: 116 | training: 117 | validation: 118 | checkpoint_frequency: 10 119 | callbacks: 120 | lr_schedule: 121 | schedule_type: 'arbitrary' 122 | schedule_dict: 123 | milestones: 124 | - 1 125 | - 5 126 | - 15 127 | - 30 128 | - 45 129 | - 55 130 | gamma: 0.5 131 | model_checkpoint: 132 | filepath: 'selimsef_densenet161_best.pth' 133 | monitor: val_loss 134 | model_dest_path: 'selimsef_densenet161.pth' 135 | verbose: true 136 | 137 | inference: 138 | window_step_size_x: 139 | window_step_size_y: 140 | output_dir: 'inference_out' 141 | -------------------------------------------------------------------------------- /libs/solaris/nets/configs/selimsef_resnet34unet_spacenet4.yml: -------------------------------------------------------------------------------- 1 | model_name: selimsef_spacenet4_resnet34unet 2 | 3 | model_path: 4 | train: false 5 | infer: true 6 | 7 | pretrained: true 8 | nn_framework: torch 9 | batch_size: 42 10 | 11 | data_specs: 12 | width: 384 13 | height: 384 14 | dtype: 15 | image_type: zscore 16 | rescale: false 17 | rescale_minima: auto 18 | rescale_maxima: auto 19 | additional_inputs: 20 | channels: 4 21 | label_type: mask 22 | is_categorical: false 23 | mask_channels: 3 24 | val_holdout_frac: 0.2 25 | data_workers: 12 26 | 27 | training_data_csv: '/path/to/training_df.csv' 28 | validation_data_csv: 29 | inference_data_csv: '/path/to/test_df.csv' 30 | 31 | training_augmentation: 32 | augmentations: 33 | RandomScale: 34 | scale_limit: 35 | - 0.5 36 | - 1.5 37 | interpolation: nearest 38 | Rotate: 39 | limit: 40 | - 5 41 | - 6 42 | border_mode: constant 43 | p: 0.3 44 | RandomCrop: 45 | height: 416 46 | width: 416 47 | always_apply: true 48 | p: 1.0 49 | Normalize: 50 | mean: 51 | - 0.006479 52 | - 0.009328 53 | - 0.01123 54 | - 0.02082 55 | std: 56 | - 0.004986 57 | - 0.004964 58 | - 0.004950 59 | - 0.004878 60 | max_pixel_value: 65535.0 61 | p: 1.0 62 | p: 1.0 63 | shuffle: true 64 | 65 | validation_augmentation: 66 | augmentations: 67 | CenterCrop: 68 | height: 384 69 | width: 384 70 | p: 1.0 71 | Normalize: 72 | mean: 73 | - 0.006479 74 | - 0.009328 75 | - 0.01123 76 | - 0.02082 77 | std: 78 | - 0.004986 79 | - 0.004964 80 | - 0.004950 81 | - 0.004878 82 | max_pixel_value: 65535.0 83 | p: 1.0 84 | p: 1.0 85 | 86 | inference_augmentation: 87 | augmentations: 88 | Normalize: 89 | mean: 90 | - 0.006479 91 | - 0.009328 92 | - 0.01123 93 | - 0.02082 94 | std: 95 | - 0.004986 96 | - 0.004964 97 | - 0.004950 98 | - 0.004878 99 | max_pixel_value: 65535.0 100 | p: 1.0 101 | p: 1.0 102 | training: 103 | epochs: 70 104 | steps_per_epoch: 105 | optimizer: AdamW 106 | lr: 2e-4 107 | opt_args: 108 | weight_decay: 0.0001 109 | loss: 110 | focal: 111 | dice: 112 | loss_weights: 113 | focal: 1 114 | dice: 1 115 | metrics: 116 | training: 117 | validation: 118 | checkpoint_frequency: 10 119 | callbacks: 120 | lr_schedule: 121 | schedule_type: 'arbitrary' 122 | schedule_dict: 123 | milestones: 124 | - 1 125 | - 5 126 | - 15 127 | - 30 128 | - 50 129 | - 60 130 | gamma: 0.5 131 | model_checkpoint: 132 | filepath: 'selimsef_resnet34_best.pth' 133 | monitor: val_loss 134 | model_dest_path: 'selimsef_resnet34.pth' 135 | verbose: true 136 | 137 | inference: 138 | window_step_size_x: 139 | window_step_size_y: 140 | output_dir: 'inference_out/' 141 | -------------------------------------------------------------------------------- /libs/solaris/nets/configs/selimsef_scse50unet_spacenet4.yml: -------------------------------------------------------------------------------- 1 | ################################################################################ 2 | ################# SOLARIS MODEL CONFIGURATION SKELETON ######################### 3 | ################################################################################ 4 | 5 | # This skeleton lays out the required instructions for running a model using 6 | # solaris. See the full documentation at [INCLUDE DOC LINK HERE] for details on 7 | # options, required arguments, and sample usage. 8 | 9 | model_name: selimsef_scse50unet_spacenet4 10 | 11 | model_path: # leave this blank unless you're using a custom model not 12 | # native to solaris. solaris will automatically find your 13 | # model. 14 | train: true # set to false for inference only 15 | infer: false # set to false for training only 16 | 17 | pretrained: false # use pretrained weights associated with the model? 18 | nn_framework: torch 19 | batch_size: 8 20 | 21 | data_specs: 22 | width: 384 23 | height: 384 24 | image_type: zscore # format of images read into the neural net. options 25 | # are 'normalized', 'zscore', '8bit', '16bit'. 26 | rescale: false # should image pixel values be rescaled before pre-processing? 27 | # If so, the image will be rescaled to the pixel range defined 28 | # by rescale_min and rescale_max below. 29 | rescale_minima: auto # the minimum values to use in rescaling (if 30 | # rescale=true). If 'auto', the minimum pixel intensity 31 | # in each channel of the image will be subtracted. If 32 | # a single value is provided, that value will be set to 33 | # zero for each channel in the input image. 34 | # if a list of values are provided, then minima in the 35 | # separate channels (in that order) will be set to that 36 | # value PRIOR to any augmentation. 37 | rescale_maxima: auto # same as rescale_minima, but for the maximum value for 38 | # each channel in the image. 39 | additional_inputs: # a list of additional columns in the training CSV (and 40 | - angle # validation CSV if applicable) That will be passed to 41 | # the model. Those values will not be augmented. 42 | # This list MUST be in the same order as the additional 43 | # input values are expected by the model. 44 | channels: 4 # number of channels in the input imagery. 45 | label_type: mask # one of ['mask', 'bbox'] 46 | is_categorical: false # are the labels binary (default) or categorical? 47 | mask_channels: 3 # number of channels in the training mask 48 | val_holdout_frac: 0.2 # if empty, assumes that separate data ref files define the 49 | # training and validation dataset. If a float between 0 and 50 | # 1, indicates the fraction of training data that's held 51 | # out for validation (and validation_data_csv will be 52 | # ignored) 53 | data_workers: # number of cpu threads to use for loading and preprocessing 54 | # input images. 55 | # other_inputs: # this can provide a list of additional inputs to pass to the 56 | # neural net for training. These inputs should be specified in 57 | # extra columns of the csv files (denoted below), either as 58 | # filepaths to additional data to load or as values to include. 59 | # NOTE: This is not currently implemented. 60 | 61 | training_data_csv: '/path/to/training_df.csv' 62 | validation_data_csv: 63 | inference_data_csv: '/path/to/test_df.csv' # TODO # path to the reference csv that defines inference data. 64 | # see the documentation for the specs of this file. 65 | 66 | training_augmentation: # augmentations for use with training data 67 | augmentations: 68 | RandomScale: 69 | scale_limit: 70 | - 0.5 71 | - 1.5 72 | interpolation: nearest 73 | RandomCrop: 74 | height: 384 75 | width: 384 76 | p: 1.0 77 | Rotate: 78 | limit: 79 | - 5 80 | - 6 81 | border_mode: constant 82 | p: 0.3 83 | Normalize: 84 | mean: 85 | - 0.006479 86 | - 0.009328 87 | - 0.01123 88 | - 0.02082 89 | std: 90 | - 0.004986 91 | - 0.004964 92 | - 0.004950 93 | - 0.004878 94 | max_pixel_value: 65535.0 95 | p: 1.0 96 | # include augmentations here. See the documentation for options and 97 | # required arguments. 98 | p: 1.0 # probability of applying the entire training augmentation pipeline. 99 | shuffle: true # should the image order be shuffled in each epoch. 100 | 101 | validation_augmentation: # augmentations for use with validation data 102 | augmentations: 103 | CenterCrop: 104 | height: 384 105 | width: 384 106 | p: 1.0 107 | Normalize: 108 | mean: 109 | - 0.006479 110 | - 0.009328 111 | - 0.01123 112 | - 0.02082 113 | std: 114 | - 0.004986 115 | - 0.004964 116 | - 0.004950 117 | - 0.004878 118 | max_pixel_value: 65535.0 119 | p: 1.0 120 | p: 1.0 121 | 122 | inference_augmentation: 123 | augmentations: 124 | Normalize: 125 | mean: 126 | - 0.006479 127 | - 0.009328 128 | - 0.01123 129 | - 0.02082 130 | std: 131 | - 0.004986 132 | - 0.004964 133 | - 0.004950 134 | - 0.004878 135 | max_pixel_value: 65535.0 136 | p: 1.0 137 | p: 1.0 138 | training: 139 | epochs: 52 # number of epochs. A list can also be provided here indicating 140 | # distinct sets of epochs at different learning rates, etc; if so, 141 | # a list of equal length must be provided in the parameter that varies 142 | # with the values for each set of epochs. 143 | steps_per_epoch: # optional argument defining # steps/epoch. If not provided, 144 | # each epoch will include the number of steps needed to go 145 | # through the entire training dataset. 146 | optimizer: AdamW # optimizer function name. see docs for options. 147 | lr: 2e-4 # learning rate. 148 | opt_args: # dictionary of values (e.g. alpha, gamma, momentum) specific to 149 | weight_decay: 0.0001 150 | loss: 151 | focal: # loss function(s). See docs for options. This should be a list of loss 152 | dice: # names with sublists of loss function hyperparameters (if applicable). 153 | # See the docs for more details and usage examples. 154 | loss_weights: 155 | focal: 1 # (optional) weights to use for each loss function if using 156 | dice: 1 # loss: composite. This must be a set of key:value pairs where 157 | # defining the weight for each sub-loss within the composite. 158 | # If using a composite and a value isn't provided here, all 159 | # losses will be weighted equally. 160 | metrics: # metrics to monitor on the training and validation sets. 161 | training: # training set metrics. 162 | validation: # validation set metrics. 163 | checkpoint_frequency: 10 # how frequently should checkpoints be saved? 164 | # this can be an int, in which case every n epochs 165 | # a checkpoint will be made, or a list, in which case 166 | # checkpoints will be saved on those specific epochs. 167 | # if not provided, only the final model is saved. 168 | callbacks: 169 | lr_schedule: 170 | schedule_type: 'arbitrary' 171 | schedule_dict: 172 | milestones: 173 | - 1 174 | - 5 175 | - 15 176 | - 30 177 | - 40 178 | - 50 179 | gamma: 0.5 180 | model_checkpoint: 181 | filepath: 'selimsef_best.pth' 182 | monitor: val_loss 183 | model_dest_path: 'selimsef.pth' # path to save the trained model output and checkpoint(s) 184 | # to. Should be a filename ending in .h5, .hdf5 for keras 185 | # or .pth, .pt for torch. Epoch numbers will be appended 186 | # for checkpoints. 187 | verbose: true # verbose text output during training 188 | 189 | inference: 190 | window_step_size_x: # size of each step for the sliding window for inference. 191 | # set to the same size as the input image size for zero 192 | # overlap; to average predictions across multiple images, 193 | # use a smaller step size. 194 | window_step_size_y: # size of each step for the sliding window for inference. 195 | # set to the same size as the input image size for zero 196 | # overlap; to average predictions across multiple images, 197 | # use a smaller step size. 198 | output_dir: 'inference_out/' 199 | -------------------------------------------------------------------------------- /libs/solaris/nets/configs/xdxd_spacenet4.yml: -------------------------------------------------------------------------------- 1 | model_name: xdxd_spacenet4 2 | 3 | model_path: 4 | train: false 5 | infer: true 6 | 7 | pretrained: true 8 | nn_framework: torch 9 | batch_size: 12 10 | 11 | data_specs: 12 | width: 512 13 | height: 512 14 | dtype: 15 | image_type: zscore 16 | rescale: false 17 | rescale_minima: auto 18 | rescale_maxima: auto 19 | channels: 4 20 | label_type: mask 21 | is_categorical: false 22 | mask_channels: 1 23 | val_holdout_frac: 0.2 24 | data_workers: 25 | 26 | training_data_csv: '/path/to/training_df.csv' 27 | validation_data_csv: 28 | inference_data_csv: '/path/to/test_df.csv' 29 | 30 | training_augmentation: 31 | augmentations: 32 | DropChannel: 33 | idx: 3 34 | axis: 2 35 | HorizontalFlip: 36 | p: 0.5 37 | RandomRotate90: 38 | p: 0.5 39 | RandomCrop: 40 | height: 512 41 | width: 512 42 | p: 1.0 43 | Normalize: 44 | mean: 45 | - 0.006479 46 | - 0.009328 47 | - 0.01123 48 | std: 49 | - 0.004986 50 | - 0.004964 51 | - 0.004950 52 | max_pixel_value: 65535.0 53 | p: 1.0 54 | p: 1.0 55 | shuffle: true 56 | validation_augmentation: 57 | augmentations: 58 | DropChannel: 59 | idx: 3 60 | axis: 2 61 | CenterCrop: 62 | height: 512 63 | width: 512 64 | p: 1.0 65 | Normalize: 66 | mean: 67 | - 0.006479 68 | - 0.009328 69 | - 0.01123 70 | std: 71 | - 0.004986 72 | - 0.004964 73 | - 0.004950 74 | max_pixel_value: 65535.0 75 | p: 1.0 76 | p: 1.0 77 | inference_augmentation: 78 | augmentations: 79 | DropChannel: 80 | idx: 3 81 | axis: 2 82 | p: 1.0 83 | Normalize: 84 | mean: 85 | - 0.006479 86 | - 0.009328 87 | - 0.01123 88 | std: 89 | - 0.004986 90 | - 0.004964 91 | - 0.004950 92 | max_pixel_value: 65535.0 93 | p: 1.0 94 | p: 1.0 95 | training: 96 | epochs: 60 97 | steps_per_epoch: 98 | optimizer: Adam 99 | lr: 1e-4 100 | opt_args: 101 | loss: 102 | bcewithlogits: 103 | jaccard: 104 | loss_weights: 105 | bcewithlogits: 10 106 | jaccard: 2.5 107 | metrics: 108 | training: 109 | validation: 110 | checkpoint_frequency: 10 111 | callbacks: 112 | model_checkpoint: 113 | filepath: 'xdxd_best.pth' 114 | monitor: val_loss 115 | model_dest_path: 'xdxd.pth' 116 | verbose: true 117 | 118 | inference: 119 | window_step_size_x: 120 | window_step_size_y: 121 | output_dir: 'inference_out/' 122 | -------------------------------------------------------------------------------- /libs/solaris/nets/infer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import skimage 4 | import torch 5 | from tqdm import tqdm 6 | from warnings import warn 7 | from .model_io import get_model 8 | from .transform import process_aug_dict 9 | from .datagen import InferenceTiler 10 | from ..raster.image import stitch_images 11 | from ..utils.core import get_data_paths 12 | 13 | 14 | class Inferer(object): 15 | """Object for training `solaris` models using PyTorch or Keras.""" 16 | 17 | def __init__(self, config, custom_model_dict=None): 18 | self.config = config 19 | self.batch_size = self.config['batch_size'] 20 | self.framework = self.config['nn_framework'] 21 | self.model_name = self.config['model_name'] 22 | # check if the model was trained as part of the same pipeline; if so, 23 | # use the output from that. If not, use the pre-trained model directly. 24 | if self.config['train']: 25 | warn('Because the configuration specifies both training and ' 26 | 'inference, solaris is switching the model weights path ' 27 | 'to the training output path.') 28 | self.model_path = self.config['training']['model_dest_path'] 29 | if custom_model_dict is not None: 30 | custom_model_dict['weight_path'] = self.config[ 31 | 'training']['model_dest_path'] 32 | else: 33 | self.model_path = self.config.get('model_path', None) 34 | self.model = get_model(self.model_name, self.framework, 35 | self.model_path, pretrained=True, 36 | custom_model_dict=custom_model_dict) 37 | self.window_step_x = self.config['inference'].get('window_step_size_x', 38 | None) 39 | self.window_step_y = self.config['inference'].get('window_step_size_y', 40 | None) 41 | if self.window_step_x is None: 42 | self.window_step_x = self.config['data_specs']['width'] 43 | if self.window_step_y is None: 44 | self.window_step_y = self.config['data_specs']['height'] 45 | self.stitching_method = self.config['inference'].get( 46 | 'stitching_method', 'average') 47 | self.output_dir = self.config['inference']['output_dir'] 48 | if not os.path.isdir(self.output_dir): 49 | os.makedirs(self.output_dir) 50 | 51 | def __call__(self, infer_df=None): 52 | """Run inference. 53 | 54 | Arguments 55 | --------- 56 | infer_df : :class:`pandas.DataFrame` or `str` 57 | A :class:`pandas.DataFrame` with a column, ``'image'``, specifying 58 | paths to images for inference. Alternatively, `infer_df` can be a 59 | path to a CSV file containing the same information. Defaults to 60 | ``None``, in which case the file path specified in the Inferer's 61 | configuration dict is used. 62 | 63 | """ 64 | 65 | if infer_df is None: 66 | infer_df = get_infer_df(self.config) 67 | 68 | inf_tiler = InferenceTiler( 69 | self.framework, 70 | width=self.config['data_specs']['width'], 71 | height=self.config['data_specs']['height'], 72 | x_step=self.window_step_x, 73 | y_step=self.window_step_y, 74 | augmentations=process_aug_dict( 75 | self.config['inference_augmentation']) 76 | ) 77 | for idx, im_path in tqdm(enumerate(infer_df['image']), 78 | total=len(infer_df['image'])): 79 | inf_input, idx_refs, ( 80 | src_im_height, src_im_width) = inf_tiler(im_path) 81 | 82 | if self.framework == 'keras': 83 | subarr_preds = self.model.predict(inf_input, 84 | batch_size=self.batch_size) 85 | 86 | elif self.framework in ['torch', 'pytorch']: 87 | with torch.no_grad(): 88 | self.model.eval() 89 | if torch.cuda.is_available(): 90 | device = torch.device('cuda') 91 | self.model = self.model.cuda() 92 | else: 93 | device = torch.device('cpu') 94 | inf_input = torch.from_numpy(inf_input).float().to(device) 95 | # add additional input data, if applicable 96 | if self.config['data_specs'].get('additional_inputs', 97 | None) is not None: 98 | inf_input = [inf_input] 99 | for i in self.config['data_specs']['additional_inputs']: 100 | inf_input.append( 101 | infer_df[i].iloc[idx].to(device)) 102 | 103 | # Revision: allow batch process to save Mem cost. 104 | subarr_preds_list = [] 105 | for batch_i in range(0, inf_input.shape[0], self.batch_size): 106 | if batch_i + self.batch_size <= inf_input.shape[0]: 107 | subarr_pred = self.model(inf_input[ 108 | batch_i:batch_i+self.batch_size, ... 109 | ]) 110 | else: 111 | subarr_pred = self.model(inf_input[ 112 | batch_i:, ... 113 | ]) 114 | subarr_preds_list.append(subarr_pred.cpu().data.numpy()) 115 | subarr_preds = np.concatenate(subarr_preds_list, axis=0) 116 | stitched_result = stitch_images(subarr_preds, 117 | idx_refs=idx_refs, 118 | out_width=src_im_width, 119 | out_height=src_im_height, 120 | method=self.stitching_method) 121 | skimage.io.imsave(os.path.join(self.output_dir, 122 | os.path.split(im_path)[1]), 123 | stitched_result, check_contrast=False) 124 | 125 | 126 | def get_infer_df(config): 127 | """Get the inference df based on the contents of ``config`` . 128 | 129 | This function uses the logic described in the documentation for the config 130 | file to determine where to find images to be used for inference. 131 | See the docs and the comments in solaris/data/config_skeleton.yml for 132 | details. 133 | 134 | Arguments 135 | --------- 136 | config : dict 137 | The loaded configuration dict for model training and/or inference. 138 | 139 | Returns 140 | ------- 141 | infer_df : :class:`dict` 142 | :class:`dict` containing at least one column: ``'image'`` . The values 143 | in this column correspond to the path to filenames to perform inference 144 | on. 145 | """ 146 | 147 | infer_df = get_data_paths(config['inference_data_csv'], infer=True) 148 | return infer_df 149 | -------------------------------------------------------------------------------- /libs/solaris/nets/losses.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tensorflow.keras import backend as K 3 | from ._keras_losses import keras_losses, k_focal_loss 4 | from ._torch_losses import torch_losses 5 | from torch import nn 6 | 7 | 8 | def get_loss(framework, loss, loss_weights=None, custom_losses=None): 9 | """Load a loss function based on a config file for the specified framework. 10 | 11 | Arguments 12 | --------- 13 | framework : string 14 | Which neural network framework to use. 15 | loss : dict 16 | Dictionary of loss functions to use. Each key is a loss function name, 17 | and each entry is a (possibly-empty) dictionary of hyperparameter-value 18 | pairs. 19 | loss_weights : dict, optional 20 | Optional dictionary of weights for loss functions. Each key is a loss 21 | function name (same as in the ``loss`` argument), and the corresponding 22 | entry is its weight. 23 | custom_losses : dict, optional 24 | Optional dictionary of Pytorch classes or Keras functions of any 25 | user-defined loss functions. Each key is a loss function name, and the 26 | corresponding entry is the Python object implementing that loss. 27 | """ 28 | # lots of exception handling here. TODO: Refactor. 29 | if not isinstance(loss, dict): 30 | raise TypeError('The loss description is formatted improperly.' 31 | ' See the docs for details.') 32 | if len(loss) > 1: 33 | 34 | # get the weights for each loss within the composite 35 | if loss_weights is None: 36 | # weight all losses equally 37 | weights = {k: 1 for k in loss.keys()} 38 | else: 39 | weights = loss_weights 40 | 41 | # check if sublosses dict and weights dict have the same keys 42 | if list(loss.keys()).sort() != list(weights.keys()).sort(): 43 | raise ValueError( 44 | 'The losses and weights must have the same name keys.') 45 | 46 | if framework == 'keras': 47 | return keras_composite_loss(loss, weights, custom_losses) 48 | elif framework in ['pytorch', 'torch']: 49 | return TorchCompositeLoss(loss, weights, custom_losses) 50 | 51 | else: # parse individual loss functions 52 | loss_name, loss_dict = list(loss.items())[0] 53 | return get_single_loss(framework, loss_name, loss_dict, custom_losses) 54 | 55 | 56 | def get_single_loss(framework, loss_name, params_dict, custom_losses=None): 57 | if framework == 'keras': 58 | if loss_name.lower() == 'focal': 59 | return k_focal_loss(**params_dict) 60 | else: 61 | # keras_losses in the next line is a matching dict 62 | # TODO: the next block doesn't handle non-focal loss functions that 63 | # have hyperparameters associated with them. It would be great to 64 | # refactor this to handle that possibility. 65 | if custom_losses is not None and loss_name in custom_losses: 66 | return custom_losses.get(loss_name) 67 | else: 68 | return keras_losses.get(loss_name.lower()) 69 | elif framework in ['torch', 'pytorch']: 70 | if params_dict is None: 71 | if custom_losses is not None and loss_name in custom_losses: 72 | return custom_losses.get(loss_name)() 73 | else: 74 | return torch_losses.get(loss_name.lower())() 75 | else: 76 | if custom_losses is not None and loss_name in custom_losses: 77 | return custom_losses.get(loss_name)(**params_dict) 78 | else: 79 | return torch_losses.get(loss_name.lower())(**params_dict) 80 | 81 | 82 | def keras_composite_loss(loss_dict, weight_dict, custom_losses=None): 83 | """Wrapper to other loss functions to create keras-compatible composite.""" 84 | 85 | def composite(y_true, y_pred): 86 | loss = K.sum(K.flatten(K.stack([weight_dict[loss_name]*get_single_loss( 87 | 'keras', loss_name, loss_params, custom_losses)(y_true, y_pred) 88 | for loss_name, loss_params in loss_dict.items()], axis=-1))) 89 | return loss 90 | 91 | return composite 92 | 93 | 94 | class TorchCompositeLoss(nn.Module): 95 | """Composite loss function.""" 96 | 97 | def __init__(self, loss_dict, weight_dict=None, custom_losses=None): 98 | """Create a composite loss function from a set of pytorch losses.""" 99 | super().__init__() 100 | self.weights = weight_dict 101 | self.losses = {loss_name: get_single_loss('pytorch', 102 | loss_name, 103 | loss_params, 104 | custom_losses) 105 | for loss_name, loss_params in loss_dict.items()} 106 | self.values = {} # values from the individual loss functions 107 | 108 | def forward(self, outputs, targets): 109 | loss = 0 110 | for func_name, weight in self.weights.items(): 111 | self.values[func_name] = self.losses[func_name](outputs, targets) 112 | loss += weight*self.values[func_name] 113 | 114 | return loss 115 | -------------------------------------------------------------------------------- /libs/solaris/nets/metrics.py: -------------------------------------------------------------------------------- 1 | from tensorflow.keras import backend as K 2 | from tensorflow import keras 3 | 4 | 5 | def get_metrics(framework, config): 6 | """Load model training metrics from a config file for a specific framework. 7 | """ 8 | training_metrics = [] 9 | validation_metrics = [] 10 | 11 | # TODO: enable passing kwargs to these metrics. This will require 12 | # writing a wrapper function that'll receive the inputs from the model 13 | # and pass them along with the kwarg to the metric function. 14 | if config['training']['metrics'].get('training', []) is None: 15 | training_metrics = [] 16 | else: 17 | for m in config['training']['metrics'].get('training', []): 18 | training_metrics.append(metric_dict[m]) 19 | if config['training']['metrics'].get('validation', []) is None: 20 | validation_metrics = [] 21 | else: 22 | for m in config['training']['metrics'].get('validation', []): 23 | validation_metrics.append(metric_dict[m]) 24 | 25 | return {'train': training_metrics, 'val': validation_metrics} 26 | 27 | 28 | def dice_coef_binary(y_true, y_pred, smooth=1e-7): 29 | ''' 30 | Dice coefficient for 2 categories. Ignores background pixel label 0 31 | Pass to model as metric during compile statement 32 | ''' 33 | y_true_f = K.flatten(K.one_hot(K.cast(y_true, 'int32'), 34 | num_classes=2)[..., 1:]) 35 | y_pred_f = K.flatten(y_pred[..., 1:]) 36 | intersect = K.sum(y_true_f * y_pred_f, axis=-1) 37 | denom = K.sum(y_true_f + y_pred_f, axis=-1) 38 | return K.mean((2. * intersect / (denom + smooth))) 39 | 40 | 41 | def precision(y_true, y_pred): 42 | """Precision for foreground pixels. 43 | 44 | Calculates pixelwise precision TP/(TP + FP). 45 | 46 | """ 47 | # count true positives 48 | truth = K.round(K.clip(y_true, K.epsilon(), 1)) 49 | pred_pos = K.round(K.clip(y_pred, K.epsilon(), 1)) 50 | true_pos = K.sum(K.cast(K.all(K.stack([truth, pred_pos], axis=2), axis=2), 51 | dtype='float64')) 52 | pred_pos_ct = K.sum(pred_pos) + K.epsilon() 53 | precision = true_pos/pred_pos_ct 54 | 55 | return precision 56 | 57 | 58 | def recall(y_true, y_pred): 59 | """Precision for foreground pixels. 60 | 61 | Calculates pixelwise recall TP/(TP + FN). 62 | 63 | """ 64 | # count true positives 65 | truth = K.round(K.clip(y_true, K.epsilon(), 1)) 66 | pred_pos = K.round(K.clip(y_pred, K.epsilon(), 1)) 67 | true_pos = K.sum(K.cast(K.all(K.stack([truth, pred_pos], axis=2), axis=2), 68 | dtype='float64')) 69 | truth_ct = K.sum(K.round(K.clip(y_true, K.epsilon(), 1))) 70 | if truth_ct == 0: 71 | return 0 72 | recall = true_pos/truth_ct 73 | 74 | return recall 75 | 76 | 77 | def f1_score(y_true, y_pred): 78 | """F1 score for foreground pixels ONLY. 79 | 80 | Calculates pixelwise F1 score for the foreground pixels (mask value == 1). 81 | Returns NaN if the model does not identify any foreground pixels in the 82 | image. 83 | 84 | """ 85 | 86 | prec = precision(y_true, y_pred) 87 | rec = recall(y_true, y_pred) 88 | # Calculate f1_score 89 | f1_score = 2 * (prec * rec) / (prec + rec) 90 | 91 | return f1_score 92 | 93 | 94 | # the keras metrics functions _should_ also work if provided with a 95 | # (y_true, y_pred) pair from pytorch, so I'll use those for both. 96 | metric_dict = { 97 | 'accuracy': keras.metrics.binary_accuracy, 98 | 'binary_accuracy': keras.metrics.binary_accuracy, 99 | 'precision': precision, 100 | 'recall': recall, 101 | 'f1_score': f1_score, 102 | 'categorical_accuracy': keras.metrics.categorical_accuracy, 103 | 'cosine': keras.metrics.CosineSimilarity, 104 | 'cosine_proximity': keras.metrics.CosineSimilarity, 105 | 'hinge': keras.metrics.hinge, 106 | 'squared_hinge': keras.metrics.squared_hinge, 107 | 'kld': keras.metrics.kullback_leibler_divergence, 108 | 'kullback_leibler_divergence': keras.metrics.kullback_leibler_divergence, 109 | 'mae': keras.metrics.mean_absolute_error, 110 | 'mean_absolute_error': keras.metrics.mean_absolute_error, 111 | 'mse': keras.metrics.mean_squared_error, 112 | 'mean_squared_error': keras.metrics.mean_squared_error, 113 | 'msle': keras.metrics.mean_squared_logarithmic_error, 114 | 'mean_squared_logarithmic_error': keras.metrics.mean_squared_logarithmic_error, 115 | 'sparse_categorical_accuracy': keras.metrics.sparse_categorical_accuracy, 116 | 'top_k_categorical_accuracy': keras.metrics.top_k_categorical_accuracy 117 | } 118 | -------------------------------------------------------------------------------- /libs/solaris/nets/model_io.py: -------------------------------------------------------------------------------- 1 | import os 2 | from tensorflow import keras 3 | import torch 4 | from warnings import warn 5 | import requests 6 | import numpy as np 7 | from tqdm import tqdm 8 | from ..nets import weights_dir 9 | from .zoo import model_dict 10 | 11 | 12 | def get_model(model_name, framework, model_path=None, pretrained=False, 13 | custom_model_dict=None): 14 | """Load a model from a file based on its name.""" 15 | if custom_model_dict is not None: 16 | md = custom_model_dict 17 | else: 18 | md = model_dict.get(model_name, None) 19 | if md is None: # if the model's not provided by solaris 20 | raise ValueError(f"{model_name} can't be found in solaris and no " 21 | "custom_model_dict was provided. Check your " 22 | "model_name in the config file and/or provide a " 23 | "custom_model_dict argument to Trainer().") 24 | if model_path is None or custom_model_dict is not None: 25 | model_path = md.get('weight_path') 26 | model = md.get('arch')() 27 | if model is not None and pretrained: 28 | try: 29 | model = _load_model_weights(model, model_path, framework) 30 | except (OSError, FileNotFoundError): 31 | warn(f'The model weights file {model_path} was not found.' 32 | ' Attempting to download from the SpaceNet repository.') 33 | weight_path = _download_weights(md) 34 | model = _load_model_weights(model, weight_path, framework) 35 | 36 | return model 37 | 38 | 39 | def _load_model_weights(model, path, framework): 40 | """Backend for loading the model.""" 41 | 42 | if framework.lower() == 'keras': 43 | try: 44 | model.load_weights(path) 45 | except OSError: 46 | # first, check to see if the weights are in the default sol dir 47 | default_path = os.path.join(weights_dir, os.path.split(path)[1]) 48 | try: 49 | model.load_weights(default_path) 50 | except OSError: 51 | # if they can't be found anywhere, raise the error. 52 | raise FileNotFoundError("{} doesn't exist.".format(path)) 53 | 54 | elif framework.lower() in ['torch', 'pytorch']: 55 | # pytorch already throws the right error on failed load, so no need 56 | # to fix exception 57 | if torch.cuda.is_available(): 58 | try: 59 | loaded = torch.load(path) 60 | except FileNotFoundError: 61 | # first, check to see if the weights are in the default sol dir 62 | default_path = os.path.join(weights_dir, 63 | os.path.split(path)[1]) 64 | loaded = torch.load(path) 65 | else: 66 | try: 67 | loaded = torch.load(path, map_location='cpu') 68 | except FileNotFoundError: 69 | default_path = os.path.join(weights_dir, 70 | os.path.split(path)[1]) 71 | loaded = torch.load(path, map_location='cpu') 72 | 73 | if isinstance(loaded, torch.nn.Module): # if it's a full model already 74 | model.load_state_dict(loaded.state_dict()) 75 | else: 76 | model.load_state_dict(loaded) 77 | 78 | return model 79 | 80 | 81 | def reset_weights(model, framework): 82 | """Re-initialize model weights for training. 83 | 84 | Arguments 85 | --------- 86 | model : :class:`tensorflow.keras.Model` or :class:`torch.nn.Module` 87 | A pre-trained, compiled model with weights saved. 88 | framework : str 89 | The deep learning framework used. Currently valid options are 90 | ``['torch', 'keras']`` . 91 | 92 | Returns 93 | ------- 94 | reinit_model : model object 95 | The model with weights re-initialized. Note this model object will also 96 | lack an optimizer, loss function, etc., which will need to be added. 97 | """ 98 | 99 | if framework == 'keras': 100 | model_json = model.to_json() 101 | reinit_model = keras.models.model_from_json(model_json) 102 | elif framework == 'torch': 103 | reinit_model = model.apply(_reset_torch_weights) 104 | 105 | return reinit_model 106 | 107 | 108 | def _reset_torch_weights(torch_layer): 109 | if isinstance(torch_layer, torch.nn.Conv2d) or \ 110 | isinstance(torch_layer, torch.nn.Linear): 111 | torch_layer.reset_parameters() 112 | 113 | 114 | def _download_weights(model_dict): 115 | """Download pretrained weights for a model.""" 116 | weight_url = model_dict.get('weight_url', None) 117 | weight_dest_path = model_dict.get('weight_path', os.path.join( 118 | weights_dir, weight_url.split('/')[-1])) 119 | if weight_url is None: 120 | raise KeyError("Can't find the weights file.") 121 | else: 122 | r = requests.get(weight_url, stream=True) 123 | if r.status_code != 200: 124 | raise ValueError('The file could not be downloaded. Check the URL' 125 | ' and network connections.') 126 | total_size = int(r.headers.get('content-length', 0)) 127 | block_size = 1024 128 | with open(weight_dest_path, 'wb') as f: 129 | for chunk in tqdm(r.iter_content(block_size), 130 | total=np.ceil(total_size//block_size), 131 | unit='KB', unit_scale=False): 132 | if chunk: 133 | f.write(chunk) 134 | 135 | return weight_dest_path 136 | -------------------------------------------------------------------------------- /libs/solaris/nets/optimizers.py: -------------------------------------------------------------------------------- 1 | """Wrappers for training optimizers.""" 2 | import math 3 | import torch 4 | from tensorflow import keras 5 | 6 | 7 | def get_optimizer(framework, config): 8 | """Get the optimizer specified in config for model training. 9 | 10 | Arguments 11 | --------- 12 | framework : str 13 | Name of the deep learning framework used. Current options are 14 | ``['torch', 'keras']``. 15 | config : dict 16 | The config dict generated from the YAML config file. 17 | 18 | Returns 19 | ------- 20 | An optimizer object for the specified deep learning framework. 21 | """ 22 | 23 | if config['training']['optimizer'] is None: 24 | raise ValueError('An optimizer must be specified in the config ' 25 | 'file.') 26 | 27 | if framework in ['torch', 'pytorch']: 28 | return torch_optimizers.get(config['training']['optimizer'].lower()) 29 | elif framework == 'keras': 30 | return keras_optimizers.get(config['training']['optimizer'].lower()) 31 | 32 | 33 | class TorchAdamW(torch.optim.Optimizer): 34 | """AdamW algorithm as implemented in `Torch_AdamW`_. 35 | 36 | The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_. 37 | The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_. 38 | Arguments: 39 | params (iterable): iterable of parameters to optimize or dicts defining 40 | parameter groups 41 | lr (float, optional): learning rate (default: 1e-3) 42 | betas (Tuple[float, float], optional): coefficients used for computing 43 | running averages of gradient and its square (default: (0.9, 0.999)) 44 | eps (float, optional): term added to the denominator to improve 45 | numerical stability (default: 1e-8) 46 | weight_decay (float, optional): weight decay coefficient (default: 1e-2) 47 | amsgrad (boolean, optional): whether to use the AMSGrad variant of this 48 | algorithm from the paper `On the Convergence of Adam and Beyond`_ 49 | (default: False) 50 | .. _Torch_AdamW: https://github.com/pytorch/pytorch/pull/3740 51 | .. _Adam\: A Method for Stochastic Optimization: 52 | https://arxiv.org/abs/1412.6980 53 | .. _Decoupled Weight Decay Regularization: 54 | https://arxiv.org/abs/1711.05101 55 | .. _On the Convergence of Adam and Beyond: 56 | https://openreview.net/forum?id=ryQu7f-RZ 57 | """ 58 | 59 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, 60 | weight_decay=1e-2, amsgrad=False): 61 | if not 0.0 <= lr: 62 | raise ValueError("Invalid learning rate: {}".format(lr)) 63 | if not 0.0 <= eps: 64 | raise ValueError("Invalid epsilon value: {}".format(eps)) 65 | if not 0.0 <= betas[0] < 1.0: 66 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 67 | if not 0.0 <= betas[1] < 1.0: 68 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 69 | defaults = dict(lr=lr, betas=betas, eps=eps, 70 | weight_decay=weight_decay, amsgrad=amsgrad) 71 | super(TorchAdamW, self).__init__(params, defaults) 72 | 73 | def __setstate__(self, state): 74 | super(TorchAdamW, self).__setstate__(state) 75 | for group in self.param_groups: 76 | group.setdefault('amsgrad', False) 77 | 78 | def step(self, closure=None): 79 | """Performs a single optimization step. 80 | Arguments: 81 | closure (callable, optional): A closure that reevaluates the model 82 | and returns the loss. 83 | """ 84 | loss = None 85 | if closure is not None: 86 | loss = closure() 87 | 88 | for group in self.param_groups: 89 | for p in group['params']: 90 | if p.grad is None: 91 | continue 92 | 93 | # Perform stepweight decay 94 | p.data.mul_(1 - group['lr'] * group['weight_decay']) 95 | 96 | # Perform optimization step 97 | grad = p.grad.data 98 | if grad.is_sparse: 99 | raise RuntimeError('Adam does not support sparse' 100 | 'gradients, please consider SparseAdam' 101 | ' instead') 102 | amsgrad = group['amsgrad'] 103 | 104 | state = self.state[p] 105 | 106 | # State initialization 107 | if len(state) == 0: 108 | state['step'] = 0 109 | # Exponential moving average of gradient values 110 | state['exp_avg'] = torch.zeros_like(p.data) 111 | # Exponential moving average of squared gradient values 112 | state['exp_avg_sq'] = torch.zeros_like(p.data) 113 | if amsgrad: 114 | # Maintains max of all exp. moving avg. of sq. grad. values 115 | state['max_exp_avg_sq'] = torch.zeros_like(p.data) 116 | 117 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 118 | if amsgrad: 119 | max_exp_avg_sq = state['max_exp_avg_sq'] 120 | beta1, beta2 = group['betas'] 121 | 122 | state['step'] += 1 123 | 124 | # Decay the first and second moment running average coefficient 125 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 126 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 127 | if amsgrad: 128 | # Maintains the maximum of all 2nd moment running avg. till now 129 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) 130 | # Use the max. for normalizing running avg. of gradient 131 | denom = max_exp_avg_sq.sqrt().add_(group['eps']) 132 | else: 133 | denom = exp_avg_sq.sqrt().add_(group['eps']) 134 | 135 | bias_correction1 = 1 - beta1 ** state['step'] 136 | bias_correction2 = 1 - beta2 ** state['step'] 137 | step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 138 | 139 | p.data.addcdiv_(-step_size, exp_avg, denom) 140 | 141 | return loss 142 | 143 | 144 | torch_optimizers = { 145 | 'adadelta': torch.optim.Adadelta, 146 | 'adam': torch.optim.Adam, 147 | 'adamw': TorchAdamW, 148 | 'sparseadam': torch.optim.SparseAdam, 149 | 'adamax': torch.optim.Adamax, 150 | 'asgd': torch.optim.ASGD, 151 | 'rmsprop': torch.optim.RMSprop, 152 | 'sgd': torch.optim.SGD, 153 | } 154 | 155 | keras_optimizers = { 156 | 'adadelta': keras.optimizers.Adadelta, 157 | 'adagrad': keras.optimizers.Adagrad, 158 | 'adam': keras.optimizers.Adam, 159 | 'adamax': keras.optimizers.Adamax, 160 | 'nadam': keras.optimizers.Nadam, 161 | 'rmsprop': keras.optimizers.RMSprop, 162 | 'sgd': keras.optimizers.SGD 163 | } 164 | -------------------------------------------------------------------------------- /libs/solaris/nets/torch_callbacks.py: -------------------------------------------------------------------------------- 1 | """PyTorch Callbacks.""" 2 | 3 | import os 4 | import numpy as np 5 | from .metrics import metric_dict 6 | import torch 7 | 8 | 9 | class TorchEarlyStopping(object): 10 | """Tracks if model training should stop based on rate of improvement. 11 | 12 | Arguments 13 | --------- 14 | patience : int, optional 15 | The number of epochs to wait before stopping the model if the metric 16 | didn't improve. Defaults to 5. 17 | threshold : float, optional 18 | The minimum metric improvement required to count as "improvement". 19 | Defaults to ``0.0`` (any improvement satisfies the requirement). 20 | verbose : bool, optional 21 | Verbose text output. Defaults to off (``False``). _NOTE_ : This 22 | currently does nothing. 23 | """ 24 | 25 | def __init__(self, patience=5, threshold=0.0, verbose=False): 26 | self.patience = patience 27 | self.threshold = threshold 28 | self.counter = 0 29 | self.best = None 30 | self.stop = False 31 | 32 | def __call__(self, metric_score): 33 | 34 | if self.best is None: 35 | self.best = metric_score 36 | self.counter = 0 37 | else: 38 | if self.best - self.threshold < metric_score: 39 | self.counter += 1 40 | else: 41 | self.best = metric_score 42 | self.counter = 0 43 | 44 | if self.counter >= self.patience: 45 | self.stop = True 46 | 47 | 48 | class TorchTerminateOnNaN(object): 49 | """Sets a stop condition if the model loss achieves an NaN or inf value. 50 | 51 | Arguments 52 | --------- 53 | patience : int, optional 54 | The number of epochs that must display an NaN loss value before 55 | stopping. Defaults to ``1``. 56 | verbose : bool, optional 57 | Verbose text output. Defaults to off (``False``). _NOTE_ : This 58 | currently does nothing. 59 | """ 60 | 61 | def __init__(self, patience=1, verbose=False): 62 | self.patience = patience 63 | self.counter = 0 64 | self.stop = False 65 | 66 | def __call__(self, loss): 67 | if np.isnan(loss) or np.isinf(loss): 68 | self.counter += 1 69 | if self.counter >= self.patience: 70 | self.stop = True 71 | else: 72 | self.counter = 0 73 | 74 | 75 | class TorchTerminateOnMetricNaN(object): 76 | """Sets a stop condition if a training metric achieves an NaN or inf value. 77 | 78 | Arguments 79 | --------- 80 | stopping_metric : str 81 | The name of the metric to stop on. The name must match a key in 82 | :const:`solaris.nets.metrics.metric_dict` . 83 | patience : int, optional 84 | The number of epochs that must display an NaN loss value before 85 | stopping. Defaults to ``1``. 86 | verbose : bool, optional 87 | Verbose text output. Defaults to off (``False``). _NOTE_ : This 88 | currently does nothing. 89 | """ 90 | 91 | def __init__(self, stopping_metric, patience=1, verbose=False): 92 | self.metric = metric_dict[stopping_metric] 93 | self.patience = patience 94 | self.counter = 0 95 | self.stop = False 96 | 97 | def __call__(self, y_true, y_pred): 98 | if np.isinf(self.metric(y_true, y_pred)) or \ 99 | np.isnan(self.metric(y_true, y_pred)): 100 | self.counter += 1 101 | if self.counter >= self.patience: 102 | self.stop = True 103 | else: 104 | self.counter = 0 105 | 106 | 107 | class TorchModelCheckpoint(object): 108 | """Save the model at specific points using Keras checkpointing args. 109 | 110 | Arguments 111 | --------- 112 | filepath : str, optional 113 | Path to save the model file to. The end of the path (before the 114 | file extension) will have ``'_[epoch]'`` added to it to ID specific 115 | checkpoints. 116 | monitor : str, optional 117 | The loss value to monitor. Options are 118 | ``['loss', 'val_loss', 'periodic']`` or a metric from the keys in 119 | :const:`solaris.nets.metrics.metric_dict` . Defaults to ``'loss'`` . If 120 | ``'periodic'``, it saves every n epochs (see `period` below). 121 | verbose : bool, optional 122 | Verbose text output. Defaults to ``False`` . 123 | save_best_only : bool, optional 124 | Save only the model with the best value? Defaults to no (``False`` ). 125 | mode : str, optional 126 | One of ``['auto', 'min', 'max']``. Is a better value higher or lower? 127 | Defaults to ``'auto'`` in which case it tries to infer it (if 128 | ``monitor='loss'`` or ``monitor='val_loss'`` , it assumes ``'min'`` , 129 | if it's a metric it assumes ``'max'`` .) If ``'min'``, it assumes lower 130 | values are better; if ``'max'`` , it assumes higher values are better. 131 | period : int, optional 132 | If using ``monitor='periodic'`` , this saves models every `period` 133 | epochs. Otherwise, it sets the minimum number of epochs between 134 | checkpoints. 135 | """ 136 | 137 | def __init__(self, filepath='', monitor='loss', verbose=False, 138 | save_best_only=False, mode='auto', period=1, 139 | weights_only=True): 140 | 141 | self.filepath = filepath 142 | self.monitor = monitor 143 | if self.monitor not in ['loss', 'val_loss', 'periodic']: 144 | self.monitor = metric_dict[self.monitor] 145 | self.verbose = verbose 146 | self.save_best_only = save_best_only 147 | self.period = period 148 | self.weights_only = weights_only 149 | self.mode = mode 150 | if self.mode == 'auto': 151 | if self.monitor in ['loss', 'val_loss']: 152 | self.mode = 'min' 153 | else: 154 | self.mode = 'max' 155 | 156 | self.epoch = 0 157 | self.last_epoch = 0 158 | self.last_saved_value = None 159 | 160 | def __call__(self, model, loss_value=None, y_true=None, y_pred=None): 161 | """Run a round of model checkpointing for an epoch. 162 | 163 | Arguments 164 | --------- 165 | model : model object 166 | The model to be saved during checkpoints. Must be a PyTorch model. 167 | loss_value : numeric, optional 168 | The numeric output of the loss function. Only required if using 169 | ``monitor='loss'`` or ``monitor='val_loss'`` . 170 | y_true : :class:`np.array` , optional 171 | The labels for the validation data. Only required if using 172 | a metric as the monitored value. 173 | y_pred : :class:`np.array` , optional 174 | The predicted values from the model. Only required if using 175 | a metric as the monitored value. 176 | """ 177 | 178 | self.epoch += 1 179 | if self.monitor == 'periodic': 180 | if self.last_epoch + self.period <= self.epoch: 181 | self.last_saved_value = loss_value 182 | self.save(model, self.weights_only) 183 | self.last_epoch = self.epoch 184 | 185 | elif self.monitor in ['loss', 'val_loss']: 186 | if self.last_saved_value is None: 187 | self.last_saved_value = loss_value 188 | if self.last_epoch + self.period <= self.epoch: 189 | self.save(model, self.weights_only) 190 | self.last_epoch = self.epoch 191 | if self.last_epoch + self.period <= self.epoch: 192 | if self.check_is_best_value(loss_value): 193 | self.last_saved_value = loss_value 194 | self.save(model, self.weights_only) 195 | self.last_epoch = self.epoch 196 | 197 | else: 198 | if self.last_saved_value is None: 199 | self.last_saved_value = self.monitor(y_true, y_pred) 200 | if self.last_epoch + self.period <= self.epoch: 201 | self.save(model, self.weights_only) 202 | self.last_epoch = self.epoch 203 | if self.last_epoch + self.period <= self.epoch: 204 | metric_value = self.monitor(y_true, y_pred) 205 | if self.check_is_best_value(metric_value): 206 | self.last_saved_value = metric_value 207 | self.save(model, self.weights_only) 208 | self.last_epoch = self.epoch 209 | 210 | def check_is_best_value(self, value): 211 | """Check if `value` is better than the best stored value.""" 212 | if self.mode == 'min' and self.last_saved_value > value: 213 | return True 214 | elif self.mode == 'max' and self.last_saved_value < value: 215 | return True 216 | else: 217 | return False 218 | 219 | def save(self, model, weights_only): 220 | """Save the model. 221 | 222 | Arguments 223 | --------- 224 | model : :class:`torch.nn.Module` 225 | A PyTorch model instance to save. 226 | weights_only : bool, optional 227 | Should the entire model be saved, or only its weights (also known 228 | as the state_dict)? Defaults to ``False`` (saves entire model). The 229 | entire model must be saved to resume training without re-defining 230 | the model architecture, optimizer, and loss function. 231 | """ 232 | save_name = os.path.splitext(self.filepath)[0] + '_epoch{}_{}'.format( 233 | self.epoch, np.round(self.last_saved_value, 3)) 234 | save_name = save_name + os.path.splitext(self.filepath)[1] 235 | if isinstance(model, torch.nn.DataParallel): 236 | to_save = model.module 237 | else: 238 | to_save = model 239 | if weights_only: 240 | torch.save(to_save.state_dict(), save_name) 241 | else: 242 | torch.save(to_save, save_name) 243 | 244 | 245 | torch_callback_dict = { 246 | "early_stopping": TorchEarlyStopping, 247 | "model_checkpoint": TorchModelCheckpoint, 248 | "terminate_on_nan": TorchTerminateOnNaN, 249 | "terminate_on_metric_nan": TorchTerminateOnMetricNaN 250 | } 251 | -------------------------------------------------------------------------------- /libs/solaris/nets/zoo/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .. import weights_dir 3 | from .xdxd_sn4 import XDXD_SpaceNet4_UNetVGG16 4 | from .selim_sef_sn4 import SelimSef_SpaceNet4_ResNet34UNet 5 | from .selim_sef_sn4 import SelimSef_SpaceNet4_DenseNet121UNet 6 | from .selim_sef_sn4 import SelimSef_SpaceNet4_DenseNet161UNet 7 | 8 | model_dict = { 9 | 'xdxd_spacenet4': { 10 | 'weight_path': os.path.join(weights_dir, 11 | 'xdxd_spacenet4_solaris_weights.pth'), 12 | 'weight_url': 'https://s3.amazonaws.com/spacenet-dataset/spacenet-model-weights/spacenet-4/xdxd_spacenet4_solaris_weights.pth', 13 | 'arch': XDXD_SpaceNet4_UNetVGG16 14 | }, 15 | 'selimsef_spacenet4_resnet34unet': { 16 | 'weight_path': os.path.join( 17 | weights_dir, 'selimsef_spacenet4_resnet34unet_solaris_weights.pth' 18 | ), 19 | 'weight_url': 'https://s3.amazonaws.com/spacenet-dataset/spacenet-model-weights/spacenet-4/selimsef_spacenet4_resnet34unet_solaris_weights.pth', 20 | 'arch': SelimSef_SpaceNet4_ResNet34UNet 21 | }, 22 | 'selimsef_spacenet4_densenet121unet': { 23 | 'weight_path': os.path.join( 24 | weights_dir, 'selimsef_spacenet4_densenet121unet_solaris_weights.pth' 25 | ), 26 | 'weight_url': 'https://s3.amazonaws.com/spacenet-dataset/spacenet-model-weights/spacenet-4/selimsef_spacenet4_densenet121unet_solaris_weights.pth', 27 | 'arch': SelimSef_SpaceNet4_DenseNet121UNet 28 | }, 29 | 'selimsef_spacenet4_densenet161unet': { 30 | 'weight_path': os.path.join( 31 | weights_dir, 'selimsef_spacenet4_densenet161unet_solaris_weights.pth' 32 | ), 33 | 'weight_url': 'https://s3.amazonaws.com/spacenet-dataset/spacenet-model-weights/spacenet-4/selimsef_spacenet4_densenet161unet_solaris_weights.pth', 34 | 'arch': SelimSef_SpaceNet4_DenseNet161UNet 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /libs/solaris/nets/zoo/xdxd_sn4.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch import nn 4 | from torchvision.models import vgg16 5 | 6 | 7 | class XDXD_SpaceNet4_UNetVGG16(nn.Module): 8 | def __init__(self, num_filters=32, pretrained=False): 9 | super().__init__() 10 | self.encoder = vgg16(pretrained=pretrained).features 11 | self.pool = nn.MaxPool2d(2, 2) 12 | 13 | self.relu = nn.ReLU(inplace=True) 14 | self.conv1 = nn.Sequential( 15 | self.encoder[0], self.relu, self.encoder[2], self.relu) 16 | self.conv2 = nn.Sequential( 17 | self.encoder[5], self.relu, self.encoder[7], self.relu) 18 | self.conv3 = nn.Sequential( 19 | self.encoder[10], self.relu, self.encoder[12], self.relu, 20 | self.encoder[14], self.relu) 21 | self.conv4 = nn.Sequential( 22 | self.encoder[17], self.relu, self.encoder[19], self.relu, 23 | self.encoder[21], self.relu) 24 | self.conv5 = nn.Sequential( 25 | self.encoder[24], self.relu, self.encoder[26], self.relu, 26 | self.encoder[28], self.relu) 27 | 28 | self.center = XDXD_SN4_DecoderBlock(512, num_filters * 8 * 2, 29 | num_filters * 8) 30 | self.dec5 = XDXD_SN4_DecoderBlock( 31 | 512 + num_filters * 8, num_filters * 8 * 2, num_filters * 8) 32 | self.dec4 = XDXD_SN4_DecoderBlock( 33 | 512 + num_filters * 8, num_filters * 8 * 2, num_filters * 8) 34 | self.dec3 = XDXD_SN4_DecoderBlock( 35 | 256 + num_filters * 8, num_filters * 4 * 2, num_filters * 2) 36 | self.dec2 = XDXD_SN4_DecoderBlock( 37 | 128 + num_filters * 2, num_filters * 2 * 2, num_filters) 38 | self.dec1 = XDXD_SN4_ConvRelu(64 + num_filters, num_filters) 39 | self.final = nn.Conv2d(num_filters, 1, kernel_size=1) 40 | 41 | def forward(self, x): 42 | conv1 = self.conv1(x) 43 | conv2 = self.conv2(self.pool(conv1)) 44 | conv3 = self.conv3(self.pool(conv2)) 45 | conv4 = self.conv4(self.pool(conv3)) 46 | conv5 = self.conv5(self.pool(conv4)) 47 | center = self.center(self.pool(conv5)) 48 | dec5 = self.dec5(torch.cat([center, conv5], 1)) 49 | dec4 = self.dec4(torch.cat([dec5, conv4], 1)) 50 | dec3 = self.dec3(torch.cat([dec4, conv3], 1)) 51 | dec2 = self.dec2(torch.cat([dec3, conv2], 1)) 52 | dec1 = self.dec1(torch.cat([dec2, conv1], 1)) 53 | x_out = self.final(dec1) 54 | return x_out 55 | 56 | 57 | class XDXD_SN4_ConvRelu(nn.Module): 58 | def __init__(self, in_, out): 59 | super().__init__() 60 | self.conv = nn.Conv2d(in_, out, 3, padding=1) 61 | self.activation = nn.ReLU(inplace=True) 62 | 63 | def forward(self, x): 64 | x = self.conv(x) 65 | x = self.activation(x) 66 | return x 67 | 68 | 69 | class XDXD_SN4_DecoderBlock(nn.Module): 70 | def __init__(self, in_channels, middle_channels, out_channels): 71 | super(XDXD_SN4_DecoderBlock, self).__init__() 72 | self.in_channels = in_channels 73 | self.block = nn.Sequential( 74 | nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), 75 | XDXD_SN4_ConvRelu(in_channels, middle_channels), 76 | XDXD_SN4_ConvRelu(middle_channels, out_channels), 77 | ) 78 | 79 | def forward(self, x): 80 | return self.block(x) 81 | 82 | # below dictionary lists models compatible with solaris. alternatively, your 83 | # own model can be used by using the path to the model as the value for 84 | # model_name in the config file. 85 | -------------------------------------------------------------------------------- /libs/solaris/raster/__init__.py: -------------------------------------------------------------------------------- 1 | from . import image 2 | -------------------------------------------------------------------------------- /libs/solaris/tile/__init__.py: -------------------------------------------------------------------------------- 1 | from . import raster_tile, vector_tile 2 | -------------------------------------------------------------------------------- /libs/solaris/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from . import cli, config, core, geo, io, tile, data 2 | -------------------------------------------------------------------------------- /libs/solaris/utils/cli.py: -------------------------------------------------------------------------------- 1 | def _func_wrapper(func_to_call, arg_dict): 2 | return func_to_call(**arg_dict) 3 | -------------------------------------------------------------------------------- /libs/solaris/utils/config.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | from ..nets import zoo 3 | 4 | 5 | def parse(path): 6 | """Parse a config file for running a model. 7 | 8 | Arguments 9 | --------- 10 | path : str 11 | Path to the YAML-formatted config file to parse. 12 | 13 | Returns 14 | ------- 15 | config : dict 16 | A `dict` containing the information from the config file at `path`. 17 | 18 | """ 19 | with open(path, 'r') as f: 20 | config = yaml.safe_load(f) 21 | f.close() 22 | if not config['train'] and not config['infer']: 23 | raise ValueError('"train", "infer", or both must be true.') 24 | if config['train'] and config['training_data_csv'] is None: 25 | raise ValueError('"training_data_csv" must be provided if training.') 26 | if config['infer'] and config['inference_data_csv'] is None: 27 | raise ValueError('"inference_data_csv" must be provided if "infer".') 28 | if config['training']['lr'] is not None: 29 | config['training']['lr'] = float(config['training']['lr']) 30 | 31 | # TODO: IMPLEMENT UPDATING VALUES BASED ON EMPTY ELEMENTS HERE! 32 | 33 | if config['validation_augmentation'] is not None \ 34 | and config['inference_augmentation'] is None: 35 | config['inference_augmentation'] = config['validation_augmentation'] 36 | 37 | return config 38 | -------------------------------------------------------------------------------- /libs/solaris/utils/core.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from shapely.wkt import loads 4 | from shapely.geometry import Point 5 | from shapely.geometry.base import BaseGeometry 6 | import pandas as pd 7 | import geopandas as gpd 8 | import pyproj 9 | import rasterio 10 | from distutils.version import LooseVersion 11 | import skimage 12 | from fiona._err import CPLE_OpenFailedError 13 | from fiona.errors import DriverError 14 | from warnings import warn 15 | 16 | 17 | def _check_rasterio_im_load(im): 18 | """Check if `im` is already loaded in; if not, load it in.""" 19 | if isinstance(im, str): 20 | return rasterio.open(im) 21 | elif isinstance(im, rasterio.DatasetReader): 22 | return im 23 | else: 24 | raise ValueError( 25 | "{} is not an accepted image format for rasterio.".format(im)) 26 | 27 | 28 | def _check_skimage_im_load(im): 29 | """Check if `im` is already loaded in; if not, load it in.""" 30 | if isinstance(im, str): 31 | return skimage.io.imread(im) 32 | elif isinstance(im, np.ndarray): 33 | return im 34 | else: 35 | raise ValueError( 36 | "{} is not an accepted image format for scikit-image.".format(im)) 37 | 38 | 39 | def _check_df_load(df): 40 | """Check if `df` is already loaded in, if not, load from file.""" 41 | if isinstance(df, str): 42 | if df.lower().endswith('json'): 43 | return _check_gdf_load(df) 44 | else: 45 | return pd.read_csv(df) 46 | elif isinstance(df, pd.DataFrame): 47 | return df 48 | else: 49 | raise ValueError(f"{df} is not an accepted DataFrame format.") 50 | 51 | 52 | def _check_gdf_load(gdf): 53 | """Check if `gdf` is already loaded in, if not, load from geojson.""" 54 | if isinstance(gdf, str): 55 | # as of geopandas 0.6.2, using the OGR CSV driver requires some add'nal 56 | # kwargs to create a valid geodataframe with a geometry column. see 57 | # https://github.com/geopandas/geopandas/issues/1234 58 | if gdf.lower().endswith('csv'): 59 | return gpd.read_file(gdf, GEOM_POSSIBLE_NAMES="geometry", 60 | KEEP_GEOM_COLUMNS="NO") 61 | try: 62 | return gpd.read_file(gdf) 63 | except (DriverError, CPLE_OpenFailedError): 64 | warn(f"GeoDataFrame couldn't be loaded: either {gdf} isn't a valid" 65 | " path or it isn't a valid vector file. Returning an empty" 66 | " GeoDataFrame.") 67 | return gpd.GeoDataFrame() 68 | elif isinstance(gdf, gpd.GeoDataFrame): 69 | return gdf 70 | else: 71 | raise ValueError(f"{gdf} is not an accepted GeoDataFrame format.") 72 | 73 | 74 | def _check_geom(geom): 75 | """Check if a geometry is loaded in. 76 | 77 | Returns the geometry if it's a shapely geometry object. If it's a wkt 78 | string or a list of coordinates, convert to a shapely geometry. 79 | """ 80 | if isinstance(geom, BaseGeometry): 81 | return geom 82 | elif isinstance(geom, str): # assume it's a wkt 83 | return loads(geom) 84 | elif isinstance(geom, list) and len(geom) == 2: # coordinates 85 | return Point(geom) 86 | 87 | 88 | def _check_crs(input_crs, return_rasterio=False): 89 | """Convert CRS to the ``pyproj.CRS`` object passed by ``solaris``.""" 90 | if not isinstance(input_crs, pyproj.CRS) and input_crs is not None: 91 | out_crs = pyproj.CRS(input_crs) 92 | else: 93 | out_crs = input_crs 94 | 95 | if return_rasterio: 96 | if LooseVersion(rasterio.__gdal_version__) >= LooseVersion("3.0.0"): 97 | out_crs = rasterio.crs.CRS.from_wkt(out_crs.to_wkt()) 98 | else: 99 | out_crs = rasterio.crs.CRS.from_wkt(out_crs.to_wkt("WKT1_GDAL")) 100 | 101 | return out_crs 102 | 103 | 104 | def get_data_paths(path, infer=False): 105 | """Get a pandas dataframe of images and labels from a csv. 106 | 107 | This file is designed to parse image:label reference CSVs (or just image) 108 | for inferencde) as defined in the documentation. Briefly, these should be 109 | CSVs containing two columns: 110 | 111 | ``'image'``: the path to images. 112 | ``'label'``: the path to the label file that corresponds to the image. 113 | 114 | Arguments 115 | --------- 116 | path : str 117 | Path to a .CSV-formatted reference file defining the location of 118 | training, validation, or inference data. See docs for details. 119 | infer : bool, optional 120 | If ``infer=True`` , the ``'label'`` column will not be returned (as it 121 | is unnecessary for inference), even if it is present. 122 | 123 | Returns 124 | ------- 125 | df : :class:`pandas.DataFrame` 126 | A :class:`pandas.DataFrame` containing the relevant `image` and `label` 127 | information from the CSV at `path` (unless ``infer=True`` , in which 128 | case only the `image` column is returned.) 129 | 130 | """ 131 | df = pd.read_csv(path) 132 | if infer: 133 | return df[['image']] # no labels in those files 134 | else: 135 | return df[['image', 'label']] # remove anything extraneous 136 | 137 | 138 | def get_files_recursively(path, traverse_subdirs=False, extension='.tif'): 139 | """Get files from subdirs of `path`, joining them to the dir.""" 140 | if traverse_subdirs: 141 | walker = os.walk(path) 142 | path_list = [] 143 | for step in walker: 144 | if not step[2]: # if there are no files in the current dir 145 | continue 146 | path_list += [os.path.join(step[0], fname) 147 | for fname in step[2] if 148 | fname.lower().endswith(extension)] 149 | return path_list 150 | else: 151 | return [os.path.join(path, f) for f in os.listdir(path) 152 | if f.endswith(extension)] 153 | -------------------------------------------------------------------------------- /libs/solaris/utils/data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | from .log import _get_logging_level 4 | from .core import get_files_recursively 5 | import logging 6 | 7 | 8 | def make_dataset_csv(im_dir, im_ext='tif', label_dir=None, label_ext='json', 9 | output_path='dataset.csv', stage='train', match_re=None, 10 | recursive=False, ignore_mismatch=None, verbose=0): 11 | """Automatically generate dataset CSVs for training. 12 | 13 | This function creates basic CSVs for training and inference automatically. 14 | See `the documentation tutorials `_ 15 | for details on the specification. A regular expression string can be 16 | provided to extract substrings for matching images to labels; if not 17 | provided, it's assumed that the filename for the image and label files is 18 | identical once extensions are stripped. By default, this function will 19 | raise an exception if there are multiple label files that match to a given 20 | image file, or if no label file matches an image file; see the 21 | `ignore_mismatch` argument for alternatives. 22 | 23 | Arguments 24 | --------- 25 | im_dir : str 26 | The path to the directory containing images to be used by your model. 27 | Images in sub-directories can be included by setting 28 | ``recursive=True``. 29 | im_ext : str, optional 30 | The file extension used by your images. Defaults to ``"tif"``. Not case 31 | sensitive. 32 | label_dir : str, optional 33 | The path to the directory containing images to be used by your model. 34 | Images in sub-directories can be included by setting 35 | ``recursive=True``. This argument is required if `stage` is ``"train"`` 36 | (default) or ``"val"``, but has no effect if `stage` is ``"infer"``. 37 | output_path : str, optional 38 | The path to save the generated CSV to. Defaults to ``"dataset.csv"``. 39 | stage : str, optional 40 | The stage that the csv is generated for. Can be ``"train"`` (default), 41 | ``"val"``, or ``"infer"``. If set to ``"train"`` or ``"val"``, 42 | `label_dir` must be provided or an error will occur. 43 | match_re : str, optional 44 | A regular expression pattern to extract substrings from image and 45 | label filenames for matching. If not provided and labels must be 46 | matched to images, it's assumed that image and label filenames are 47 | identical after stripping directory and extension. Has no effect if 48 | ``stage="infer"``. The pattern must contain at least one capture group 49 | for compatibility with :func:`pandas.Series.str.extract`. 50 | recursive : bool, optional 51 | Should sub-directories in `im_dir` and `label_dir` be traversed to 52 | find images and label files? Defaults to no (``False``). 53 | ignore_mismatch : str, optional 54 | Dictates how mismatches between image files and label files should be 55 | handled. By default, having != 1 label file per image file will raise 56 | a ``ValueError``. If ``ignore_mismatch="skip"``, any image files with 57 | != 1 matching label will be skipped. 58 | verbose : int, optional 59 | Verbose text output. By default, none is provided; if ``True`` or 60 | ``1``, information-level outputs are provided; if ``2``, extremely 61 | verbose text is output. 62 | 63 | Returns 64 | ------- 65 | output_df : :class:`pandas.DataFrame` 66 | A :class:`pandas.DataFrame` with one column titled ``"image"`` and 67 | a second titled ``"label"`` (if ``stage != "infer"``). The function 68 | also saves a CSV at `output_path`. 69 | """ 70 | logger = logging.getLogger(__name__) 71 | logger.setLevel(_get_logging_level(int(verbose))) 72 | logger.debug('Checking arguments.') 73 | 74 | if stage != 'infer' and label_dir is None: 75 | raise ValueError("label_dir must be provided if stage is not infer.") 76 | logger.info('Matching images to labels.') 77 | logger.debug('Getting image file paths.') 78 | im_fnames = get_files_recursively(im_dir, traverse_subdirs=recursive, 79 | extension=im_ext) 80 | logger.debug(f"Got {len(im_fnames)} image file paths.") 81 | temp_im_df = pd.DataFrame({'image_path': im_fnames}) 82 | 83 | if stage != 'infer': 84 | logger.debug('Preparing training or validation set.') 85 | logger.debug('Getting label file paths.') 86 | label_fnames = get_files_recursively(label_dir, 87 | traverse_subdirs=recursive, 88 | extension=label_ext) 89 | logger.debug(f"Got {len(label_fnames)} label file paths.") 90 | if len(im_fnames) != len(label_fnames): 91 | logger.warn('The number of images and label files is not equal.') 92 | 93 | logger.debug("Matching image files to label files.") 94 | logger.debug("Extracting image filename substrings for matching.") 95 | temp_label_df = pd.DataFrame({'label_path': label_fnames}) 96 | temp_im_df['image_fname'] = temp_im_df['image_path'].apply( 97 | lambda x: os.path.split(x)[1]) 98 | temp_label_df['label_fname'] = temp_label_df['label_path'].apply( 99 | lambda x: os.path.split(x)[1]) 100 | if match_re: 101 | logger.debug('match_re is True, extracting regex matches') 102 | im_match_strs = temp_im_df['image_fname'].str.extract(match_re) 103 | label_match_strs = temp_label_df['label_fname'].str.extract( 104 | match_re) 105 | if len(im_match_strs.columns) > 1 or \ 106 | len(label_match_strs.columns) > 1: 107 | raise ValueError('Multiple regex matches occurred within ' 108 | 'individual filenames.') 109 | else: 110 | temp_im_df['match_str'] = im_match_strs 111 | temp_label_df['match_str'] = label_match_strs 112 | else: 113 | logger.debug('match_re is False, will match by fname without ext') 114 | temp_im_df['match_str'] = temp_im_df['image_fname'].apply( 115 | lambda x: os.path.splitext(x)[0]) 116 | temp_label_df['match_str'] = temp_label_df['label_fname'].apply( 117 | lambda x: os.path.splitext(x)[0]) 118 | 119 | logger.debug('Aligning label and image dataframes by' 120 | ' match_str.') 121 | temp_join_df = pd.merge(temp_im_df, temp_label_df, on='match_str', 122 | how='inner') 123 | logger.debug(f'Length of joined dataframe: {len(temp_join_df)}') 124 | if len(temp_join_df) < len(temp_im_df) and \ 125 | ignore_mismatch is None: 126 | raise ValueError('There is not a perfect 1:1 match of images ' 127 | 'to label files. To allow this behavior, see ' 128 | 'the make_dataset_csv() ignore_mismatch ' 129 | 'argument.') 130 | elif len(temp_join_df) > len(temp_im_df) and ignore_mismatch is None: 131 | raise ValueError('There are multiple label files matching at ' 132 | 'least one image file.') 133 | elif len(temp_join_df) > len(temp_im_df) and ignore_mismatch == 'skip': 134 | logger.info('ignore_mismatch="skip", so dropping any images with ' 135 | f'duplicates. Original images: {len(temp_im_df)}') 136 | dup_rows = temp_join_df.duplicated(subset='match_str', keep=False) 137 | temp_join_df = temp_join_df.loc[~dup_rows, :] 138 | logger.info('Remaining images after dropping duplicates: ' 139 | f'{len(temp_join_df)}') 140 | logger.debug('Dropping extra columns from output dataframe.') 141 | output_df = temp_join_df[['image_path', 'label_path']].rename( 142 | columns={'image_path': 'image', 'label_path': 'label'}) 143 | 144 | elif stage == 'infer': 145 | logger.debug('Preparing inference dataset dataframe.') 146 | output_df = temp_im_df.rename(columns={'image_path': 'image'}) 147 | 148 | logger.debug(f'Saving output dataframe to {output_path} .') 149 | output_df.to_csv(output_path, index=False) 150 | 151 | return output_df 152 | -------------------------------------------------------------------------------- /libs/solaris/utils/log.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | 4 | def _get_logging_level(level_int): 5 | """Convert a logging level integer into a log level.""" 6 | if isinstance(level_int, bool): 7 | level_int = int(level_int) 8 | if level_int < 0: 9 | return logging.CRITICAL + 1 # silence all possible outputs 10 | elif level_int == 0: 11 | return logging.WARNING 12 | elif level_int == 1: 13 | return logging.INFO 14 | elif level_int == 2: 15 | return logging.DEBUG 16 | elif level_int in [10, 20, 30, 40, 50]: # if user provides the logger int 17 | return level_int 18 | elif isinstance(level_int, int): # if it's an int but not one of the above 19 | return level_int 20 | else: 21 | raise ValueError(f"logging level set to {level_int}, " 22 | "but it must be an integer <= 2.") 23 | -------------------------------------------------------------------------------- /libs/solaris/utils/raster.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import tensorflow as tf 4 | 5 | 6 | def reorder_axes(arr, target='tensorflow'): 7 | """Check order of axes in an array or tensor and convert to desired format. 8 | 9 | Arguments 10 | --------- 11 | arr : :class:`numpy.array` or :class:`torch.Tensor` or :class:`tensorflow.Tensor` 12 | target : str, optional 13 | Desired axis order type. Possible values: 14 | - ``'tensorflow'`` (default): ``[N, Y, X, C]`` or ``[Y, X, C]`` 15 | - ``'torch'`` : ``[N, C, Y, X]`` or ``[C, Y, X]`` 16 | 17 | Returns 18 | ------- 19 | out_arr : an object of the same class as `arr` with axes in the desired 20 | order. 21 | """ 22 | 23 | if isinstance(arr, torch.Tensor) or isinstance(arr, np.ndarray): 24 | axes = list(arr.shape) 25 | elif isinstance(arr, tf.Tensor): 26 | axes = arr.get_shape().as_list() 27 | 28 | if isinstance(arr, torch.Tensor): 29 | if len(axes) == 3: 30 | if target == 'tensorflow' and axes[0] < axes[1]: 31 | arr = arr.permute(1, 2, 0) 32 | elif target == 'torch' and axes[2] < axes[1]: 33 | arr = arr.permute(2, 0, 1) 34 | elif len(axes) == 4: 35 | if target == 'tensorflow' and axes[1] < axes[2]: 36 | arr = arr.permute(0, 2, 3, 1) 37 | elif target == 'torch' and axes[3] < axes[2]: 38 | arr = arr.permute(0, 3, 1, 2) 39 | 40 | elif isinstance(arr, np.ndarray): 41 | if len(axes) == 3: 42 | if target == 'tensorflow' and axes[0] < axes[1]: 43 | arr = np.moveaxis(arr, 0, -1) 44 | elif target == 'torch' and axes[2] < axes[1]: 45 | arr = np.moveaxis(arr, 2, 0) 46 | elif len(axes) == 4: 47 | if target == 'tensorflow' and axes[1] < axes[2]: 48 | arr = np.moveaxis(arr, 1, -1) 49 | elif target == 'torch' and axes[3] < axes[2]: 50 | arr = np.moveaxis(arr, 3, 1) 51 | 52 | elif isinstance(arr, tf.Tensor): 53 | # permutation is obnoxious in tensorflow; convert to numpy, permute, 54 | # convert back. 55 | np_version = arr.eval() 56 | np_version = reorder_axes(np_version, target=target) 57 | arr = tf.convert_to_tensor(np_version) 58 | 59 | return arr 60 | -------------------------------------------------------------------------------- /libs/solaris/utils/tdigest.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/aws-open-data-satellite-lidar-tutorial/928196f105df202e04d5bcdc64ad449cf65b183d/libs/solaris/utils/tdigest.py -------------------------------------------------------------------------------- /libs/solaris/utils/tile.py: -------------------------------------------------------------------------------- 1 | from .core import _check_crs 2 | import geopandas as gpd 3 | import json 4 | from affine import Affine 5 | from rasterio.windows import Window 6 | from rasterio.vrt import WarpedVRT 7 | from rasterio.enums import Resampling 8 | # temporarily removing the below until I can get COG functionality implemented 9 | # from rio_tiler.utils import get_vrt_transform, has_alpha_band 10 | # from rio_tiler.utils import _requested_tile_aligned_with_internal_tile 11 | 12 | 13 | def save_empty_geojson(path, crs): 14 | crs = _check_crs(crs) 15 | empty_geojson_dict = { 16 | "type": "FeatureCollection", 17 | "crs": 18 | { 19 | "type": "name", 20 | "properties": 21 | { 22 | "name": "urn:ogc:def:crs:EPSG:{}".format(crs.to_epsg()) 23 | } 24 | }, 25 | "features": 26 | [] 27 | } 28 | 29 | with open(path, 'w') as f: 30 | json.dump(empty_geojson_dict, f) 31 | f.close() 32 | 33 | 34 | # def read_cog_tile(src, 35 | # bounds, 36 | # tile_size, 37 | # indexes=None, 38 | # nodata=None, 39 | # resampling_method="bilinear", 40 | # tile_edge_padding=2): 41 | # """ 42 | # Read cloud-optimized geotiff tile. 43 | # 44 | # Notes 45 | # ----- 46 | # Modified from `rio-tiler `_. 47 | # License included below per terms of use. 48 | # 49 | # BSD 3-Clause License 50 | # (c) 2017 Mapbox 51 | # All rights reserved. 52 | # 53 | # Redistribution and use in source and binary forms, with or without 54 | # modification, are permitted provided that the following conditions are met: 55 | # 56 | # * Redistributions of source code must retain the above copyright notice, this 57 | # list of conditions and the following disclaimer. 58 | # 59 | # * Redistributions in binary form must reproduce the above copyright notice, 60 | # this list of conditions and the following disclaimer in the documentation 61 | # and/or other materials provided with the distribution. 62 | # 63 | # * Neither the name of the copyright holder nor the names of its 64 | # contributors may be used to endorse or promote products derived from 65 | # this software without specific prior written permission. 66 | # 67 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 68 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 69 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 70 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 71 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 72 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 73 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 74 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 75 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 76 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 77 | # 78 | # Arguments 79 | # --------- 80 | # src : rasterio.io.DatasetReader 81 | # rasterio.io.DatasetReader object 82 | # bounds : list 83 | # Tile bounds (left, bottom, right, top) 84 | # tile_size : list 85 | # Output image size 86 | # indexes : list of ints or a single int, optional, (defaults: None) 87 | # If `indexes` is a list, the result is a 3D array, but is 88 | # a 2D array if it is a band index number. 89 | # nodata: int or float, optional (defaults: None) 90 | # resampling_method : str, optional (default: "bilinear") 91 | # Resampling algorithm 92 | # tile_edge_padding : int, optional (default: 2) 93 | # Padding to apply to each edge of the tile when retrieving data 94 | # to assist in reducing resampling artefacts along edges. 95 | # 96 | # Returns 97 | # ------- 98 | # out : array, int 99 | # returns pixel value. 100 | # """ 101 | # if isinstance(indexes, int): 102 | # indexes = [indexes] 103 | # elif isinstance(indexes, tuple): 104 | # indexes = list(indexes) 105 | # 106 | # vrt_params = dict( 107 | # add_alpha=True, crs='epsg:' + str(src.crs.to_epsg()), 108 | # resampling=Resampling[resampling_method] 109 | # ) 110 | # 111 | # vrt_transform, vrt_width, vrt_height = get_vrt_transform( 112 | # src, bounds, bounds_crs='epsg:' + str(src.crs.to_epsg())) 113 | # out_window = Window(col_off=0, row_off=0, 114 | # width=vrt_width, height=vrt_height) 115 | # 116 | # if tile_edge_padding > 0 and not \ 117 | # _requested_tile_aligned_with_internal_tile(src, bounds, tile_size): 118 | # vrt_transform = vrt_transform * Affine.translation( 119 | # -tile_edge_padding, -tile_edge_padding 120 | # ) 121 | # orig__vrt_height = vrt_height 122 | # orig_vrt_width = vrt_width 123 | # vrt_height = vrt_height + 2 * tile_edge_padding 124 | # vrt_width = vrt_width + 2 * tile_edge_padding 125 | # out_window = Window( 126 | # col_off=tile_edge_padding, 127 | # row_off=tile_edge_padding, 128 | # width=orig_vrt_width, 129 | # height=orig__vrt_height, 130 | # ) 131 | # 132 | # vrt_params.update(dict(transform=vrt_transform, 133 | # width=vrt_width, 134 | # height=vrt_height)) 135 | # 136 | # indexes = indexes if indexes is not None else src.indexes 137 | # out_shape = (len(indexes), tile_size[1], tile_size[0]) 138 | # 139 | # nodata = nodata if nodata is not None else src.nodata 140 | # if nodata is not None: 141 | # vrt_params.update(dict(nodata=nodata, 142 | # add_alpha=False, 143 | # src_nodata=nodata)) 144 | # 145 | # if has_alpha_band(src): 146 | # vrt_params.update(dict(add_alpha=False)) 147 | # 148 | # with WarpedVRT(src, **vrt_params) as vrt: 149 | # data = vrt.read( 150 | # out_shape=out_shape, 151 | # indexes=indexes, 152 | # window=out_window, 153 | # resampling=Resampling[resampling_method], 154 | # ) 155 | # mask = vrt.dataset_mask(out_shape=(tile_size[1], tile_size[0]), 156 | # window=out_window) 157 | # 158 | # return data, mask, out_window, vrt_transform 159 | -------------------------------------------------------------------------------- /libs/solaris/vector/__init__.py: -------------------------------------------------------------------------------- 1 | from . import graph, mask, polygon 2 | -------------------------------------------------------------------------------- /networks/vgg16_unet.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: MIT-0 3 | 4 | """ Originally developed by XD_XD in the SpaceNet 4 Challenge , then integrated 5 | by Solaris project. We modified the code to accomodate different number of 6 | input channels, e.g. 5-channel RGB+LIDAR input images. 7 | """ 8 | 9 | import torch 10 | from torch import nn 11 | from torchvision.models import vgg16 12 | 13 | 14 | def get_modified_vgg16_unet(in_channels=3): 15 | """ Get a modified VGG16-Unet model with customized input channel numbers. 16 | For example, we can set in_channels=3 and input RGB 3-channel images. 17 | On the other hand, we can set in_channels=5 if we want to input both RGB 18 | and 2-channel LIDAR data (elevation + intensity). 19 | """ 20 | class Modified_VGG16Unet(VGG16Unet): 21 | def __init__(self): 22 | super().__init__(in_channels=in_channels) 23 | return Modified_VGG16Unet 24 | 25 | 26 | class DecoderBlock(nn.Module): 27 | def __init__(self, in_channels, middle_channels, out_channels): 28 | super().__init__() 29 | self.in_channels = in_channels 30 | self.block = nn.Sequential( 31 | nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), 32 | ConvRelu(in_channels, middle_channels), 33 | ConvRelu(middle_channels, out_channels), 34 | ) 35 | 36 | def forward(self, x): 37 | return self.block(x) 38 | 39 | class ConvRelu(nn.Module): 40 | def __init__(self, in_, out): 41 | super().__init__() 42 | self.conv = nn.Conv2d(in_, out, 3, padding=1) 43 | self.activation = nn.ReLU(inplace=True) 44 | 45 | def forward(self, x): 46 | x = self.conv(x) 47 | x = self.activation(x) 48 | return x 49 | 50 | class VGG16Unet(nn.Module): 51 | def __init__(self, in_channels=3, num_filters=32, pretrained=False): 52 | super().__init__() 53 | # Get VGG16 net as encoder 54 | self.encoder = vgg16(pretrained=pretrained).features 55 | self.pool = nn.MaxPool2d(2, 2) 56 | self.relu = nn.ReLU(inplace=True) 57 | 58 | # Modify encoder architecture 59 | self.encoder[0] = nn.Conv2d( 60 | in_channels, 64, kernel_size=3, stride=1, padding=1) 61 | self.conv1 = nn.Sequential( 62 | self.encoder[0], self.relu, self.encoder[2], self.relu) 63 | self.conv2 = nn.Sequential( 64 | self.encoder[5], self.relu, self.encoder[7], self.relu) 65 | self.conv3 = nn.Sequential( 66 | self.encoder[10], self.relu, self.encoder[12], self.relu, 67 | self.encoder[14], self.relu) 68 | self.conv4 = nn.Sequential( 69 | self.encoder[17], self.relu, self.encoder[19], self.relu, 70 | self.encoder[21], self.relu) 71 | self.conv5 = nn.Sequential( 72 | self.encoder[24], self.relu, self.encoder[26], self.relu, 73 | self.encoder[28], self.relu) 74 | 75 | # Build decoder 76 | self.center = DecoderBlock( 77 | 512, num_filters*8*2, num_filters*8) 78 | self.dec5 = DecoderBlock( 79 | 512 + num_filters*8, num_filters*8*2, num_filters*8) 80 | self.dec4 = DecoderBlock( 81 | 512 + num_filters*8, num_filters*8*2, num_filters*8) 82 | self.dec3 = DecoderBlock( 83 | 256 + num_filters*8, num_filters*4*2, num_filters*2) 84 | self.dec2 = DecoderBlock( 85 | 128 + num_filters*2, num_filters*2*2, num_filters) 86 | self.dec1 = ConvRelu(64 + num_filters, num_filters) 87 | 88 | # Final output layer outputs logits, not probability 89 | self.final = nn.Conv2d(num_filters, 1, kernel_size=1) 90 | 91 | def forward(self, x): 92 | conv1 = self.conv1(x) 93 | conv2 = self.conv2(self.pool(conv1)) 94 | conv3 = self.conv3(self.pool(conv2)) 95 | conv4 = self.conv4(self.pool(conv3)) 96 | conv5 = self.conv5(self.pool(conv4)) 97 | center = self.center(self.pool(conv5)) 98 | dec5 = self.dec5(torch.cat([center, conv5], 1)) 99 | dec4 = self.dec4(torch.cat([dec5, conv4], 1)) 100 | dec3 = self.dec3(torch.cat([dec4, conv3], 1)) 101 | dec2 = self.dec2(torch.cat([dec3, conv2], 1)) 102 | dec1 = self.dec1(torch.cat([dec2, conv1], 1)) 103 | x_out = self.final(dec1) 104 | return x_out 105 | -------------------------------------------------------------------------------- /pip-requirements.txt: -------------------------------------------------------------------------------- 1 | albumentations==0.4.3 2 | geopandas>=0.7.0 3 | ipykernel 4 | ipywidgets 5 | numba 6 | opencv-python==4.2.0.34 7 | osmnx 8 | p_tqdm 9 | statsmodels 10 | tensorflow 11 | torch>=1.3.1 12 | torchsummary 13 | torchvision 14 | utm 15 | -------------------------------------------------------------------------------- /setup-env.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 4 | # SPDX-License-Identifier: MIT-0 5 | 6 | # This script setup a conda environment that contains dependencies of 7 | # the 'solaris' module. 8 | 9 | # Exit when error occurs 10 | set -e 11 | 12 | # Create conda environment 13 | export ENV_NAME=$@ 14 | conda create -n $ENV_NAME -y --channel conda-forge 15 | 16 | # Activate the environment in Bash shell 17 | . /home/ec2-user/anaconda3/etc/profile.d/conda.sh 18 | conda activate $ENV_NAME 19 | 20 | # Install dependencies 21 | conda install --file conda-requirements.txt -y --channel conda-forge 22 | pip install -r pip-requirements.txt 23 | 24 | --------------------------------------------------------------------------------