├── .gitignore ├── LICENSE ├── README.md ├── data ├── images │ ├── patch0.npy │ ├── patch0.tfrecords │ ├── patch1.npy │ ├── patch1.tfrecords │ ├── patch10.npy │ ├── patch10.tfrecords │ ├── patch11.npy │ ├── patch11.tfrecords │ ├── patch12.npy │ ├── patch12.tfrecords │ ├── patch13.npy │ ├── patch13.tfrecords │ ├── patch14.npy │ ├── patch14.tfrecords │ ├── patch15.npy │ ├── patch15.tfrecords │ ├── patch16.npy │ ├── patch16.tfrecords │ ├── patch17.npy │ ├── patch17.tfrecords │ ├── patch18.npy │ ├── patch18.tfrecords │ ├── patch19.npy │ ├── patch19.tfrecords │ ├── patch2.npy │ ├── patch2.tfrecords │ ├── patch3.npy │ ├── patch3.tfrecords │ ├── patch4.npy │ ├── patch4.tfrecords │ ├── patch5.npy │ ├── patch5.tfrecords │ ├── patch6.npy │ ├── patch6.tfrecords │ ├── patch7.npy │ ├── patch7.tfrecords │ ├── patch8.npy │ ├── patch8.tfrecords │ ├── patch9.npy │ └── patch9.tfrecords ├── masks │ ├── mask_patch0.npy │ ├── mask_patch1.npy │ ├── mask_patch10.npy │ ├── mask_patch11.npy │ ├── mask_patch12.npy │ ├── mask_patch13.npy │ ├── mask_patch14.npy │ ├── mask_patch15.npy │ ├── mask_patch16.npy │ ├── mask_patch17.npy │ ├── mask_patch18.npy │ ├── mask_patch19.npy │ ├── mask_patch2.npy │ ├── mask_patch3.npy │ ├── mask_patch4.npy │ ├── mask_patch5.npy │ ├── mask_patch6.npy │ ├── mask_patch7.npy │ ├── mask_patch8.npy │ └── mask_patch9.npy └── splits │ ├── test_files_tfrecords │ ├── train_files_numpy │ ├── train_files_tfrecords │ ├── val_files_numpy │ └── val_files_tfrecords ├── graphical_abstract.png ├── image_labels ├── experiments │ └── params.json └── src │ ├── UNet.py │ ├── run_UCAM.py │ ├── tfrecords.py │ ├── train_UNet.py │ └── utils.py └── single_pixel_labels ├── experiments └── params.json └── src ├── datasets.py ├── run_masked.py ├── train_masked.py └── unet.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 David Lobell's Lab Repo 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Weakly Supervised Deep Learning for Segmentation of Remote Sensing Imagery 2 | 3 | This repo contains the implementation of weakly supervised segmentation for remote sensing imagery from [our Remote Sensing paper](https://www.mdpi.com/2072-4292/12/2/207/htm). We perform cropland segmentation using two types of labels commonly found in remote sensing datasets that can be considered sources of “weak supervision”: (1) labels comprised of single geotagged points and (2) image-level labels. 4 | 5 |

6 | 7 |

