├── .gitignore
├── LICENSE
├── README.md
├── data
├── images
│ ├── patch0.npy
│ ├── patch0.tfrecords
│ ├── patch1.npy
│ ├── patch1.tfrecords
│ ├── patch10.npy
│ ├── patch10.tfrecords
│ ├── patch11.npy
│ ├── patch11.tfrecords
│ ├── patch12.npy
│ ├── patch12.tfrecords
│ ├── patch13.npy
│ ├── patch13.tfrecords
│ ├── patch14.npy
│ ├── patch14.tfrecords
│ ├── patch15.npy
│ ├── patch15.tfrecords
│ ├── patch16.npy
│ ├── patch16.tfrecords
│ ├── patch17.npy
│ ├── patch17.tfrecords
│ ├── patch18.npy
│ ├── patch18.tfrecords
│ ├── patch19.npy
│ ├── patch19.tfrecords
│ ├── patch2.npy
│ ├── patch2.tfrecords
│ ├── patch3.npy
│ ├── patch3.tfrecords
│ ├── patch4.npy
│ ├── patch4.tfrecords
│ ├── patch5.npy
│ ├── patch5.tfrecords
│ ├── patch6.npy
│ ├── patch6.tfrecords
│ ├── patch7.npy
│ ├── patch7.tfrecords
│ ├── patch8.npy
│ ├── patch8.tfrecords
│ ├── patch9.npy
│ └── patch9.tfrecords
├── masks
│ ├── mask_patch0.npy
│ ├── mask_patch1.npy
│ ├── mask_patch10.npy
│ ├── mask_patch11.npy
│ ├── mask_patch12.npy
│ ├── mask_patch13.npy
│ ├── mask_patch14.npy
│ ├── mask_patch15.npy
│ ├── mask_patch16.npy
│ ├── mask_patch17.npy
│ ├── mask_patch18.npy
│ ├── mask_patch19.npy
│ ├── mask_patch2.npy
│ ├── mask_patch3.npy
│ ├── mask_patch4.npy
│ ├── mask_patch5.npy
│ ├── mask_patch6.npy
│ ├── mask_patch7.npy
│ ├── mask_patch8.npy
│ └── mask_patch9.npy
└── splits
│ ├── test_files_tfrecords
│ ├── train_files_numpy
│ ├── train_files_tfrecords
│ ├── val_files_numpy
│ └── val_files_tfrecords
├── graphical_abstract.png
├── image_labels
├── experiments
│ └── params.json
└── src
│ ├── UNet.py
│ ├── run_UCAM.py
│ ├── tfrecords.py
│ ├── train_UNet.py
│ └── utils.py
└── single_pixel_labels
├── experiments
└── params.json
└── src
├── datasets.py
├── run_masked.py
├── train_masked.py
└── unet.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 David Lobell's Lab Repo
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Weakly Supervised Deep Learning for Segmentation of Remote Sensing Imagery
2 |
3 | This repo contains the implementation of weakly supervised segmentation for remote sensing imagery from [our Remote Sensing paper](https://www.mdpi.com/2072-4292/12/2/207/htm). We perform cropland segmentation using two types of labels commonly found in remote sensing datasets that can be considered sources of “weak supervision”: (1) labels comprised of single geotagged points and (2) image-level labels.
4 |
5 |
6 |
7 |
8 |
9 | ## Usage
10 |
11 | To train a U-Net on single-pixel labels, run
12 | ```bash
13 | cd single_pixel_labels
14 | python src/run_masked.py --model_dir ./experiments
15 | ```
16 |
17 | To train a U-Net on image-level labels, run
18 | ```bash
19 | cd image_labels
20 | python src/run_UCAM.py --model_dir ./experiments
21 | ```
22 |
23 | Note that the code for the single-pixel labels was written in PyTorch, while the code for the image-level labels was written in TensorFlow 1.x.
24 |
25 |
26 | ## Data and Experimental Setup
27 |
28 | The models, datasets, and data loaders are currently written to process Landsat tiles in either `.tfrecord` (for the image-level labels) or `.npy` (for the single-pixel labels) formats. The last layer of the `.tfrecord` or `.npy` files are assumed to be the segmentation ground truth label.
29 |
30 | In each `experiments` directory is a `params.json` file that defines the data directory, training/val/test split file names, model architecture, number of channels in the input data, data augmentation procedures to use, and what model outputs to save. The `--model_dir` flag must point to the directory with a valid `params.json` file in order for training to proceed.
31 |
32 | ## Citation
33 |
34 | When using the code from this repo, please cite:
35 | * S. Wang, W. Chen, S. M. Xie, G. Azzari, and D. B. Lobell, “Weakly Supervised Deep Learning for Segmentation of Remote Sensing Imagery,” Remote Sensing, vol. 12, no. 2, p. 207, Jan. 2020, doi: 10.3390/rs12020207.
36 |
37 | Please feel free to email sherwang [at] stanford [dot] edu with any questions about the code or suggestions for improvement.
38 |
--------------------------------------------------------------------------------
/data/images/patch0.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/images/patch0.npy
--------------------------------------------------------------------------------
/data/images/patch0.tfrecords:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/images/patch0.tfrecords
--------------------------------------------------------------------------------
/data/images/patch1.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/images/patch1.npy
--------------------------------------------------------------------------------
/data/images/patch1.tfrecords:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/images/patch1.tfrecords
--------------------------------------------------------------------------------
/data/images/patch10.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/images/patch10.npy
--------------------------------------------------------------------------------
/data/images/patch10.tfrecords:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/images/patch10.tfrecords
--------------------------------------------------------------------------------
/data/images/patch11.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/images/patch11.npy
--------------------------------------------------------------------------------
/data/images/patch11.tfrecords:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/images/patch11.tfrecords
--------------------------------------------------------------------------------
/data/images/patch12.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/images/patch12.npy
--------------------------------------------------------------------------------
/data/images/patch12.tfrecords:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/images/patch12.tfrecords
--------------------------------------------------------------------------------
/data/images/patch13.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/images/patch13.npy
--------------------------------------------------------------------------------
/data/images/patch13.tfrecords:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/images/patch13.tfrecords
--------------------------------------------------------------------------------
/data/images/patch14.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/images/patch14.npy
--------------------------------------------------------------------------------
/data/images/patch14.tfrecords:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/images/patch14.tfrecords
--------------------------------------------------------------------------------
/data/images/patch15.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/images/patch15.npy
--------------------------------------------------------------------------------
/data/images/patch15.tfrecords:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/images/patch15.tfrecords
--------------------------------------------------------------------------------
/data/images/patch16.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/images/patch16.npy
--------------------------------------------------------------------------------
/data/images/patch16.tfrecords:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/images/patch16.tfrecords
--------------------------------------------------------------------------------
/data/images/patch17.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/images/patch17.npy
--------------------------------------------------------------------------------
/data/images/patch17.tfrecords:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/images/patch17.tfrecords
--------------------------------------------------------------------------------
/data/images/patch18.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/images/patch18.npy
--------------------------------------------------------------------------------
/data/images/patch18.tfrecords:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/images/patch18.tfrecords
--------------------------------------------------------------------------------
/data/images/patch19.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/images/patch19.npy
--------------------------------------------------------------------------------
/data/images/patch19.tfrecords:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/images/patch19.tfrecords
--------------------------------------------------------------------------------
/data/images/patch2.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/images/patch2.npy
--------------------------------------------------------------------------------
/data/images/patch2.tfrecords:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/images/patch2.tfrecords
--------------------------------------------------------------------------------
/data/images/patch3.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/images/patch3.npy
--------------------------------------------------------------------------------
/data/images/patch3.tfrecords:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/images/patch3.tfrecords
--------------------------------------------------------------------------------
/data/images/patch4.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/images/patch4.npy
--------------------------------------------------------------------------------
/data/images/patch4.tfrecords:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/images/patch4.tfrecords
--------------------------------------------------------------------------------
/data/images/patch5.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/images/patch5.npy
--------------------------------------------------------------------------------
/data/images/patch5.tfrecords:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/images/patch5.tfrecords
--------------------------------------------------------------------------------
/data/images/patch6.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/images/patch6.npy
--------------------------------------------------------------------------------
/data/images/patch6.tfrecords:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/images/patch6.tfrecords
--------------------------------------------------------------------------------
/data/images/patch7.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/images/patch7.npy
--------------------------------------------------------------------------------
/data/images/patch7.tfrecords:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/images/patch7.tfrecords
--------------------------------------------------------------------------------
/data/images/patch8.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/images/patch8.npy
--------------------------------------------------------------------------------
/data/images/patch8.tfrecords:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/images/patch8.tfrecords
--------------------------------------------------------------------------------
/data/images/patch9.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/images/patch9.npy
--------------------------------------------------------------------------------
/data/images/patch9.tfrecords:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/images/patch9.tfrecords
--------------------------------------------------------------------------------
/data/masks/mask_patch0.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/masks/mask_patch0.npy
--------------------------------------------------------------------------------
/data/masks/mask_patch1.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/masks/mask_patch1.npy
--------------------------------------------------------------------------------
/data/masks/mask_patch10.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/masks/mask_patch10.npy
--------------------------------------------------------------------------------
/data/masks/mask_patch11.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/masks/mask_patch11.npy
--------------------------------------------------------------------------------
/data/masks/mask_patch12.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/masks/mask_patch12.npy
--------------------------------------------------------------------------------
/data/masks/mask_patch13.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/masks/mask_patch13.npy
--------------------------------------------------------------------------------
/data/masks/mask_patch14.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/masks/mask_patch14.npy
--------------------------------------------------------------------------------
/data/masks/mask_patch15.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/masks/mask_patch15.npy
--------------------------------------------------------------------------------
/data/masks/mask_patch16.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/masks/mask_patch16.npy
--------------------------------------------------------------------------------
/data/masks/mask_patch17.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/masks/mask_patch17.npy
--------------------------------------------------------------------------------
/data/masks/mask_patch18.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/masks/mask_patch18.npy
--------------------------------------------------------------------------------
/data/masks/mask_patch19.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/masks/mask_patch19.npy
--------------------------------------------------------------------------------
/data/masks/mask_patch2.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/masks/mask_patch2.npy
--------------------------------------------------------------------------------
/data/masks/mask_patch3.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/masks/mask_patch3.npy
--------------------------------------------------------------------------------
/data/masks/mask_patch4.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/masks/mask_patch4.npy
--------------------------------------------------------------------------------
/data/masks/mask_patch5.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/masks/mask_patch5.npy
--------------------------------------------------------------------------------
/data/masks/mask_patch6.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/masks/mask_patch6.npy
--------------------------------------------------------------------------------
/data/masks/mask_patch7.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/masks/mask_patch7.npy
--------------------------------------------------------------------------------
/data/masks/mask_patch8.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/masks/mask_patch8.npy
--------------------------------------------------------------------------------
/data/masks/mask_patch9.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/data/masks/mask_patch9.npy
--------------------------------------------------------------------------------
/data/splits/test_files_tfrecords:
--------------------------------------------------------------------------------
1 | patch15.tfrecords
2 | patch16.tfrecords
3 | patch17.tfrecords
4 | patch18.tfrecords
5 | patch19.tfrecords
6 |
--------------------------------------------------------------------------------
/data/splits/train_files_numpy:
--------------------------------------------------------------------------------
1 | patch0.npy
2 | patch1.npy
3 | patch2.npy
4 | patch3.npy
5 | patch4.npy
6 | patch5.npy
7 | patch6.npy
8 | patch7.npy
9 | patch8.npy
10 | patch9.npy
11 |
--------------------------------------------------------------------------------
/data/splits/train_files_tfrecords:
--------------------------------------------------------------------------------
1 | patch0.tfrecords
2 | patch1.tfrecords
3 | patch2.tfrecords
4 | patch3.tfrecords
5 | patch4.tfrecords
6 | patch5.tfrecords
7 | patch6.tfrecords
8 | patch7.tfrecords
9 | patch8.tfrecords
10 | patch9.tfrecords
11 |
--------------------------------------------------------------------------------
/data/splits/val_files_numpy:
--------------------------------------------------------------------------------
1 | patch10.npy
2 | patch11.npy
3 | patch12.npy
4 | patch13.npy
5 | patch14.npy
6 | patch15.npy
7 | patch16.npy
8 | patch17.npy
9 | patch18.npy
10 | patch19.npy
11 |
--------------------------------------------------------------------------------
/data/splits/val_files_tfrecords:
--------------------------------------------------------------------------------
1 | patch10.tfrecords
2 | patch11.tfrecords
3 | patch12.tfrecords
4 | patch13.tfrecords
5 | patch14.tfrecords
6 |
--------------------------------------------------------------------------------
/graphical_abstract.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LobellLab/weakly_supervised/d8363b50e2daa2b741985b22597529b4c489c39c/graphical_abstract.png
--------------------------------------------------------------------------------
/image_labels/experiments/params.json:
--------------------------------------------------------------------------------
1 | {
2 | "data_folder": "../data/images",
3 | "train_file": "splits/train_files_tfrecords",
4 | "eval_file": "splits/val_files_tfrecords",
5 | "test_file": "splits/test_files_tfrecords",
6 | "data_type": "landsat",
7 | "task_type": "classification",
8 | "model": "UNet",
9 | "unet_type": "gap_sigmoid",
10 |
11 | "num_layers": 4,
12 | "starting_filters": 32,
13 | "img_size": 50,
14 | "CAM_size": 48,
15 | "has_class_labels": true,
16 | "has_seg_labels": true,
17 | "class_1_segmentation_proportion": 0.424,
18 | "num_channels": 7,
19 | "use_random_flip_and_rotate": true,
20 | "num_labels": 1,
21 | "fixed_split": true,
22 | "get_label_percent": false,
23 | "get_test_metrics": false,
24 |
25 | "threshold_sample_steps": 1,
26 | "save_summary_steps": 100,
27 | "save_cam": false,
28 | "save_CAM_train": false,
29 | "save_CAM_eval": false,
30 | "get_test_metrics": false,
31 | "get_saliency": false,
32 | "save_best_weights": false,
33 | "save_last_weights": false,
34 | "fix_unet_weights": false,
35 | "restore_from_different_model": false,
36 |
37 | "learning_rate": 1e-3,
38 | "use_batch_norm": true,
39 | "batch_size": 32,
40 | "num_epochs": 200,
41 | "bn_momentum": 0.9,
42 | "l2_lambda": 0.01,
43 | "num_parallel_calls": 4
44 | }
45 |
--------------------------------------------------------------------------------
/image_labels/src/UNet.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 |
4 | def conv_conv_pool(X, f, stage, params, is_training=False, pool=True):
5 |
6 | conv_name_base = 'down' + str(stage)
7 | bn_name_base = 'bn' + str(stage)
8 |
9 | bn_momentum = params.bn_momentum
10 | l2_lambda = params.l2_lambda
11 |
12 | F1, F2 = f
13 |
14 | X = tf.layers.conv2d(
15 | inputs=X,
16 | filters=F1,
17 | kernel_size=(3,3),
18 | strides=(1,1),
19 | padding='same',
20 | kernel_initializer=tf.contrib.layers.xavier_initializer(),
21 | kernel_regularizer=tf.contrib.layers.l2_regularizer(l2_lambda),
22 | name=conv_name_base + 'a')
23 | X = tf.layers.batch_normalization(
24 | inputs=X,
25 | axis=3,
26 | momentum=bn_momentum,
27 | training=is_training,
28 | name=bn_name_base + 'a')
29 | X = tf.nn.relu(X)
30 |
31 | X = tf.layers.conv2d(
32 | inputs=X,
33 | filters=F2,
34 | kernel_size=(3,3),
35 | strides=(1,1),
36 | padding='same',
37 | kernel_initializer=tf.contrib.layers.xavier_initializer(),
38 | kernel_regularizer=tf.contrib.layers.l2_regularizer(l2_lambda),
39 | name=conv_name_base + 'b')
40 | X = tf.layers.batch_normalization(
41 | inputs=X,
42 | axis=3,
43 | momentum=bn_momentum,
44 | training=is_training,
45 | name=bn_name_base + 'b')
46 | X = tf.nn.relu(X)
47 |
48 | if not pool:
49 | return X
50 |
51 | pool = tf.layers.max_pooling2d(
52 | inputs=X,
53 | pool_size=(2,2),
54 | strides=(2,2),
55 | padding='valid')
56 |
57 | return X, pool
58 |
59 |
60 | def upconv_concat(X, X_prev, f, stage, params):
61 |
62 | conv_name_base = 'up' + str(stage)
63 | l2_lambda = params.l2_lambda
64 |
65 | upconv = tf.layers.conv2d_transpose(
66 | inputs=X,
67 | filters=f,
68 | kernel_size=2,
69 | strides=2,
70 | padding='valid',
71 | kernel_initializer=tf.contrib.layers.xavier_initializer(),
72 | kernel_regularizer=tf.contrib.layers.l2_regularizer(l2_lambda),
73 | name=conv_name_base)
74 |
75 | prev_shape = tf.shape(X_prev)
76 | curr_shape = tf.shape(upconv)
77 | offsets = [0, (prev_shape[1] - curr_shape[1]) // 2, (prev_shape[2] - curr_shape[2]) // 2, 0]
78 | new_shape = [-1, curr_shape[1], curr_shape[2], -1]
79 | X_prev_cropped = tf.reshape(tf.slice(X_prev, offsets, new_shape), curr_shape)
80 |
81 | return tf.concat([upconv, X_prev_cropped], axis=-1)
82 |
83 |
84 | def UNet(is_training, inputs, params):
85 |
86 | images = inputs['images']
87 | num_layers = params.num_layers
88 | starting_filters = params.starting_filters
89 |
90 | # U-Net convolutional layers
91 | convs = []
92 | conv, pool = conv_conv_pool(images, [starting_filters, starting_filters], 1, params, is_training)
93 | convs.append(conv)
94 |
95 | # add more convlutional layers in loop
96 | for l in range(2, num_layers):
97 | num_filters = starting_filters * 2**(l-1)
98 | conv, pool = conv_conv_pool(pool, [num_filters, num_filters], l, params, is_training)
99 | convs.append(conv)
100 |
101 | if num_layers == 2:
102 | num_filters = starting_filters
103 |
104 | # last convolutional layer (no pool)
105 | conv = conv_conv_pool(pool, [num_filters * 2, num_filters * 2], num_layers, params, is_training, pool=False)
106 |
107 | # convolution transpose layers (deconvolutional layers)
108 | for l in range(1, num_layers):
109 | num_filters = starting_filters * 2**(num_layers - l - 1)
110 | up = upconv_concat(conv, convs.pop(), num_filters, l, params)
111 | conv = conv_conv_pool(up, [num_filters, num_filters], l+num_layers, params, is_training, pool=False)
112 |
113 | CAM_input = tf.identity(conv)
114 |
115 | return CAM_input
116 |
117 |
118 |
119 | def build_UNet(mode, inputs, params, reuse=False):
120 | """
121 | Function defines the graph operations.
122 | Input:
123 | mode: (string) can be 'train' or 'eval'
124 | inputs: (dict) contains the inputs of the graph (features, labels, etc.)
125 | params: (params) contains hyperparameters of the model
126 | reuse: (bool) whether to reuse the weights
127 | Returns:
128 | model_spec: (dict) contains the graph operations for training/evaluation
129 | """
130 | is_training = (mode == 'train')
131 | labels = inputs.get('labels')
132 | seg_labels = inputs.get('seg_labels')
133 |
134 | with tf.variable_scope('model', reuse=reuse):
135 | CAM_input = UNet(is_training, inputs, params)
136 |
137 | if params.unet_type == 'sigmoid_gap':
138 | cam_unnorm = tf.layers.conv2d(
139 | inputs=CAM_input,
140 | filters=1,
141 | kernel_size=(1,1),
142 | strides=(1,1),
143 | padding='valid',
144 | kernel_initializer=tf.contrib.layers.xavier_initializer(),
145 | kernel_regularizer=tf.contrib.layers.l2_regularizer(params.l2_lambda),
146 | name='final_conv')
147 |
148 | cam_unnorm = tf.squeeze(cam_unnorm, 3) # get rid of last dim
149 | cam = tf.sigmoid(cam_unnorm)
150 | cam_predictions = tf.round(cam)
151 |
152 | elif params.unet_type == 'gap_sigmoid':
153 | # out = tf.sigmoid(CAM_input)
154 | cam = tf.reduce_mean(CAM_input, axis=[1,2])
155 | cam = tf.contrib.layers.flatten(cam)
156 | logits = tf.layers.dense(inputs=cam, units=params.num_labels, name='dense')
157 |
158 |
159 | if params.task_type == 'classification':
160 |
161 | if params.unet_type == 'sigmoid_gap':
162 | expits = tf.reduce_mean(cam, axis=[1, 2])
163 | expits = tf.expand_dims(expits, -1)
164 |
165 | elif params.unet_type == 'gap_sigmoid':
166 | expits = tf.sigmoid(logits)
167 |
168 | predictions = tf.round(expits)
169 |
170 | else:
171 | predictions = cam_predictions
172 |
173 | # possibly classification task also has segmentation labels
174 | if seg_labels is not None and params.unet_type == 'sigmoid_gap':
175 | # sizes might not match, cut the segmentation label down to size
176 | extra_h = tf.shape(seg_labels)[1] - tf.shape(cam)[1]
177 | extra_h_before = extra_h // 2
178 | extra_h_after = tf.shape(seg_labels)[1] - (extra_h - extra_h_before)
179 | extra_w = tf.shape(seg_labels)[2] - tf.shape(cam)[2]
180 | extra_w_before = extra_w // 2
181 | extra_w_after = tf.shape(seg_labels)[2] - (extra_w - extra_w_before)
182 | seg_labels_cut = seg_labels[:, extra_h_before:extra_h_after, extra_w_before:extra_w_after]
183 | # cam or cam unnorm
184 | ce = tf.nn.sigmoid_cross_entropy_with_logits(labels=seg_labels_cut, logits=cam_unnorm)
185 | weighted_ce = (seg_labels_cut * ce) / params.class_1_segmentation_proportion + (1-seg_labels_cut) * ce
186 | segmentation_loss = tf.reduce_mean(weighted_ce)
187 |
188 | # Define loss and task accuracy
189 | if params.task_type == 'classification':
190 | train_labels = labels
191 | ce = -train_labels * tf.log(expits + 1.0e-8) - (1 - train_labels) * tf.log(1 - expits + 1.0e-8)
192 | weighted_ce = (train_labels * ce) / params.class_1_segmentation_proportion + \
193 | (1 - train_labels) * ce
194 | loss = tf.reduce_mean(weighted_ce)
195 | logits = -tf.log((1.0 / expits) - 1 + 1e-8)
196 | else:
197 | train_labels = seg_labels_cut
198 | loss = segmentation_loss
199 |
200 | l2_loss = tf.losses.get_regularization_loss()
201 | loss += l2_loss
202 | accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.round(train_labels), predictions), tf.float32))
203 |
204 | # Define training step that minimizes loss with Adam optimizer
205 | if is_training:
206 | optimizer = tf.train.AdamOptimizer(params.learning_rate)
207 | global_step = tf.train.get_or_create_global_step()
208 | if params.use_batch_norm:
209 | # Add a dependency to update the moving mean and variance for batch normalization
210 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
211 | with tf.control_dependencies(update_ops):
212 | train_op = optimizer.minimize(loss, global_step=global_step)
213 | else:
214 | train_op = optimizer.minimize(loss, global_step=global_step)
215 |
216 | # METRICS AND SUMMARIES
217 | with tf.variable_scope('metrics'):
218 | metrics = {
219 | 'loss': tf.metrics.mean(loss),
220 | 'accuracy': tf.metrics.accuracy(labels=tf.round(train_labels), predictions=predictions)
221 | }
222 | if params.task_type == 'classification' and seg_labels is not None and params.unet_type == 'sigmoid_gap':
223 | metrics.update({'segmentation_loss': tf.metrics.mean(segmentation_loss),
224 | 'segmentation_acc': tf.metrics.accuracy(labels=seg_labels_cut, predictions=cam_predictions)
225 | })
226 | # Group the update ops for the tf.metrics
227 | # print(metrics.values())
228 | update_metrics_op = tf.group(*[op for _, op in metrics.values()])
229 | metric_variables = tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES, scope='metrics')
230 | metrics_init_op = tf.variables_initializer(metric_variables)
231 |
232 | summaries = []
233 | summaries.append(tf.summary.scalar('loss', loss))
234 | summaries.append(tf.summary.scalar('accuracy', accuracy))
235 | summaries.append(tf.summary.scalar('loss_MA', metrics['loss'][0]))
236 | summaries.append(tf.summary.scalar('accuracy_MA', metrics['accuracy'][0]))
237 | if params.task_type == 'classification' and seg_labels is not None and params.unet_type == 'sigmoid_gap':
238 | summaries.append(tf.summary.scalar('segmentation_loss', segmentation_loss))
239 | summaries.append(tf.summary.scalar('segmentation_loss_MA', metrics['segmentation_loss'][0]))
240 | summaries.append(tf.summary.scalar('segmentation_acc_MA', metrics['segmentation_acc'][0]))
241 |
242 | if params.dict.get('rgb_image') in {None, False}:
243 | image_rgb = 255.0 * 5 * inputs['images'][:,:,:,1:4][:,:,:,::-1]
244 | summaries.append(tf.summary.image('image', image_rgb))
245 | else:
246 | image_rgb = inputs['images']
247 |
248 | if params.unet_type == 'sigmoid_gap':
249 | paddings_h = tf.shape(image_rgb)[1] - tf.shape(cam)[1]
250 | paddings_h_before = paddings_h // 2
251 | paddings_h_after = paddings_h - paddings_h_before
252 | paddings_w = tf.shape(image_rgb)[2] - tf.shape(cam)[2]
253 | paddings_w_before = paddings_w // 2
254 | paddings_w_after = paddings_w - paddings_w_before
255 | paddings = tf.convert_to_tensor([[0, 0], [paddings_h_before, paddings_h_after], [paddings_w_before, paddings_w_after]])
256 | cam_padded = tf.pad(cam, paddings, 'CONSTANT', constant_values=0.5)
257 | cam_rgb = 255.0 * tf.tile(tf.expand_dims(cam_padded,-1), [1,1,1,3])
258 |
259 | if seg_labels is not None:
260 | # tile the segmentation, label and image
261 | label_rgb = 255.0 * tf.tile(tf.expand_dims(seg_labels,-1), [1,1,1,3])
262 | concated = tf.concat([image_rgb, label_rgb], axis=2)
263 | else:
264 | concated = image_rgb
265 |
266 | hard_cam_padded = tf.pad(cam_predictions, paddings, 'CONSTANT', constant_values=0.5)
267 | hard_cam_rgb = 255.0 * tf.tile(tf.expand_dims(hard_cam_padded, -1), [1, 1, 1, 3])
268 | concated = tf.concat([concated, cam_rgb, hard_cam_rgb], axis=2)
269 | summaries.append(tf.summary.image('concatenated_images', concated, 50))
270 |
271 | model_spec = inputs
272 | model_spec['variable_init_op'] = tf.global_variables_initializer()
273 | model_spec['predictions'] = predictions
274 | model_spec['loss'] = loss
275 | model_spec['accuracy'] = accuracy
276 | model_spec['cam'] = cam
277 | model_spec['metrics_init_op'] = metrics_init_op
278 | model_spec['metrics'] = metrics
279 | model_spec['update_metrics'] = update_metrics_op
280 | model_spec['summary_op'] = tf.summary.merge(summaries)
281 |
282 | if params.unet_type == 'gap_sigmoid':
283 | model_spec['CAM_input'] = CAM_input
284 | model_spec['dense_kernel'] = [v for v in tf.trainable_variables() if "dense/kernel" in v.name][0]
285 | model_spec['dense_bias'] = [v for v in tf.trainable_variables() if "dense/bias" in v.name][0]
286 |
287 | if is_training:
288 | model_spec['train_op'] = train_op
289 |
290 | return model_spec
291 |
--------------------------------------------------------------------------------
/image_labels/src/run_UCAM.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import logging
3 | import os
4 | import datetime
5 | import uuid
6 | from pathlib import Path
7 |
8 | from tfrecords import tfrecord_iterator
9 | from UNet import build_UNet
10 | from utils import Params, set_logger
11 | from train_UNet import train_and_evaluate
12 |
13 | parser = argparse.ArgumentParser()
14 | parser.add_argument('--model_dir', default='../experiments',
15 | help="Experiment directory containing params.json")
16 | parser.add_argument('--data_dir', default='../data',
17 | help="Directory containing the dataset")
18 | parser.add_argument('--restore_from', default=None,
19 | help="Optional directory or file containing weights to reload before training")
20 |
21 | if __name__ == '__main__':
22 | # Uncomment below if want reproducible experiments
23 | # tf.set_random_seed(230)
24 |
25 | # Load parameters from json file
26 | args = parser.parse_args()
27 | json_path = os.path.join(args.model_dir, 'params.json')
28 | assert os.path.isfile(json_path), "No params.json configuration file found at {}".format(json_path)
29 | params = Params(json_path)
30 |
31 | # Set up logger
32 | ts_uuid = f'{datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S%z")}_{uuid.uuid4().hex[:6]}'
33 | log_dir = Path(args.model_dir) / 'logs'
34 | log_dir.mkdir(exist_ok=True)
35 | set_logger(str(log_dir / f'{ts_uuid}_train.log'))
36 | logging.info("Creating the datasets from TFRecords...")
37 | data_dir = args.data_dir
38 | data_folder = params.data_folder
39 | train_file = os.path.join(data_dir, params.train_file)
40 | val_file = os.path.join(data_dir, params.eval_file)
41 | test_file = os.path.join(data_dir, params.test_file)
42 |
43 | train_filenames = []
44 | val_filenames = []
45 | test_filenames = []
46 | with open(train_file) as f:
47 | for l in f:
48 | train_filenames.append(os.path.join(data_dir, data_folder, l[:-1]))
49 | with open(val_file) as f:
50 | for l in f:
51 | val_filenames.append(os.path.join(data_dir, data_folder, l[:-1]))
52 | with open(test_file) as f:
53 | for l in f:
54 | test_filenames.append(os.path.join(data_dir, data_folder, l[:-1]))
55 | params.train_size = len(train_filenames)
56 | params.eval_size = len(val_filenames)
57 | params.test_size = len(test_filenames)
58 |
59 | train_dataset = tfrecord_iterator(True, train_filenames, params)
60 | eval_dataset = tfrecord_iterator(False, val_filenames, params)
61 | test_dataset = tfrecord_iterator(False, test_filenames, params)
62 |
63 | logging.info("Creating the model...")
64 | train_model = build_UNet('train', train_dataset, params)
65 | eval_model = build_UNet('eval', eval_dataset, params, reuse=True)
66 | test_model = build_UNet('eval', test_dataset, params, reuse=True)
67 |
68 | logging.info("Starting training for {} epochs".format(params.num_epochs))
69 | train_and_evaluate(train_model, eval_model, test_model, args.model_dir, params, args.restore_from, ts_uuid)
70 |
71 |
72 |
73 |
--------------------------------------------------------------------------------
/image_labels/src/tfrecords.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 |
4 |
5 | def parse_function(example, num_channels, img_size,
6 | has_class_labels, has_seg_labels):
7 | """
8 | Transforms an example into a pair of a tensor and a float, representing
9 | an image and its label, respectively.
10 | """
11 | feature_dict = {
12 | 'image': tf.FixedLenFeature((img_size, img_size, num_channels), tf.float32),
13 | }
14 | if has_class_labels:
15 | feature_dict.update({'label': tf.FixedLenFeature((1,), tf.float32)})
16 | if has_seg_labels:
17 | feature_dict.update(
18 | {'seg_labels': tf.FixedLenFeature((img_size, img_size), tf.float32)})
19 | parsed_features = tf.parse_single_example(example, features=feature_dict)
20 |
21 | ret = [parsed_features['image']]
22 | if has_class_labels:
23 | ret.append(parsed_features['label'])
24 | else:
25 | # TF doesn't accept None, we filter this out later anyway
26 | ret.append(0)
27 |
28 | if has_seg_labels:
29 | ret.append(parsed_features['seg_labels'])
30 | else:
31 | ret.append(0)
32 | return tuple(ret)
33 |
34 |
35 | def augment_data(image, label, seg_label, perform_random_flip_and_rotate,
36 | num_channels, has_seg_labels):
37 | """
38 | Image augmentation for training. Applies the following operations:
39 | - Horizontally flip the image with probabiliy 0.5
40 | - Vertically flip the image with probability 0.5
41 | - Apply random rotation
42 | """
43 | if perform_random_flip_and_rotate:
44 | if has_seg_labels:
45 | image = tf.concat([image, tf.expand_dims(seg_label, -1)], 2)
46 | image = tf.image.random_flip_left_right(image)
47 | image = tf.image.random_flip_up_down(image)
48 | rotate_angle = tf.random_shuffle([0.0, 90.0, 180.0, 270.0])[0]
49 | image = tf.contrib.image.rotate(
50 | image, rotate_angle * np.pi / 180.0, interpolation='BILINEAR')
51 | if has_seg_labels:
52 | seg_label = image[:, :, -1]
53 |
54 | image = image[:,:,:num_channels]
55 |
56 | return image, label, seg_label
57 |
58 |
59 | def tfrecord_iterator(is_training, file_names, params):
60 |
61 | # Create a Dataset serving batches of images and labels
62 | # We don't repeat for multiple epochs because we always train and evaluate for one epoch
63 | def parse_fn(p):
64 | return parse_function(
65 | p, params.num_channels, params.img_size, params.has_class_labels,
66 | params.has_seg_labels)
67 |
68 | def train_fn(f, l, sl):
69 | return augment_data(
70 | f, l, sl, params.use_random_flip_and_rotate, params.num_channels,
71 | params.has_seg_labels)
72 |
73 | if is_training:
74 | dataset = (tf.data.TFRecordDataset(file_names)
75 | .map(parse_fn, num_parallel_calls=params.num_parallel_calls)
76 | .map(train_fn, num_parallel_calls=params.num_parallel_calls)
77 | .batch(params.batch_size)
78 | .prefetch(1))
79 | else:
80 | dataset = (tf.data.TFRecordDataset(file_names)
81 | .map(parse_fn)
82 | .batch(params.batch_size)
83 | .prefetch(1))
84 |
85 | # Create reinitializable iterator from dataset
86 | iterator = dataset.make_initializable_iterator()
87 | images, labels, seg_labels = iterator.get_next()
88 | iterator_init_op = iterator.initializer
89 |
90 | inputs = {'images': images,
91 | 'iterator_init_op': iterator_init_op}
92 | if params.has_class_labels:
93 | inputs.update({'labels': labels})
94 | if params.has_seg_labels:
95 | inputs.update({'seg_labels': seg_labels})
96 |
97 | return inputs
98 |
--------------------------------------------------------------------------------
/image_labels/src/train_UNet.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | from pathlib import Path
4 |
5 | import numpy as np
6 | import pandas as pd
7 | from tqdm import trange
8 | import tensorflow as tf
9 |
10 | from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
11 |
12 | def save_cam(sess, cam, dense_kernel, dense_bias, predictions, model_spec, model_dir, metrics_val, epoch, num_steps):
13 | cam_dir = os.path.join(model_dir, 'eval_cam')
14 | if not os.path.exists(cam_dir):
15 | os.makedirs(cam_dir)
16 | best_acc = 0.0
17 | best_acc_dir = os.path.join(cam_dir, 'best_eval_acc.npy')
18 | if os.path.isfile(best_acc_dir):
19 | best_acc = np.load(best_acc_dir)[0]
20 |
21 | if metrics_val['accuracy'] > best_acc:
22 | print("Saving CAM for best epoch so far...")
23 | np.save(best_acc_dir, np.array([metrics_val['accuracy'], epoch]))
24 | sess.run(model_spec['iterator_init_op'])
25 | for i in range(num_steps):
26 | cam_val, dense_kernel_val, dense_bias_val, predictions_val = sess.run([cam, dense_kernel, dense_bias, predictions])
27 | np.save(os.path.join(cam_dir, 'eval-cam-batch'+str(i)), cam_val)
28 | np.save(os.path.join(cam_dir, 'eval-dense-kernel-batch'+str(i)), dense_kernel_val)
29 | np.save(os.path.join(cam_dir, 'eval-dense-bias-batch'+str(i)), dense_bias_val)
30 | np.save(os.path.join(cam_dir, 'eval-prediction-batch'+str(i)), predictions_val)
31 |
32 |
33 | def get_best_threshold(sess, model_spec, num_steps, params):
34 |
35 | sample_steps = params.threshold_sample_steps
36 | batch_size = params.batch_size
37 | train_size = params.train_size
38 | if train_size > 20000:
39 | CAMs = np.zeros((int(batch_size * (num_steps // sample_steps + 1)), params.CAM_size, params.CAM_size))
40 | labels_vals = np.zeros((int(batch_size * (num_steps // sample_steps + 1)), 1))
41 | else:
42 | CAMs = np.zeros((train_size, params.CAM_size, params.CAM_size))
43 | labels_vals = np.zeros((train_size, 1))
44 | sample_steps = 1
45 |
46 | CAM_input = model_spec['CAM_input']
47 | dense_kernel = model_spec['dense_kernel']
48 | dense_bias = model_spec['dense_bias']
49 |
50 | loss = model_spec['loss']
51 | labels = model_spec['labels']
52 |
53 | sess.run(model_spec['iterator_init_op'])
54 |
55 | logging.info("- Computing optimal threshold on training set...")
56 | t = trange(num_steps)
57 | j = 0
58 |
59 | for i in t:
60 | if i % sample_steps == 0:
61 | CAM_input_val, dense_kernel_val, dense_bias_val, labels_val = sess.run([CAM_input, dense_kernel, dense_bias, labels])
62 | if CAM_input_val.shape[0] != batch_size:
63 | continue
64 |
65 | CAM = np.sum(np.matmul(CAM_input_val, dense_kernel_val) + dense_bias_val, axis=-1)
66 | CAMs[j*batch_size:(j+1)*batch_size,:,:] = np.array(CAM)
67 | labels_vals[j*batch_size:(j+1)*batch_size,:] = np.array(labels_val)
68 |
69 | j = j + 1
70 | print("dense kernel shape:", dense_kernel_val.shape)
71 |
72 | def get_classif_acc(threshold, cam, ground_truth):
73 | segpreds = cam > threshold # all cam values > threshold become 1
74 | preds = np.mean(segpreds, axis=(1,2)) > 0.5 # image-level label becomes 1 if >0.5 pixels are 1
75 | acc = sum(preds == ground_truth) / len(ground_truth)
76 | return acc
77 |
78 | possible_thresholds = list(np.linspace(-10,10,1001))
79 | labels_vals = labels_vals.flatten().astype(bool)
80 | classif_accs = [get_classif_acc(x, CAMs, labels_vals) for x in possible_thresholds]
81 |
82 | return np.array(possible_thresholds)[np.argmax(np.array(classif_accs))], np.max(classif_accs)
83 |
84 |
85 | def get_seg_metrics(sess, model_spec, num_steps, params, threshold):
86 |
87 | CAM_input = model_spec['CAM_input']
88 | dense_kernel = model_spec['dense_kernel']
89 | dense_bias = model_spec['dense_bias']
90 | seg_labels = model_spec['seg_labels']
91 |
92 | acc = 0.0
93 | pre = 0.0
94 | rec = 0.0
95 | f1s = 0.0
96 |
97 | sess.run(model_spec['iterator_init_op'])
98 |
99 | print("- Computing segmentation accuracy...")
100 | t = trange(num_steps)
101 |
102 | for i in t:
103 | CAM_input_val, dense_kernel_val, dense_bias_val, seg_values = sess.run([CAM_input, dense_kernel, dense_bias, seg_labels])
104 | CAM = np.sum(np.matmul(CAM_input_val, dense_kernel_val) + dense_bias_val, axis=-1)
105 | seg_pred = CAM > threshold
106 | offset = (seg_values.shape[1] - params.CAM_size) // 2
107 | seg_true = seg_values[:,offset:params.CAM_size+offset,offset:params.CAM_size+offset]
108 | seg_true = seg_true.astype(int)
109 | acc += accuracy_score(seg_true.flatten(), seg_pred.flatten())
110 | pre += precision_score(seg_true.flatten(), seg_pred.flatten())
111 | rec += recall_score(seg_true.flatten(), seg_pred.flatten())
112 | f1s += f1_score(seg_true.flatten(), seg_pred.flatten())
113 |
114 | return acc / num_steps, pre / num_steps, rec / num_steps, f1s / num_steps
115 |
116 |
117 | def train_sess(sess, model_spec, num_steps, writer, params, model_dir, epoch):
118 |
119 | # Get relevant graph operations or nodes needed for training
120 | loss = model_spec['loss']
121 | train_op = model_spec['train_op']
122 | update_metrics = model_spec['update_metrics']
123 | metrics = model_spec['metrics']
124 | summary_op = model_spec['summary_op']
125 | global_step = tf.train.get_global_step()
126 |
127 | # Load training dataset into pipeline and initialize the metrics local variables
128 | sess.run(model_spec['iterator_init_op'])
129 | sess.run(model_spec['metrics_init_op'])
130 |
131 | # Use tqdm for progress bar
132 | t = trange(num_steps)
133 | for i in t:
134 | handles = [train_op, update_metrics, loss]
135 | # Evaluate summaries every 100 steps
136 | if i % 100 == 0:
137 | _, _, loss_val, summ, global_step_val = sess.run(
138 | handles + [summary_op, global_step])
139 |
140 | # Write training summary to tensorboard
141 | writer.add_summary(summ, global_step_val)
142 | else:
143 | _, _, loss_val = sess.run(handles)
144 |
145 | # Log the loss in the tqdm progress bar
146 | t.set_postfix(loss='{:05.3f}'.format(loss_val))
147 |
148 | metrics_values = {k: v[0] for k, v in metrics.items()}
149 | metrics_val = sess.run(metrics_values)
150 | metrics_string = " ; ".join("{}: {:05.3f}".format(k, v) for k, v in metrics_val.items())
151 | logging.info("- Train metrics: " + metrics_string)
152 |
153 | return metrics_val
154 |
155 |
156 | def evaluate_sess(sess, model_spec, num_steps, writer, params, model_dir, epoch):
157 | update_metrics = model_spec['update_metrics']
158 | metrics = model_spec['metrics']
159 | summary_op = model_spec['summary_op']
160 | global_step = tf.train.get_global_step()
161 | predictions = model_spec['predictions']
162 | if params.model == 'UNet':
163 | CAM_input = model_spec['CAM_input']
164 | dense_kernel = model_spec['dense_kernel']
165 | dense_bias = model_spec['dense_bias']
166 |
167 | # Load evaluation dataset into pipeline and initialize the metrics local variables
168 | sess.run(model_spec['iterator_init_op'])
169 | sess.run(model_spec['metrics_init_op'])
170 |
171 | for i in range(num_steps):
172 | handles = [update_metrics]
173 | if i % 100 == 0:
174 | handles += [summary_op, global_step]
175 | _, summ, global_step_val = sess.run(handles)
176 | writer.add_summary(summ, global_step_val)
177 | else:
178 | # adds results automatically to metrics
179 | sess.run(handles)
180 |
181 | # Get the values of the metrics
182 | metrics_values = {k: v[0] for k, v in metrics.items()}
183 | metrics_val = sess.run(metrics_values)
184 | metrics_string = " ; ".join("{}: {:05.3f}".format(k, v) for k, v in metrics_val.items())
185 | logging.info("- Eval metrics: " + metrics_string)
186 |
187 | if params.model == 'UNet' and params.save_cam and epoch > 5:
188 | save_cam(sess, CAM_input, dense_kernel, dense_bias, predictions, model_spec, model_dir, metrics_val, epoch, num_steps)
189 |
190 | return metrics_val
191 |
192 |
193 | def train_and_evaluate(train_model_spec, eval_model_spec, test_model_spec, model_dir, params, restore_from=None, ts_uuid=''):
194 | model_dir = Path(model_dir)
195 |
196 | train_loss = []
197 | train_acc = []
198 | train_segacc_list = []
199 | train_segpre_list = []
200 | train_segrec_list = []
201 | train_segf1s_list = []
202 |
203 | eval_loss = []
204 | eval_acc = []
205 | eval_segacc_list = []
206 | eval_segpre_list = []
207 | eval_segrec_list = []
208 | eval_segf1s_list = []
209 |
210 | test_loss = []
211 | test_acc = []
212 | test_segacc_list = []
213 | test_segpre_list = []
214 | test_segrec_list = []
215 | test_segf1s_list = []
216 |
217 | saver = tf.train.Saver()
218 |
219 | with tf.Session() as sess:
220 | # Initialize model variables
221 | sess.run(train_model_spec['variable_init_op'])
222 |
223 | # Set up tensorboard files
224 | train_summary_dir = model_dir / f'{ts_uuid}_train_summaries'
225 | eval_summary_dir = model_dir / f'{ts_uuid}_eval_summaries'
226 | test_summary_dir = model_dir / f'{ts_uuid}_test_summaries'
227 | train_summary_dir.mkdir(exist_ok=True)
228 | eval_summary_dir.mkdir(exist_ok=True)
229 | test_summary_dir.mkdir(exist_ok=True)
230 | train_writer = tf.summary.FileWriter(str(train_summary_dir), sess.graph)
231 | eval_writer = tf.summary.FileWriter(str(eval_summary_dir), sess.graph)
232 | test_writer = tf.summary.FileWriter(str(test_summary_dir), sess.graph)
233 |
234 | best_accuracy = 0.0
235 |
236 | for epoch in range(params.num_epochs):
237 | # Run one epoch
238 | logging.info("Epoch {}/{}".format(epoch + 1, params.num_epochs))
239 |
240 | # Compute number of batches in one epoch
241 | num_steps = (params.train_size + params.batch_size - 1) // params.batch_size
242 | train_metrics = train_sess(sess, train_model_spec, num_steps, train_writer, params, model_dir, epoch)
243 |
244 | # if gap_sigmoid model, find best threshold
245 | if params.unet_type == 'gap_sigmoid':
246 | best_thresh, best_acc = get_best_threshold(sess, train_model_spec, num_steps, params)
247 | logging.info('- Best threshold on train samples: {:0.4f}, best accuracy: {:0.4f}'.format(best_thresh, best_acc))
248 |
249 | train_segacc, train_segpre, train_segrec, train_segf1s = get_seg_metrics(sess, train_model_spec, num_steps, params, best_thresh)
250 | train_segacc_list.append(train_segacc)
251 | train_segpre_list.append(train_segpre)
252 | train_segrec_list.append(train_segrec)
253 | train_segf1s_list.append(train_segf1s)
254 | logging.info('- Train segmentation accuracy: {:0.4f}'.format(train_segacc))
255 | else:
256 | train_segacc_list.append(train_metrics['segmentation_acc'])
257 |
258 | # Evaluate for one epoch on validation set
259 | num_steps = (params.eval_size + params.batch_size - 1) // params.batch_size
260 | eval_metrics = evaluate_sess(sess, eval_model_spec, num_steps, eval_writer, params, model_dir, epoch)
261 |
262 | if params.unet_type == 'gap_sigmoid':
263 | eval_segacc, eval_segpre, eval_segrec, eval_segf1s = get_seg_metrics(sess, eval_model_spec, num_steps, params, best_thresh)
264 | eval_segacc_list.append(eval_segacc)
265 | eval_segpre_list.append(eval_segpre)
266 | eval_segrec_list.append(eval_segrec)
267 | eval_segf1s_list.append(eval_segf1s)
268 | logging.info('- Eval segmentation accuracy: {:0.4f}'.format(eval_segacc))
269 | else:
270 | eval_segacc.append(eval_metrics['segmentation_acc'])
271 |
272 | # Evaluate for one epoch on test set
273 | num_steps = (params.test_size + params.batch_size - 1) // params.batch_size
274 | test_metrics = evaluate_sess(sess, test_model_spec, num_steps, test_writer, params, model_dir, epoch)
275 |
276 | if params.unet_type == 'gap_sigmoid':
277 | test_segacc, test_segpre, test_segrec, test_segf1s = get_seg_metrics(sess, test_model_spec, num_steps, params, best_thresh)
278 | test_segacc_list.append(test_segacc)
279 | test_segpre_list.append(test_segpre)
280 | test_segrec_list.append(test_segrec)
281 | test_segf1s_list.append(test_segf1s)
282 | logging.info('- Test segmentation accuracy: {:0.4f}'.format(test_segacc))
283 | else:
284 | test_segacc.append(test_metrics['segmentation_acc'])
285 |
286 | train_loss.append(train_metrics['loss'])
287 | train_acc.append(train_metrics['accuracy'])
288 | eval_loss.append(eval_metrics['loss'])
289 | eval_acc.append(eval_metrics['accuracy'])
290 | test_loss.append(test_metrics['loss'])
291 | test_acc.append(test_metrics['accuracy'])
292 |
293 | # Save best model so far based on task accuracy (image classification)
294 | if eval_metrics['accuracy'] > best_accuracy:
295 | print("Saving best model...")
296 | saver.save(sess, os.path.join(model_dir, "best_model.ckpt"))
297 | best_accuracy = max(best_accuracy, eval_metrics['accuracy'])
298 |
299 | # Write metrics to disk for easy analysis
300 | df = pd.DataFrame({'train_loss': train_loss, 'eval_loss': eval_loss, 'test_loss': test_loss,
301 | 'train_accuracy': train_acc, 'eval_accuracy': eval_acc, 'test_accuracy': test_acc,
302 | 'train_segacc': train_segacc_list, 'eval_segacc': eval_segacc_list, 'test_segacc': test_segacc_list,
303 | 'train_segpre': train_segpre_list, 'eval_segpre': eval_segpre_list, 'test_segpre': test_segpre_list,
304 | 'train_segrec': train_segrec_list, 'eval_segrec': eval_segrec_list, 'test_segrec': test_segrec_list,
305 | 'train_segf1s': train_segf1s_list, 'eval_segf1s': eval_segf1s_list, 'test_segf1s': test_segf1s_list})
306 | df.to_csv(os.path.join(model_dir, 'metrics.csv'), index=False)
307 |
308 |
309 |
--------------------------------------------------------------------------------
/image_labels/src/utils.py:
--------------------------------------------------------------------------------
1 | """General utility functions"""
2 |
3 | import json
4 | import logging
5 | import numpy as np
6 | import os
7 |
8 |
9 | class Params():
10 | """Class that loads hyperparameters from a json file.
11 | Example:
12 | ```
13 | params = Params(json_path)
14 | print(params.learning_rate)
15 | params.learning_rate = 0.5 # change the value of learning_rate in params
16 | ```
17 | """
18 |
19 | def __init__(self, json_path):
20 | self.update(json_path)
21 |
22 | def save(self, json_path):
23 | """Saves parameters to json file"""
24 | with open(json_path, 'w') as f:
25 | json.dump(self.__dict__, f, indent=4)
26 |
27 | def update(self, json_path):
28 | """Loads parameters from json file"""
29 | with open(json_path) as f:
30 | params = json.load(f)
31 | self.__dict__.update(params)
32 |
33 | @property
34 | def dict(self):
35 | """Gives dict-like access to Params instance by `params.dict['learning_rate']`"""
36 | return self.__dict__
37 |
38 |
39 | def set_logger(log_path):
40 | """Sets the logger to log info in terminal and file `log_path`.
41 | In general, it is useful to have a logger so that every output to the terminal is saved
42 | in a permanent file. Here we save it to `model_dir/train.log`.
43 | Example:
44 | ```
45 | logging.info("Starting training...")
46 | ```
47 | Args:
48 | log_path: (string) where to log
49 | """
50 | logger = logging.getLogger()
51 | logger.setLevel(logging.INFO)
52 |
53 | if not logger.handlers:
54 | # Logging to a file
55 | file_handler = logging.FileHandler(log_path)
56 | file_handler.setFormatter(logging.Formatter('%(asctime)s:%(levelname)s: %(message)s'))
57 | logger.addHandler(file_handler)
58 |
59 | # Logging to console
60 | stream_handler = logging.StreamHandler()
61 | stream_handler.setFormatter(logging.Formatter('%(message)s'))
62 | logger.addHandler(stream_handler)
--------------------------------------------------------------------------------
/single_pixel_labels/experiments/params.json:
--------------------------------------------------------------------------------
1 | {
2 | "model": "unet",
3 | "run": "run0",
4 |
5 | "tile_dir": "../data/images",
6 | "train_file": "../data/splits/train_files_numpy",
7 | "val_file": "../data/splits/val_files_numpy",
8 | "mask_dir": "../data/masks",
9 | "n_train": 10,
10 | "n_val": 10,
11 |
12 | "layers": 4,
13 | "label_size": 48,
14 | "batch_size": 32,
15 | "shuffle": true,
16 | "num_workers": 4,
17 |
18 | "epochs": 50,
19 | "starting_filters": 32,
20 | "bn_momentum": 0.1,
21 | "lr": 0.001,
22 | "beta1": 0.5,
23 | "beta2": 0.999,
24 | "weight_decay": 0.0,
25 | "decay_steps": 100000,
26 | "gamma": 0.9,
27 | "save_model": true,
28 |
29 | "visdom_every": 100
30 | }
31 |
--------------------------------------------------------------------------------
/single_pixel_labels/src/datasets.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import Dataset, DataLoader
2 | from torchvision import transforms
3 | import torch
4 | import glob
5 | import os
6 | import numpy as np
7 |
8 |
9 | def cdl_to_binary(cdl):
10 | #print(((cdl <= 60) | (cdl >= 196)) | ((cdl >= 66) & (cdl <= 77)))
11 | return (((cdl <= 60) | (cdl >= 196)) | ((cdl >= 66) & (cdl <= 77)))
12 |
13 |
14 | class MaskedTileDataset(Dataset):
15 |
16 | def __init__(self, tile_dir, tile_files, mask_dir, transform=None, n_samples=None):
17 |
18 | self.tile_dir = tile_dir
19 | self.tile_files = tile_files
20 | self.transform = transform
21 | self.n_samples = n_samples
22 | self.mask_dir = mask_dir
23 |
24 |
25 | def __len__(self):
26 | if self.n_samples: return self.n_samples
27 | else: return len(self.tile_files)
28 |
29 |
30 | def __getitem__(self, idx):
31 | tile = np.load(os.path.join(self.tile_dir, self.tile_files[idx]))
32 | tile = np.nan_to_num(tile)
33 | tile = np.moveaxis(tile, -1, 0)
34 |
35 | mask = np.load(os.path.join(self.mask_dir, 'mask_'+self.tile_files[idx]))
36 | mask = np.expand_dims(mask, axis=0)
37 | tile = np.concatenate([tile, mask], axis=0) # attach mask to tile to ensure same transformations are applied
38 |
39 | if self.transform:
40 | tile = self.transform(tile)
41 | features = tile[:7,:,:]
42 |
43 | label = tile[-2,:,:] * 10000
44 | label = cdl_to_binary(label)
45 | label = label.float()
46 |
47 | mask = tile[-1,:,:]
48 | mask = mask.byte()
49 |
50 | return features, label, mask
51 |
52 |
53 | class RandomFlipAndRotate(object):
54 | """
55 | Does data augmentation during training by randomly flipping (horizontal
56 | and vertical) and randomly rotating (0, 90, 180, 270 degrees). Keep in mind
57 | that pytorch samples are CxWxH.
58 | """
59 | def __call__(self, tile):
60 | # randomly flip
61 | if np.random.rand() < 0.5: tile = np.flip(tile, axis=2).copy()
62 | if np.random.rand() < 0.5: tile = np.flip(tile, axis=1).copy()
63 |
64 | # randomly rotate
65 | rotations = np.random.choice([0,1,2,3])
66 | if rotations > 0: tile = np.rot90(tile, k=rotations, axes=(1,2)).copy()
67 |
68 | return tile
69 |
70 |
71 | class ToFloatTensor(object):
72 | """
73 | Converts numpy arrays to float Variables in Pytorch.
74 | """
75 | def __call__(self, tile):
76 | tile = torch.from_numpy(tile).float()
77 | return tile
78 |
79 |
80 | def masked_tile_dataloader(tile_dir, tile_files, mask_dir, augment=True, batch_size=4, shuffle=True, num_workers=4, n_samples=None):
81 | """
82 | Returns a dataloader with Landsat tiles.
83 | """
84 | transform_list = []
85 | if augment: transform_list.append(RandomFlipAndRotate())
86 | transform_list.append(ToFloatTensor())
87 | transform = transforms.Compose(transform_list)
88 |
89 | dataset = MaskedTileDataset(tile_dir, tile_files, mask_dir, transform=transform, n_samples=n_samples)
90 | dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)
91 | return dataloader
92 |
--------------------------------------------------------------------------------
/single_pixel_labels/src/run_masked.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import numpy as np
3 | import time
4 | import os
5 | import copy
6 | import random
7 | import json
8 |
9 | import torch
10 | import torch.nn as nn
11 | import torch.optim as optim
12 | from torch.optim import lr_scheduler
13 | import torchvision
14 |
15 | from datasets import masked_tile_dataloader
16 | from unet import UNet
17 | from train_masked import train_model
18 |
19 |
20 |
21 | parser = argparse.ArgumentParser(description="Masked UNet for Remote Sensing Image Segmentation")
22 |
23 | # directories
24 | parser.add_argument('--model_dir', type=str)
25 | parser.add_argument('--gpu', type=int, default=0)
26 |
27 | args = parser.parse_args()
28 |
29 | # model and dataset hyperparameters
30 | param_file = os.path.join(args.model_dir, 'params.json')
31 | with open(param_file) as f:
32 | params = json.load(f)
33 |
34 | # dataloaders
35 | train_files = []
36 | val_files = []
37 | with open(params['train_file']) as f:
38 | for l in f:
39 | if '.npy' in l:
40 | train_files.append(l[:-1])
41 | with open(params['val_file']) as f:
42 | for l in f:
43 | if '.npy' in l:
44 | val_files.append(l[:-1])
45 |
46 | dataloaders = {}
47 | dataloaders['train'] = masked_tile_dataloader(params['tile_dir'],
48 | train_files,
49 | params['mask_dir'],
50 | augment=True,
51 | batch_size=params['batch_size'],
52 | shuffle=params['shuffle'],
53 | num_workers=params['num_workers'],
54 | n_samples=params['n_train'])
55 | dataloaders['val'] = masked_tile_dataloader(params['tile_dir'],
56 | val_files,
57 | params['mask_dir'],
58 | augment=False,
59 | batch_size=params['batch_size'],
60 | shuffle=params['shuffle'],
61 | num_workers=params['num_workers'],
62 | n_samples=params['n_val'])
63 | dataset_sizes = {}
64 | dataset_sizes['train'] = len(train_files)
65 | dataset_sizes['val'] = len(val_files)
66 |
67 | device = torch.device("cuda:"+str(args.gpu) if torch.cuda.is_available() else "cpu")
68 |
69 | model = UNet(in_channels=7, out_channels=1,
70 | starting_filters=params['starting_filters'],
71 | bn_momentum=params['bn_momentum'])
72 |
73 | model = model.to(device)
74 |
75 | criterion = nn.BCEWithLogitsLoss()
76 |
77 | # Observe that all parameters are being optimized
78 | optimizer = optim.Adam(model.parameters(), lr=params['lr'],
79 | betas=(params['beta1'], params['beta2']), weight_decay=params['weight_decay'])
80 |
81 | # Decay LR by a factor of gamma every X epochs
82 | exp_lr_scheduler = lr_scheduler.StepLR(optimizer,
83 | step_size=params['decay_steps'],
84 | gamma=params['gamma'])
85 |
86 | model, metrics = train_model(model, dataloaders, dataset_sizes, criterion, optimizer, exp_lr_scheduler,
87 | params, num_epochs=params['epochs'], gpu=args.gpu)
88 |
89 | if params['save_model']:
90 | torch.save(model.state_dict(), os.path.join(args.model_dir, 'best_model.pt'))
91 |
92 | with open(os.path.join(args.model_dir, 'metrics.json'), 'w') as f:
93 | json.dump(metrics, f)
94 |
--------------------------------------------------------------------------------
/single_pixel_labels/src/train_masked.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import time
3 | import os
4 | import copy
5 | import random
6 |
7 | import torch
8 | import torch.nn as nn
9 | import torch.optim as optim
10 | from torch.optim import lr_scheduler
11 | import torchvision
12 | from torchvision import models
13 |
14 | from tqdm import tqdm
15 |
16 | def train_model(model, dataloaders, dataset_sizes, criterion, optimizer, scheduler, params, num_epochs=20, gpu=0):
17 | """
18 | Trains a model for all epochs using the provided dataloader.
19 | """
20 | t0 = time.time()
21 |
22 | device = torch.device("cuda:"+str(gpu) if torch.cuda.is_available() else "cpu")
23 |
24 | best_model_wts = copy.deepcopy(model.state_dict())
25 | best_acc = 0.0
26 |
27 | metrics = {'train_loss': [], 'train_acc': [], 'train_segacc': [],
28 | 'val_loss': [], 'val_acc': [], 'val_segacc': []}
29 |
30 | for epoch in range(num_epochs):
31 | print('Epoch {}/{}'.format(epoch, num_epochs - 1))
32 | print('-' * 10)
33 |
34 | # Each epoch has a training and validation phase
35 | for phase in ['train', 'val']:
36 | if phase == 'train':
37 | scheduler.step()
38 | model.train() # Set model to training mode
39 | else:
40 | model.eval() # Set model to evaluate mode
41 |
42 | running_loss = 0.0
43 | running_corrects = 0
44 | running_corrects_full = 0
45 | i = 0
46 |
47 | # Iterate over data.
48 | for inputs, labels, masks in tqdm(dataloaders[phase]):
49 | inputs = inputs.to(device)
50 |
51 | labels = labels.to(device)
52 | label_size = labels.shape[-1]
53 | target_size = params["label_size"]
54 | offset = (label_size - target_size)//2
55 | labels = labels[:,offset:offset+target_size,offset:offset+target_size]
56 |
57 | masks = masks.to(device)
58 | masks = masks[:,offset:offset+target_size,offset:offset+target_size]
59 |
60 | # zero the parameter gradients
61 | optimizer.zero_grad()
62 |
63 | # forward
64 | # track history if only in train
65 | s = nn.Sigmoid()
66 | with torch.set_grad_enabled(phase == 'train'):
67 | outputs = model(inputs)
68 | outputs = outputs.squeeze()
69 | preds = s(outputs) >= 0.5
70 | preds = preds.float()
71 |
72 | outputs_masked = torch.masked_select(outputs, masks)
73 | labels_masked = torch.masked_select(labels, masks)
74 | preds_masked = torch.masked_select(preds, masks)
75 |
76 | loss = criterion(outputs_masked, labels_masked)
77 |
78 | # backward + optimize only if in training phase
79 | if phase == 'train':
80 | loss.backward()
81 | optimizer.step()
82 |
83 | # statistics
84 | running_loss += loss.item() * inputs.size(0)
85 | running_corrects += torch.sum(preds_masked == labels_masked.data)
86 | running_corrects_full += torch.sum(preds == labels.data)
87 |
88 | i += 1
89 |
90 | epoch_loss = running_loss / dataset_sizes[phase]
91 | epoch_acc = running_corrects.double() / dataset_sizes[phase]
92 | epoch_segacc = running_corrects_full.double() / (dataset_sizes[phase] * params["label_size"]**2)
93 |
94 | metrics[phase+'_loss'].append(epoch_loss)
95 | metrics[phase+'_acc'].append(float(epoch_acc.cpu().numpy()))
96 | metrics[phase+'_segacc'].append(float(epoch_segacc.cpu().numpy()))
97 |
98 | print('{} loss: {:.4f}, single pixel accuracy: {:.4f}, full segmentation accuracy: {:.4f}'.format(
99 | phase, epoch_loss, epoch_acc, epoch_segacc))
100 |
101 | # deep copy the model
102 | if phase == 'val' and epoch_acc > best_acc:
103 | best_acc = epoch_acc
104 | best_segacc = epoch_segacc
105 | best_model_wts = copy.deepcopy(model.state_dict())
106 |
107 | print()
108 |
109 | time_elapsed = time.time() - t0
110 | print('Training complete in {:.0f}m {:.0f}s'.format(
111 | time_elapsed // 60, time_elapsed % 60))
112 | print('Best val single pixel accuracy: {:4f}'.format(best_acc))
113 | print('Corresponding val full segmentation accuracy: {:.4f}'.format(best_segacc))
114 |
115 | # load best model weights
116 | model.load_state_dict(best_model_wts)
117 | return model, metrics
118 |
119 |
120 |
--------------------------------------------------------------------------------
/single_pixel_labels/src/unet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.autograd import Variable
4 | import numpy as np
5 | from PIL import Image
6 | from torch.nn.functional import sigmoid
7 |
8 |
9 | class conv_conv(nn.Module):
10 | ''' (conv => ReLU) * 2 => maxpool '''
11 |
12 | def __init__(self, in_channels, out_channels, bn_momentum=0.1):
13 | """
14 | Args:
15 | in_channels (int): input channel
16 | out_channels (int): output channel
17 | bn_momentum (float): batch norm momentum
18 | """
19 | super(conv_conv, self).__init__()
20 | self.conv = nn.Sequential(
21 | nn.Conv2d(in_channels, out_channels, 3, padding=1, stride=1),
22 | nn.BatchNorm2d(out_channels, momentum=bn_momentum),
23 | nn.ReLU(inplace=True),
24 | nn.Conv2d(out_channels, out_channels, 3, padding=1, stride=1),
25 | nn.BatchNorm2d(out_channels, momentum=bn_momentum),
26 | nn.ReLU(inplace=True)
27 | )
28 |
29 | def forward(self, X):
30 | X = self.conv(X)
31 | return X
32 |
33 |
34 | class downconv(nn.Module):
35 | def __init__(self, in_channels, out_channels, bn_momentum=0.1):
36 | super(downconv, self).__init__()
37 | self.conv = conv_conv(in_channels, out_channels, bn_momentum)
38 | self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
39 |
40 | def forward(self, X):
41 | X = self.conv(X)
42 | pool_X = self.pool(X)
43 | return pool_X, X
44 |
45 |
46 | class upconv_concat(nn.Module):
47 | ''' (conv => ReLU) * 2 => maxpool '''
48 |
49 | def __init__(self, in_channels, out_channels, bn_momentum=0.1):
50 | """
51 | Args:
52 | in_channels (int): input channel
53 | out_channels (int): output channel
54 | """
55 | super(upconv_concat, self).__init__()
56 | self.upconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
57 | self.conv = conv_conv(in_channels, out_channels, bn_momentum)
58 |
59 | def forward(self, X1, X2):
60 | X1 = self.upconv(X1)
61 | X1_dim = X1.size()[2]
62 | X2 = extract_img(X1_dim, X2)
63 | X1 = torch.cat((X1, X2), dim=1)
64 | X1 = self.conv(X1)
65 | return X1
66 |
67 |
68 | def extract_img(size, in_tensor):
69 | """
70 | Args:
71 | size (int): size of crop
72 | in_tensor (tensor): tensor to be cropped
73 | """
74 | dim1, dim2 = in_tensor.size()[2:]
75 | in_tensor = in_tensor[:, :, int((dim1-size)/2):int((dim1+size)/2), int((dim2-size)/2):int((dim2+size)/2)]
76 | return in_tensor
77 |
78 |
79 | class UNet(nn.Module):
80 |
81 | def __init__(self, in_channels, out_channels, starting_filters=32, bn_momentum=0.1):
82 | super(UNet, self).__init__()
83 | self.conv1 = downconv(in_channels, starting_filters, bn_momentum)
84 | self.conv2 = downconv(starting_filters, starting_filters * 2, bn_momentum)
85 | self.conv3 = downconv(starting_filters * 2, starting_filters * 4, bn_momentum)
86 | self.conv4 = conv_conv(starting_filters * 4, starting_filters * 8, bn_momentum)
87 | self.upconv1 = upconv_concat(starting_filters * 8, starting_filters * 4, bn_momentum)
88 | self.upconv2 = upconv_concat(starting_filters * 4, starting_filters * 2, bn_momentum)
89 | self.upconv3 = upconv_concat(starting_filters * 2, starting_filters, bn_momentum)
90 | self.conv_out = nn.Conv2d(starting_filters, out_channels, 1, padding=0, stride=1)
91 |
92 | def forward(self, X):
93 | X, conv1 = self.conv1(X)
94 | X, conv2 = self.conv2(X)
95 | X, conv3 = self.conv3(X)
96 | X = self.conv4(X)
97 | X = self.upconv1(X, conv3)
98 | X = self.upconv2(X, conv2)
99 | X = self.upconv3(X, conv1)
100 | X = self.conv_out(X)
101 | return X
102 |
103 |
104 | class UCAM(nn.Module):
105 |
106 | def __init__(self, in_channels, out_channels, starting_filters=32, bn_momentum=0.1):
107 | super(UCAM, self).__init__()
108 | self.conv1 = downconv(in_channels, starting_filters, bn_momentum)
109 | self.conv2 = downconv(starting_filters, starting_filters * 2, bn_momentum)
110 | self.conv3 = downconv(starting_filters * 2, starting_filters * 4, bn_momentum)
111 | self.conv4 = conv_conv(starting_filters * 4, starting_filters * 8, bn_momentum)
112 | self.upconv1 = upconv_concat(starting_filters * 8, starting_filters * 4, bn_momentum)
113 | self.upconv2 = upconv_concat(starting_filters * 4, starting_filters * 2, bn_momentum)
114 | self.upconv3 = upconv_concat(starting_filters * 2, starting_filters, bn_momentum)
115 | self.gap = nn.AdaptiveAvgPool2d(1)
116 | self.dense = nn.Linear(starting_filters, 1)
117 |
118 | def forward(self, X):
119 | X, conv1 = self.conv1(X)
120 | X, conv2 = self.conv2(X)
121 | X, conv3 = self.conv3(X)
122 | X = self.conv4(X)
123 | X = self.upconv1(X, conv3)
124 | X = self.upconv2(X, conv2)
125 | X = self.upconv3(X, conv1)
126 | out = self.gap(X)
127 | out = self.dense(out)
128 | return X, out
129 |
130 |
131 |
132 |
--------------------------------------------------------------------------------