8 | 9 | ## Usage 10 | 11 | To train a U-Net on single-pixel labels, run 12 | ```bash 13 | cd single_pixel_labels 14 | python src/run_masked.py --model_dir ./experiments 15 | ``` 16 | 17 | To train a U-Net on image-level labels, run 18 | ```bash 19 | cd image_labels 20 | python src/run_UCAM.py --model_dir ./experiments 21 | ``` 22 | 23 | Note that the code for the single-pixel labels was written in PyTorch, while the code for the image-level labels was written in TensorFlow 1.x. 24 | 25 | 26 | ## Data and Experimental Setup 27 | 28 | The models, datasets, and data loaders are currently written to process Landsat tiles in either `.tfrecord` (for the image-level labels) or `.npy` (for the single-pixel labels) formats. The last layer of the `.tfrecord` or `.npy` files are assumed to be the segmentation ground truth label. 29 | 30 | In each `experiments` directory is a `params.json` file that defines the data directory, training/val/test split file names, model architecture, number of channels in the input data, data augmentation procedures to use, and what model outputs to save. The `--model_dir` flag must point to the directory with a valid `params.json` file in order for training to proceed. 31 | 32 | ## Citation 33 | 34 | When using the code from this repo, please cite: 35 | * S. Wang, W. Chen, S. M. Xie, G. Azzari, and D. B. Lobell, “Weakly Supervised Deep Learning for Segmentation of Remote Sensing Imagery,” Remote Sensing, vol. 12, no. 2, p. 207, Jan. 2020, doi: 10.3390/rs12020207. 36 | 37 | Please feel free to email sherwang [at] stanford [dot] edu with any questions about the code or suggestions for improvement. 38 | -------------------------------------------------------------------------------- /data/images/patch0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/images/patch0.npy -------------------------------------------------------------------------------- /data/images/patch0.tfrecords: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/images/patch0.tfrecords -------------------------------------------------------------------------------- /data/images/patch1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/images/patch1.npy -------------------------------------------------------------------------------- /data/images/patch1.tfrecords: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/images/patch1.tfrecords -------------------------------------------------------------------------------- /data/images/patch10.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/images/patch10.npy -------------------------------------------------------------------------------- /data/images/patch10.tfrecords: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/images/patch10.tfrecords -------------------------------------------------------------------------------- /data/images/patch11.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/images/patch11.npy -------------------------------------------------------------------------------- /data/images/patch11.tfrecords: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/images/patch11.tfrecords -------------------------------------------------------------------------------- /data/images/patch12.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/images/patch12.npy -------------------------------------------------------------------------------- /data/images/patch12.tfrecords: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/images/patch12.tfrecords -------------------------------------------------------------------------------- /data/images/patch13.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/images/patch13.npy -------------------------------------------------------------------------------- /data/images/patch13.tfrecords: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/images/patch13.tfrecords -------------------------------------------------------------------------------- /data/images/patch14.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/images/patch14.npy -------------------------------------------------------------------------------- /data/images/patch14.tfrecords: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/images/patch14.tfrecords -------------------------------------------------------------------------------- /data/images/patch15.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/images/patch15.npy -------------------------------------------------------------------------------- /data/images/patch15.tfrecords: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/images/patch15.tfrecords -------------------------------------------------------------------------------- /data/images/patch16.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/images/patch16.npy -------------------------------------------------------------------------------- /data/images/patch16.tfrecords: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/images/patch16.tfrecords -------------------------------------------------------------------------------- /data/images/patch17.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/images/patch17.npy -------------------------------------------------------------------------------- /data/images/patch17.tfrecords: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/images/patch17.tfrecords -------------------------------------------------------------------------------- /data/images/patch18.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/images/patch18.npy -------------------------------------------------------------------------------- /data/images/patch18.tfrecords: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/images/patch18.tfrecords -------------------------------------------------------------------------------- /data/images/patch19.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/images/patch19.npy -------------------------------------------------------------------------------- /data/images/patch19.tfrecords: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/images/patch19.tfrecords -------------------------------------------------------------------------------- /data/images/patch2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/images/patch2.npy -------------------------------------------------------------------------------- /data/images/patch2.tfrecords: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/images/patch2.tfrecords -------------------------------------------------------------------------------- /data/images/patch3.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/images/patch3.npy -------------------------------------------------------------------------------- /data/images/patch3.tfrecords: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/images/patch3.tfrecords -------------------------------------------------------------------------------- /data/images/patch4.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/images/patch4.npy -------------------------------------------------------------------------------- /data/images/patch4.tfrecords: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/images/patch4.tfrecords -------------------------------------------------------------------------------- /data/images/patch5.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/images/patch5.npy -------------------------------------------------------------------------------- /data/images/patch5.tfrecords: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/images/patch5.tfrecords -------------------------------------------------------------------------------- /data/images/patch6.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/images/patch6.npy -------------------------------------------------------------------------------- /data/images/patch6.tfrecords: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/images/patch6.tfrecords -------------------------------------------------------------------------------- /data/images/patch7.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/images/patch7.npy -------------------------------------------------------------------------------- /data/images/patch7.tfrecords: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/images/patch7.tfrecords -------------------------------------------------------------------------------- /data/images/patch8.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/images/patch8.npy -------------------------------------------------------------------------------- /data/images/patch8.tfrecords: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/images/patch8.tfrecords -------------------------------------------------------------------------------- /data/images/patch9.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/images/patch9.npy -------------------------------------------------------------------------------- /data/images/patch9.tfrecords: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/images/patch9.tfrecords -------------------------------------------------------------------------------- /data/masks/mask_patch0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/masks/mask_patch0.npy -------------------------------------------------------------------------------- /data/masks/mask_patch1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/masks/mask_patch1.npy -------------------------------------------------------------------------------- /data/masks/mask_patch10.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/masks/mask_patch10.npy -------------------------------------------------------------------------------- /data/masks/mask_patch11.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/masks/mask_patch11.npy -------------------------------------------------------------------------------- /data/masks/mask_patch12.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/masks/mask_patch12.npy -------------------------------------------------------------------------------- /data/masks/mask_patch13.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/masks/mask_patch13.npy -------------------------------------------------------------------------------- /data/masks/mask_patch14.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/masks/mask_patch14.npy -------------------------------------------------------------------------------- /data/masks/mask_patch15.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/masks/mask_patch15.npy -------------------------------------------------------------------------------- /data/masks/mask_patch16.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/masks/mask_patch16.npy -------------------------------------------------------------------------------- /data/masks/mask_patch17.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/masks/mask_patch17.npy -------------------------------------------------------------------------------- /data/masks/mask_patch18.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/masks/mask_patch18.npy -------------------------------------------------------------------------------- /data/masks/mask_patch19.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/masks/mask_patch19.npy -------------------------------------------------------------------------------- /data/masks/mask_patch2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/masks/mask_patch2.npy -------------------------------------------------------------------------------- /data/masks/mask_patch3.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/masks/mask_patch3.npy -------------------------------------------------------------------------------- /data/masks/mask_patch4.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/masks/mask_patch4.npy -------------------------------------------------------------------------------- /data/masks/mask_patch5.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/masks/mask_patch5.npy -------------------------------------------------------------------------------- /data/masks/mask_patch6.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/masks/mask_patch6.npy -------------------------------------------------------------------------------- /data/masks/mask_patch7.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/masks/mask_patch7.npy -------------------------------------------------------------------------------- /data/masks/mask_patch8.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/masks/mask_patch8.npy -------------------------------------------------------------------------------- /data/masks/mask_patch9.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/masks/mask_patch9.npy -------------------------------------------------------------------------------- /data/splits/test_files_tfrecords: -------------------------------------------------------------------------------- 1 | patch15.tfrecords 2 | patch16.tfrecords 3 | patch17.tfrecords 4 | patch18.tfrecords 5 | patch19.tfrecords 6 | -------------------------------------------------------------------------------- /data/splits/train_files_numpy: -------------------------------------------------------------------------------- 1 | patch0.npy 2 | patch1.npy 3 | patch2.npy 4 | patch3.npy 5 | patch4.npy 6 | patch5.npy 7 | patch6.npy 8 | patch7.npy 9 | patch8.npy 10 | patch9.npy 11 | -------------------------------------------------------------------------------- /data/splits/train_files_tfrecords: -------------------------------------------------------------------------------- 1 | patch0.tfrecords 2 | patch1.tfrecords 3 | patch2.tfrecords 4 | patch3.tfrecords 5 | patch4.tfrecords 6 | patch5.tfrecords 7 | patch6.tfrecords 8 | patch7.tfrecords 9 | patch8.tfrecords 10 | patch9.tfrecords 11 | -------------------------------------------------------------------------------- /data/splits/val_files_numpy: -------------------------------------------------------------------------------- 1 | patch10.npy 2 | patch11.npy 3 | patch12.npy 4 | patch13.npy 5 | patch14.npy 6 | patch15.npy 7 | patch16.npy 8 | patch17.npy 9 | patch18.npy 10 | patch19.npy 11 | -------------------------------------------------------------------------------- /data/splits/val_files_tfrecords: -------------------------------------------------------------------------------- 1 | patch10.tfrecords 2 | patch11.tfrecords 3 | patch12.tfrecords 4 | patch13.tfrecords 5 | patch14.tfrecords 6 | -------------------------------------------------------------------------------- /graphical_abstract.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/graphical_abstract.png -------------------------------------------------------------------------------- /image_labels/experiments/params.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_folder": "../data/images", 3 | "train_file": "splits/train_files_tfrecords", 4 | "eval_file": "splits/val_files_tfrecords", 5 | "test_file": "splits/test_files_tfrecords", 6 | "data_type": "landsat", 7 | "task_type": "classification", 8 | "model": "UNet", 9 | "unet_type": "gap_sigmoid", 10 | 11 | "num_layers": 4, 12 | "starting_filters": 32, 13 | "img_size": 50, 14 | "CAM_size": 48, 15 | "has_class_labels": true, 16 | "has_seg_labels": true, 17 | "class_1_segmentation_proportion": 0.424, 18 | "num_channels": 7, 19 | "use_random_flip_and_rotate": true, 20 | "num_labels": 1, 21 | "fixed_split": true, 22 | "get_label_percent": false, 23 | "get_test_metrics": false, 24 | 25 | "threshold_sample_steps": 1, 26 | "save_summary_steps": 100, 27 | "save_cam": false, 28 | "save_CAM_train": false, 29 | "save_CAM_eval": false, 30 | "get_test_metrics": false, 31 | "get_saliency": false, 32 | "save_best_weights": false, 33 | "save_last_weights": false, 34 | "fix_unet_weights": false, 35 | "restore_from_different_model": false, 36 | 37 | "learning_rate": 1e-3, 38 | "use_batch_norm": true, 39 | "batch_size": 32, 40 | "num_epochs": 200, 41 | "bn_momentum": 0.9, 42 | "l2_lambda": 0.01, 43 | "num_parallel_calls": 4 44 | } 45 | -------------------------------------------------------------------------------- /image_labels/src/UNet.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | def conv_conv_pool(X, f, stage, params, is_training=False, pool=True): 5 | 6 | conv_name_base = 'down' + str(stage) 7 | bn_name_base = 'bn' + str(stage) 8 | 9 | bn_momentum = params.bn_momentum 10 | l2_lambda = params.l2_lambda 11 | 12 | F1, F2 = f 13 | 14 | X = tf.layers.conv2d( 15 | inputs=X, 16 | filters=F1, 17 | kernel_size=(3,3), 18 | strides=(1,1), 19 | padding='same', 20 | kernel_initializer=tf.contrib.layers.xavier_initializer(), 21 | kernel_regularizer=tf.contrib.layers.l2_regularizer(l2_lambda), 22 | name=conv_name_base + 'a') 23 | X = tf.layers.batch_normalization( 24 | inputs=X, 25 | axis=3, 26 | momentum=bn_momentum, 27 | training=is_training, 28 | name=bn_name_base + 'a') 29 | X = tf.nn.relu(X) 30 | 31 | X = tf.layers.conv2d( 32 | inputs=X, 33 | filters=F2, 34 | kernel_size=(3,3), 35 | strides=(1,1), 36 | padding='same', 37 | kernel_initializer=tf.contrib.layers.xavier_initializer(), 38 | kernel_regularizer=tf.contrib.layers.l2_regularizer(l2_lambda), 39 | name=conv_name_base + 'b') 40 | X = tf.layers.batch_normalization( 41 | inputs=X, 42 | axis=3, 43 | momentum=bn_momentum, 44 | training=is_training, 45 | name=bn_name_base + 'b') 46 | X = tf.nn.relu(X) 47 | 48 | if not pool: 49 | return X 50 | 51 | pool = tf.layers.max_pooling2d( 52 | inputs=X, 53 | pool_size=(2,2), 54 | strides=(2,2), 55 | padding='valid') 56 | 57 | return X, pool 58 | 59 | 60 | def upconv_concat(X, X_prev, f, stage, params): 61 | 62 | conv_name_base = 'up' + str(stage) 63 | l2_lambda = params.l2_lambda 64 | 65 | upconv = tf.layers.conv2d_transpose( 66 | inputs=X, 67 | filters=f, 68 | kernel_size=2, 69 | strides=2, 70 | padding='valid', 71 | kernel_initializer=tf.contrib.layers.xavier_initializer(), 72 | kernel_regularizer=tf.contrib.layers.l2_regularizer(l2_lambda), 73 | name=conv_name_base) 74 | 75 | prev_shape = tf.shape(X_prev) 76 | curr_shape = tf.shape(upconv) 77 | offsets = [0, (prev_shape[1] - curr_shape[1]) // 2, (prev_shape[2] - curr_shape[2]) // 2, 0] 78 | new_shape = [-1, curr_shape[1], curr_shape[2], -1] 79 | X_prev_cropped = tf.reshape(tf.slice(X_prev, offsets, new_shape), curr_shape) 80 | 81 | return tf.concat([upconv, X_prev_cropped], axis=-1) 82 | 83 | 84 | def UNet(is_training, inputs, params): 85 | 86 | images = inputs['images'] 87 | num_layers = params.num_layers 88 | starting_filters = params.starting_filters 89 | 90 | # U-Net convolutional layers 91 | convs = [] 92 | conv, pool = conv_conv_pool(images, [starting_filters, starting_filters], 1, params, is_training) 93 | convs.append(conv) 94 | 95 | # add more convlutional layers in loop 96 | for l in range(2, num_layers): 97 | num_filters = starting_filters * 2**(l-1) 98 | conv, pool = conv_conv_pool(pool, [num_filters, num_filters], l, params, is_training) 99 | convs.append(conv) 100 | 101 | if num_layers == 2: 102 | num_filters = starting_filters 103 | 104 | # last convolutional layer (no pool) 105 | conv = conv_conv_pool(pool, [num_filters * 2, num_filters * 2], num_layers, params, is_training, pool=False) 106 | 107 | # convolution transpose layers (deconvolutional layers) 108 | for l in range(1, num_layers): 109 | num_filters = starting_filters * 2**(num_layers - l - 1) 110 | up = upconv_concat(conv, convs.pop(), num_filters, l, params) 111 | conv = conv_conv_pool(up, [num_filters, num_filters], l+num_layers, params, is_training, pool=False) 112 | 113 | CAM_input = tf.identity(conv) 114 | 115 | return CAM_input 116 | 117 | 118 | 119 | def build_UNet(mode, inputs, params, reuse=False): 120 | """ 121 | Function defines the graph operations. 122 | Input: 123 | mode: (string) can be 'train' or 'eval' 124 | inputs: (dict) contains the inputs of the graph (features, labels, etc.) 125 | params: (params) contains hyperparameters of the model 126 | reuse: (bool) whether to reuse the weights 127 | Returns: 128 | model_spec: (dict) contains the graph operations for training/evaluation 129 | """ 130 | is_training = (mode == 'train') 131 | labels = inputs.get('labels') 132 | seg_labels = inputs.get('seg_labels') 133 | 134 | with tf.variable_scope('model', reuse=reuse): 135 | CAM_input = UNet(is_training, inputs, params) 136 | 137 | if params.unet_type == 'sigmoid_gap': 138 | cam_unnorm = tf.layers.conv2d( 139 | inputs=CAM_input, 140 | filters=1, 141 | kernel_size=(1,1), 142 | strides=(1,1), 143 | padding='valid', 144 | kernel_initializer=tf.contrib.layers.xavier_initializer(), 145 | kernel_regularizer=tf.contrib.layers.l2_regularizer(params.l2_lambda), 146 | name='final_conv') 147 | 148 | cam_unnorm = tf.squeeze(cam_unnorm, 3) # get rid of last dim 149 | cam = tf.sigmoid(cam_unnorm) 150 | cam_predictions = tf.round(cam) 151 | 152 | elif params.unet_type == 'gap_sigmoid': 153 | # out = tf.sigmoid(CAM_input) 154 | cam = tf.reduce_mean(CAM_input, axis=[1,2]) 155 | cam = tf.contrib.layers.flatten(cam) 156 | logits = tf.layers.dense(inputs=cam, units=params.num_labels, name='dense') 157 | 158 | 159 | if params.task_type == 'classification': 160 | 161 | if params.unet_type == 'sigmoid_gap': 162 | expits = tf.reduce_mean(cam, axis=[1, 2]) 163 | expits = tf.expand_dims(expits, -1) 164 | 165 | elif params.unet_type == 'gap_sigmoid': 166 | expits = tf.sigmoid(logits) 167 | 168 | predictions = tf.round(expits) 169 | 170 | else: 171 | predictions = cam_predictions 172 | 173 | # possibly classification task also has segmentation labels 174 | if seg_labels is not None and params.unet_type == 'sigmoid_gap': 175 | # sizes might not match, cut the segmentation label down to size 176 | extra_h = tf.shape(seg_labels)[1] - tf.shape(cam)[1] 177 | extra_h_before = extra_h // 2 178 | extra_h_after = tf.shape(seg_labels)[1] - (extra_h - extra_h_before) 179 | extra_w = tf.shape(seg_labels)[2] - tf.shape(cam)[2] 180 | extra_w_before = extra_w // 2 181 | extra_w_after = tf.shape(seg_labels)[2] - (extra_w - extra_w_before) 182 | seg_labels_cut = seg_labels[:, extra_h_before:extra_h_after, extra_w_before:extra_w_after] 183 | # cam or cam unnorm 184 | ce = tf.nn.sigmoid_cross_entropy_with_logits(labels=seg_labels_cut, logits=cam_unnorm) 185 | weighted_ce = (seg_labels_cut * ce) / params.class_1_segmentation_proportion + (1-seg_labels_cut) * ce 186 | segmentation_loss = tf.reduce_mean(weighted_ce) 187 | 188 | # Define loss and task accuracy 189 | if params.task_type == 'classification': 190 | train_labels = labels 191 | ce = -train_labels * tf.log(expits + 1.0e-8) - (1 - train_labels) * tf.log(1 - expits + 1.0e-8) 192 | weighted_ce = (train_labels * ce) / params.class_1_segmentation_proportion + \ 193 | (1 - train_labels) * ce 194 | loss = tf.reduce_mean(weighted_ce) 195 | logits = -tf.log((1.0 / expits) - 1 + 1e-8) 196 | else: 197 | train_labels = seg_labels_cut 198 | loss = segmentation_loss 199 | 200 | l2_loss = tf.losses.get_regularization_loss() 201 | loss += l2_loss 202 | accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.round(train_labels), predictions), tf.float32)) 203 | 204 | # Define training step that minimizes loss with Adam optimizer 205 | if is_training: 206 | optimizer = tf.train.AdamOptimizer(params.learning_rate) 207 | global_step = tf.train.get_or_create_global_step() 208 | if params.use_batch_norm: 209 | # Add a dependency to update the moving mean and variance for batch normalization 210 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 211 | with tf.control_dependencies(update_ops): 212 | train_op = optimizer.minimize(loss, global_step=global_step) 213 | else: 214 | train_op = optimizer.minimize(loss, global_step=global_step) 215 | 216 | # METRICS AND SUMMARIES 217 | with tf.variable_scope('metrics'): 218 | metrics = { 219 | 'loss': tf.metrics.mean(loss), 220 | 'accuracy': tf.metrics.accuracy(labels=tf.round(train_labels), predictions=predictions) 221 | } 222 | if params.task_type == 'classification' and seg_labels is not None and params.unet_type == 'sigmoid_gap': 223 | metrics.update({'segmentation_loss': tf.metrics.mean(segmentation_loss), 224 | 'segmentation_acc': tf.metrics.accuracy(labels=seg_labels_cut, predictions=cam_predictions) 225 | }) 226 | # Group the update ops for the tf.metrics 227 | # print(metrics.values()) 228 | update_metrics_op = tf.group(*[op for _, op in metrics.values()]) 229 | metric_variables = tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES, scope='metrics') 230 | metrics_init_op = tf.variables_initializer(metric_variables) 231 | 232 | summaries = [] 233 | summaries.append(tf.summary.scalar('loss', loss)) 234 | summaries.append(tf.summary.scalar('accuracy', accuracy)) 235 | summaries.append(tf.summary.scalar('loss_MA', metrics['loss'][0])) 236 | summaries.append(tf.summary.scalar('accuracy_MA', metrics['accuracy'][0])) 237 | if params.task_type == 'classification' and seg_labels is not None and params.unet_type == 'sigmoid_gap': 238 | summaries.append(tf.summary.scalar('segmentation_loss', segmentation_loss)) 239 | summaries.append(tf.summary.scalar('segmentation_loss_MA', metrics['segmentation_loss'][0])) 240 | summaries.append(tf.summary.scalar('segmentation_acc_MA', metrics['segmentation_acc'][0])) 241 | 242 | if params.dict.get('rgb_image') in {None, False}: 243 | image_rgb = 255.0 * 5 * inputs['images'][:,:,:,1:4][:,:,:,::-1] 244 | summaries.append(tf.summary.image('image', image_rgb)) 245 | else: 246 | image_rgb = inputs['images'] 247 | 248 | if params.unet_type == 'sigmoid_gap': 249 | paddings_h = tf.shape(image_rgb)[1] - tf.shape(cam)[1] 250 | paddings_h_before = paddings_h // 2 251 | paddings_h_after = paddings_h - paddings_h_before 252 | paddings_w = tf.shape(image_rgb)[2] - tf.shape(cam)[2] 253 | paddings_w_before = paddings_w // 2 254 | paddings_w_after = paddings_w - paddings_w_before 255 | paddings = tf.convert_to_tensor([[0, 0], [paddings_h_before, paddings_h_after], [paddings_w_before, paddings_w_after]]) 256 | cam_padded = tf.pad(cam, paddings, 'CONSTANT', constant_values=0.5) 257 | cam_rgb = 255.0 * tf.tile(tf.expand_dims(cam_padded,-1), [1,1,1,3]) 258 | 259 | if seg_labels is not None: 260 | # tile the segmentation, label and image 261 | label_rgb = 255.0 * tf.tile(tf.expand_dims(seg_labels,-1), [1,1,1,3]) 262 | concated = tf.concat([image_rgb, label_rgb], axis=2) 263 | else: 264 | concated = image_rgb 265 | 266 | hard_cam_padded = tf.pad(cam_predictions, paddings, 'CONSTANT', constant_values=0.5) 267 | hard_cam_rgb = 255.0 * tf.tile(tf.expand_dims(hard_cam_padded, -1), [1, 1, 1, 3]) 268 | concated = tf.concat([concated, cam_rgb, hard_cam_rgb], axis=2) 269 | summaries.append(tf.summary.image('concatenated_images', concated, 50)) 270 | 271 | model_spec = inputs 272 | model_spec['variable_init_op'] = tf.global_variables_initializer() 273 | model_spec['predictions'] = predictions 274 | model_spec['loss'] = loss 275 | model_spec['accuracy'] = accuracy 276 | model_spec['cam'] = cam 277 | model_spec['metrics_init_op'] = metrics_init_op 278 | model_spec['metrics'] = metrics 279 | model_spec['update_metrics'] = update_metrics_op 280 | model_spec['summary_op'] = tf.summary.merge(summaries) 281 | 282 | if params.unet_type == 'gap_sigmoid': 283 | model_spec['CAM_input'] = CAM_input 284 | model_spec['dense_kernel'] = [v for v in tf.trainable_variables() if "dense/kernel" in v.name][0] 285 | model_spec['dense_bias'] = [v for v in tf.trainable_variables() if "dense/bias" in v.name][0] 286 | 287 | if is_training: 288 | model_spec['train_op'] = train_op 289 | 290 | return model_spec 291 | -------------------------------------------------------------------------------- /image_labels/src/run_UCAM.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | import datetime 5 | import uuid 6 | from pathlib import Path 7 | 8 | from tfrecords import tfrecord_iterator 9 | from UNet import build_UNet 10 | from utils import Params, set_logger 11 | from train_UNet import train_and_evaluate 12 | 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('--model_dir', default='../experiments', 15 | help="Experiment directory containing params.json") 16 | parser.add_argument('--data_dir', default='../data', 17 | help="Directory containing the dataset") 18 | parser.add_argument('--restore_from', default=None, 19 | help="Optional directory or file containing weights to reload before training") 20 | 21 | if __name__ == '__main__': 22 | # Uncomment below if want reproducible experiments 23 | # tf.set_random_seed(230) 24 | 25 | # Load parameters from json file 26 | args = parser.parse_args() 27 | json_path = os.path.join(args.model_dir, 'params.json') 28 | assert os.path.isfile(json_path), "No params.json configuration file found at {}".format(json_path) 29 | params = Params(json_path) 30 | 31 | # Set up logger 32 | ts_uuid = f'{datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S%z")}_{uuid.uuid4().hex[:6]}' 33 | log_dir = Path(args.model_dir) / 'logs' 34 | log_dir.mkdir(exist_ok=True) 35 | set_logger(str(log_dir / f'{ts_uuid}_train.log')) 36 | logging.info("Creating the datasets from TFRecords...") 37 | data_dir = args.data_dir 38 | data_folder = params.data_folder 39 | train_file = os.path.join(data_dir, params.train_file) 40 | val_file = os.path.join(data_dir, params.eval_file) 41 | test_file = os.path.join(data_dir, params.test_file) 42 | 43 | train_filenames = [] 44 | val_filenames = [] 45 | test_filenames = [] 46 | with open(train_file) as f: 47 | for l in f: 48 | train_filenames.append(os.path.join(data_dir, data_folder, l[:-1])) 49 | with open(val_file) as f: 50 | for l in f: 51 | val_filenames.append(os.path.join(data_dir, data_folder, l[:-1])) 52 | with open(test_file) as f: 53 | for l in f: 54 | test_filenames.append(os.path.join(data_dir, data_folder, l[:-1])) 55 | params.train_size = len(train_filenames) 56 | params.eval_size = len(val_filenames) 57 | params.test_size = len(test_filenames) 58 | 59 | train_dataset = tfrecord_iterator(True, train_filenames, params) 60 | eval_dataset = tfrecord_iterator(False, val_filenames, params) 61 | test_dataset = tfrecord_iterator(False, test_filenames, params) 62 | 63 | logging.info("Creating the model...") 64 | train_model = build_UNet('train', train_dataset, params) 65 | eval_model = build_UNet('eval', eval_dataset, params, reuse=True) 66 | test_model = build_UNet('eval', test_dataset, params, reuse=True) 67 | 68 | logging.info("Starting training for {} epochs".format(params.num_epochs)) 69 | train_and_evaluate(train_model, eval_model, test_model, args.model_dir, params, args.restore_from, ts_uuid) 70 | 71 | 72 | 73 | -------------------------------------------------------------------------------- /image_labels/src/tfrecords.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | 5 | def parse_function(example, num_channels, img_size, 6 | has_class_labels, has_seg_labels): 7 | """ 8 | Transforms an example into a pair of a tensor and a float, representing 9 | an image and its label, respectively. 10 | """ 11 | feature_dict = { 12 | 'image': tf.FixedLenFeature((img_size, img_size, num_channels), tf.float32), 13 | } 14 | if has_class_labels: 15 | feature_dict.update({'label': tf.FixedLenFeature((1,), tf.float32)}) 16 | if has_seg_labels: 17 | feature_dict.update( 18 | {'seg_labels': tf.FixedLenFeature((img_size, img_size), tf.float32)}) 19 | parsed_features = tf.parse_single_example(example, features=feature_dict) 20 | 21 | ret = [parsed_features['image']] 22 | if has_class_labels: 23 | ret.append(parsed_features['label']) 24 | else: 25 | # TF doesn't accept None, we filter this out later anyway 26 | ret.append(0) 27 | 28 | if has_seg_labels: 29 | ret.append(parsed_features['seg_labels']) 30 | else: 31 | ret.append(0) 32 | return tuple(ret) 33 | 34 | 35 | def augment_data(image, label, seg_label, perform_random_flip_and_rotate, 36 | num_channels, has_seg_labels): 37 | """ 38 | Image augmentation for training. Applies the following operations: 39 | - Horizontally flip the image with probabiliy 0.5 40 | - Vertically flip the image with probability 0.5 41 | - Apply random rotation 42 | """ 43 | if perform_random_flip_and_rotate: 44 | if has_seg_labels: 45 | image = tf.concat([image, tf.expand_dims(seg_label, -1)], 2) 46 | image = tf.image.random_flip_left_right(image) 47 | image = tf.image.random_flip_up_down(image) 48 | rotate_angle = tf.random_shuffle([0.0, 90.0, 180.0, 270.0])[0] 49 | image = tf.contrib.image.rotate( 50 | image, rotate_angle * np.pi / 180.0, interpolation='BILINEAR') 51 | if has_seg_labels: 52 | seg_label = image[:, :, -1] 53 | 54 | image = image[:,:,:num_channels] 55 | 56 | return image, label, seg_label 57 | 58 | 59 | def tfrecord_iterator(is_training, file_names, params): 60 | 61 | # Create a Dataset serving batches of images and labels 62 | # We don't repeat for multiple epochs because we always train and evaluate for one epoch 63 | def parse_fn(p): 64 | return parse_function( 65 | p, params.num_channels, params.img_size, params.has_class_labels, 66 | params.has_seg_labels) 67 | 68 | def train_fn(f, l, sl): 69 | return augment_data( 70 | f, l, sl, params.use_random_flip_and_rotate, params.num_channels, 71 | params.has_seg_labels) 72 | 73 | if is_training: 74 | dataset = (tf.data.TFRecordDataset(file_names) 75 | .map(parse_fn, num_parallel_calls=params.num_parallel_calls) 76 | .map(train_fn, num_parallel_calls=params.num_parallel_calls) 77 | .batch(params.batch_size) 78 | .prefetch(1)) 79 | else: 80 | dataset = (tf.data.TFRecordDataset(file_names) 81 | .map(parse_fn) 82 | .batch(params.batch_size) 83 | .prefetch(1)) 84 | 85 | # Create reinitializable iterator from dataset 86 | iterator = dataset.make_initializable_iterator() 87 | images, labels, seg_labels = iterator.get_next() 88 | iterator_init_op = iterator.initializer 89 | 90 | inputs = {'images': images, 91 | 'iterator_init_op': iterator_init_op} 92 | if params.has_class_labels: 93 | inputs.update({'labels': labels}) 94 | if params.has_seg_labels: 95 | inputs.update({'seg_labels': seg_labels}) 96 | 97 | return inputs 98 | -------------------------------------------------------------------------------- /image_labels/src/train_UNet.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from pathlib import Path 4 | 5 | import numpy as np 6 | import pandas as pd 7 | from tqdm import trange 8 | import tensorflow as tf 9 | 10 | from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score 11 | 12 | def save_cam(sess, cam, dense_kernel, dense_bias, predictions, model_spec, model_dir, metrics_val, epoch, num_steps): 13 | cam_dir = os.path.join(model_dir, 'eval_cam') 14 | if not os.path.exists(cam_dir): 15 | os.makedirs(cam_dir) 16 | best_acc = 0.0 17 | best_acc_dir = os.path.join(cam_dir, 'best_eval_acc.npy') 18 | if os.path.isfile(best_acc_dir): 19 | best_acc = np.load(best_acc_dir)[0] 20 | 21 | if metrics_val['accuracy'] > best_acc: 22 | print("Saving CAM for best epoch so far...") 23 | np.save(best_acc_dir, np.array([metrics_val['accuracy'], epoch])) 24 | sess.run(model_spec['iterator_init_op']) 25 | for i in range(num_steps): 26 | cam_val, dense_kernel_val, dense_bias_val, predictions_val = sess.run([cam, dense_kernel, dense_bias, predictions]) 27 | np.save(os.path.join(cam_dir, 'eval-cam-batch'+str(i)), cam_val) 28 | np.save(os.path.join(cam_dir, 'eval-dense-kernel-batch'+str(i)), dense_kernel_val) 29 | np.save(os.path.join(cam_dir, 'eval-dense-bias-batch'+str(i)), dense_bias_val) 30 | np.save(os.path.join(cam_dir, 'eval-prediction-batch'+str(i)), predictions_val) 31 | 32 | 33 | def get_best_threshold(sess, model_spec, num_steps, params): 34 | 35 | sample_steps = params.threshold_sample_steps 36 | batch_size = params.batch_size 37 | train_size = params.train_size 38 | if train_size > 20000: 39 | CAMs = np.zeros((int(batch_size * (num_steps // sample_steps + 1)), params.CAM_size, params.CAM_size)) 40 | labels_vals = np.zeros((int(batch_size * (num_steps // sample_steps + 1)), 1)) 41 | else: 42 | CAMs = np.zeros((train_size, params.CAM_size, params.CAM_size)) 43 | labels_vals = np.zeros((train_size, 1)) 44 | sample_steps = 1 45 | 46 | CAM_input = model_spec['CAM_input'] 47 | dense_kernel = model_spec['dense_kernel'] 48 | dense_bias = model_spec['dense_bias'] 49 | 50 | loss = model_spec['loss'] 51 | labels = model_spec['labels'] 52 | 53 | sess.run(model_spec['iterator_init_op']) 54 | 55 | logging.info("- Computing optimal threshold on training set...") 56 | t = trange(num_steps) 57 | j = 0 58 | 59 | for i in t: 60 | if i % sample_steps == 0: 61 | CAM_input_val, dense_kernel_val, dense_bias_val, labels_val = sess.run([CAM_input, dense_kernel, dense_bias, labels]) 62 | if CAM_input_val.shape[0] != batch_size: 63 | continue 64 | 65 | CAM = np.sum(np.matmul(CAM_input_val, dense_kernel_val) + dense_bias_val, axis=-1) 66 | CAMs[j*batch_size:(j+1)*batch_size,:,:] = np.array(CAM) 67 | labels_vals[j*batch_size:(j+1)*batch_size,:] = np.array(labels_val) 68 | 69 | j = j + 1 70 | print("dense kernel shape:", dense_kernel_val.shape) 71 | 72 | def get_classif_acc(threshold, cam, ground_truth): 73 | segpreds = cam > threshold # all cam values > threshold become 1 74 | preds = np.mean(segpreds, axis=(1,2)) > 0.5 # image-level label becomes 1 if >0.5 pixels are 1 75 | acc = sum(preds == ground_truth) / len(ground_truth) 76 | return acc 77 | 78 | possible_thresholds = list(np.linspace(-10,10,1001)) 79 | labels_vals = labels_vals.flatten().astype(bool) 80 | classif_accs = [get_classif_acc(x, CAMs, labels_vals) for x in possible_thresholds] 81 | 82 | return np.array(possible_thresholds)[np.argmax(np.array(classif_accs))], np.max(classif_accs) 83 | 84 | 85 | def get_seg_metrics(sess, model_spec, num_steps, params, threshold): 86 | 87 | CAM_input = model_spec['CAM_input'] 88 | dense_kernel = model_spec['dense_kernel'] 89 | dense_bias = model_spec['dense_bias'] 90 | seg_labels = model_spec['seg_labels'] 91 | 92 | acc = 0.0 93 | pre = 0.0 94 | rec = 0.0 95 | f1s = 0.0 96 | 97 | sess.run(model_spec['iterator_init_op']) 98 | 99 | print("- Computing segmentation accuracy...") 100 | t = trange(num_steps) 101 | 102 | for i in t: 103 | CAM_input_val, dense_kernel_val, dense_bias_val, seg_values = sess.run([CAM_input, dense_kernel, dense_bias, seg_labels]) 104 | CAM = np.sum(np.matmul(CAM_input_val, dense_kernel_val) + dense_bias_val, axis=-1) 105 | seg_pred = CAM > threshold 106 | offset = (seg_values.shape[1] - params.CAM_size) // 2 107 | seg_true = seg_values[:,offset:params.CAM_size+offset,offset:params.CAM_size+offset] 108 | seg_true = seg_true.astype(int) 109 | acc += accuracy_score(seg_true.flatten(), seg_pred.flatten()) 110 | pre += precision_score(seg_true.flatten(), seg_pred.flatten()) 111 | rec += recall_score(seg_true.flatten(), seg_pred.flatten()) 112 | f1s += f1_score(seg_true.flatten(), seg_pred.flatten()) 113 | 114 | return acc / num_steps, pre / num_steps, rec / num_steps, f1s / num_steps 115 | 116 | 117 | def train_sess(sess, model_spec, num_steps, writer, params, model_dir, epoch): 118 | 119 | # Get relevant graph operations or nodes needed for training 120 | loss = model_spec['loss'] 121 | train_op = model_spec['train_op'] 122 | update_metrics = model_spec['update_metrics'] 123 | metrics = model_spec['metrics'] 124 | summary_op = model_spec['summary_op'] 125 | global_step = tf.train.get_global_step() 126 | 127 | # Load training dataset into pipeline and initialize the metrics local variables 128 | sess.run(model_spec['iterator_init_op']) 129 | sess.run(model_spec['metrics_init_op']) 130 | 131 | # Use tqdm for progress bar 132 | t = trange(num_steps) 133 | for i in t: 134 | handles = [train_op, update_metrics, loss] 135 | # Evaluate summaries every 100 steps 136 | if i % 100 == 0: 137 | _, _, loss_val, summ, global_step_val = sess.run( 138 | handles + [summary_op, global_step]) 139 | 140 | # Write training summary to tensorboard 141 | writer.add_summary(summ, global_step_val) 142 | else: 143 | _, _, loss_val = sess.run(handles) 144 | 145 | # Log the loss in the tqdm progress bar 146 | t.set_postfix(loss='{:05.3f}'.format(loss_val)) 147 | 148 | metrics_values = {k: v[0] for k, v in metrics.items()} 149 | metrics_val = sess.run(metrics_values) 150 | metrics_string = " ; ".join("{}: {:05.3f}".format(k, v) for k, v in metrics_val.items()) 151 | logging.info("- Train metrics: " + metrics_string) 152 | 153 | return metrics_val 154 | 155 | 156 | def evaluate_sess(sess, model_spec, num_steps, writer, params, model_dir, epoch): 157 | update_metrics = model_spec['update_metrics'] 158 | metrics = model_spec['metrics'] 159 | summary_op = model_spec['summary_op'] 160 | global_step = tf.train.get_global_step() 161 | predictions = model_spec['predictions'] 162 | if params.model == 'UNet': 163 | CAM_input = model_spec['CAM_input'] 164 | dense_kernel = model_spec['dense_kernel'] 165 | dense_bias = model_spec['dense_bias'] 166 | 167 | # Load evaluation dataset into pipeline and initialize the metrics local variables 168 | sess.run(model_spec['iterator_init_op']) 169 | sess.run(model_spec['metrics_init_op']) 170 | 171 | for i in range(num_steps): 172 | handles = [update_metrics] 173 | if i % 100 == 0: 174 | handles += [summary_op, global_step] 175 | _, summ, global_step_val = sess.run(handles) 176 | writer.add_summary(summ, global_step_val) 177 | else: 178 | # adds results automatically to metrics 179 | sess.run(handles) 180 | 181 | # Get the values of the metrics 182 | metrics_values = {k: v[0] for k, v in metrics.items()} 183 | metrics_val = sess.run(metrics_values) 184 | metrics_string = " ; ".join("{}: {:05.3f}".format(k, v) for k, v in metrics_val.items()) 185 | logging.info("- Eval metrics: " + metrics_string) 186 | 187 | if params.model == 'UNet' and params.save_cam and epoch > 5: 188 | save_cam(sess, CAM_input, dense_kernel, dense_bias, predictions, model_spec, model_dir, metrics_val, epoch, num_steps) 189 | 190 | return metrics_val 191 | 192 | 193 | def train_and_evaluate(train_model_spec, eval_model_spec, test_model_spec, model_dir, params, restore_from=None, ts_uuid=''): 194 | model_dir = Path(model_dir) 195 | 196 | train_loss = [] 197 | train_acc = [] 198 | train_segacc_list = [] 199 | train_segpre_list = [] 200 | train_segrec_list = [] 201 | train_segf1s_list = [] 202 | 203 | eval_loss = [] 204 | eval_acc = [] 205 | eval_segacc_list = [] 206 | eval_segpre_list = [] 207 | eval_segrec_list = [] 208 | eval_segf1s_list = [] 209 | 210 | test_loss = [] 211 | test_acc = [] 212 | test_segacc_list = [] 213 | test_segpre_list = [] 214 | test_segrec_list = [] 215 | test_segf1s_list = [] 216 | 217 | saver = tf.train.Saver() 218 | 219 | with tf.Session() as sess: 220 | # Initialize model variables 221 | sess.run(train_model_spec['variable_init_op']) 222 | 223 | # Set up tensorboard files 224 | train_summary_dir = model_dir / f'{ts_uuid}_train_summaries' 225 | eval_summary_dir = model_dir / f'{ts_uuid}_eval_summaries' 226 | test_summary_dir = model_dir / f'{ts_uuid}_test_summaries' 227 | train_summary_dir.mkdir(exist_ok=True) 228 | eval_summary_dir.mkdir(exist_ok=True) 229 | test_summary_dir.mkdir(exist_ok=True) 230 | train_writer = tf.summary.FileWriter(str(train_summary_dir), sess.graph) 231 | eval_writer = tf.summary.FileWriter(str(eval_summary_dir), sess.graph) 232 | test_writer = tf.summary.FileWriter(str(test_summary_dir), sess.graph) 233 | 234 | best_accuracy = 0.0 235 | 236 | for epoch in range(params.num_epochs): 237 | # Run one epoch 238 | logging.info("Epoch {}/{}".format(epoch + 1, params.num_epochs)) 239 | 240 | # Compute number of batches in one epoch 241 | num_steps = (params.train_size + params.batch_size - 1) // params.batch_size 242 | train_metrics = train_sess(sess, train_model_spec, num_steps, train_writer, params, model_dir, epoch) 243 | 244 | # if gap_sigmoid model, find best threshold 245 | if params.unet_type == 'gap_sigmoid': 246 | best_thresh, best_acc = get_best_threshold(sess, train_model_spec, num_steps, params) 247 | logging.info('- Best threshold on train samples: {:0.4f}, best accuracy: {:0.4f}'.format(best_thresh, best_acc)) 248 | 249 | train_segacc, train_segpre, train_segrec, train_segf1s = get_seg_metrics(sess, train_model_spec, num_steps, params, best_thresh) 250 | train_segacc_list.append(train_segacc) 251 | train_segpre_list.append(train_segpre) 252 | train_segrec_list.append(train_segrec) 253 | train_segf1s_list.append(train_segf1s) 254 | logging.info('- Train segmentation accuracy: {:0.4f}'.format(train_segacc)) 255 | else: 256 | train_segacc_list.append(train_metrics['segmentation_acc']) 257 | 258 | # Evaluate for one epoch on validation set 259 | num_steps = (params.eval_size + params.batch_size - 1) // params.batch_size 260 | eval_metrics = evaluate_sess(sess, eval_model_spec, num_steps, eval_writer, params, model_dir, epoch) 261 | 262 | if params.unet_type == 'gap_sigmoid': 263 | eval_segacc, eval_segpre, eval_segrec, eval_segf1s = get_seg_metrics(sess, eval_model_spec, num_steps, params, best_thresh) 264 | eval_segacc_list.append(eval_segacc) 265 | eval_segpre_list.append(eval_segpre) 266 | eval_segrec_list.append(eval_segrec) 267 | eval_segf1s_list.append(eval_segf1s) 268 | logging.info('- Eval segmentation accuracy: {:0.4f}'.format(eval_segacc)) 269 | else: 270 | eval_segacc.append(eval_metrics['segmentation_acc']) 271 | 272 | # Evaluate for one epoch on test set 273 | num_steps = (params.test_size + params.batch_size - 1) // params.batch_size 274 | test_metrics = evaluate_sess(sess, test_model_spec, num_steps, test_writer, params, model_dir, epoch) 275 | 276 | if params.unet_type == 'gap_sigmoid': 277 | test_segacc, test_segpre, test_segrec, test_segf1s = get_seg_metrics(sess, test_model_spec, num_steps, params, best_thresh) 278 | test_segacc_list.append(test_segacc) 279 | test_segpre_list.append(test_segpre) 280 | test_segrec_list.append(test_segrec) 281 | test_segf1s_list.append(test_segf1s) 282 | logging.info('- Test segmentation accuracy: {:0.4f}'.format(test_segacc)) 283 | else: 284 | test_segacc.append(test_metrics['segmentation_acc']) 285 | 286 | train_loss.append(train_metrics['loss']) 287 | train_acc.append(train_metrics['accuracy']) 288 | eval_loss.append(eval_metrics['loss']) 289 | eval_acc.append(eval_metrics['accuracy']) 290 | test_loss.append(test_metrics['loss']) 291 | test_acc.append(test_metrics['accuracy']) 292 | 293 | # Save best model so far based on task accuracy (image classification) 294 | if eval_metrics['accuracy'] > best_accuracy: 295 | print("Saving best model...") 296 | saver.save(sess, os.path.join(model_dir, "best_model.ckpt")) 297 | best_accuracy = max(best_accuracy, eval_metrics['accuracy']) 298 | 299 | # Write metrics to disk for easy analysis 300 | df = pd.DataFrame({'train_loss': train_loss, 'eval_loss': eval_loss, 'test_loss': test_loss, 301 | 'train_accuracy': train_acc, 'eval_accuracy': eval_acc, 'test_accuracy': test_acc, 302 | 'train_segacc': train_segacc_list, 'eval_segacc': eval_segacc_list, 'test_segacc': test_segacc_list, 303 | 'train_segpre': train_segpre_list, 'eval_segpre': eval_segpre_list, 'test_segpre': test_segpre_list, 304 | 'train_segrec': train_segrec_list, 'eval_segrec': eval_segrec_list, 'test_segrec': test_segrec_list, 305 | 'train_segf1s': train_segf1s_list, 'eval_segf1s': eval_segf1s_list, 'test_segf1s': test_segf1s_list}) 306 | df.to_csv(os.path.join(model_dir, 'metrics.csv'), index=False) 307 | 308 | 309 | -------------------------------------------------------------------------------- /image_labels/src/utils.py: -------------------------------------------------------------------------------- 1 | """General utility functions""" 2 | 3 | import json 4 | import logging 5 | import numpy as np 6 | import os 7 | 8 | 9 | class Params(): 10 | """Class that loads hyperparameters from a json file. 11 | Example: 12 | ``` 13 | params = Params(json_path) 14 | print(params.learning_rate) 15 | params.learning_rate = 0.5 # change the value of learning_rate in params 16 | ``` 17 | """ 18 | 19 | def __init__(self, json_path): 20 | self.update(json_path) 21 | 22 | def save(self, json_path): 23 | """Saves parameters to json file""" 24 | with open(json_path, 'w') as f: 25 | json.dump(self.__dict__, f, indent=4) 26 | 27 | def update(self, json_path): 28 | """Loads parameters from json file""" 29 | with open(json_path) as f: 30 | params = json.load(f) 31 | self.__dict__.update(params) 32 | 33 | @property 34 | def dict(self): 35 | """Gives dict-like access to Params instance by `params.dict['learning_rate']`""" 36 | return self.__dict__ 37 | 38 | 39 | def set_logger(log_path): 40 | """Sets the logger to log info in terminal and file `log_path`. 41 | In general, it is useful to have a logger so that every output to the terminal is saved 42 | in a permanent file. Here we save it to `model_dir/train.log`. 43 | Example: 44 | ``` 45 | logging.info("Starting training...") 46 | ``` 47 | Args: 48 | log_path: (string) where to log 49 | """ 50 | logger = logging.getLogger() 51 | logger.setLevel(logging.INFO) 52 | 53 | if not logger.handlers: 54 | # Logging to a file 55 | file_handler = logging.FileHandler(log_path) 56 | file_handler.setFormatter(logging.Formatter('%(asctime)s:%(levelname)s: %(message)s')) 57 | logger.addHandler(file_handler) 58 | 59 | # Logging to console 60 | stream_handler = logging.StreamHandler() 61 | stream_handler.setFormatter(logging.Formatter('%(message)s')) 62 | logger.addHandler(stream_handler) -------------------------------------------------------------------------------- /single_pixel_labels/experiments/params.json: -------------------------------------------------------------------------------- 1 | { 2 | "model": "unet", 3 | "run": "run0", 4 | 5 | "tile_dir": "../data/images", 6 | "train_file": "../data/splits/train_files_numpy", 7 | "val_file": "../data/splits/val_files_numpy", 8 | "mask_dir": "../data/masks", 9 | "n_train": 10, 10 | "n_val": 10, 11 | 12 | "layers": 4, 13 | "label_size": 48, 14 | "batch_size": 32, 15 | "shuffle": true, 16 | "num_workers": 4, 17 | 18 | "epochs": 50, 19 | "starting_filters": 32, 20 | "bn_momentum": 0.1, 21 | "lr": 0.001, 22 | "beta1": 0.5, 23 | "beta2": 0.999, 24 | "weight_decay": 0.0, 25 | "decay_steps": 100000, 26 | "gamma": 0.9, 27 | "save_model": true, 28 | 29 | "visdom_every": 100 30 | } 31 | -------------------------------------------------------------------------------- /single_pixel_labels/src/datasets.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset, DataLoader 2 | from torchvision import transforms 3 | import torch 4 | import glob 5 | import os 6 | import numpy as np 7 | 8 | 9 | def cdl_to_binary(cdl): 10 | #print(((cdl <= 60) | (cdl >= 196)) | ((cdl >= 66) & (cdl <= 77))) 11 | return (((cdl <= 60) | (cdl >= 196)) | ((cdl >= 66) & (cdl <= 77))) 12 | 13 | 14 | class MaskedTileDataset(Dataset): 15 | 16 | def __init__(self, tile_dir, tile_files, mask_dir, transform=None, n_samples=None): 17 | 18 | self.tile_dir = tile_dir 19 | self.tile_files = tile_files 20 | self.transform = transform 21 | self.n_samples = n_samples 22 | self.mask_dir = mask_dir 23 | 24 | 25 | def __len__(self): 26 | if self.n_samples: return self.n_samples 27 | else: return len(self.tile_files) 28 | 29 | 30 | def __getitem__(self, idx): 31 | tile = np.load(os.path.join(self.tile_dir, self.tile_files[idx])) 32 | tile = np.nan_to_num(tile) 33 | tile = np.moveaxis(tile, -1, 0) 34 | 35 | mask = np.load(os.path.join(self.mask_dir, 'mask_'+self.tile_files[idx])) 36 | mask = np.expand_dims(mask, axis=0) 37 | tile = np.concatenate([tile, mask], axis=0) # attach mask to tile to ensure same transformations are applied 38 | 39 | if self.transform: 40 | tile = self.transform(tile) 41 | features = tile[:7,:,:] 42 | 43 | label = tile[-2,:,:] * 10000 44 | label = cdl_to_binary(label) 45 | label = label.float() 46 | 47 | mask = tile[-1,:,:] 48 | mask = mask.byte() 49 | 50 | return features, label, mask 51 | 52 | 53 | class RandomFlipAndRotate(object): 54 | """ 55 | Does data augmentation during training by randomly flipping (horizontal 56 | and vertical) and randomly rotating (0, 90, 180, 270 degrees). Keep in mind 57 | that pytorch samples are CxWxH. 58 | """ 59 | def __call__(self, tile): 60 | # randomly flip 61 | if np.random.rand() < 0.5: tile = np.flip(tile, axis=2).copy() 62 | if np.random.rand() < 0.5: tile = np.flip(tile, axis=1).copy() 63 | 64 | # randomly rotate 65 | rotations = np.random.choice([0,1,2,3]) 66 | if rotations > 0: tile = np.rot90(tile, k=rotations, axes=(1,2)).copy() 67 | 68 | return tile 69 | 70 | 71 | class ToFloatTensor(object): 72 | """ 73 | Converts numpy arrays to float Variables in Pytorch. 74 | """ 75 | def __call__(self, tile): 76 | tile = torch.from_numpy(tile).float() 77 | return tile 78 | 79 | 80 | def masked_tile_dataloader(tile_dir, tile_files, mask_dir, augment=True, batch_size=4, shuffle=True, num_workers=4, n_samples=None): 81 | """ 82 | Returns a dataloader with Landsat tiles. 83 | """ 84 | transform_list = [] 85 | if augment: transform_list.append(RandomFlipAndRotate()) 86 | transform_list.append(ToFloatTensor()) 87 | transform = transforms.Compose(transform_list) 88 | 89 | dataset = MaskedTileDataset(tile_dir, tile_files, mask_dir, transform=transform, n_samples=n_samples) 90 | dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers) 91 | return dataloader 92 | -------------------------------------------------------------------------------- /single_pixel_labels/src/run_masked.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import time 4 | import os 5 | import copy 6 | import random 7 | import json 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.optim as optim 12 | from torch.optim import lr_scheduler 13 | import torchvision 14 | 15 | from datasets import masked_tile_dataloader 16 | from unet import UNet 17 | from train_masked import train_model 18 | 19 | 20 | 21 | parser = argparse.ArgumentParser(description="Masked UNet for Remote Sensing Image Segmentation") 22 | 23 | # directories 24 | parser.add_argument('--model_dir', type=str) 25 | parser.add_argument('--gpu', type=int, default=0) 26 | 27 | args = parser.parse_args() 28 | 29 | # model and dataset hyperparameters 30 | param_file = os.path.join(args.model_dir, 'params.json') 31 | with open(param_file) as f: 32 | params = json.load(f) 33 | 34 | # dataloaders 35 | train_files = [] 36 | val_files = [] 37 | with open(params['train_file']) as f: 38 | for l in f: 39 | if '.npy' in l: 40 | train_files.append(l[:-1]) 41 | with open(params['val_file']) as f: 42 | for l in f: 43 | if '.npy' in l: 44 | val_files.append(l[:-1]) 45 | 46 | dataloaders = {} 47 | dataloaders['train'] = masked_tile_dataloader(params['tile_dir'], 48 | train_files, 49 | params['mask_dir'], 50 | augment=True, 51 | batch_size=params['batch_size'], 52 | shuffle=params['shuffle'], 53 | num_workers=params['num_workers'], 54 | n_samples=params['n_train']) 55 | dataloaders['val'] = masked_tile_dataloader(params['tile_dir'], 56 | val_files, 57 | params['mask_dir'], 58 | augment=False, 59 | batch_size=params['batch_size'], 60 | shuffle=params['shuffle'], 61 | num_workers=params['num_workers'], 62 | n_samples=params['n_val']) 63 | dataset_sizes = {} 64 | dataset_sizes['train'] = len(train_files) 65 | dataset_sizes['val'] = len(val_files) 66 | 67 | device = torch.device("cuda:"+str(args.gpu) if torch.cuda.is_available() else "cpu") 68 | 69 | model = UNet(in_channels=7, out_channels=1, 70 | starting_filters=params['starting_filters'], 71 | bn_momentum=params['bn_momentum']) 72 | 73 | model = model.to(device) 74 | 75 | criterion = nn.BCEWithLogitsLoss() 76 | 77 | # Observe that all parameters are being optimized 78 | optimizer = optim.Adam(model.parameters(), lr=params['lr'], 79 | betas=(params['beta1'], params['beta2']), weight_decay=params['weight_decay']) 80 | 81 | # Decay LR by a factor of gamma every X epochs 82 | exp_lr_scheduler = lr_scheduler.StepLR(optimizer, 83 | step_size=params['decay_steps'], 84 | gamma=params['gamma']) 85 | 86 | model, metrics = train_model(model, dataloaders, dataset_sizes, criterion, optimizer, exp_lr_scheduler, 87 | params, num_epochs=params['epochs'], gpu=args.gpu) 88 | 89 | if params['save_model']: 90 | torch.save(model.state_dict(), os.path.join(args.model_dir, 'best_model.pt')) 91 | 92 | with open(os.path.join(args.model_dir, 'metrics.json'), 'w') as f: 93 | json.dump(metrics, f) 94 | -------------------------------------------------------------------------------- /single_pixel_labels/src/train_masked.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import time 3 | import os 4 | import copy 5 | import random 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.optim as optim 10 | from torch.optim import lr_scheduler 11 | import torchvision 12 | from torchvision import models 13 | 14 | from tqdm import tqdm 15 | 16 | def train_model(model, dataloaders, dataset_sizes, criterion, optimizer, scheduler, params, num_epochs=20, gpu=0): 17 | """ 18 | Trains a model for all epochs using the provided dataloader. 19 | """ 20 | t0 = time.time() 21 | 22 | device = torch.device("cuda:"+str(gpu) if torch.cuda.is_available() else "cpu") 23 | 24 | best_model_wts = copy.deepcopy(model.state_dict()) 25 | best_acc = 0.0 26 | 27 | metrics = {'train_loss': [], 'train_acc': [], 'train_segacc': [], 28 | 'val_loss': [], 'val_acc': [], 'val_segacc': []} 29 | 30 | for epoch in range(num_epochs): 31 | print('Epoch {}/{}'.format(epoch, num_epochs - 1)) 32 | print('-' * 10) 33 | 34 | # Each epoch has a training and validation phase 35 | for phase in ['train', 'val']: 36 | if phase == 'train': 37 | scheduler.step() 38 | model.train() # Set model to training mode 39 | else: 40 | model.eval() # Set model to evaluate mode 41 | 42 | running_loss = 0.0 43 | running_corrects = 0 44 | running_corrects_full = 0 45 | i = 0 46 | 47 | # Iterate over data. 48 | for inputs, labels, masks in tqdm(dataloaders[phase]): 49 | inputs = inputs.to(device) 50 | 51 | labels = labels.to(device) 52 | label_size = labels.shape[-1] 53 | target_size = params["label_size"] 54 | offset = (label_size - target_size)//2 55 | labels = labels[:,offset:offset+target_size,offset:offset+target_size] 56 | 57 | masks = masks.to(device) 58 | masks = masks[:,offset:offset+target_size,offset:offset+target_size] 59 | 60 | # zero the parameter gradients 61 | optimizer.zero_grad() 62 | 63 | # forward 64 | # track history if only in train 65 | s = nn.Sigmoid() 66 | with torch.set_grad_enabled(phase == 'train'): 67 | outputs = model(inputs) 68 | outputs = outputs.squeeze() 69 | preds = s(outputs) >= 0.5 70 | preds = preds.float() 71 | 72 | outputs_masked = torch.masked_select(outputs, masks) 73 | labels_masked = torch.masked_select(labels, masks) 74 | preds_masked = torch.masked_select(preds, masks) 75 | 76 | loss = criterion(outputs_masked, labels_masked) 77 | 78 | # backward + optimize only if in training phase 79 | if phase == 'train': 80 | loss.backward() 81 | optimizer.step() 82 | 83 | # statistics 84 | running_loss += loss.item() * inputs.size(0) 85 | running_corrects += torch.sum(preds_masked == labels_masked.data) 86 | running_corrects_full += torch.sum(preds == labels.data) 87 | 88 | i += 1 89 | 90 | epoch_loss = running_loss / dataset_sizes[phase] 91 | epoch_acc = running_corrects.double() / dataset_sizes[phase] 92 | epoch_segacc = running_corrects_full.double() / (dataset_sizes[phase] * params["label_size"]**2) 93 | 94 | metrics[phase+'_loss'].append(epoch_loss) 95 | metrics[phase+'_acc'].append(float(epoch_acc.cpu().numpy())) 96 | metrics[phase+'_segacc'].append(float(epoch_segacc.cpu().numpy())) 97 | 98 | print('{} loss: {:.4f}, single pixel accuracy: {:.4f}, full segmentation accuracy: {:.4f}'.format( 99 | phase, epoch_loss, epoch_acc, epoch_segacc)) 100 | 101 | # deep copy the model 102 | if phase == 'val' and epoch_acc > best_acc: 103 | best_acc = epoch_acc 104 | best_segacc = epoch_segacc 105 | best_model_wts = copy.deepcopy(model.state_dict()) 106 | 107 | print() 108 | 109 | time_elapsed = time.time() - t0 110 | print('Training complete in {:.0f}m {:.0f}s'.format( 111 | time_elapsed // 60, time_elapsed % 60)) 112 | print('Best val single pixel accuracy: {:4f}'.format(best_acc)) 113 | print('Corresponding val full segmentation accuracy: {:.4f}'.format(best_segacc)) 114 | 115 | # load best model weights 116 | model.load_state_dict(best_model_wts) 117 | return model, metrics 118 | 119 | 120 | -------------------------------------------------------------------------------- /single_pixel_labels/src/unet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | import numpy as np 5 | from PIL import Image 6 | from torch.nn.functional import sigmoid 7 | 8 | 9 | class conv_conv(nn.Module): 10 | ''' (conv => ReLU) * 2 => maxpool ''' 11 | 12 | def __init__(self, in_channels, out_channels, bn_momentum=0.1): 13 | """ 14 | Args: 15 | in_channels (int): input channel 16 | out_channels (int): output channel 17 | bn_momentum (float): batch norm momentum 18 | """ 19 | super(conv_conv, self).__init__() 20 | self.conv = nn.Sequential( 21 | nn.Conv2d(in_channels, out_channels, 3, padding=1, stride=1), 22 | nn.BatchNorm2d(out_channels, momentum=bn_momentum), 23 | nn.ReLU(inplace=True), 24 | nn.Conv2d(out_channels, out_channels, 3, padding=1, stride=1), 25 | nn.BatchNorm2d(out_channels, momentum=bn_momentum), 26 | nn.ReLU(inplace=True) 27 | ) 28 | 29 | def forward(self, X): 30 | X = self.conv(X) 31 | return X 32 | 33 | 34 | class downconv(nn.Module): 35 | def __init__(self, in_channels, out_channels, bn_momentum=0.1): 36 | super(downconv, self).__init__() 37 | self.conv = conv_conv(in_channels, out_channels, bn_momentum) 38 | self.pool = nn.MaxPool2d(kernel_size=2, stride=2) 39 | 40 | def forward(self, X): 41 | X = self.conv(X) 42 | pool_X = self.pool(X) 43 | return pool_X, X 44 | 45 | 46 | class upconv_concat(nn.Module): 47 | ''' (conv => ReLU) * 2 => maxpool ''' 48 | 49 | def __init__(self, in_channels, out_channels, bn_momentum=0.1): 50 | """ 51 | Args: 52 | in_channels (int): input channel 53 | out_channels (int): output channel 54 | """ 55 | super(upconv_concat, self).__init__() 56 | self.upconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2) 57 | self.conv = conv_conv(in_channels, out_channels, bn_momentum) 58 | 59 | def forward(self, X1, X2): 60 | X1 = self.upconv(X1) 61 | X1_dim = X1.size()[2] 62 | X2 = extract_img(X1_dim, X2) 63 | X1 = torch.cat((X1, X2), dim=1) 64 | X1 = self.conv(X1) 65 | return X1 66 | 67 | 68 | def extract_img(size, in_tensor): 69 | """ 70 | Args: 71 | size (int): size of crop 72 | in_tensor (tensor): tensor to be cropped 73 | """ 74 | dim1, dim2 = in_tensor.size()[2:] 75 | in_tensor = in_tensor[:, :, int((dim1-size)/2):int((dim1+size)/2), int((dim2-size)/2):int((dim2+size)/2)] 76 | return in_tensor 77 | 78 | 79 | class UNet(nn.Module): 80 | 81 | def __init__(self, in_channels, out_channels, starting_filters=32, bn_momentum=0.1): 82 | super(UNet, self).__init__() 83 | self.conv1 = downconv(in_channels, starting_filters, bn_momentum) 84 | self.conv2 = downconv(starting_filters, starting_filters * 2, bn_momentum) 85 | self.conv3 = downconv(starting_filters * 2, starting_filters * 4, bn_momentum) 86 | self.conv4 = conv_conv(starting_filters * 4, starting_filters * 8, bn_momentum) 87 | self.upconv1 = upconv_concat(starting_filters * 8, starting_filters * 4, bn_momentum) 88 | self.upconv2 = upconv_concat(starting_filters * 4, starting_filters * 2, bn_momentum) 89 | self.upconv3 = upconv_concat(starting_filters * 2, starting_filters, bn_momentum) 90 | self.conv_out = nn.Conv2d(starting_filters, out_channels, 1, padding=0, stride=1) 91 | 92 | def forward(self, X): 93 | X, conv1 = self.conv1(X) 94 | X, conv2 = self.conv2(X) 95 | X, conv3 = self.conv3(X) 96 | X = self.conv4(X) 97 | X = self.upconv1(X, conv3) 98 | X = self.upconv2(X, conv2) 99 | X = self.upconv3(X, conv1) 100 | X = self.conv_out(X) 101 | return X 102 | 103 | 104 | class UCAM(nn.Module): 105 | 106 | def __init__(self, in_channels, out_channels, starting_filters=32, bn_momentum=0.1): 107 | super(UCAM, self).__init__() 108 | self.conv1 = downconv(in_channels, starting_filters, bn_momentum) 109 | self.conv2 = downconv(starting_filters, starting_filters * 2, bn_momentum) 110 | self.conv3 = downconv(starting_filters * 2, starting_filters * 4, bn_momentum) 111 | self.conv4 = conv_conv(starting_filters * 4, starting_filters * 8, bn_momentum) 112 | self.upconv1 = upconv_concat(starting_filters * 8, starting_filters * 4, bn_momentum) 113 | self.upconv2 = upconv_concat(starting_filters * 4, starting_filters * 2, bn_momentum) 114 | self.upconv3 = upconv_concat(starting_filters * 2, starting_filters, bn_momentum) 115 | self.gap = nn.AdaptiveAvgPool2d(1) 116 | self.dense = nn.Linear(starting_filters, 1) 117 | 118 | def forward(self, X): 119 | X, conv1 = self.conv1(X) 120 | X, conv2 = self.conv2(X) 121 | X, conv3 = self.conv3(X) 122 | X = self.conv4(X) 123 | X = self.upconv1(X, conv3) 124 | X = self.upconv2(X, conv2) 125 | X = self.upconv3(X, conv1) 126 | out = self.gap(X) 127 | out = self.dense(out) 128 | return X, out 129 | 130 | 131 | 132 | --------------------------------------------------------------------------------