├── .github
└── workflows
│ └── run-pre-commit.yml
├── .gitignore
├── .pre-commit-config.yaml
├── LICENSE
├── README.md
├── pyproject.toml
├── requirements-cuda11.txt
├── requirements-dev.txt
├── scripts
├── kill_hanging_processes.sh
└── sample-wandb-logging.png
├── seg_lapa
├── __init__.py
├── callbacks
│ └── log_media.py
├── config
│ ├── callbacks
│ │ ├── disabled.yaml
│ │ ├── standard.yaml
│ │ └── with_lr_monitor.yaml
│ ├── dataset
│ │ └── lapa.yaml
│ ├── load_weights
│ │ ├── disabled.yaml
│ │ └── pretrain.yaml
│ ├── logger
│ │ ├── disabled.yaml
│ │ ├── wandb.yaml
│ │ └── wandb_debug.yaml
│ ├── model
│ │ └── deeplabv3.yaml
│ ├── optimizer
│ │ ├── adam.yaml
│ │ └── sgd.yaml
│ ├── scheduler
│ │ ├── cyclic.yaml
│ │ ├── disabled.yaml
│ │ ├── plateau.yaml
│ │ ├── poly.yaml
│ │ └── step.yaml
│ ├── train.yaml
│ └── trainer
│ │ ├── debug.yaml
│ │ └── standard.yaml
├── config_parse
│ ├── callbacks_available.py
│ ├── callbacks_conf.py
│ ├── conf_utils.py
│ ├── dataset_conf.py
│ ├── load_weights_conf.py
│ ├── logger_conf.py
│ ├── model_conf.py
│ ├── optimizer_conf.py
│ ├── scheduler_conf.py
│ ├── train_conf.py
│ └── trainer_conf.py
├── datasets
│ └── lapa.py
├── loss_func.py
├── metrics.py
├── networks
│ └── deeplab
│ │ ├── aspp.py
│ │ ├── backbone
│ │ ├── __init__.py
│ │ ├── drn.py
│ │ ├── mobilenet.py
│ │ ├── resnet.py
│ │ └── xception.py
│ │ ├── decoder.py
│ │ ├── decoder_masks.py
│ │ ├── deeplab.py
│ │ ├── readme.md
│ │ └── sync_batchnorm
│ │ ├── __init__.py
│ │ ├── batchnorm.py
│ │ ├── comm.py
│ │ ├── replicate.py
│ │ └── unittest.py
├── train.py
└── utils
│ ├── path_check.py
│ ├── segmentation_label2rgb.py
│ └── utils.py
├── setup.cfg
└── setup.py
/.github/workflows/run-pre-commit.yml:
--------------------------------------------------------------------------------
1 | name: run-pre-commit
2 |
3 | on: [pull_request]
4 |
5 | jobs:
6 | run-pre-commit:
7 | runs-on: ubuntu-latest
8 | steps:
9 | - uses: actions/checkout@v2
10 | with:
11 | fetch-depth: 0
12 | - uses: actions/setup-python@v2
13 | with:
14 | python-version: '3.8'
15 | - uses: pre-commit/action@v2.0.0
16 | with:
17 | token: ${{ secrets.GITHUB_TOKEN }}
18 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Custom
2 | lightning_logs/
3 | wandb/
4 | checkpoints/
5 | log-media/
6 |
7 | .DS_Store
8 | *.tar.gz
9 |
10 | # Byte-compiled / optimized / DLL files
11 | __pycache__/
12 | *.py[cod]
13 | *$py.class
14 | *.pyc
15 |
16 | # C extensions
17 | *.so
18 |
19 | # Distribution / packaging
20 | .Python
21 | build/
22 | develop-eggs/
23 | dist/
24 | downloads/
25 | eggs/
26 | .eggs/
27 | lib/
28 | lib64/
29 | parts/
30 | sdist/
31 | var/
32 | wheels/
33 | *.egg-info/
34 | .installed.cfg
35 | *.egg
36 | MANIFEST
37 |
38 | # Lightning /research
39 | test_tube_exp/
40 | tests/tests_tt_dir/
41 | tests/save_dir
42 | default/
43 | data/
44 | test_tube_logs/
45 | test_tube_data/
46 | #datasets/ # dir in code called datasets.
47 | model_weights/
48 | tests/save_dir
49 | tests/tests_tt_dir/
50 | processed/
51 | raw/
52 |
53 | # PyInstaller
54 | # Usually these files are written by a python script from a template
55 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
56 | *.manifest
57 | *.spec
58 |
59 | # Installer logs
60 | pip-log.txt
61 | pip-delete-this-directory.txt
62 |
63 | # Unit test / coverage reports
64 | htmlcov/
65 | .tox/
66 | .coverage
67 | .coverage.*
68 | .cache
69 | nosetests.xml
70 | coverage.xml
71 | *.cover
72 | .hypothesis/
73 | .pytest_cache/
74 |
75 | # Translations
76 | *.mo
77 | *.pot
78 |
79 | # Django stuff:
80 | *.log
81 | local_settings.py
82 | db.sqlite3
83 |
84 | # Flask stuff:
85 | instance/
86 | .webassets-cache
87 |
88 | # Scrapy stuff:
89 | .scrapy
90 |
91 | # Sphinx documentation
92 | docs/_build/
93 |
94 | # PyBuilder
95 | target/
96 |
97 | # Jupyter Notebook
98 | .ipynb_checkpoints
99 |
100 | # pyenv
101 | .python-version
102 |
103 | # celery beat schedule file
104 | celerybeat-schedule
105 |
106 | # SageMath parsed files
107 | *.sage.py
108 |
109 | # Environments
110 | .env
111 | .venv
112 | env/
113 | venv/
114 | ENV/
115 | env.bak/
116 | venv.bak/
117 |
118 | # Spyder project settings
119 | .spyderproject
120 | .spyproject
121 |
122 | # Rope project settings
123 | .ropeproject
124 |
125 | # mkdocs documentation
126 | /site
127 |
128 | # mypy
129 | .mypy_cache/
130 |
131 | # IDEs
132 | .idea
133 | .vscode
134 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://github.com/ambv/black
3 | rev: stable
4 | hooks:
5 | - id: black
6 | language_version: python3.8
7 | - repo: https://github.com/pre-commit/mirrors-prettier
8 | rev: v2.0.0
9 | hooks:
10 | - id: prettier
11 | files: \.(yaml|md)$ # The filename extensions this formatter edits
12 | - repo: https://github.com/pre-commit/pre-commit-hooks
13 | rev: v3.4.0
14 | hooks:
15 | - id: end-of-file-fixer
16 | - id: requirements-txt-fixer
17 | - id: trailing-whitespace
18 | args: [--markdown-linebreak-ext=md] # preserve Markdown hard linebreaks
19 | - id: check-merge-conflict
20 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ---
2 |
3 |
4 |
5 | # Sementation Lapa
6 |
7 | [](https://www.nature.com/articles/nature14539)
8 | [](https://papers.nips.cc/book/advances-in-neural-information-processing-systems-31-2018)
9 | [](https://github.com/psf/black)
10 | [](https://github.com/prettier/prettier)
11 |
12 |
13 |
14 | ## Description
15 |
16 | This an example project showcasing [Pytorch Lightning](https://www.pytorchlightning.ai/) for training,
17 | [hydra](https://hydra.cc/) for the configuration system and [wandb](https://wandb.ai/) (Weights and Biases) for logging.
18 |
19 | The project tackles a more realistic setting than MNIST by demonstrating segmentation of facial regions on the
20 | [LaPa dataset](https://github.com/JDAI-CV/lapa-dataset) with Deeplabv3+.
21 |
22 | 
23 |
24 |
25 | ## Install
26 |
27 | If using Ampere GPUs (RTX 3090), then CUDA 11.1 is required. Some libraries throw error on trying to
28 | install with CUDA 11.0 (unsupported gpu architecture 'compute_86'). This error is solved by moving to
29 | CUDA 11.1.
30 |
31 | ### Install Pytorch as per CUDA (Feb 2021)
32 |
33 | Install Pytorch (`torch` and `torchvision`) before installing the other dependencies.
34 |
35 | #### CUDA 11.1
36 |
37 | `pytorch` and `torchvision` need to be installed from source. Check:
38 |
39 | - https://github.com/pytorch/pytorch#installation
40 | - https://github.com/pytorch/vision
41 |
42 | For torchvision, install system dependencies:
43 |
44 | ```shell script
45 | sudo apt install libturbojpeg libpng-dev libjpeg-dev
46 | ```
47 |
48 | #### CUDA 11.0
49 |
50 | Systems with Cuda 11.0 (such as those with Ampere GPUs):
51 |
52 | ```shell script
53 | pip install -r requirements-cuda11_0.txt
54 | ```
55 |
56 | #### CUDA 10.x
57 |
58 | System with CUDA 10.x:
59 |
60 | ```shell script
61 | pip install -r requirements-cuda10.txt
62 | ```
63 |
64 | ### Install project package and dependencies
65 |
66 | ```shell script
67 | # clone project
68 | git clone git@github.com:Shreeyak/pytorch-lightning-segmentation-lapa.git
69 |
70 | # install project in development mode
71 | cd pytorch-lightning-segmentation-lapa
72 | pip install -e .
73 |
74 | # Setup git precommits
75 | pip install -r requirements-dev.txt
76 | pre-commit install
77 | ```
78 |
79 | #### Developer dependencies
80 |
81 | This repository uses git pre-commit hooks to auto-format code.
82 | These developer dependencies are in requirements-dev.txt.
83 | The other files describing pre-commit hooks are: `pyproject.toml`, `.pre-commit-config.yaml`
84 |
85 | ## Usage
86 |
87 | Download the Lapa dataset from https://github.com/JDAI-CV/lapa-dataset
88 | It can be placed at `seg_lapa/data`.
89 |
90 | Run training. See the [hydra documentation](https://hydra.cc/docs/advanced/override_grammar/basic)
91 | on how to override the config values from the CLI.
92 |
93 | ```bash
94 | # Run training
95 | python -m seg_lapa.train dataset.data_dir=
96 |
97 | # Run on multiple gpus
98 | python -m seg_lapa.train dataset.data_dir= train.gpus=\"0,1\"
99 | ```
100 |
101 | ## Using this template for your own project
102 |
103 | To use this template for your own project:
104 |
105 | 1. Search and replace `seg_lapa` with your project name
106 | 2. Edit setup.py with new package name, requirements and other details
107 | 3. Replace the model, dataloaders, loss function, metric with your own!
108 | 4. Update the readme! Add your own links to your paper at the top, add citation info at bottom.
109 |
110 | This template was based on the Pytorch-Lightning
111 | [seed project](https://github.com/PyTorchLightning/deep-learning-project-template).
112 |
113 | ### Callbacks
114 |
115 | The callbacks can be configured from the config files or
116 | [command line overrides](https://hydra.cc/docs/next/advanced/override_grammar/basic/).
117 | To disable a config, simply remove them from the config. More callbacks can easily be added to the config system
118 | as needed. The following callbacks are added as of now:
119 |
120 | - [Early Stopping](https://pytorch-lightning.readthedocs.io/en/latest/generated/pytorch_lightning.callbacks.EarlyStopping.html#pytorch_lightning.callbacks.EarlyStopping)
121 | - [Model Checkpoint](https://pytorch-lightning.readthedocs.io/en/latest/generated/pytorch_lightning.callbacks.ModelCheckpoint.html#pytorch_lightning.callbacks.ModelCheckpoint)
122 | - [Log Media](#logmedia)
123 |
124 | CLI override Examples:
125 |
126 | ```shell script
127 | # Disable the LogMedia callback.
128 | python -m seg_lapa.train "~callbacks.log_media"
129 |
130 | # Set the EarlyStopping callback to wait for 20 epochs before terminating.
131 | python -m seg_lapa.train callbacks.early_stopping.patience=20
132 | ```
133 |
134 | #### LogMedia
135 |
136 | The LogMedia callback is used to log media, such as images and point clouds, to the logger and to local disk.
137 | It is also used to save the config files for each run. The `LightningModule` adds data to a queue, which is
138 | fetched within the `LogMedia` callback and logged to the logger and/or disk.
139 |
140 | To customize this callback for your application, override or modify the following methods:
141 |
142 | - `LogMedia._get_preds_from_lightningmodule()`
143 | - `LogMedia.__save_media_to_disk()`
144 | - `LogMedia.__save_media_to_logger()`
145 | - The LightningModule should have an attribute of type `LogMediaQueue` called `self.log_media`.
146 | Change the data that you push into the queue in train/val/test steps as per requirement.
147 |
148 | ##### Notes:
149 |
150 | - LogMedia currently supports the Weights and Biases logger only.
151 | - By default, LogMedia only saves the latest samples to disk. To save the results from each step/epoch, pass
152 | `save_latest_only=False`.
153 |
154 | #### EarlyStopping
155 |
156 | This is Lightning's built-in callback. Here's some tips on how to configure early stopping:
157 |
158 | ```
159 | Args:
160 | monitor: Monitor a key validation metric (eg: mIoU). Monitoring loss is not a good idea as it is an unreliable
161 | indicator of model performance. Two models might have the same loss but different performance
162 | or the loss might start increasing, even though performance does not decrease.
163 |
164 | min_delta: Project-dependent - choose a value for your metric below which you'd consider the improvement
165 | negligible.
166 | Example: For segmentation, I do not care for improvements less than 0.05% IoU in general.
167 | But in kaggle competitions, even 0.01% would matter.
168 |
169 | patience: Patience is the number of val epochs to wait for to see an improvement. It is affected by the
170 | ``check_val_every_n_epoch`` and ``val_check_interval`` params to the PL Trainer.
171 |
172 | Takes experimentation to figure out appropriate patience for your project. Train the model
173 | once without early stopping and see how long it takes to converge on a given dataset.
174 | Choose the number of epochs between when you feel it's started to converge and after you're
175 | sure the model has converged. Reduce the patience if you see the model continues to train for too long.
176 | ```
177 |
178 | #### ModelCheckpoint
179 |
180 | This is also Lightning's built-in callback to save checkpoints. It can monitor a logged value and save best checkpoints,
181 | save the latest checkpoint or save checkpoints every N steps/epoch.
182 | We save checkpoints in our own logs directory structure, which is different from Lightning's default.
183 |
184 | ### Loggers
185 |
186 | At this point, this project only supports the WandB logger (Weights and Biases). Other loggers can easily be added.
187 | Modify these methods after adding your logger to the config system:
188 |
189 | - `utils.generate_log_dir_path()` - Generates dir structure to save logs
190 | - `LogMedia._log_images_to_wandb()` - If logging media such as images
191 |
192 | ### Notes
193 |
194 | #### Pre-commit workflow
195 |
196 | The project uses pre-commit hooks for `black` and `prettier`, which are auto-format tools for `.py` and
197 | `.yaml | .md` files respectively. After pushing code to a branch, the formatter will automatically run
198 | and apply changes to submitted code.
199 |
200 | #### Absolute imports
201 |
202 | This project is setup as a package. One of the advantages of setting it up as a
203 | package is that it is easy to import modules from anywhere.
204 | To avoid errors with pytorch-lightning, always use absolute imports:
205 |
206 | ```python
207 | from seg_lapa.loss_func import CrossEntropy2D
208 | from seg_lapa import metrics
209 | import seg_lapa.metrics as metrics
210 | ```
211 |
212 | ### Citation
213 |
214 | ```
215 | @article{YourName,
216 | title={Your Title},
217 | author={Your team},
218 | journal={Location},
219 | year={Year}
220 | }
221 | ```
222 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.black]
2 | line-length = 120
3 | target-version = ["py38"]
4 | exclude = "(.eggs|.git|.hg|.mypy_cache|.nox|.tox|.venv|.svn|_build|buck-out|build|dist)"
5 |
--------------------------------------------------------------------------------
/requirements-cuda11.txt:
--------------------------------------------------------------------------------
1 | # This file is for installing pytorch 1.7 for cuda 11, as it requires different syntax
2 | --find-links https://download.pytorch.org/whl/torch_stable.html
3 | torch==1.7.1+cu110
4 | torchvision==0.8.2+cu110
5 |
--------------------------------------------------------------------------------
/requirements-dev.txt:
--------------------------------------------------------------------------------
1 | # Developer utilities
2 | pre-commit>=2.9.3
3 |
--------------------------------------------------------------------------------
/scripts/kill_hanging_processes.sh:
--------------------------------------------------------------------------------
1 | # When using DDP, processes don't always die cleanly. This snippet will kill all the processes of seg_lapa package.
2 | kill $(ps aux | grep seg_lapa | grep -v grep | awk '{print $2}')
3 |
--------------------------------------------------------------------------------
/scripts/sample-wandb-logging.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Shreeyak/pytorch-lightning-segmentation-template/851df18c61a6b7354304e3e54b939136378aa142/scripts/sample-wandb-logging.png
--------------------------------------------------------------------------------
/seg_lapa/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Shreeyak/pytorch-lightning-segmentation-template/851df18c61a6b7354304e3e54b939136378aa142/seg_lapa/__init__.py
--------------------------------------------------------------------------------
/seg_lapa/callbacks/log_media.py:
--------------------------------------------------------------------------------
1 | from collections import deque
2 | from dataclasses import dataclass
3 | from enum import Enum
4 | from typing import Any, List, Optional
5 |
6 | import numpy as np
7 | import torch
8 | import wandb
9 | from omegaconf import DictConfig, OmegaConf
10 | from pytorch_lightning import loggers as pl_loggers
11 | from pytorch_lightning.callbacks import Callback
12 | from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn
13 |
14 | # Project-specific imports for logging media to disk
15 | import cv2
16 | import math
17 | from pathlib import Path
18 | from seg_lapa.utils.segmentation_label2rgb import LabelToRGB, Palette
19 |
20 | CONFIG_FNAME = "config.yaml"
21 |
22 |
23 | class Mode(Enum):
24 | TRAIN = "Train"
25 | VAL = "Val"
26 | TEST = "Test"
27 |
28 |
29 | @dataclass
30 | class PredData:
31 | """Holds the data read and converted from the LightningModule's LogMediaQueue"""
32 |
33 | inputs: np.ndarray
34 | labels: np.ndarray
35 | preds: np.ndarray
36 |
37 |
38 | class LogMediaQueue:
39 | """Holds a circular queue for each of train/val/test modes, each of which contain the latest N batches of data"""
40 |
41 | def __init__(self, max_len: int = 3):
42 | if max_len < 1:
43 | raise ValueError(f"Queue must be length >= 1. Given: {max_len}")
44 |
45 | self.max_len = max_len
46 | self.log_media = {
47 | Mode.TRAIN: deque(maxlen=self.max_len),
48 | Mode.VAL: deque(maxlen=self.max_len),
49 | Mode.TEST: deque(maxlen=self.max_len),
50 | }
51 |
52 | def clear(self):
53 | """Clear all queues"""
54 | for mode, queue in self.log_media.items():
55 | queue.clear()
56 |
57 | @rank_zero_only
58 | def append(self, data: Any, mode: Mode):
59 | """Add a batch of data to a queue. Mode selects train/val/test queue"""
60 | self.log_media[mode].append(data)
61 |
62 | @rank_zero_only
63 | def fetch(self, mode: Mode) -> List[Any]:
64 | """Fetch all the batches available in a queue. Empties the selected queue"""
65 | data_r = []
66 | while len(self.log_media[mode]) > 0:
67 | data_r.append(self.log_media[mode].popleft())
68 |
69 | return data_r
70 |
71 | def len(self, mode: Mode) -> int:
72 | """Get the number of elements in a queue"""
73 | return len(self.log_media[mode])
74 |
75 |
76 | class LogMedia(Callback):
77 | r"""Logs model output images and other media to weights and biases
78 |
79 | This callback required adding an attribute to the LightningModule called ``self.log_media``. This is a circular
80 | queue that holds the latest N batches. This callback fetches the latest data from the queue for logging.
81 |
82 | Usage:
83 | import pytorch_lightning as pl
84 |
85 | class MyModel(pl.LightningModule):
86 | self.log_media: LogMediaQueue = LogMediaQueue(max_len)
87 |
88 | trainer = pl.Trainer(callbacks=[LogMedia()])
89 |
90 | Args:
91 | cfg (omegaconf.DictConfig, optional): The hydra cfg file given to run. If passed, It will be saved to the
92 | logs dir.
93 | period_epoch (int): If > 0, log every N epochs
94 | period_step (int): If > 0, log every N steps (i.e. batches)
95 | max_samples (int): Max number of data samples to log
96 | save_to_disk (bool): If True, save results to disk
97 | save_latest_only (only): If True, will overwrite prev results at each period.
98 | exp_dir (str or Path): Path to directory where results will be saved
99 | verbose (bool): verbosity mode. Default: ``True``.
100 | """
101 |
102 | SUPPORTED_LOGGERS = [pl_loggers.WandbLogger]
103 |
104 | def __init__(
105 | self,
106 | cfg: DictConfig = None,
107 | max_samples: int = 10,
108 | period_epoch: int = 1,
109 | period_step: int = 0,
110 | save_to_disk: bool = True,
111 | save_latest_only: bool = True,
112 | exp_dir: Optional[str] = None,
113 | verbose: bool = True,
114 | ):
115 | super().__init__()
116 | self.cfg = cfg
117 | self.max_samples = max_samples
118 | self.period_epoch = period_epoch
119 | self.period_step = period_step
120 | self.save_to_disk = save_to_disk
121 | self.save_latest_only = save_latest_only
122 | self.verbose = verbose
123 | self.valid_logger = False
124 |
125 | try:
126 | self.exp_dir = Path(exp_dir) if self.save_to_disk else None
127 | except TypeError as e:
128 | raise ValueError(f"Invalid exp_dir: {exp_dir}. \n{e}")
129 |
130 | if not OmegaConf.is_config(self.cfg):
131 | raise ValueError(f"Config file not of type {DictConfig}. Given: {type(cfg)}")
132 |
133 | # Project-specific fields
134 | self.class_labels_lapa = {
135 | 0: "background",
136 | 1: "skin",
137 | 2: "eyebrow_left",
138 | 3: "eyebrow_right",
139 | 4: "eye_left",
140 | 5: "eye_right",
141 | 6: "nose",
142 | 7: "lip_upper",
143 | 8: "inner_mouth",
144 | 9: "lip_lower",
145 | 10: "hair",
146 | }
147 |
148 | def setup(self, trainer, pl_module, stage: str):
149 | # This callback requires a ``.log_media`` attribute in LightningModule
150 | req_attr = "log_media"
151 | if not hasattr(pl_module, req_attr):
152 | raise AttributeError(
153 | f"{pl_module.__class__.__name__}.{req_attr} not found. The {LogMedia.__name__} "
154 | f"callback requires the LightningModule to have the {req_attr} attribute."
155 | )
156 | if not isinstance(pl_module.log_media, LogMediaQueue):
157 | raise AttributeError(f"{pl_module.__class__.__name__}.{req_attr} must be of type {LogMediaQueue.__name__}")
158 |
159 | if self.verbose:
160 | pl_module.print(
161 | f"Initializing Callback {LogMedia.__name__}. "
162 | f"Logging to disk: {self.exp_dir if self.save_to_disk else False}"
163 | )
164 |
165 | self._create_log_dir()
166 | self.valid_logger = True if self._logger_is_supported(trainer) else False
167 |
168 | # Save copy of config to logger
169 | if self.valid_logger:
170 | if isinstance(trainer.logger, pl_loggers.WandbLogger):
171 | OmegaConf.save(self.cfg, Path(trainer.logger.experiment.dir) / "train.yaml")
172 | trainer.logger.experiment.save("*.yaml")
173 | else:
174 | raise NotImplementedError
175 |
176 | def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
177 | if self._should_log_step(trainer, batch_idx):
178 | self._log_results(trainer, pl_module, Mode.TRAIN, batch_idx)
179 |
180 | def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
181 | if self._should_log_step(trainer, batch_idx):
182 | self._log_results(trainer, pl_module, Mode.VAL, batch_idx)
183 |
184 | def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
185 | if self._should_log_step(trainer, batch_idx):
186 | self._log_results(trainer, pl_module, Mode.TEST, batch_idx)
187 |
188 | def on_train_epoch_end(self, trainer, pl_module, outputs):
189 | if self._should_log_epoch(trainer):
190 | self._log_results(trainer, pl_module, Mode.TRAIN)
191 |
192 | def on_validation_epoch_end(self, trainer, pl_module):
193 | if self._should_log_epoch(trainer):
194 | self._log_results(trainer, pl_module, Mode.VAL)
195 |
196 | def on_test_epoch_end(self, trainer, pl_module):
197 | if self._should_log_epoch(trainer):
198 | self._log_results(trainer, pl_module, Mode.TEST)
199 |
200 | def _log_results(self, trainer, pl_module, mode: Mode, batch_idx: Optional[int] = None):
201 | pred_data = self._get_preds_from_lightningmodule(pl_module, mode)
202 | self._save_media_to_logger(trainer, pred_data, mode)
203 | self._save_media_to_disk(trainer, pred_data, mode, batch_idx)
204 |
205 | def _should_log_epoch(self, trainer):
206 | if trainer.running_sanity_check:
207 | return False
208 | if self.period_epoch < 1 or ((trainer.current_epoch + 1) % self.period_epoch != 0):
209 | return False
210 | return True
211 |
212 | def _should_log_step(self, trainer, batch_idx):
213 | if trainer.running_sanity_check:
214 | return False
215 | if self.period_step < 1 or ((batch_idx + 1) % self.period_step != 0):
216 | return False
217 | return True
218 |
219 | @rank_zero_only
220 | def _create_log_dir(self):
221 | if not self.save_to_disk:
222 | return
223 |
224 | self.exp_dir.mkdir(parents=True, exist_ok=True)
225 | if self.cfg is not None:
226 | fname = self.exp_dir / CONFIG_FNAME
227 | OmegaConf.save(config=self.cfg, f=fname, resolve=True)
228 |
229 | @rank_zero_only
230 | def _logger_is_supported(self, trainer):
231 | """This callback only works with wandb logger"""
232 | for logger_type in self.SUPPORTED_LOGGERS:
233 | if isinstance(trainer.logger, logger_type):
234 | return True
235 |
236 | rank_zero_warn(
237 | f"Unsupported logger: '{trainer.logger}', will not log any media to logger this run."
238 | f" Supported loggers: {[sup_log.__name__ for sup_log in self.SUPPORTED_LOGGERS]}."
239 | )
240 | return False
241 |
242 | @rank_zero_only
243 | def _get_preds_from_lightningmodule(self, pl_module, mode: Mode) -> Optional[PredData]:
244 | """Fetch latest N batches from the data queue in LightningModule.
245 | Process the tensors as required (example, convert to numpy arrays and scale)
246 | """
247 | if pl_module.log_media.len(mode) == 0: # Empty queue
248 | rank_zero_warn(f"\nEmpty LogMediaQueue! Mode: {mode}. Epoch: {pl_module.trainer.current_epoch}")
249 | return None
250 |
251 | media_data = pl_module.log_media.fetch(mode)
252 |
253 | inputs = torch.cat([x["inputs"] for x in media_data], dim=0)
254 | labels = torch.cat([x["labels"] for x in media_data], dim=0)
255 | preds = torch.cat([x["preds"] for x in media_data], dim=0)
256 |
257 | # Limit the num of samples and convert to numpy
258 | inputs = inputs[: self.max_samples].detach().cpu().numpy().transpose((0, 2, 3, 1))
259 | inputs = (inputs * 255).astype(np.uint8)
260 | labels = labels[: self.max_samples].detach().cpu().numpy().astype(np.uint8)
261 | preds = preds[: self.max_samples].detach().cpu().numpy().astype(np.uint8)
262 |
263 | out = PredData(inputs=inputs, labels=labels, preds=preds)
264 |
265 | return out
266 |
267 | @rank_zero_only
268 | def _save_media_to_disk(self, trainer, pred_data: Optional[PredData], mode: Mode, batch_idx: Optional[int] = None):
269 | """For a given mode (train/val/test), save the results to disk"""
270 | if not self.save_to_disk:
271 | return
272 | if pred_data is None: # Empty queue
273 | rank_zero_warn(f"Empty queue! Mode: {mode}")
274 | return
275 |
276 | # Create output filename
277 | if self.save_latest_only:
278 | output_filename = f"results.{mode.name.lower()}.png"
279 | else:
280 | if batch_idx is None:
281 | output_filename = f"results-epoch{trainer.current_epoch}.{mode.name.lower()}.png"
282 | else:
283 | output_filename = f"results-epoch{trainer.current_epoch}-step{batch_idx}.{mode.name.lower()}.png"
284 | output_filename = self.exp_dir / output_filename
285 |
286 | # Get the latest batches from the data queue in LightningModule
287 | inputs, labels, preds = pred_data.inputs, pred_data.labels, pred_data.preds
288 |
289 | # Colorize labels and predictions
290 | label2rgb = LabelToRGB()
291 | labels_rgb = [label2rgb.map_color_palette(lbl, Palette.LAPA) for lbl in labels]
292 | preds_rgb = [label2rgb.map_color_palette(pred, Palette.LAPA) for pred in preds]
293 | inputs_l = [ipt for ipt in inputs]
294 |
295 | # Create collage of results
296 | results_l = []
297 | # Combine each pair of inp/lbl/pred into singe image
298 | for inp, lbl, pred in zip(inputs_l, labels_rgb, preds_rgb):
299 | res_combined = np.concatenate((inp, lbl, pred), axis=1)
300 | results_l.append(res_combined)
301 | # Create grid from combined imgs
302 | n_imgs = len(results_l)
303 | n_cols = 4 # Fix num of columns
304 | n_rows = int(math.ceil(n_imgs / n_cols))
305 | img_h, img_w, _ = results_l[0].shape
306 | grid_results = np.zeros((img_h * n_rows, img_w * n_cols, 3), dtype=np.uint8)
307 | for idy in range(n_rows):
308 | for idx in range(n_cols):
309 | grid_results[idy * img_h : (idy + 1) * img_h, idx * img_w : (idx + 1) * img_w, :] = results_l[idx + idy]
310 |
311 | # Save collage
312 | if not cv2.imwrite(str(output_filename), cv2.cvtColor(grid_results, cv2.COLOR_RGB2BGR)):
313 | rank_zero_warn(f"Error in writing image: {output_filename}")
314 |
315 | @rank_zero_only
316 | def _save_media_to_logger(self, trainer, pred_data: Optional[PredData], mode: Mode):
317 | """Log images to wandb at the end of a batch. Steps are common for train/val/test"""
318 | if not self.valid_logger:
319 | return
320 | if pred_data is None: # Empty queue
321 | return
322 |
323 | if isinstance(trainer.logger, pl_loggers.WandbLogger):
324 | self._log_media_to_wandb(trainer, pred_data, mode)
325 | else:
326 | raise NotImplementedError(f"No method to log media to logger: {trainer.logger}")
327 |
328 | def _log_media_to_wandb(self, trainer, pred_data: Optional[PredData], mode: Mode):
329 | # Get the latest batches from the data queue in LightningModule
330 | inputs, labels, preds = pred_data.inputs, pred_data.labels, pred_data.preds
331 |
332 | # Create wandb Image for logging
333 | mask_list = []
334 | for img, lbl, pred in zip(inputs, labels, preds):
335 | mask_img = wandb.Image(
336 | img,
337 | masks={
338 | "predictions": {"mask_data": pred, "class_labels": self.class_labels_lapa},
339 | "groud_truth": {"mask_data": lbl, "class_labels": self.class_labels_lapa},
340 | },
341 | )
342 | mask_list.append(mask_img)
343 |
344 | wandb_log_label = f"{mode.name.title()}/Predictions"
345 | trainer.logger.experiment.log({wandb_log_label: mask_list}, commit=False)
346 |
--------------------------------------------------------------------------------
/seg_lapa/config/callbacks/disabled.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 | name: disabled
3 |
--------------------------------------------------------------------------------
/seg_lapa/config/callbacks/standard.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 | name: standard
3 |
4 | early_stopping:
5 | monitor: "Val/mIoU"
6 | min_delta: 0.0005
7 | patience: 10
8 | mode: "max"
9 | verbose: false
10 |
11 | checkpoints:
12 | filename: "best" # PL Default: "{epoch}-{step}". `=` in filename can cause errors when parsing cli overrides.
13 | save_last: true
14 | save_top_k: 1
15 | monitor: "Val/mIoU"
16 | mode: "max"
17 | period: 10
18 | verbose: false
19 |
20 | log_media:
21 | max_samples: 10
22 | period_epoch: 1
23 | period_step: 0
24 | save_to_disk: true
25 | save_latest_only: true
26 | verbose: true
27 |
--------------------------------------------------------------------------------
/seg_lapa/config/callbacks/with_lr_monitor.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 | name: standard
3 |
4 | early_stopping:
5 | monitor: "Val/mIoU"
6 | min_delta: 0.0005
7 | patience: 10
8 | mode: "max"
9 | verbose: false
10 |
11 | checkpoints:
12 | filename: "best" # PL Default: "{epoch}-{step}". `=` in filename can cause errors when parsing cli overrides.
13 | save_last: true
14 | save_top_k: 1
15 | monitor: "Val/mIoU"
16 | mode: "max"
17 | period: 10
18 | verbose: false
19 |
20 | log_media:
21 | max_samples: 10
22 | period_epoch: 1
23 | period_step: 0
24 | save_to_disk: true
25 | save_latest_only: true
26 | verbose: true
27 |
28 | lr_monitor:
29 | # set to epoch or step to log lr of all optimizers at the same interval, set to None to log at individual
30 | # interval according to the interval key of each scheduler.
31 | logging_interval: "step"
32 | log_momentum: false # If true, log the momentum values of the optimizer, if present
33 |
--------------------------------------------------------------------------------
/seg_lapa/config/dataset/lapa.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 | name: lapa
3 | data_dir: "/home/gaini/lightning/LaPa"
4 | batch_size: 16
5 | num_workers: 4
6 | resize_h: 256
7 | resize_w: 256
8 |
--------------------------------------------------------------------------------
/seg_lapa/config/load_weights/disabled.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 | name: disabled
3 |
--------------------------------------------------------------------------------
/seg_lapa/config/load_weights/pretrain.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 | name: load_weights
3 |
4 | # Use this to load pre-train weights (such as model pretrained on pascal dataset)
5 | path: null
6 |
--------------------------------------------------------------------------------
/seg_lapa/config/logger/disabled.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 | name: disabled
3 |
--------------------------------------------------------------------------------
/seg_lapa/config/logger/wandb.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 | name: wandb
3 | entity: cleargrasp2
4 | project: segmentation
5 | run_name: null # If None, will generate random name
6 | run_id: null # Pass run_id to resume logging to that run
7 |
--------------------------------------------------------------------------------
/seg_lapa/config/logger/wandb_debug.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 | name: wandb
3 | entity: cleargrasp2
4 | project: shrek_debug
5 | run_name: null # If None, will generate random name
6 | run_id: null # Pass run_id to resume logging to that run
7 |
--------------------------------------------------------------------------------
/seg_lapa/config/model/deeplabv3.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 | name: deeplabv3
3 | backbone: "drn" # Possible values for backbone: ['xception', 'resnet', 'drn']
4 | num_classes: 11
5 | output_stride: 8 # Not used with 'drn' backbone
6 | sync_bn: False # This enables custom batchnorm code that syncs across gpus.
7 | #enable_amp: False # Should always be false, since PL takes case of 16bit training
8 |
--------------------------------------------------------------------------------
/seg_lapa/config/optimizer/adam.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 | name: adam
3 | lr: 1e-3
4 | weight_decay: 0.0005
5 |
--------------------------------------------------------------------------------
/seg_lapa/config/optimizer/sgd.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 | name: sgd
3 | lr: 1e-6
4 | momentum: 0.9
5 | weight_decay: 0.0
6 | nesterov: False
7 |
--------------------------------------------------------------------------------
/seg_lapa/config/scheduler/cyclic.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 | # Sets the learning rate of each parameter group according to cyclical learning rate policy (CLR).
3 | # The policy cycles the learning rate between two boundaries with a constant frequency
4 | name: cyclic
5 |
6 | # Initial learning rate which is the lower boundary in the cycle for each parameter group.
7 | base_lr: 0.0001
8 |
9 | # Upper learning rate boundaries in the cycle for each parameter group.
10 | max_lr: 0.01
11 |
12 | # Number of training iterations in the increasing half of a cycle. Should be set to 2-10 times the number of iterations in an epoch
13 | step_size_up: 20
14 |
15 | # Number of training iterations in the decreasing half of a cycle. If step_size_down is None, it is set to step_size_up.
16 | step_size_down: null
17 |
--------------------------------------------------------------------------------
/seg_lapa/config/scheduler/disabled.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 | # Disabled learning rate scheduler.
3 | name: disabled
4 |
--------------------------------------------------------------------------------
/seg_lapa/config/scheduler/plateau.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 | # Reduce learning rate when a metric has no improvement for a ‘patience’ number of epochs.
3 | # Models often benefit from reducing the learning rate by a factor of 2-10 once learning stagnates.
4 | name: plateau
5 |
6 | # Factor by which the learning rate will be reduced. new_lr = lr * factor.
7 | factor: 0.8
8 |
9 | # Number of epochs with no improvement after which learning rate will be reduced.
10 | patience: 25
11 |
12 | # A lower bound on the learning rate of all param groups or each group respectively
13 | min_lr: 1e-7
14 |
15 | # One of min, max. In min mode, lr change when metric stops decreasing; in max mode when metric stops increasing.
16 | mode: "max"
17 |
18 | # Threshold for measuring the new optimum, to only focus on significant changes.
19 | threshold: 1e-4
20 |
21 | # Number of epochs to wait before resuming normal operation after lr has been reduced.
22 | cooldown: 0
23 |
24 | # Minimal decay applied to lr. If the difference between new and old lr is smaller than eps,
25 | # the update is ignored
26 | eps: 1e-8
27 |
28 | # If True, prints a message to stdout for each update.
29 | verbose: False
30 |
--------------------------------------------------------------------------------
/seg_lapa/config/scheduler/poly.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 | # The base lr decrease smoothly upto max_iter as a polynomial func
3 | # lr = base_lr * (1 - current_iter / max_iter) ^ pow_factor
4 | name: poly
5 |
6 | # Maximum number of iterations for which training will run
7 | max_iter: 500
8 |
9 | # The polinomial factor per formula above
10 | pow_factor: 0.9
11 |
--------------------------------------------------------------------------------
/seg_lapa/config/scheduler/step.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 | # Decays the learning rate of each parameter group by gamma every step_size epochs.
3 | name: step
4 |
5 | # Period of learning rate decay.
6 | step_size: 30
7 |
8 | # Multiplicative factor of learning rate decay.
9 | gamma: 0.1
10 |
--------------------------------------------------------------------------------
/seg_lapa/config/train.yaml:
--------------------------------------------------------------------------------
1 | random_seed: 0 # If None, seeds not set. If int, uses value to seed.
2 | logs_root_dir: "./" # Where to save logs and checkpoints.
3 |
4 | defaults:
5 | - dataset: lapa
6 | - model: deeplabv3
7 | - optimizer: adam
8 | - trainer: standard
9 | - scheduler: disabled
10 | - logger: wandb
11 | - callbacks: standard
12 | - load_weights: disabled
13 |
14 | # To disable any .log files
15 | - hydra/job_logging: disabled
16 | - hydra/hydra_logging: disabled
17 |
18 | hydra:
19 | output_subdir: null # Disable saving of config files. We'll do that ourselves.
20 | run:
21 | dir: . # Set working dir to current directory
22 |
--------------------------------------------------------------------------------
/seg_lapa/config/trainer/debug.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 | name: trainer
3 |
4 | gpus: 1
5 | accelerator: "ddp"
6 | precision: 16
7 |
8 | max_epochs: 3
9 | resume_from_checkpoint: null
10 | log_every_n_steps: 1
11 |
12 | # For deterministic runs
13 | benchmark: False
14 | deterministic: True
15 |
16 | # Limit batches for debugging
17 | fast_dev_run: False
18 | overfit_batches: 0.0
19 | limit_train_batches: 2
20 | limit_val_batches: 2
21 | limit_test_batches: 2
22 |
--------------------------------------------------------------------------------
/seg_lapa/config/trainer/standard.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 | name: trainer
3 |
4 | gpus: 1 # Denotes the number of gpus to use. Set CUDA_VISIBLE_DEVICES env var to control which gpus are used.
5 | accelerator: "ddp"
6 | precision: 16
7 |
8 | max_epochs: 100
9 | resume_from_checkpoint: null
10 | log_every_n_steps: 1
11 |
12 | # For deterministic runs
13 | benchmark: False # If true enables cudnn.benchmark.
14 | deterministic: True # If true enables cudnn.deterministic.
15 |
16 | # Limit batches for debugging
17 | fast_dev_run: False # If True, runs 1 batch of train, val and test to find any bugs (ie: a sort of unit test).
18 | overfit_batches: 0.0 # Overfit on subset of training data. Use the same as val/test set. (floats = percent, int = num_batches). Warn: 1 will be cast to 1.0.
19 | limit_train_batches: 1.0 # How much of training dataset to check (floats = percent, int = num_batches). Warn: 1 will be cast to 1.0.
20 | limit_val_batches: 1.0
21 | limit_test_batches: 1.0
22 |
--------------------------------------------------------------------------------
/seg_lapa/config_parse/callbacks_available.py:
--------------------------------------------------------------------------------
1 | """Dataclasses just to initialize and return Callback objects"""
2 | from typing import Optional
3 |
4 | from omegaconf import DictConfig
5 | from pydantic.dataclasses import dataclass
6 | from pytorch_lightning.callbacks import Callback, EarlyStopping, LearningRateMonitor, ModelCheckpoint
7 |
8 | from seg_lapa.callbacks.log_media import LogMedia
9 | from seg_lapa.config_parse.conf_utils import asdict_filtered
10 |
11 |
12 | @dataclass(frozen=True)
13 | class EarlyStopConf:
14 | monitor: str
15 | min_delta: float
16 | patience: int
17 | mode: str
18 | verbose: bool = False
19 |
20 | def get_callback(self) -> Callback:
21 | return EarlyStopping(**asdict_filtered(self))
22 |
23 |
24 | @dataclass(frozen=True)
25 | class CheckpointConf:
26 | filename: Optional[str]
27 | monitor: Optional[str]
28 | mode: str
29 | save_last: Optional[bool]
30 | period: int
31 | save_top_k: Optional[int]
32 | verbose: bool = False
33 |
34 | def get_callback(self, logs_dir) -> Callback:
35 | return ModelCheckpoint(dirpath=logs_dir, **asdict_filtered(self))
36 |
37 |
38 | @dataclass(frozen=True)
39 | class LogMediaConf:
40 | max_samples: int
41 | period_epoch: int
42 | period_step: int
43 | save_to_disk: bool
44 | save_latest_only: bool
45 | verbose: bool = False
46 |
47 | def get_callback(self, exp_dir: str, cfg: DictConfig) -> Callback:
48 | return LogMedia(exp_dir=exp_dir, cfg=cfg, **asdict_filtered(self))
49 |
50 |
51 | @dataclass(frozen=True)
52 | class LearningRateMonitorConf:
53 | logging_interval: Optional[str]
54 | log_momentum: bool = False
55 |
56 | def get_callback(self) -> Callback:
57 | return LearningRateMonitor(**asdict_filtered(self))
58 |
--------------------------------------------------------------------------------
/seg_lapa/config_parse/callbacks_conf.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from typing import Dict, List, Optional
3 |
4 | from omegaconf import DictConfig
5 | from pydantic.dataclasses import dataclass
6 | from pytorch_lightning.callbacks import Callback
7 |
8 | from seg_lapa.config_parse.callbacks_available import (
9 | CheckpointConf,
10 | EarlyStopConf,
11 | LearningRateMonitorConf,
12 | LogMediaConf,
13 | )
14 | from seg_lapa.config_parse.conf_utils import validate_config_group_generic
15 |
16 |
17 | # The Callbacks config cannot be directly initialized because it contains sub-entries for each callback, each
18 | # of which describes a separate class.
19 | # For each of the callbacks, we define a dataclass and use them to init the list of callbacks
20 |
21 |
22 | @dataclass(frozen=True)
23 | class CallbacksConf(ABC):
24 | name: str
25 |
26 | @abstractmethod
27 | def get_callbacks_list(self, *args) -> List:
28 | return []
29 |
30 |
31 | @dataclass(frozen=True)
32 | class DisabledCallbacksConf(CallbacksConf):
33 | def get_callbacks_list(self, *args) -> List:
34 | return []
35 |
36 |
37 | @dataclass(frozen=True)
38 | class StandardCallbacksConf(CallbacksConf):
39 | """Get a dictionary of all the callbacks."""
40 |
41 | early_stopping: Optional[Dict] = None
42 | checkpoints: Optional[Dict] = None
43 | log_media: Optional[Dict] = None
44 | lr_monitor: Optional[Dict] = None
45 |
46 | def get_callbacks_list(self, exp_dir: str, cfg: DictConfig) -> List[Callback]:
47 | """Get all available callbacks and the Callback Objects in list
48 | If a callback's entry is not present in the config file, it'll not be output in the list
49 | """
50 | callbacks_list = []
51 | if self.early_stopping is not None:
52 | early_stop = EarlyStopConf(**self.early_stopping).get_callback()
53 | callbacks_list.append(early_stop)
54 |
55 | if self.checkpoints is not None:
56 | checkpoint = CheckpointConf(**self.checkpoints).get_callback(exp_dir)
57 | callbacks_list.append(checkpoint)
58 |
59 | if self.log_media is not None:
60 | log_media = LogMediaConf(**self.log_media).get_callback(exp_dir, cfg)
61 | callbacks_list.append(log_media)
62 |
63 | if self.lr_monitor is not None:
64 | lr_monitor = LearningRateMonitorConf(**self.lr_monitor).get_callback()
65 | callbacks_list.append(lr_monitor)
66 |
67 | return callbacks_list
68 |
69 |
70 | valid_names = {
71 | "disabled": DisabledCallbacksConf,
72 | "standard": StandardCallbacksConf,
73 | }
74 |
75 |
76 | def validate_config_group(cfg_subgroup: DictConfig) -> CallbacksConf:
77 | validated_dataclass = validate_config_group_generic(
78 | cfg_subgroup, dataclass_dict=valid_names, config_category="callback"
79 | )
80 | return validated_dataclass
81 |
--------------------------------------------------------------------------------
/seg_lapa/config_parse/conf_utils.py:
--------------------------------------------------------------------------------
1 | import dataclasses
2 | from typing import Dict, Optional, Sequence
3 |
4 | from omegaconf import DictConfig, OmegaConf
5 |
6 |
7 | def asdict_filtered(obj, remove_keys: Optional[Sequence[str]] = None) -> Dict:
8 | """Returns the attributes of a dataclass in the form of a dict, with unwanted attributes removed.
9 | Each config group has the term 'name', which is helpful in identifying the node that was chosen
10 | in the config group (Eg. config group = optimizers, nodes = adam, sgd).
11 | However, the 'name' parameter is not required for initializing any dataclasses. Hence it needs to be removed.
12 |
13 | Args:
14 | obj: The dataclass whose atrributes will be converted to dict
15 | remove_keys: The keys to remove from the dict. The default is ['name'].
16 | """
17 | if not dataclasses.is_dataclass(obj):
18 | raise ValueError(f"Not a dataclass/dataclass instance")
19 |
20 | if remove_keys is None:
21 | remove_keys = ["name"]
22 |
23 | # Clean the arguments
24 | args = dataclasses.asdict(obj)
25 | for key in remove_keys:
26 | if key in args:
27 | args.pop(key)
28 |
29 | return args
30 |
31 |
32 | def validate_config_group_generic(cfg_group: DictConfig, dataclass_dict: Dict, config_category: str = "option"):
33 | """Use a hydra config group to initialize a pydantic dataclass. Initializing it validates the data.
34 | Each of our config groups has a name parameter, which is used to map to valid dataclasses for validation.
35 |
36 | Pydantic will force the parameters to the desired datatype and will throw errors if the config
37 | cannot be cast to the dataclass members.
38 |
39 | Args:
40 | cfg_group: The config group extracted from the hydra config.
41 | dataclass_dict: A dict containing the mapping from 'name' entry in config files to matching
42 | pydantic dataclasses for validation.
43 | config_category: For pretty print statements. Configure the name of the config group when throwing error.
44 |
45 | Raises:
46 | ValueError: If the name parameter does not match to any of the valid options
47 | """
48 | if not OmegaConf.is_config(cfg_group):
49 | raise ValueError(f"Given config not an OmegaConf config. Got: {type(cfg_group)}")
50 |
51 | # Get the "name" entry in config
52 | name = cfg_group.name
53 | if name is None:
54 | raise KeyError(
55 | f"The given config does not contain a 'name' entry. Cannot map to a dataclass.\n"
56 | f" Config:\n {OmegaConf.to_yaml(cfg_group)}"
57 | )
58 |
59 | # Convert hydra config to dict - This dict contains the arguments to init dataclass
60 | cfg_asdict = OmegaConf.to_container(cfg_group, resolve=True)
61 |
62 | # Get the dataclass to init from the mapping. Init the dataclass using hydra config
63 | try:
64 | dataclass_obj = dataclass_dict[name](**cfg_asdict)
65 | except KeyError:
66 | raise ValueError(
67 | f"Invalid Config: '{cfg_group.name}' is not a valid {config_category}. "
68 | f"Valid Options: {list(dataclass_dict.keys())}"
69 | )
70 |
71 | return dataclass_obj
72 |
--------------------------------------------------------------------------------
/seg_lapa/config_parse/dataset_conf.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 |
3 | import pytorch_lightning as pl
4 | from omegaconf import DictConfig
5 | from pydantic.dataclasses import dataclass
6 |
7 | from seg_lapa.config_parse.conf_utils import asdict_filtered, validate_config_group_generic
8 | from seg_lapa.datasets.lapa import LaPaDataModule
9 |
10 |
11 | @dataclass(frozen=True)
12 | class DatasetConf(ABC):
13 | name: str
14 |
15 | @abstractmethod
16 | def get_datamodule(self) -> pl.LightningDataModule:
17 | pass
18 |
19 |
20 | @dataclass(frozen=True)
21 | class LapaConf(DatasetConf):
22 | data_dir: str
23 | batch_size: int
24 | num_workers: int
25 | resize_h: int
26 | resize_w: int
27 |
28 | def get_datamodule(self) -> LaPaDataModule:
29 | return LaPaDataModule(**asdict_filtered(self))
30 |
31 |
32 | valid_names = {"lapa": LapaConf}
33 |
34 |
35 | def validate_config_group(cfg_subgroup: DictConfig) -> DatasetConf:
36 | validated_dataclass = validate_config_group_generic(
37 | cfg_subgroup, dataclass_dict=valid_names, config_category="dataset"
38 | )
39 | return validated_dataclass
40 |
--------------------------------------------------------------------------------
/seg_lapa/config_parse/load_weights_conf.py:
--------------------------------------------------------------------------------
1 | from abc import ABC
2 | from typing import Optional
3 |
4 | from omegaconf import DictConfig
5 | from pydantic.dataclasses import dataclass
6 |
7 | from seg_lapa.config_parse.conf_utils import validate_config_group_generic
8 |
9 |
10 | @dataclass(frozen=True)
11 | class LoadWeightsConf(ABC):
12 | name: str
13 | path: Optional[str] = None
14 |
15 |
16 | valid_names = {
17 | "disabled": LoadWeightsConf,
18 | "load_weights": LoadWeightsConf,
19 | }
20 |
21 |
22 | def validate_config_group(cfg_subgroup: DictConfig) -> LoadWeightsConf:
23 | validated_dataclass = validate_config_group_generic(
24 | cfg_subgroup, dataclass_dict=valid_names, config_category="load_weights"
25 | )
26 | return validated_dataclass
27 |
--------------------------------------------------------------------------------
/seg_lapa/config_parse/logger_conf.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from pathlib import Path
3 | from typing import Optional
4 |
5 | import wandb
6 | from omegaconf import DictConfig, OmegaConf
7 | from pydantic.dataclasses import dataclass
8 | from pytorch_lightning import loggers as pl_loggers
9 |
10 | from seg_lapa.config_parse.conf_utils import asdict_filtered, validate_config_group_generic
11 |
12 |
13 | @dataclass
14 | class LoggerConf(ABC):
15 | name: str
16 |
17 | @abstractmethod
18 | def get_logger(self, *args):
19 | pass
20 |
21 | @abstractmethod
22 | def get_run_id(self, *args):
23 | """Loggers such as WandB generate a unique run id that can be used to resume runs"""
24 | pass
25 |
26 |
27 | @dataclass
28 | class DisabledLoggerConf(LoggerConf):
29 | @staticmethod
30 | def get_logger(*args):
31 | return False
32 |
33 | @staticmethod
34 | def get_run_id():
35 | return None
36 |
37 |
38 | @dataclass
39 | class WandbConf(LoggerConf):
40 | """Weights and Biases. Ref: wandb.com"""
41 |
42 | entity: str
43 | project: str
44 | run_name: Optional[str]
45 | run_id: Optional[str] = None # Pass run_id to resume logging to that run.
46 |
47 | def get_logger(self, cfg: DictConfig, save_dir: Path) -> pl_loggers.WandbLogger:
48 | """Returns the Weights and Biases (wandb) logger object (really an wandb Run object)
49 | The run object corresponds to a single execution of the script and is returned from `wandb.init()`.
50 |
51 | Args:
52 | run_id: Unique run id. If run id exists, will continue logging to that run.
53 | cfg: The entire config got from hydra, for purposes of logging the config of each run in wandb.
54 | save_dir: Root dir to save wandb log files
55 |
56 | Returns:
57 | wandb.wandb_sdk.wandb_run.Run: wandb run object. Can be used for logging.
58 | """
59 | # Some argument names to wandb are different from the attribute names of the class.
60 | # Pop the offending attributes before passing to init func.
61 | args_dict = asdict_filtered(self)
62 | run_name = args_dict.pop("run_name")
63 | run_id = args_dict.pop("run_id")
64 |
65 | # If `self.save_hyperparams()` is called in LightningModule, it will save the cfg passed as argument
66 | # cfg_dict = OmegaConf.to_container(cfg, resolve=True)
67 |
68 | wb_logger = pl_loggers.WandbLogger(name=run_name, id=run_id, save_dir=str(save_dir), **args_dict)
69 |
70 | return wb_logger
71 |
72 | def get_run_id(self):
73 | """If a run_id has been provided by the user, resume logging to that run.
74 | Otherwise a random run-id will be generated
75 | """
76 | if self.run_id is None:
77 | self.run_id = wandb.util.generate_id()
78 |
79 | return self.run_id
80 |
81 |
82 | valid_names = {
83 | "wandb": WandbConf,
84 | "disabled": DisabledLoggerConf,
85 | }
86 |
87 |
88 | def validate_config_group(cfg_subgroup: DictConfig) -> LoggerConf:
89 | validated_dataclass = validate_config_group_generic(
90 | cfg_subgroup, dataclass_dict=valid_names, config_category="logger"
91 | )
92 | return validated_dataclass
93 |
--------------------------------------------------------------------------------
/seg_lapa/config_parse/model_conf.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 |
3 | import torch
4 | from omegaconf import DictConfig
5 | from pydantic.dataclasses import dataclass
6 |
7 | from seg_lapa.config_parse.conf_utils import asdict_filtered, validate_config_group_generic
8 | from seg_lapa.networks.deeplab.deeplab import DeepLab
9 |
10 |
11 | @dataclass(frozen=True)
12 | class ModelConf(ABC):
13 | name: str
14 | num_classes: int
15 |
16 | @abstractmethod
17 | def get_model(self):
18 | pass
19 |
20 |
21 | @dataclass(frozen=True)
22 | class Deeplabv3Conf(ModelConf):
23 | backbone: str
24 | output_stride: int
25 | sync_bn: bool # Can use PL to sync batchnorm. This enables custom batchnorm code.
26 | enable_amp: bool = False # Should always be false, since PL takes case of 16bit training
27 |
28 | def get_model(self) -> torch.nn.Module:
29 | return DeepLab(**asdict_filtered(self))
30 |
31 |
32 | valid_names = {
33 | "deeplabv3": Deeplabv3Conf,
34 | }
35 |
36 |
37 | def validate_config_group(cfg_subgroup: DictConfig) -> ModelConf:
38 | validated_dataclass = validate_config_group_generic(
39 | cfg_subgroup, dataclass_dict=valid_names, config_category="model"
40 | )
41 | return validated_dataclass
42 |
--------------------------------------------------------------------------------
/seg_lapa/config_parse/optimizer_conf.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 |
3 | import torch
4 | from omegaconf import DictConfig
5 | from pydantic.dataclasses import dataclass
6 |
7 | from seg_lapa.config_parse.conf_utils import asdict_filtered, validate_config_group_generic
8 |
9 |
10 | @dataclass(frozen=True)
11 | class OptimConf(ABC):
12 | name: str
13 |
14 | @abstractmethod
15 | def get_optimizer(self, model_params) -> torch.optim.Optimizer:
16 | pass
17 |
18 |
19 | @dataclass(frozen=True)
20 | class AdamConf(OptimConf):
21 | lr: float
22 | weight_decay: float
23 |
24 | def get_optimizer(self, model_params) -> torch.optim.Optimizer:
25 | return torch.optim.Adam(params=model_params, **asdict_filtered(self))
26 |
27 |
28 | @dataclass(frozen=True)
29 | class SgdConf(OptimConf):
30 | lr: float
31 | momentum: float
32 | weight_decay: float
33 | nesterov: bool
34 |
35 | def get_optimizer(self, model_params) -> torch.optim.Optimizer:
36 | return torch.optim.SGD(params=model_params, **asdict_filtered(self))
37 |
38 |
39 | valid_names = {"adam": AdamConf, "sgd": SgdConf}
40 |
41 |
42 | def validate_config_group(cfg_subgroup: DictConfig) -> OptimConf:
43 | validated_dataclass = validate_config_group_generic(
44 | cfg_subgroup, dataclass_dict=valid_names, config_category="optimizer"
45 | )
46 | return validated_dataclass
47 |
--------------------------------------------------------------------------------
/seg_lapa/config_parse/scheduler_conf.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from typing import List, Optional, Union
3 |
4 | import torch
5 | from omegaconf import DictConfig
6 | from pydantic.dataclasses import dataclass
7 | from torch.optim.optimizer import Optimizer
8 |
9 | from seg_lapa.config_parse.conf_utils import asdict_filtered, validate_config_group_generic
10 |
11 |
12 | @dataclass(frozen=True)
13 | class SchedulerConf(ABC):
14 | name: str
15 |
16 | @abstractmethod
17 | def get_scheduler(self, optimizer: Optimizer) -> Optional[torch.optim.lr_scheduler._LRScheduler]:
18 | pass
19 |
20 |
21 | @dataclass(frozen=True)
22 | class DisabledConfig(SchedulerConf):
23 | def get_scheduler(self, optimizer: Optimizer) -> None:
24 | return None
25 |
26 |
27 | @dataclass(frozen=True)
28 | class CyclicConfig(SchedulerConf):
29 | base_lr: float
30 | max_lr: float
31 | step_size_up: int
32 | step_size_down: Optional[int]
33 |
34 | def get_scheduler(self, optimizer: Optimizer) -> torch.optim.lr_scheduler.CyclicLR:
35 | return torch.optim.lr_scheduler.CyclicLR(optimizer, cycle_momentum=False, **asdict_filtered(self))
36 |
37 |
38 | @dataclass(frozen=True)
39 | class PolyConfig(SchedulerConf):
40 | max_iter: int
41 | pow_factor: float
42 |
43 | def get_scheduler(self, optimizer: Optimizer) -> torch.optim.lr_scheduler.LambdaLR:
44 | max_iter = float(self.max_iter)
45 | pow_factor = float(self.pow_factor)
46 |
47 | def poly_schedule(n_iter: int) -> float:
48 | return (1 - n_iter / max_iter) ** pow_factor
49 |
50 | return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=poly_schedule)
51 |
52 |
53 | @dataclass(frozen=True)
54 | class StepConfig(SchedulerConf):
55 | step_size: int
56 | gamma: float
57 |
58 | def get_scheduler(self, optimizer: Optimizer) -> torch.optim.lr_scheduler.StepLR:
59 | return torch.optim.lr_scheduler.StepLR(optimizer, **asdict_filtered(self))
60 |
61 |
62 | @dataclass(frozen=True)
63 | class PlateauConfig(SchedulerConf):
64 | factor: float
65 | patience: int
66 | min_lr: Union[float, List[float]]
67 | mode: str
68 | threshold: float
69 | cooldown: int
70 | eps: float
71 | verbose: bool
72 |
73 | def get_scheduler(self, optimizer: Optimizer) -> torch.optim.lr_scheduler.ReduceLROnPlateau:
74 | return torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, **asdict_filtered(self))
75 |
76 |
77 | valid_names = {
78 | "disabled": DisabledConfig,
79 | "cyclic": CyclicConfig,
80 | "plateau": PlateauConfig,
81 | "poly": PolyConfig,
82 | "step": StepConfig,
83 | }
84 |
85 |
86 | def validate_config_group(cfg_subgroup: DictConfig) -> SchedulerConf:
87 | validated_dataclass = validate_config_group_generic(
88 | cfg_subgroup, dataclass_dict=valid_names, config_category="scheduler"
89 | )
90 | return validated_dataclass
91 |
--------------------------------------------------------------------------------
/seg_lapa/config_parse/train_conf.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | from omegaconf import DictConfig
4 | from pydantic.dataclasses import dataclass
5 |
6 | from seg_lapa.config_parse import (
7 | callbacks_conf,
8 | dataset_conf,
9 | load_weights_conf,
10 | logger_conf,
11 | model_conf,
12 | optimizer_conf,
13 | scheduler_conf,
14 | trainer_conf,
15 | )
16 | from seg_lapa.config_parse.callbacks_conf import CallbacksConf
17 | from seg_lapa.config_parse.dataset_conf import DatasetConf
18 | from seg_lapa.config_parse.load_weights_conf import LoadWeightsConf
19 | from seg_lapa.config_parse.logger_conf import LoggerConf
20 | from seg_lapa.config_parse.model_conf import ModelConf
21 | from seg_lapa.config_parse.optimizer_conf import OptimConf
22 | from seg_lapa.config_parse.scheduler_conf import SchedulerConf
23 | from seg_lapa.config_parse.trainer_conf import TrainerConf
24 |
25 |
26 | @dataclass(frozen=True)
27 | class TrainConf:
28 | random_seed: Optional[int]
29 | logs_root_dir: str
30 | dataset: DatasetConf
31 | optimizer: OptimConf
32 | model: ModelConf
33 | trainer: TrainerConf
34 | scheduler: SchedulerConf
35 | logger: LoggerConf
36 | callbacks: CallbacksConf
37 | load_weights: LoadWeightsConf
38 |
39 |
40 | class ParseConfig:
41 | @classmethod
42 | def parse_config(cls, cfg: DictConfig) -> TrainConf:
43 | """Parses the config file read from hydra to populate the TrainConfig dataclass"""
44 | config = TrainConf(
45 | random_seed=cfg.random_seed,
46 | logs_root_dir=cfg.logs_root_dir,
47 | dataset=dataset_conf.validate_config_group(cfg.dataset),
48 | model=model_conf.validate_config_group(cfg.model),
49 | optimizer=optimizer_conf.validate_config_group(cfg.optimizer),
50 | trainer=trainer_conf.validate_config_group(cfg.trainer),
51 | scheduler=scheduler_conf.validate_config_group(cfg.scheduler),
52 | logger=logger_conf.validate_config_group(cfg.logger),
53 | callbacks=callbacks_conf.validate_config_group(cfg.callbacks),
54 | load_weights=load_weights_conf.validate_config_group(cfg.load_weights),
55 | )
56 |
57 | return config
58 |
--------------------------------------------------------------------------------
/seg_lapa/config_parse/trainer_conf.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from typing import List, Optional
3 |
4 | import pytorch_lightning as pl
5 | from omegaconf import DictConfig
6 | from pydantic.dataclasses import dataclass
7 | from pytorch_lightning.callbacks import Callback
8 | from pytorch_lightning.loggers.base import LightningLoggerBase
9 |
10 | from seg_lapa.config_parse.conf_utils import asdict_filtered, validate_config_group_generic
11 |
12 |
13 | @dataclass(frozen=True)
14 | class TrainerConf(ABC):
15 | name: str
16 |
17 | @abstractmethod
18 | def get_trainer(
19 | self, pl_logger: LightningLoggerBase, callbacks: List[Callback], default_root_dir: str
20 | ) -> pl.Trainer:
21 | pass
22 |
23 |
24 | @dataclass(frozen=True)
25 | class TrainerConfig(TrainerConf):
26 | gpus: int
27 | accelerator: Optional[str]
28 | precision: int
29 | max_epochs: int
30 | resume_from_checkpoint: Optional[str]
31 | log_every_n_steps: int
32 |
33 | benchmark: bool = False
34 | deterministic: bool = False
35 | fast_dev_run: bool = False
36 | overfit_batches: float = 0.0
37 | limit_train_batches: float = 1.0
38 | limit_val_batches: float = 1.0
39 | limit_test_batches: float = 1.0
40 |
41 | def get_trainer(
42 | self, pl_logger: LightningLoggerBase, callbacks: List[Callback], default_root_dir: str
43 | ) -> pl.Trainer:
44 | trainer = pl.Trainer(
45 | logger=pl_logger,
46 | callbacks=callbacks,
47 | default_root_dir=default_root_dir,
48 | **asdict_filtered(self),
49 | )
50 | return trainer
51 |
52 |
53 | valid_names = {"trainer": TrainerConfig}
54 |
55 |
56 | def validate_config_group(cfg_subgroup: DictConfig) -> TrainerConf:
57 | validated_dataclass = validate_config_group_generic(
58 | cfg_subgroup, dataclass_dict=valid_names, config_category="trainer"
59 | )
60 | return validated_dataclass
61 |
--------------------------------------------------------------------------------
/seg_lapa/datasets/lapa.py:
--------------------------------------------------------------------------------
1 | import enum
2 | import random
3 | from pathlib import Path
4 | from typing import List, Optional, Tuple, Union
5 |
6 | import albumentations as A
7 | import cv2
8 | import numpy as np
9 | import pytorch_lightning as pl
10 | import torch
11 | import torchvision
12 | from torch.utils.data import DataLoader, Dataset
13 |
14 | from seg_lapa.utils.path_check import get_path, PathType
15 |
16 |
17 | class DatasetSplit(enum.Enum):
18 | TRAIN = 0
19 | VAL = 1
20 | TEST = 2
21 |
22 |
23 | class LapaDataset(Dataset):
24 | """The Landmark guided face Parsing dataset (LaPa)
25 | Contains pixel-level annotations for face parsing.
26 |
27 | References:
28 | https://github.com/JDAI-CV/lapa-dataset
29 | """
30 |
31 | @enum.unique
32 | class LapaClassId(enum.IntEnum):
33 | # Mapping of the classes within the lapa dataset
34 | BACKGROUND = 0
35 | SKIN = 1
36 | EYEBROW_LEFT = 2
37 | EYEBROW_RIGHT = 3
38 | EYE_LEFT = 4
39 | EYE_RIGHT = 5
40 | NOSE = 6
41 | LIP_UPPER = 7
42 | INNER_MOUTH = 8
43 | LIP_LOWER = 9
44 | HAIR = 10
45 |
46 | SUBDIR_IMAGES = "images"
47 | SUBDIR_LABELS = "labels"
48 |
49 | SUBDIR_SPLIT = {DatasetSplit.TRAIN: "train", DatasetSplit.VAL: "val", DatasetSplit.TEST: "test"}
50 |
51 | def __init__(
52 | self,
53 | root_dir: Union[str, Path],
54 | data_split: DatasetSplit,
55 | image_ext: Tuple[str] = ("*.jpg",),
56 | label_ext: Tuple[str] = ("*.png",),
57 | augmentations: Optional[A.Compose] = None,
58 | ):
59 | super().__init__()
60 | self.augmentations = augmentations
61 | self.image_ext = image_ext # The file extensions of input images to search for in input dir
62 | self.label_ext = label_ext # The file extensions of labels to search for in label dir
63 | self.root_dir = self._check_dir(root_dir)
64 |
65 | # Get subdirs for images and labels
66 | self.images_dir = self._check_dir(self.root_dir / self.SUBDIR_SPLIT[data_split] / self.SUBDIR_IMAGES)
67 | self.labels_dir = self._check_dir(self.root_dir / self.SUBDIR_SPLIT[data_split] / self.SUBDIR_LABELS)
68 |
69 | # Create list of filenames
70 | self._datalist_input = [] # Variable containing list of all input images filenames in dataset
71 | self._datalist_label = [] # Variable containing list of all ground truth filenames in dataset
72 | self._create_lists_filenames()
73 |
74 | def __len__(self):
75 | return len(self._datalist_input)
76 |
77 | def __getitem__(self, index):
78 | # Read input rgb imgs
79 | image_path = self._datalist_input[index]
80 | img = self._read_image(image_path)
81 |
82 | # Read ground truth labels
83 | label_path = self._datalist_label[index]
84 | label = self._read_label(label_path)
85 |
86 | # Apply image augmentations
87 | if self.augmentations is not None:
88 | augmented = self.augmentations(image=img, mask=label)
89 | img = augmented["image"]
90 | label = augmented["mask"]
91 |
92 | # Convert to Tensor. RGB images are normally numpy uint8 array with shape (H, W, 3).
93 | # RGB tensors should be (3, H, W) with dtype float32 in range [0, 1] (may change with normalization applied)
94 | img_tensor = torchvision.transforms.ToTensor()(img)
95 | label_tensor = torch.from_numpy(label)
96 |
97 | # TODO: Return dict
98 | # data = {
99 | # 'image': img_tensor,
100 | # 'label': label_tensor
101 | # }
102 | # return data
103 |
104 | return img_tensor, label_tensor.long()
105 |
106 | @staticmethod
107 | def _check_dir(dir_path: Union[str, Path]) -> Path:
108 | return get_path(dir_path, must_exist=True, path_type=PathType.DIR)
109 |
110 | @staticmethod
111 | def _read_label(label_path: Path) -> np.ndarray:
112 | mask = cv2.imread(str(label_path), cv2.IMREAD_GRAYSCALE | cv2.IMREAD_ANYDEPTH | cv2.IMREAD_ANYCOLOR)
113 |
114 | if len(mask.shape) != 2:
115 | raise RuntimeError(f"The shape of label must be (H, W). Got: {mask.shape}")
116 |
117 | return mask.astype(np.int32)
118 |
119 | @staticmethod
120 | def _read_image(image_path: Path) -> np.ndarray:
121 | mask = cv2.imread(str(image_path), cv2.IMREAD_COLOR)
122 | mask = cv2.cvtColor(mask, cv2.COLOR_BGR2RGB)
123 |
124 | if len(mask.shape) != 3:
125 | raise RuntimeError(f"The shape of image must be (H, W, C). Got: {mask.shape}")
126 |
127 | return mask
128 |
129 | def _create_lists_filenames(self):
130 | """Creates a list of filenames of images and labels in dataset"""
131 | self._datalist_input = self._get_matching_files_in_dir(self.images_dir, self.image_ext)
132 | self._datalist_label = self._get_matching_files_in_dir(self.labels_dir, self.label_ext)
133 |
134 | num_images = len(self._datalist_input)
135 | num_labels = len(self._datalist_label)
136 | if num_images != num_labels:
137 | raise ValueError(
138 | f"The number of images ({num_images}) and labels ({num_labels}) do not match."
139 | f"\n Images dir: {self.images_dir}\n Labels dir:{self.labels_dir}"
140 | )
141 |
142 | def _get_matching_files_in_dir(self, data_dir: Union[str, Path], wildcard_patterns: Tuple[str]) -> List[Path]:
143 | """Get filenames within a dir that match a set of wildcard patterns
144 | Will not search within subdirectories.
145 |
146 | Args:
147 | data_dir: Directory to search within
148 | wildcard_patterns: Tuple of wildcard patterns matching required filenames. Eg: ('*.rgb.png', '*.rgb.jpg')
149 |
150 | Returns:
151 | list[Path]: List of paths to files found
152 | """
153 | data_dir = self._check_dir(data_dir)
154 |
155 | list_matching_files = []
156 | for ext in wildcard_patterns:
157 | list_matching_files += sorted(data_dir.glob(ext))
158 |
159 | if len(list_matching_files) == 0:
160 | raise ValueError(
161 | "No matching files found in given directory."
162 | f"\n Directory: {data_dir}\n Search patterns: {wildcard_patterns}"
163 | )
164 |
165 | return list_matching_files
166 |
167 |
168 | class LaPaDataModule(pl.LightningDataModule):
169 | def __init__(self, data_dir: str, batch_size: int, num_workers: int, resize_h: int, resize_w: int):
170 | super().__init__()
171 | self.batch_size = batch_size
172 | self.num_workers = num_workers
173 | self.data_dir = get_path(data_dir, must_exist=True, path_type=PathType.DIR)
174 | self.resize_h = resize_h
175 | self.resize_w = resize_w
176 |
177 | self.lapa_train = None
178 | self.lapa_val = None
179 | self.lapa_test = None
180 |
181 | def prepare_data(self):
182 | """download dataset, tokenize, etc"""
183 | """
184 | Downloading original data from author's google drive link:
185 | >>> import gdown
186 | >>> url = "https://drive.google.com/uc?export=download&id=1EtyCtiQZt2Y5qrb-0YxRxaVLpVcgCOQV"
187 | >>> output = "lapa-downloaded.tar.gz"
188 | >>> gdown.download(url, output, quiet=False, proxy=False)
189 | """
190 | pass
191 |
192 | def setup(self, stage=None):
193 | # count number of classes, perform train/val/test splits, apply transforms, etc
194 | augs_train = self.get_augs_train()
195 | augs_test = self.get_augs_test()
196 |
197 | self.lapa_train = LapaDataset(root_dir=self.data_dir, data_split=DatasetSplit.TRAIN, augmentations=augs_train)
198 | self.lapa_test = LapaDataset(root_dir=self.data_dir, data_split=DatasetSplit.VAL, augmentations=augs_test)
199 | self.lapa_val = LapaDataset(root_dir=self.data_dir, data_split=DatasetSplit.TEST, augmentations=augs_test)
200 |
201 | def train_dataloader(self):
202 | train_loader = DataLoader(
203 | self.lapa_train,
204 | batch_size=self.batch_size,
205 | num_workers=self.num_workers,
206 | pin_memory=True,
207 | drop_last=True,
208 | worker_init_fn=self._dataloader_worker_init,
209 | )
210 | return train_loader
211 |
212 | def val_dataloader(self):
213 | val_loader = DataLoader(
214 | self.lapa_val,
215 | batch_size=self.batch_size,
216 | num_workers=self.num_workers,
217 | pin_memory=True,
218 | drop_last=False,
219 | worker_init_fn=self._dataloader_worker_init,
220 | )
221 | return val_loader
222 |
223 | def test_dataloader(self):
224 | test_loader = DataLoader(
225 | self.lapa_test,
226 | batch_size=self.batch_size,
227 | num_workers=self.num_workers,
228 | pin_memory=True,
229 | drop_last=False,
230 | worker_init_fn=self._dataloader_worker_init,
231 | )
232 | return test_loader
233 |
234 | def get_augs_test(self):
235 | augs_test = A.Compose(
236 | [
237 | # Geometric Augs
238 | A.SmallestMaxSize(max_size=self.resize_h, interpolation=0, p=1.0),
239 | A.CenterCrop(height=self.resize_h, width=self.resize_w, p=1.0),
240 | ]
241 | )
242 | return augs_test
243 |
244 | def get_augs_train(self):
245 | augs_train = A.Compose(
246 | [
247 | # Geometric Augs
248 | A.SmallestMaxSize(max_size=self.resize_h, interpolation=0, p=1.0),
249 | A.CenterCrop(height=self.resize_h, width=self.resize_w, p=1.0),
250 | ]
251 | )
252 | return augs_train
253 |
254 | @staticmethod
255 | def _dataloader_worker_init(*args):
256 | """Seeds the workers within the Dataloader"""
257 | worker_seed = torch.initial_seed() % 2 ** 32
258 | np.random.seed(worker_seed)
259 | random.seed(worker_seed)
260 |
--------------------------------------------------------------------------------
/seg_lapa/loss_func.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class CrossEntropy2D(nn.CrossEntropyLoss):
7 | """Use the torch.nn.CrossEntropyLoss loss to calculate mean loss per image or per pixel.
8 | Deeplab models calculate mean loss per image.
9 |
10 | Inputs:
11 | - inputs (Tensor): Raw output of network (without softmax applied).
12 | - targets (Tensor): Ground truth, containing a int class index in the range :math:`[0, C-1]` as the
13 | `target` for each pixel.
14 |
15 | Shape and dtype:
16 | - Input: [B, C, H, W], where C = num_classes. dtype=float16/32
17 | - Target: [B, H, W]. dtype=int32/64
18 | - Output: scalar.
19 |
20 | Args:
21 | loss_per_image (bool, optional):
22 | ignore_index (int, optional): Defaults to 255. The pixels with this labels do not contribute to loss
23 |
24 | References:
25 | https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html#crossentropyloss
26 | """
27 |
28 | def __init__(self, loss_per_image: bool = True, ignore_index: int = 255):
29 | if loss_per_image:
30 | reduction = "sum"
31 | else:
32 | reduction = "mean"
33 |
34 | super().__init__(reduction=reduction, ignore_index=ignore_index)
35 | self.loss_per_image = loss_per_image
36 |
37 | def forward(self, inputs: torch.Tensor, targets: torch.Tensor):
38 | loss = F.cross_entropy(
39 | inputs, targets, weight=self.weight, ignore_index=self.ignore_index, reduction=self.reduction
40 | )
41 |
42 | if self.loss_per_image:
43 | batch_size = inputs.shape[0]
44 | loss = loss / batch_size
45 |
46 | return loss
47 |
--------------------------------------------------------------------------------
/seg_lapa/metrics.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 |
3 | import torch
4 | from pytorch_lightning import metrics
5 |
6 |
7 | @dataclass
8 | class IouMetric:
9 | iou_per_class: torch.Tensor
10 | miou: torch.Tensor # Mean IoU across all classes
11 | accuracy: torch.Tensor
12 | precision: torch.Tensor
13 | recall: torch.Tensor
14 | specificity: torch.Tensor
15 |
16 |
17 | class Iou(metrics.Metric):
18 | def __init__(self, num_classes: int = 11, normalize: bool = False):
19 | """Calculates the metrics iou, true positives and false positives/negatives for multi-class classification
20 | problems such as semantic segmentation.
21 | Because this is an expensive operation, we do not compute or sync the values per step.
22 |
23 | Forward accepts:
24 |
25 | - ``prediction`` (float or long tensor): ``(N, H, W)``
26 | - ``label`` (long tensor): ``(N, H, W)``
27 |
28 | Note:
29 | This metric produces a dataclass as output, so it can not be directly logged.
30 | """
31 | super().__init__(compute_on_step=False, dist_sync_on_step=False)
32 |
33 | self.num_classes = num_classes
34 | # Metric normally calculated on batch. If true, final metrics (tp, fn, etc) will reflect average values per image
35 | self.normalize = normalize
36 |
37 | self.acc_confusion_matrix = None # The accumulated confusion matrix
38 | self.count_samples = None # Number of samples seen
39 | # Use `add_state()` for attr to track their state and synchronize state across processes
40 | self.add_state(
41 | "acc_confusion_matrix", default=torch.zeros((self.num_classes, self.num_classes)), dist_reduce_fx="sum"
42 | )
43 | self.add_state("count_samples", default=torch.tensor(0), dist_reduce_fx="sum")
44 |
45 | def update(self, prediction: torch.Tensor, label: torch.Tensor):
46 | """Calculate the confusion matrix and accumulate it
47 |
48 | Args:
49 | prediction: Predictions of network (after argmax). Shape: [N, H, W]
50 | label: Ground truth. Each pixel has int value denoting class. Shape: [N, H, W]
51 | """
52 | assert prediction.shape == label.shape
53 | assert len(label.shape) == 3
54 |
55 | num_images = int(label.shape[0])
56 |
57 | label = label.view(-1).long()
58 | prediction = prediction.view(-1).long()
59 |
60 | # Calculate confusion matrix
61 | conf_mat = torch.bincount(self.num_classes * label + prediction, minlength=self.num_classes ** 2)
62 | conf_mat = conf_mat.reshape((self.num_classes, self.num_classes))
63 |
64 | # Accumulate values
65 | self.acc_confusion_matrix += conf_mat
66 | self.count_samples += num_images
67 |
68 | def compute(self):
69 | """Compute the final IoU and other metrics across all samples seen"""
70 | # Normalize the accumulated confusion matrix, if needed
71 | conf_mat = self.acc_confusion_matrix
72 | if self.normalize:
73 | conf_mat = conf_mat / self.count_samples # Get average per image
74 |
75 | # Calculate True Positive (TP), False Positive (FP), False Negative (FN) and True Negative (TN)
76 | tp = conf_mat.diagonal()
77 | fn = conf_mat.sum(dim=0) - tp
78 | fp = conf_mat.sum(dim=1) - tp
79 | total_px = conf_mat.sum()
80 | tn = total_px - (tp + fn + fp)
81 |
82 | # Calculate Intersection over Union (IoU)
83 | eps = 1e-6
84 | iou_per_class = (tp + eps) / (fn + fp + tp + eps) # Use epsilon to avoid zero division errors
85 | iou_per_class[torch.isnan(iou_per_class)] = 0
86 | mean_iou = iou_per_class.mean()
87 |
88 | # Accuracy (what proportion of predictions — both Positive and Negative — were correctly classified?)
89 | accuracy = (tp + tn) / (tp + fp + fn + tn)
90 | accuracy[torch.isnan(accuracy)] = 0
91 |
92 | # Precision (what proportion of predicted Positives is truly Positive?)
93 | precision = tp / (tp + fp)
94 | precision[torch.isnan(precision)] = 0
95 |
96 | # Recall or True Positive Rate (what proportion of actual Positives is correctly classified?)
97 | recall = tp / (tp + fn)
98 | recall[torch.isnan(recall)] = 0
99 |
100 | # Specificity or true negative rate
101 | specificity = tn / (tn + fp)
102 | specificity[torch.isnan(specificity)] = 0
103 |
104 | data_r = IouMetric(
105 | iou_per_class=iou_per_class,
106 | miou=mean_iou,
107 | accuracy=accuracy,
108 | precision=precision,
109 | recall=recall,
110 | specificity=specificity,
111 | )
112 |
113 | return data_r
114 |
115 |
116 | # Tests
117 | def test_iou():
118 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
119 | print(f"Using device: {device}")
120 |
121 | # Create Fake label and prediction
122 | label = torch.zeros((1, 4, 4), dtype=torch.float32, device=device)
123 | pred = torch.zeros((1, 4, 4), dtype=torch.float32, device=device)
124 | label[:, :3, :3] = 1
125 | pred[:, -3:, -3:] = 1
126 | expected_iou = torch.tensor([2.0 / 12, 4.0 / 14], device=device)
127 |
128 | print("Testing IoU metrics", end="")
129 | iou_train = Iou(num_classes=2)
130 | iou_train.to(device)
131 | iou_train(pred, label)
132 | metrics_r = iou_train.compute()
133 | iou_per_class = metrics_r.iou_per_class
134 | assert (iou_per_class - expected_iou).sum() < 1e-6
135 | print(" passed")
136 |
137 |
138 | if __name__ == "__main__":
139 | # Run tests
140 | print("Running tests on metrics module...\n")
141 | test_iou()
142 |
--------------------------------------------------------------------------------
/seg_lapa/networks/deeplab/aspp.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from .sync_batchnorm.batchnorm import SynchronizedBatchNorm2d
6 |
7 |
8 | class _ASPPModule(nn.Module):
9 | def __init__(self, inplanes, planes, kernel_size, padding, dilation, BatchNorm):
10 | super(_ASPPModule, self).__init__()
11 | self.atrous_conv = nn.Conv2d(
12 | inplanes, planes, kernel_size=kernel_size, stride=1, padding=padding, dilation=dilation, bias=False
13 | )
14 | self.bn = BatchNorm(planes)
15 | self.relu = nn.ReLU()
16 |
17 | self._init_weight()
18 |
19 | def forward(self, x):
20 | x = self.atrous_conv(x)
21 | x = self.bn(x)
22 |
23 | return self.relu(x)
24 |
25 | def _init_weight(self):
26 | for m in self.modules():
27 | if isinstance(m, nn.Conv2d):
28 | torch.nn.init.kaiming_normal_(m.weight)
29 | elif isinstance(m, SynchronizedBatchNorm2d):
30 | m.weight.data.fill_(1)
31 | m.bias.data.zero_()
32 | elif isinstance(m, nn.BatchNorm2d):
33 | m.weight.data.fill_(1)
34 | m.bias.data.zero_()
35 |
36 |
37 | class ASPP(nn.Module):
38 | def __init__(self, backbone, output_stride, BatchNorm):
39 | super(ASPP, self).__init__()
40 | if backbone == "drn":
41 | inplanes = 512
42 | elif backbone == "mobilenet":
43 | inplanes = 320
44 | else:
45 | inplanes = 2048
46 | if output_stride == 16:
47 | dilations = [1, 6, 12, 18]
48 | elif output_stride == 8:
49 | dilations = [1, 12, 24, 36]
50 | else:
51 | raise NotImplementedError
52 |
53 | self.aspp1 = _ASPPModule(inplanes, 256, 1, padding=0, dilation=dilations[0], BatchNorm=BatchNorm)
54 | self.aspp2 = _ASPPModule(inplanes, 256, 3, padding=dilations[1], dilation=dilations[1], BatchNorm=BatchNorm)
55 | self.aspp3 = _ASPPModule(inplanes, 256, 3, padding=dilations[2], dilation=dilations[2], BatchNorm=BatchNorm)
56 | self.aspp4 = _ASPPModule(inplanes, 256, 3, padding=dilations[3], dilation=dilations[3], BatchNorm=BatchNorm)
57 |
58 | self.global_avg_pool = nn.Sequential(
59 | nn.AdaptiveAvgPool2d((1, 1)), nn.Conv2d(inplanes, 256, 1, stride=1, bias=False), BatchNorm(256), nn.ReLU()
60 | )
61 | self.conv1 = nn.Conv2d(1280, 256, 1, bias=False)
62 | self.bn1 = BatchNorm(256)
63 | self.relu = nn.ReLU()
64 | self.dropout = nn.Dropout(0.5)
65 | self._init_weight()
66 |
67 | def forward(self, x):
68 | x1 = self.aspp1(x)
69 | x2 = self.aspp2(x)
70 | x3 = self.aspp3(x)
71 | x4 = self.aspp4(x)
72 | x5 = self.global_avg_pool(x)
73 | x5 = F.interpolate(x5, size=x4.size()[2:], mode="bilinear", align_corners=True)
74 | x = torch.cat((x1, x2, x3, x4, x5), dim=1)
75 |
76 | x = self.conv1(x)
77 | x = self.bn1(x)
78 | x = self.relu(x)
79 |
80 | return self.dropout(x)
81 |
82 | def _init_weight(self):
83 | for m in self.modules():
84 | if isinstance(m, nn.Conv2d):
85 | torch.nn.init.kaiming_normal_(m.weight)
86 | elif isinstance(m, SynchronizedBatchNorm2d):
87 | m.weight.data.fill_(1)
88 | m.bias.data.zero_()
89 | elif isinstance(m, nn.BatchNorm2d):
90 | m.weight.data.fill_(1)
91 | m.bias.data.zero_()
92 |
93 |
94 | def build_aspp(backbone, output_stride, BatchNorm):
95 | return ASPP(backbone, output_stride, BatchNorm)
96 |
--------------------------------------------------------------------------------
/seg_lapa/networks/deeplab/backbone/__init__.py:
--------------------------------------------------------------------------------
1 | from . import drn, mobilenet, resnet, xception
2 |
3 |
4 | def build_backbone(backbone, output_stride, BatchNorm):
5 | if backbone == "resnet":
6 | return resnet.ResNet101(output_stride, BatchNorm)
7 | elif backbone == "xception":
8 | return xception.AlignedXception(output_stride, BatchNorm)
9 | elif backbone == "drn":
10 | return drn.drn_d_54(BatchNorm)
11 | elif backbone == "mobilenet":
12 | return mobilenet.MobileNetV2(output_stride, BatchNorm)
13 | else:
14 | raise NotImplementedError
15 |
--------------------------------------------------------------------------------
/seg_lapa/networks/deeplab/backbone/drn.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch.nn as nn
4 | import torch.utils.model_zoo as model_zoo
5 |
6 | from ..sync_batchnorm.batchnorm import SynchronizedBatchNorm2d
7 |
8 | webroot = "http://dl.yf.io/drn/"
9 |
10 | model_urls = {
11 | "resnet50": "https://download.pytorch.org/models/resnet50-19c8e357.pth",
12 | "drn-c-26": webroot + "drn_c_26-ddedf421.pth",
13 | "drn-c-42": webroot + "drn_c_42-9d336e8c.pth",
14 | "drn-c-58": webroot + "drn_c_58-0a53a92c.pth",
15 | "drn-d-22": webroot + "drn_d_22-4bd2f8ea.pth",
16 | "drn-d-38": webroot + "drn_d_38-eebb45f0.pth",
17 | "drn-d-54": webroot + "drn_d_54-0e0534ff.pth",
18 | "drn-d-105": webroot + "drn_d_105-12b40979.pth",
19 | }
20 |
21 |
22 | def conv3x3(in_planes, out_planes, stride=1, padding=1, dilation=1):
23 | return nn.Conv2d(
24 | in_planes, out_planes, kernel_size=3, stride=stride, padding=padding, bias=False, dilation=dilation
25 | )
26 |
27 |
28 | class BasicBlock(nn.Module):
29 | expansion = 1
30 |
31 | def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=(1, 1), residual=True, BatchNorm=None):
32 | super(BasicBlock, self).__init__()
33 | self.conv1 = conv3x3(inplanes, planes, stride, padding=dilation[0], dilation=dilation[0])
34 | self.bn1 = BatchNorm(planes)
35 | self.relu = nn.ReLU(inplace=True)
36 | self.conv2 = conv3x3(planes, planes, padding=dilation[1], dilation=dilation[1])
37 | self.bn2 = BatchNorm(planes)
38 | self.downsample = downsample
39 | self.stride = stride
40 | self.residual = residual
41 |
42 | def forward(self, x):
43 | residual = x
44 |
45 | out = self.conv1(x)
46 | out = self.bn1(out)
47 | out = self.relu(out)
48 |
49 | out = self.conv2(out)
50 | out = self.bn2(out)
51 |
52 | if self.downsample is not None:
53 | residual = self.downsample(x)
54 | if self.residual:
55 | out += residual
56 | out = self.relu(out)
57 |
58 | return out
59 |
60 |
61 | class Bottleneck(nn.Module):
62 | expansion = 4
63 |
64 | def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=(1, 1), residual=True, BatchNorm=None):
65 | super(Bottleneck, self).__init__()
66 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
67 | self.bn1 = BatchNorm(planes)
68 | self.conv2 = nn.Conv2d(
69 | planes, planes, kernel_size=3, stride=stride, padding=dilation[1], bias=False, dilation=dilation[1]
70 | )
71 | self.bn2 = BatchNorm(planes)
72 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
73 | self.bn3 = BatchNorm(planes * 4)
74 | self.relu = nn.ReLU(inplace=True)
75 | self.downsample = downsample
76 | self.stride = stride
77 |
78 | def forward(self, x):
79 | residual = x
80 |
81 | out = self.conv1(x)
82 | out = self.bn1(out)
83 | out = self.relu(out)
84 |
85 | out = self.conv2(out)
86 | out = self.bn2(out)
87 | out = self.relu(out)
88 |
89 | out = self.conv3(out)
90 | out = self.bn3(out)
91 |
92 | if self.downsample is not None:
93 | residual = self.downsample(x)
94 |
95 | out += residual
96 | out = self.relu(out)
97 |
98 | return out
99 |
100 |
101 | class DRN(nn.Module):
102 | def __init__(self, block, layers, arch="D", channels=(16, 32, 64, 128, 256, 512, 512, 512), BatchNorm=None):
103 | super(DRN, self).__init__()
104 | self.inplanes = channels[0]
105 | self.out_dim = channels[-1]
106 | self.arch = arch
107 |
108 | if arch == "C":
109 | self.conv1 = nn.Conv2d(3, channels[0], kernel_size=7, stride=1, padding=3, bias=False)
110 | self.bn1 = BatchNorm(channels[0])
111 | self.relu = nn.ReLU(inplace=True)
112 |
113 | self.layer1 = self._make_layer(BasicBlock, channels[0], layers[0], stride=1, BatchNorm=BatchNorm)
114 | self.layer2 = self._make_layer(BasicBlock, channels[1], layers[1], stride=2, BatchNorm=BatchNorm)
115 |
116 | elif arch == "D":
117 | self.layer0 = nn.Sequential(
118 | nn.Conv2d(3, channels[0], kernel_size=7, stride=1, padding=3, bias=False),
119 | BatchNorm(channels[0]),
120 | nn.ReLU(inplace=True),
121 | )
122 |
123 | self.layer1 = self._make_conv_layers(channels[0], layers[0], stride=1, BatchNorm=BatchNorm)
124 | self.layer2 = self._make_conv_layers(channels[1], layers[1], stride=2, BatchNorm=BatchNorm)
125 |
126 | self.layer3 = self._make_layer(block, channels[2], layers[2], stride=2, BatchNorm=BatchNorm)
127 | self.layer4 = self._make_layer(block, channels[3], layers[3], stride=2, BatchNorm=BatchNorm)
128 | self.layer5 = self._make_layer(block, channels[4], layers[4], dilation=2, new_level=False, BatchNorm=BatchNorm)
129 | self.layer6 = (
130 | None
131 | if layers[5] == 0
132 | else self._make_layer(block, channels[5], layers[5], dilation=4, new_level=False, BatchNorm=BatchNorm)
133 | )
134 |
135 | if arch == "C":
136 | self.layer7 = (
137 | None
138 | if layers[6] == 0
139 | else self._make_layer(
140 | BasicBlock, channels[6], layers[6], dilation=2, new_level=False, residual=False, BatchNorm=BatchNorm
141 | )
142 | )
143 | self.layer8 = (
144 | None
145 | if layers[7] == 0
146 | else self._make_layer(
147 | BasicBlock, channels[7], layers[7], dilation=1, new_level=False, residual=False, BatchNorm=BatchNorm
148 | )
149 | )
150 | elif arch == "D":
151 | self.layer7 = (
152 | None
153 | if layers[6] == 0
154 | else self._make_conv_layers(channels[6], layers[6], dilation=2, BatchNorm=BatchNorm)
155 | )
156 | self.layer8 = (
157 | None
158 | if layers[7] == 0
159 | else self._make_conv_layers(channels[7], layers[7], dilation=1, BatchNorm=BatchNorm)
160 | )
161 |
162 | self._init_weight()
163 |
164 | def _init_weight(self):
165 | for m in self.modules():
166 | if isinstance(m, nn.Conv2d):
167 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
168 | m.weight.data.normal_(0, math.sqrt(2.0 / n))
169 | elif isinstance(m, SynchronizedBatchNorm2d):
170 | m.weight.data.fill_(1)
171 | m.bias.data.zero_()
172 | elif isinstance(m, nn.BatchNorm2d):
173 | m.weight.data.fill_(1)
174 | m.bias.data.zero_()
175 |
176 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, new_level=True, residual=True, BatchNorm=None):
177 | assert dilation == 1 or dilation % 2 == 0
178 | downsample = None
179 | if stride != 1 or self.inplanes != planes * block.expansion:
180 | downsample = nn.Sequential(
181 | nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
182 | BatchNorm(planes * block.expansion),
183 | )
184 |
185 | layers = list()
186 | layers.append(
187 | block(
188 | self.inplanes,
189 | planes,
190 | stride,
191 | downsample,
192 | dilation=(1, 1) if dilation == 1 else (dilation // 2 if new_level else dilation, dilation),
193 | residual=residual,
194 | BatchNorm=BatchNorm,
195 | )
196 | )
197 | self.inplanes = planes * block.expansion
198 | for i in range(1, blocks):
199 | layers.append(
200 | block(self.inplanes, planes, residual=residual, dilation=(dilation, dilation), BatchNorm=BatchNorm)
201 | )
202 |
203 | return nn.Sequential(*layers)
204 |
205 | def _make_conv_layers(self, channels, convs, stride=1, dilation=1, BatchNorm=None):
206 | modules = []
207 | for i in range(convs):
208 | modules.extend(
209 | [
210 | nn.Conv2d(
211 | self.inplanes,
212 | channels,
213 | kernel_size=3,
214 | stride=stride if i == 0 else 1,
215 | padding=dilation,
216 | bias=False,
217 | dilation=dilation,
218 | ),
219 | BatchNorm(channels),
220 | nn.ReLU(inplace=True),
221 | ]
222 | )
223 | self.inplanes = channels
224 | return nn.Sequential(*modules)
225 |
226 | def forward(self, x):
227 | if self.arch == "C":
228 | x = self.conv1(x)
229 | x = self.bn1(x)
230 | x = self.relu(x)
231 | elif self.arch == "D":
232 | x = self.layer0(x)
233 |
234 | x = self.layer1(x)
235 | x = self.layer2(x)
236 |
237 | x = self.layer3(x)
238 | low_level_feat = x
239 |
240 | x = self.layer4(x)
241 | x = self.layer5(x)
242 |
243 | if self.layer6 is not None:
244 | x = self.layer6(x)
245 |
246 | if self.layer7 is not None:
247 | x = self.layer7(x)
248 |
249 | if self.layer8 is not None:
250 | x = self.layer8(x)
251 |
252 | return x, low_level_feat
253 |
254 |
255 | class DRN_A(nn.Module):
256 | def __init__(self, block, layers, BatchNorm=None):
257 | self.inplanes = 64
258 | super(DRN_A, self).__init__()
259 | self.out_dim = 512 * block.expansion
260 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
261 | self.bn1 = BatchNorm(64)
262 | self.relu = nn.ReLU(inplace=True)
263 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
264 | self.layer1 = self._make_layer(block, 64, layers[0], BatchNorm=BatchNorm)
265 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, BatchNorm=BatchNorm)
266 | self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2, BatchNorm=BatchNorm)
267 | self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4, BatchNorm=BatchNorm)
268 |
269 | self._init_weight()
270 |
271 | def _init_weight(self):
272 | for m in self.modules():
273 | if isinstance(m, nn.Conv2d):
274 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
275 | m.weight.data.normal_(0, math.sqrt(2.0 / n))
276 | elif isinstance(m, SynchronizedBatchNorm2d):
277 | m.weight.data.fill_(1)
278 | m.bias.data.zero_()
279 | elif isinstance(m, nn.BatchNorm2d):
280 | m.weight.data.fill_(1)
281 | m.bias.data.zero_()
282 |
283 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None):
284 | downsample = None
285 | if stride != 1 or self.inplanes != planes * block.expansion:
286 | downsample = nn.Sequential(
287 | nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
288 | BatchNorm(planes * block.expansion),
289 | )
290 |
291 | layers = []
292 | layers.append(block(self.inplanes, planes, stride, downsample, BatchNorm=BatchNorm))
293 | self.inplanes = planes * block.expansion
294 | for i in range(1, blocks):
295 | layers.append(
296 | block(
297 | self.inplanes,
298 | planes,
299 | dilation=(
300 | dilation,
301 | dilation,
302 | ),
303 | BatchNorm=BatchNorm,
304 | )
305 | )
306 |
307 | return nn.Sequential(*layers)
308 |
309 | def forward(self, x):
310 | x = self.conv1(x)
311 | x = self.bn1(x)
312 | x = self.relu(x)
313 | x = self.maxpool(x)
314 |
315 | x = self.layer1(x)
316 | x = self.layer2(x)
317 | x = self.layer3(x)
318 | x = self.layer4(x)
319 |
320 | return x
321 |
322 |
323 | def drn_a_50(BatchNorm, pretrained=True):
324 | model = DRN_A(Bottleneck, [3, 4, 6, 3], BatchNorm=BatchNorm)
325 | if pretrained:
326 | model.load_state_dict(model_zoo.load_url(model_urls["resnet50"]))
327 | return model
328 |
329 |
330 | def drn_c_26(BatchNorm, pretrained=True):
331 | model = DRN(BasicBlock, [1, 1, 2, 2, 2, 2, 1, 1], arch="C", BatchNorm=BatchNorm)
332 | if pretrained:
333 | pretrained = model_zoo.load_url(model_urls["drn-c-26"])
334 | del pretrained["fc.weight"]
335 | del pretrained["fc.bias"]
336 | model.load_state_dict(pretrained)
337 | return model
338 |
339 |
340 | def drn_c_42(BatchNorm, pretrained=True):
341 | model = DRN(BasicBlock, [1, 1, 3, 4, 6, 3, 1, 1], arch="C", BatchNorm=BatchNorm)
342 | if pretrained:
343 | pretrained = model_zoo.load_url(model_urls["drn-c-42"])
344 | del pretrained["fc.weight"]
345 | del pretrained["fc.bias"]
346 | model.load_state_dict(pretrained)
347 | return model
348 |
349 |
350 | def drn_c_58(BatchNorm, pretrained=True):
351 | model = DRN(Bottleneck, [1, 1, 3, 4, 6, 3, 1, 1], arch="C", BatchNorm=BatchNorm)
352 | if pretrained:
353 | pretrained = model_zoo.load_url(model_urls["drn-c-58"])
354 | del pretrained["fc.weight"]
355 | del pretrained["fc.bias"]
356 | model.load_state_dict(pretrained)
357 | return model
358 |
359 |
360 | def drn_d_22(BatchNorm, pretrained=True):
361 | model = DRN(BasicBlock, [1, 1, 2, 2, 2, 2, 1, 1], arch="D", BatchNorm=BatchNorm)
362 | if pretrained:
363 | pretrained = model_zoo.load_url(model_urls["drn-d-22"])
364 | del pretrained["fc.weight"]
365 | del pretrained["fc.bias"]
366 | model.load_state_dict(pretrained)
367 | return model
368 |
369 |
370 | def drn_d_24(BatchNorm, pretrained=True):
371 | model = DRN(BasicBlock, [1, 1, 2, 2, 2, 2, 2, 2], arch="D", BatchNorm=BatchNorm)
372 | if pretrained:
373 | pretrained = model_zoo.load_url(model_urls["drn-d-24"])
374 | del pretrained["fc.weight"]
375 | del pretrained["fc.bias"]
376 | model.load_state_dict(pretrained)
377 | return model
378 |
379 |
380 | def drn_d_38(BatchNorm, pretrained=True):
381 | model = DRN(BasicBlock, [1, 1, 3, 4, 6, 3, 1, 1], arch="D", BatchNorm=BatchNorm)
382 | if pretrained:
383 | pretrained = model_zoo.load_url(model_urls["drn-d-38"])
384 | del pretrained["fc.weight"]
385 | del pretrained["fc.bias"]
386 | model.load_state_dict(pretrained)
387 | return model
388 |
389 |
390 | def drn_d_40(BatchNorm, pretrained=True):
391 | model = DRN(BasicBlock, [1, 1, 3, 4, 6, 3, 2, 2], arch="D", BatchNorm=BatchNorm)
392 | if pretrained:
393 | pretrained = model_zoo.load_url(model_urls["drn-d-40"])
394 | del pretrained["fc.weight"]
395 | del pretrained["fc.bias"]
396 | model.load_state_dict(pretrained)
397 | return model
398 |
399 |
400 | def drn_d_54(BatchNorm, pretrained=True):
401 | model = DRN(Bottleneck, [1, 1, 3, 4, 6, 3, 1, 1], arch="D", BatchNorm=BatchNorm)
402 | if pretrained:
403 | pretrained = model_zoo.load_url(model_urls["drn-d-54"])
404 | del pretrained["fc.weight"]
405 | del pretrained["fc.bias"]
406 | model.load_state_dict(pretrained)
407 | return model
408 |
409 |
410 | def drn_d_105(BatchNorm, pretrained=True):
411 | model = DRN(Bottleneck, [1, 1, 3, 4, 23, 3, 1, 1], arch="D", BatchNorm=BatchNorm)
412 | if pretrained:
413 | pretrained = model_zoo.load_url(model_urls["drn-d-105"])
414 | del pretrained["fc.weight"]
415 | del pretrained["fc.bias"]
416 | model.load_state_dict(pretrained)
417 | return model
418 |
419 |
420 | if __name__ == "__main__":
421 | import torch
422 |
423 | model = drn_a_50(BatchNorm=nn.BatchNorm2d, pretrained=True)
424 | input = torch.rand(1, 3, 512, 512)
425 | output, low_level_feat = model(input)
426 | print(output.size())
427 | print(low_level_feat.size())
428 |
--------------------------------------------------------------------------------
/seg_lapa/networks/deeplab/backbone/mobilenet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import torch.utils.model_zoo as model_zoo
5 |
6 | from ..sync_batchnorm.batchnorm import SynchronizedBatchNorm2d
7 |
8 |
9 | def conv_bn(inp, oup, stride, BatchNorm):
10 | return nn.Sequential(nn.Conv2d(inp, oup, 3, stride, 1, bias=False), BatchNorm(oup), nn.ReLU6(inplace=True))
11 |
12 |
13 | def fixed_padding(inputs, kernel_size, dilation):
14 | kernel_size_effective = kernel_size + (kernel_size - 1) * (dilation - 1)
15 | pad_total = kernel_size_effective - 1
16 | pad_beg = pad_total // 2
17 | pad_end = pad_total - pad_beg
18 | padded_inputs = F.pad(inputs, (pad_beg, pad_end, pad_beg, pad_end))
19 | return padded_inputs
20 |
21 |
22 | class InvertedResidual(nn.Module):
23 | def __init__(self, inp, oup, stride, dilation, expand_ratio, BatchNorm):
24 | super(InvertedResidual, self).__init__()
25 | self.stride = stride
26 | assert stride in [1, 2]
27 |
28 | hidden_dim = round(inp * expand_ratio)
29 | self.use_res_connect = self.stride == 1 and inp == oup
30 | self.kernel_size = 3
31 | self.dilation = dilation
32 |
33 | if expand_ratio == 1:
34 | self.conv = nn.Sequential(
35 | # dw
36 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 0, dilation, groups=hidden_dim, bias=False),
37 | BatchNorm(hidden_dim),
38 | nn.ReLU6(inplace=True),
39 | # pw-linear
40 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, 1, 1, bias=False),
41 | BatchNorm(oup),
42 | )
43 | else:
44 | self.conv = nn.Sequential(
45 | # pw
46 | nn.Conv2d(inp, hidden_dim, 1, 1, 0, 1, bias=False),
47 | BatchNorm(hidden_dim),
48 | nn.ReLU6(inplace=True),
49 | # dw
50 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 0, dilation, groups=hidden_dim, bias=False),
51 | BatchNorm(hidden_dim),
52 | nn.ReLU6(inplace=True),
53 | # pw-linear
54 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, 1, bias=False),
55 | BatchNorm(oup),
56 | )
57 |
58 | def forward(self, x):
59 | x_pad = fixed_padding(x, self.kernel_size, dilation=self.dilation)
60 | if self.use_res_connect:
61 | x = x + self.conv(x_pad)
62 | else:
63 | x = self.conv(x_pad)
64 | return x
65 |
66 |
67 | class MobileNetV2(nn.Module):
68 | def __init__(self, output_stride=8, BatchNorm=None, width_mult=1.0, pretrained=True):
69 | super(MobileNetV2, self).__init__()
70 | block = InvertedResidual
71 | input_channel = 32
72 | current_stride = 1
73 | rate = 1
74 | interverted_residual_setting = [
75 | # t, c, n, s
76 | [1, 16, 1, 1],
77 | [6, 24, 2, 2],
78 | [6, 32, 3, 2],
79 | [6, 64, 4, 2],
80 | [6, 96, 3, 1],
81 | [6, 160, 3, 2],
82 | [6, 320, 1, 1],
83 | ]
84 |
85 | # building first layer
86 | input_channel = int(input_channel * width_mult)
87 | self.features = [conv_bn(3, input_channel, 2, BatchNorm)]
88 | current_stride *= 2
89 | # building inverted residual blocks
90 | for t, c, n, s in interverted_residual_setting:
91 | if current_stride == output_stride:
92 | stride = 1
93 | dilation = rate
94 | rate *= s
95 | else:
96 | stride = s
97 | dilation = 1
98 | current_stride *= s
99 | output_channel = int(c * width_mult)
100 | for i in range(n):
101 | if i == 0:
102 | self.features.append(block(input_channel, output_channel, stride, dilation, t, BatchNorm))
103 | else:
104 | self.features.append(block(input_channel, output_channel, 1, dilation, t, BatchNorm))
105 | input_channel = output_channel
106 | self.features = nn.Sequential(*self.features)
107 | self._initialize_weights()
108 |
109 | if pretrained:
110 | self._load_pretrained_model()
111 |
112 | self.low_level_features = self.features[0:4]
113 | self.high_level_features = self.features[4:]
114 |
115 | def forward(self, x):
116 | low_level_feat = self.low_level_features(x)
117 | x = self.high_level_features(low_level_feat)
118 | return x, low_level_feat
119 |
120 | def _load_pretrained_model(self):
121 | pretrain_dict = model_zoo.load_url("http://jeff95.me/models/mobilenet_v2-6a65762b.pth")
122 | model_dict = {}
123 | state_dict = self.state_dict()
124 | for k, v in pretrain_dict.items():
125 | if k in state_dict:
126 | model_dict[k] = v
127 | state_dict.update(model_dict)
128 | self.load_state_dict(state_dict)
129 |
130 | def _initialize_weights(self):
131 | for m in self.modules():
132 | if isinstance(m, nn.Conv2d):
133 | # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
134 | # m.weight.data.normal_(0, math.sqrt(2. / n))
135 | torch.nn.init.kaiming_normal_(m.weight)
136 | elif isinstance(m, SynchronizedBatchNorm2d):
137 | m.weight.data.fill_(1)
138 | m.bias.data.zero_()
139 | elif isinstance(m, nn.BatchNorm2d):
140 | m.weight.data.fill_(1)
141 | m.bias.data.zero_()
142 |
143 |
144 | if __name__ == "__main__":
145 | input = torch.rand(1, 3, 512, 512)
146 | model = MobileNetV2(output_stride=16, BatchNorm=nn.BatchNorm2d)
147 | output, low_level_feat = model(input)
148 | print(output.size())
149 | print(low_level_feat.size())
150 |
--------------------------------------------------------------------------------
/seg_lapa/networks/deeplab/backbone/resnet.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch.nn as nn
4 | import torch.utils.model_zoo as model_zoo
5 |
6 | from ..sync_batchnorm.batchnorm import SynchronizedBatchNorm2d
7 |
8 |
9 | class Bottleneck(nn.Module):
10 | expansion = 4
11 |
12 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, BatchNorm=None):
13 | super(Bottleneck, self).__init__()
14 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
15 | self.bn1 = BatchNorm(planes)
16 | self.conv2 = nn.Conv2d(
17 | planes, planes, kernel_size=3, stride=stride, dilation=dilation, padding=dilation, bias=False
18 | )
19 | self.bn2 = BatchNorm(planes)
20 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
21 | self.bn3 = BatchNorm(planes * 4)
22 | self.relu = nn.ReLU(inplace=True)
23 | self.downsample = downsample
24 | self.stride = stride
25 | self.dilation = dilation
26 |
27 | def forward(self, x):
28 | residual = x
29 |
30 | out = self.conv1(x)
31 | out = self.bn1(out)
32 | out = self.relu(out)
33 |
34 | out = self.conv2(out)
35 | out = self.bn2(out)
36 | out = self.relu(out)
37 |
38 | out = self.conv3(out)
39 | out = self.bn3(out)
40 |
41 | if self.downsample is not None:
42 | residual = self.downsample(x)
43 |
44 | out += residual
45 | out = self.relu(out)
46 |
47 | return out
48 |
49 |
50 | class ResNet(nn.Module):
51 | def __init__(self, block, layers, output_stride, BatchNorm, pretrained=True):
52 | self.inplanes = 64
53 | super(ResNet, self).__init__()
54 | blocks = [1, 2, 4]
55 | if output_stride == 16:
56 | strides = [1, 2, 2, 1]
57 | dilations = [1, 1, 1, 2]
58 | elif output_stride == 8:
59 | strides = [1, 2, 1, 1]
60 | dilations = [1, 1, 2, 4]
61 | else:
62 | raise NotImplementedError
63 |
64 | # Modules
65 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
66 | self.bn1 = BatchNorm(64)
67 | self.relu = nn.ReLU(inplace=True)
68 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
69 |
70 | self.layer1 = self._make_layer(
71 | block, 64, layers[0], stride=strides[0], dilation=dilations[0], BatchNorm=BatchNorm
72 | )
73 | self.layer2 = self._make_layer(
74 | block, 128, layers[1], stride=strides[1], dilation=dilations[1], BatchNorm=BatchNorm
75 | )
76 | self.layer3 = self._make_layer(
77 | block, 256, layers[2], stride=strides[2], dilation=dilations[2], BatchNorm=BatchNorm
78 | )
79 | self.layer4 = self._make_MG_unit(
80 | block, 512, blocks=blocks, stride=strides[3], dilation=dilations[3], BatchNorm=BatchNorm
81 | )
82 | # self.layer4 = self._make_layer(block, 512, layers[3], stride=strides[3], dilation=dilations[3], BatchNorm=BatchNorm)
83 | self._init_weight()
84 |
85 | if pretrained:
86 | self._load_pretrained_model()
87 |
88 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None):
89 | downsample = None
90 | if stride != 1 or self.inplanes != planes * block.expansion:
91 | downsample = nn.Sequential(
92 | nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
93 | BatchNorm(planes * block.expansion),
94 | )
95 |
96 | layers = []
97 | layers.append(block(self.inplanes, planes, stride, dilation, downsample, BatchNorm))
98 | self.inplanes = planes * block.expansion
99 | for i in range(1, blocks):
100 | layers.append(block(self.inplanes, planes, dilation=dilation, BatchNorm=BatchNorm))
101 |
102 | return nn.Sequential(*layers)
103 |
104 | def _make_MG_unit(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None):
105 | downsample = None
106 | if stride != 1 or self.inplanes != planes * block.expansion:
107 | downsample = nn.Sequential(
108 | nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
109 | BatchNorm(planes * block.expansion),
110 | )
111 |
112 | layers = []
113 | layers.append(
114 | block(
115 | self.inplanes, planes, stride, dilation=blocks[0] * dilation, downsample=downsample, BatchNorm=BatchNorm
116 | )
117 | )
118 | self.inplanes = planes * block.expansion
119 | for i in range(1, len(blocks)):
120 | layers.append(block(self.inplanes, planes, stride=1, dilation=blocks[i] * dilation, BatchNorm=BatchNorm))
121 |
122 | return nn.Sequential(*layers)
123 |
124 | def forward(self, input):
125 | x = self.conv1(input)
126 | x = self.bn1(x)
127 | x = self.relu(x)
128 | x = self.maxpool(x)
129 |
130 | x = self.layer1(x)
131 | low_level_feat = x
132 | x = self.layer2(x)
133 | x = self.layer3(x)
134 | x = self.layer4(x)
135 | return x, low_level_feat
136 |
137 | def _init_weight(self):
138 | for m in self.modules():
139 | if isinstance(m, nn.Conv2d):
140 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
141 | m.weight.data.normal_(0, math.sqrt(2.0 / n))
142 | elif isinstance(m, SynchronizedBatchNorm2d):
143 | m.weight.data.fill_(1)
144 | m.bias.data.zero_()
145 | elif isinstance(m, nn.BatchNorm2d):
146 | m.weight.data.fill_(1)
147 | m.bias.data.zero_()
148 |
149 | def _load_pretrained_model(self):
150 | pretrain_dict = model_zoo.load_url("https://download.pytorch.org/models/resnet101-5d3b4d8f.pth")
151 | model_dict = {}
152 | state_dict = self.state_dict()
153 | for k, v in pretrain_dict.items():
154 | if k in state_dict:
155 | model_dict[k] = v
156 | state_dict.update(model_dict)
157 | self.load_state_dict(state_dict)
158 |
159 |
160 | def ResNet101(output_stride, BatchNorm, pretrained=True):
161 | """Constructs a ResNet-101 model.
162 | Args:
163 | pretrained (bool): If True, returns a model pre-trained on ImageNet
164 | """
165 | model = ResNet(Bottleneck, [3, 4, 23, 3], output_stride, BatchNorm, pretrained=pretrained)
166 | return model
167 |
168 |
169 | if __name__ == "__main__":
170 | import torch
171 |
172 | model = ResNet101(BatchNorm=nn.BatchNorm2d, pretrained=True, output_stride=8)
173 | input = torch.rand(1, 3, 512, 512)
174 | output, low_level_feat = model(input)
175 | print(output.size())
176 | print(low_level_feat.size())
177 |
--------------------------------------------------------------------------------
/seg_lapa/networks/deeplab/backbone/xception.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import torch.utils.model_zoo as model_zoo
6 |
7 | from ..sync_batchnorm.batchnorm import SynchronizedBatchNorm2d
8 |
9 |
10 | def fixed_padding(inputs, kernel_size, dilation):
11 | kernel_size_effective = kernel_size + (kernel_size - 1) * (dilation - 1)
12 | pad_total = kernel_size_effective - 1
13 | pad_beg = pad_total // 2
14 | pad_end = pad_total - pad_beg
15 | padded_inputs = F.pad(inputs, (pad_beg, pad_end, pad_beg, pad_end))
16 | return padded_inputs
17 |
18 |
19 | class SeparableConv2d(nn.Module):
20 | def __init__(self, inplanes, planes, kernel_size=3, stride=1, dilation=1, bias=False, BatchNorm=None):
21 | super(SeparableConv2d, self).__init__()
22 |
23 | self.conv1 = nn.Conv2d(inplanes, inplanes, kernel_size, stride, 0, dilation, groups=inplanes, bias=bias)
24 | self.bn = BatchNorm(inplanes)
25 | self.pointwise = nn.Conv2d(inplanes, planes, 1, 1, 0, 1, 1, bias=bias)
26 |
27 | def forward(self, x):
28 | x = fixed_padding(x, self.conv1.kernel_size[0], dilation=self.conv1.dilation[0])
29 | x = self.conv1(x)
30 | x = self.bn(x)
31 | x = self.pointwise(x)
32 | return x
33 |
34 |
35 | class Block(nn.Module):
36 | def __init__(
37 | self,
38 | inplanes,
39 | planes,
40 | reps,
41 | stride=1,
42 | dilation=1,
43 | BatchNorm=None,
44 | start_with_relu=True,
45 | grow_first=True,
46 | is_last=False,
47 | ):
48 | super(Block, self).__init__()
49 |
50 | if planes != inplanes or stride != 1:
51 | self.skip = nn.Conv2d(inplanes, planes, 1, stride=stride, bias=False)
52 | self.skipbn = BatchNorm(planes)
53 | else:
54 | self.skip = None
55 |
56 | self.relu = nn.ReLU(inplace=True)
57 | rep = []
58 |
59 | filters = inplanes
60 | if grow_first:
61 | rep.append(self.relu)
62 | rep.append(SeparableConv2d(inplanes, planes, 3, 1, dilation, BatchNorm=BatchNorm))
63 | rep.append(BatchNorm(planes))
64 | filters = planes
65 |
66 | for i in range(reps - 1):
67 | rep.append(self.relu)
68 | rep.append(SeparableConv2d(filters, filters, 3, 1, dilation, BatchNorm=BatchNorm))
69 | rep.append(BatchNorm(filters))
70 |
71 | if not grow_first:
72 | rep.append(self.relu)
73 | rep.append(SeparableConv2d(inplanes, planes, 3, 1, dilation, BatchNorm=BatchNorm))
74 | rep.append(BatchNorm(planes))
75 |
76 | if stride != 1:
77 | rep.append(self.relu)
78 | rep.append(SeparableConv2d(planes, planes, 3, 2, BatchNorm=BatchNorm))
79 | rep.append(BatchNorm(planes))
80 |
81 | if stride == 1 and is_last:
82 | rep.append(self.relu)
83 | rep.append(SeparableConv2d(planes, planes, 3, 1, BatchNorm=BatchNorm))
84 | rep.append(BatchNorm(planes))
85 |
86 | if not start_with_relu:
87 | rep = rep[1:]
88 |
89 | self.rep = nn.Sequential(*rep)
90 |
91 | def forward(self, inp):
92 | x = self.rep(inp)
93 |
94 | if self.skip is not None:
95 | skip = self.skip(inp)
96 | skip = self.skipbn(skip)
97 | else:
98 | skip = inp
99 |
100 | x = x + skip
101 |
102 | return x
103 |
104 |
105 | class AlignedXception(nn.Module):
106 | """
107 | Modified Alighed Xception
108 | """
109 |
110 | def __init__(self, output_stride, BatchNorm, pretrained=True):
111 | super(AlignedXception, self).__init__()
112 |
113 | if output_stride == 16:
114 | entry_block3_stride = 2
115 | middle_block_dilation = 1
116 | exit_block_dilations = (1, 2)
117 | elif output_stride == 8:
118 | entry_block3_stride = 1
119 | middle_block_dilation = 2
120 | exit_block_dilations = (2, 4)
121 | else:
122 | raise NotImplementedError
123 |
124 | # Entry flow
125 | self.conv1 = nn.Conv2d(3, 32, 3, stride=2, padding=1, bias=False)
126 | self.bn1 = BatchNorm(32)
127 | self.relu = nn.ReLU(inplace=True)
128 |
129 | self.conv2 = nn.Conv2d(32, 64, 3, stride=1, padding=1, bias=False)
130 | self.bn2 = BatchNorm(64)
131 |
132 | self.block1 = Block(64, 128, reps=2, stride=2, BatchNorm=BatchNorm, start_with_relu=False)
133 | self.block2 = Block(128, 256, reps=2, stride=2, BatchNorm=BatchNorm, start_with_relu=False, grow_first=True)
134 | self.block3 = Block(
135 | 256,
136 | 728,
137 | reps=2,
138 | stride=entry_block3_stride,
139 | BatchNorm=BatchNorm,
140 | start_with_relu=True,
141 | grow_first=True,
142 | is_last=True,
143 | )
144 |
145 | # Middle flow
146 | self.block4 = Block(
147 | 728,
148 | 728,
149 | reps=3,
150 | stride=1,
151 | dilation=middle_block_dilation,
152 | BatchNorm=BatchNorm,
153 | start_with_relu=True,
154 | grow_first=True,
155 | )
156 | self.block5 = Block(
157 | 728,
158 | 728,
159 | reps=3,
160 | stride=1,
161 | dilation=middle_block_dilation,
162 | BatchNorm=BatchNorm,
163 | start_with_relu=True,
164 | grow_first=True,
165 | )
166 | self.block6 = Block(
167 | 728,
168 | 728,
169 | reps=3,
170 | stride=1,
171 | dilation=middle_block_dilation,
172 | BatchNorm=BatchNorm,
173 | start_with_relu=True,
174 | grow_first=True,
175 | )
176 | self.block7 = Block(
177 | 728,
178 | 728,
179 | reps=3,
180 | stride=1,
181 | dilation=middle_block_dilation,
182 | BatchNorm=BatchNorm,
183 | start_with_relu=True,
184 | grow_first=True,
185 | )
186 | self.block8 = Block(
187 | 728,
188 | 728,
189 | reps=3,
190 | stride=1,
191 | dilation=middle_block_dilation,
192 | BatchNorm=BatchNorm,
193 | start_with_relu=True,
194 | grow_first=True,
195 | )
196 | self.block9 = Block(
197 | 728,
198 | 728,
199 | reps=3,
200 | stride=1,
201 | dilation=middle_block_dilation,
202 | BatchNorm=BatchNorm,
203 | start_with_relu=True,
204 | grow_first=True,
205 | )
206 | self.block10 = Block(
207 | 728,
208 | 728,
209 | reps=3,
210 | stride=1,
211 | dilation=middle_block_dilation,
212 | BatchNorm=BatchNorm,
213 | start_with_relu=True,
214 | grow_first=True,
215 | )
216 | self.block11 = Block(
217 | 728,
218 | 728,
219 | reps=3,
220 | stride=1,
221 | dilation=middle_block_dilation,
222 | BatchNorm=BatchNorm,
223 | start_with_relu=True,
224 | grow_first=True,
225 | )
226 | self.block12 = Block(
227 | 728,
228 | 728,
229 | reps=3,
230 | stride=1,
231 | dilation=middle_block_dilation,
232 | BatchNorm=BatchNorm,
233 | start_with_relu=True,
234 | grow_first=True,
235 | )
236 | self.block13 = Block(
237 | 728,
238 | 728,
239 | reps=3,
240 | stride=1,
241 | dilation=middle_block_dilation,
242 | BatchNorm=BatchNorm,
243 | start_with_relu=True,
244 | grow_first=True,
245 | )
246 | self.block14 = Block(
247 | 728,
248 | 728,
249 | reps=3,
250 | stride=1,
251 | dilation=middle_block_dilation,
252 | BatchNorm=BatchNorm,
253 | start_with_relu=True,
254 | grow_first=True,
255 | )
256 | self.block15 = Block(
257 | 728,
258 | 728,
259 | reps=3,
260 | stride=1,
261 | dilation=middle_block_dilation,
262 | BatchNorm=BatchNorm,
263 | start_with_relu=True,
264 | grow_first=True,
265 | )
266 | self.block16 = Block(
267 | 728,
268 | 728,
269 | reps=3,
270 | stride=1,
271 | dilation=middle_block_dilation,
272 | BatchNorm=BatchNorm,
273 | start_with_relu=True,
274 | grow_first=True,
275 | )
276 | self.block17 = Block(
277 | 728,
278 | 728,
279 | reps=3,
280 | stride=1,
281 | dilation=middle_block_dilation,
282 | BatchNorm=BatchNorm,
283 | start_with_relu=True,
284 | grow_first=True,
285 | )
286 | self.block18 = Block(
287 | 728,
288 | 728,
289 | reps=3,
290 | stride=1,
291 | dilation=middle_block_dilation,
292 | BatchNorm=BatchNorm,
293 | start_with_relu=True,
294 | grow_first=True,
295 | )
296 | self.block19 = Block(
297 | 728,
298 | 728,
299 | reps=3,
300 | stride=1,
301 | dilation=middle_block_dilation,
302 | BatchNorm=BatchNorm,
303 | start_with_relu=True,
304 | grow_first=True,
305 | )
306 |
307 | # Exit flow
308 | self.block20 = Block(
309 | 728,
310 | 1024,
311 | reps=2,
312 | stride=1,
313 | dilation=exit_block_dilations[0],
314 | BatchNorm=BatchNorm,
315 | start_with_relu=True,
316 | grow_first=False,
317 | is_last=True,
318 | )
319 |
320 | self.conv3 = SeparableConv2d(1024, 1536, 3, stride=1, dilation=exit_block_dilations[1], BatchNorm=BatchNorm)
321 | self.bn3 = BatchNorm(1536)
322 |
323 | self.conv4 = SeparableConv2d(1536, 1536, 3, stride=1, dilation=exit_block_dilations[1], BatchNorm=BatchNorm)
324 | self.bn4 = BatchNorm(1536)
325 |
326 | self.conv5 = SeparableConv2d(1536, 2048, 3, stride=1, dilation=exit_block_dilations[1], BatchNorm=BatchNorm)
327 | self.bn5 = BatchNorm(2048)
328 |
329 | # Init weights
330 | self._init_weight()
331 |
332 | # Load pretrained model
333 | if pretrained:
334 | self._load_pretrained_model()
335 |
336 | def forward(self, x):
337 | # Entry flow
338 | x = self.conv1(x)
339 | x = self.bn1(x)
340 | x = self.relu(x)
341 |
342 | x = self.conv2(x)
343 | x = self.bn2(x)
344 | x = self.relu(x)
345 |
346 | x = self.block1(x)
347 | # add relu here
348 | x = self.relu(x)
349 | low_level_feat = x
350 | x = self.block2(x)
351 | x = self.block3(x)
352 |
353 | # Middle flow
354 | x = self.block4(x)
355 | x = self.block5(x)
356 | x = self.block6(x)
357 | x = self.block7(x)
358 | x = self.block8(x)
359 | x = self.block9(x)
360 | x = self.block10(x)
361 | x = self.block11(x)
362 | x = self.block12(x)
363 | x = self.block13(x)
364 | x = self.block14(x)
365 | x = self.block15(x)
366 | x = self.block16(x)
367 | x = self.block17(x)
368 | x = self.block18(x)
369 | x = self.block19(x)
370 |
371 | # Exit flow
372 | x = self.block20(x)
373 | x = self.relu(x)
374 | x = self.conv3(x)
375 | x = self.bn3(x)
376 | x = self.relu(x)
377 |
378 | x = self.conv4(x)
379 | x = self.bn4(x)
380 | x = self.relu(x)
381 |
382 | x = self.conv5(x)
383 | x = self.bn5(x)
384 | x = self.relu(x)
385 |
386 | return x, low_level_feat
387 |
388 | def _init_weight(self):
389 | for m in self.modules():
390 | if isinstance(m, nn.Conv2d):
391 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
392 | m.weight.data.normal_(0, math.sqrt(2.0 / n))
393 | elif isinstance(m, SynchronizedBatchNorm2d):
394 | m.weight.data.fill_(1)
395 | m.bias.data.zero_()
396 | elif isinstance(m, nn.BatchNorm2d):
397 | m.weight.data.fill_(1)
398 | m.bias.data.zero_()
399 |
400 | def _load_pretrained_model(self):
401 | pretrain_dict = model_zoo.load_url("http://data.lip6.fr/cadene/pretrainedmodels/xception-b5690688.pth")
402 | model_dict = {}
403 | state_dict = self.state_dict()
404 |
405 | for k, v in pretrain_dict.items():
406 | if k in state_dict:
407 | if "pointwise" in k:
408 | v = v.unsqueeze(-1).unsqueeze(-1)
409 | if k.startswith("block11"):
410 | model_dict[k] = v
411 | model_dict[k.replace("block11", "block12")] = v
412 | model_dict[k.replace("block11", "block13")] = v
413 | model_dict[k.replace("block11", "block14")] = v
414 | model_dict[k.replace("block11", "block15")] = v
415 | model_dict[k.replace("block11", "block16")] = v
416 | model_dict[k.replace("block11", "block17")] = v
417 | model_dict[k.replace("block11", "block18")] = v
418 | model_dict[k.replace("block11", "block19")] = v
419 | elif k.startswith("block12"):
420 | model_dict[k.replace("block12", "block20")] = v
421 | elif k.startswith("bn3"):
422 | model_dict[k] = v
423 | model_dict[k.replace("bn3", "bn4")] = v
424 | elif k.startswith("conv4"):
425 | model_dict[k.replace("conv4", "conv5")] = v
426 | elif k.startswith("bn4"):
427 | model_dict[k.replace("bn4", "bn5")] = v
428 | else:
429 | model_dict[k] = v
430 | state_dict.update(model_dict)
431 | self.load_state_dict(state_dict)
432 |
433 |
434 | if __name__ == "__main__":
435 | import torch
436 |
437 | model = AlignedXception(BatchNorm=nn.BatchNorm2d, pretrained=True, output_stride=16)
438 | input = torch.rand(1, 3, 512, 512)
439 | output, low_level_feat = model(input)
440 | print(output.size())
441 | print(low_level_feat.size())
442 |
--------------------------------------------------------------------------------
/seg_lapa/networks/deeplab/decoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from .sync_batchnorm.batchnorm import SynchronizedBatchNorm2d
6 |
7 |
8 | class Decoder(nn.Module):
9 | def __init__(self, num_classes, backbone, BatchNorm):
10 | super(Decoder, self).__init__()
11 | if backbone == "resnet" or backbone == "drn":
12 | low_level_inplanes = 256
13 | elif backbone == "xception":
14 | low_level_inplanes = 128
15 | elif backbone == "mobilenet":
16 | low_level_inplanes = 24
17 | else:
18 | raise NotImplementedError
19 |
20 | self.conv1 = nn.Conv2d(low_level_inplanes, 48, 1, bias=False)
21 | self.bn1 = BatchNorm(48)
22 | self.relu = nn.ReLU()
23 | self.last_conv = nn.Sequential(
24 | nn.Conv2d(304, 256, kernel_size=3, stride=1, padding=1, bias=False),
25 | BatchNorm(256),
26 | nn.ReLU(),
27 | nn.Dropout(0.5),
28 | nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False),
29 | BatchNorm(256),
30 | nn.ReLU(),
31 | nn.Dropout(0.1),
32 | nn.Conv2d(256, num_classes, kernel_size=1, stride=1),
33 | )
34 | self._init_weight()
35 |
36 | def forward(self, x, low_level_feat):
37 | low_level_feat = self.conv1(low_level_feat)
38 | low_level_feat = self.bn1(low_level_feat)
39 | low_level_feat = self.relu(low_level_feat)
40 |
41 | x = F.interpolate(x, size=low_level_feat.size()[2:], mode="bilinear", align_corners=True)
42 | x = torch.cat((x, low_level_feat), dim=1)
43 | x = self.last_conv(x)
44 |
45 | return x
46 |
47 | def _init_weight(self):
48 | for m in self.modules():
49 | if isinstance(m, nn.Conv2d):
50 | torch.nn.init.kaiming_normal_(m.weight)
51 | elif isinstance(m, SynchronizedBatchNorm2d):
52 | m.weight.data.fill_(1)
53 | m.bias.data.zero_()
54 | elif isinstance(m, nn.BatchNorm2d):
55 | m.weight.data.fill_(1)
56 | m.bias.data.zero_()
57 |
58 |
59 | def build_decoder(num_classes, backbone, BatchNorm):
60 | return Decoder(num_classes, backbone, BatchNorm)
61 |
--------------------------------------------------------------------------------
/seg_lapa/networks/deeplab/decoder_masks.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from .sync_batchnorm.batchnorm import SynchronizedBatchNorm2d
6 |
7 |
8 | class Decoder(nn.Module):
9 | def __init__(self, num_classes, backbone, BatchNorm):
10 | super(Decoder, self).__init__()
11 | if backbone == "resnet" or backbone == "drn":
12 | low_level_inplanes = 256
13 | elif backbone == "xception":
14 | low_level_inplanes = 128
15 | elif backbone == "mobilenet":
16 | low_level_inplanes = 24
17 | else:
18 | raise NotImplementedError
19 |
20 | self.conv1 = nn.Conv2d(low_level_inplanes, 48, 1, bias=False)
21 | self.bn1 = BatchNorm(48)
22 | self.relu = nn.ReLU()
23 | self.last_conv = nn.Sequential(
24 | nn.Conv2d(304, 256, kernel_size=3, stride=1, padding=1, bias=False),
25 | BatchNorm(256),
26 | nn.ReLU(),
27 | nn.Dropout(0.5),
28 | nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False),
29 | BatchNorm(256),
30 | nn.ReLU(),
31 | nn.Dropout(0.1),
32 | nn.Conv2d(256, num_classes, kernel_size=1, stride=1),
33 | )
34 | self._init_weight()
35 |
36 | def forward(self, x, low_level_feat):
37 | low_level_feat = self.conv1(low_level_feat)
38 | low_level_feat = self.bn1(low_level_feat)
39 | low_level_feat = self.relu(low_level_feat)
40 |
41 | x = F.interpolate(x, size=low_level_feat.size()[2:], mode="bilinear", align_corners=True)
42 | x = torch.cat((x, low_level_feat), dim=1)
43 | x = self.last_conv(x)
44 |
45 | return x
46 |
47 | def _init_weight(self):
48 | for m in self.modules():
49 | if isinstance(m, nn.Conv2d):
50 | torch.nn.init.kaiming_normal_(m.weight)
51 | elif isinstance(m, SynchronizedBatchNorm2d):
52 | m.weight.data.fill_(1)
53 | m.bias.data.zero_()
54 | elif isinstance(m, nn.BatchNorm2d):
55 | m.weight.data.fill_(1)
56 | m.bias.data.zero_()
57 |
58 |
59 | def build_decoder(num_classes, backbone, BatchNorm):
60 | return Decoder(num_classes, backbone, BatchNorm)
61 |
--------------------------------------------------------------------------------
/seg_lapa/networks/deeplab/deeplab.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.cuda.amp import autocast
5 |
6 | from .aspp import build_aspp
7 | from .backbone import build_backbone
8 | from .decoder import build_decoder
9 | from .sync_batchnorm.batchnorm import SynchronizedBatchNorm2d
10 |
11 |
12 | class DeepLab(nn.Module):
13 | def __init__(
14 | self, backbone="drn", output_stride=8, num_classes=11, sync_bn=False, freeze_bn=False, enable_amp=False
15 | ):
16 | super(DeepLab, self).__init__()
17 |
18 | if backbone == "drn" and output_stride != 8:
19 | raise ValueError(f'The "drn" backbone only supports output stride = 8. Input: {output_stride}')
20 |
21 | # Ref for sync_bn: https://hangzhang.org/PyTorch-Encoding/tutorials/syncbn.html
22 | # Sync batchnorm is required when the effective batchsize per GPU is small (~4)
23 | if sync_bn is True:
24 | BatchNorm = SynchronizedBatchNorm2d
25 | else:
26 | BatchNorm = nn.BatchNorm2d
27 |
28 | self.backbone = build_backbone(backbone, output_stride, BatchNorm)
29 | self.aspp = build_aspp(backbone, output_stride, BatchNorm)
30 | self.decoder = build_decoder(num_classes, backbone, BatchNorm)
31 |
32 | self.freeze_bn = freeze_bn
33 |
34 | self.enable_amp = enable_amp
35 |
36 | def forward(self, inputs):
37 | with autocast(enabled=self.enable_amp):
38 | """Pytorch Automatic Mixed Precision (AMP) Training
39 | Ref: https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast
40 |
41 | - For use with DataParallel, we must add autocast within model definition.
42 | """
43 | x, low_level_feat = self.backbone(inputs)
44 | x = self.aspp(x)
45 | x = self.decoder(x, low_level_feat)
46 | x = F.interpolate(x, size=inputs.size()[2:], mode="bilinear", align_corners=True)
47 |
48 | return x
49 |
50 | def freeze_bn(self):
51 | for m in self.modules():
52 | if isinstance(m, SynchronizedBatchNorm2d):
53 | m.eval()
54 | elif isinstance(m, nn.BatchNorm2d):
55 | m.eval()
56 |
57 | def get_1x_lr_params(self):
58 | modules = [self.backbone]
59 | for i in range(len(modules)):
60 | for m in modules[i].named_modules():
61 | if self.freeze_bn:
62 | if isinstance(m[1], nn.Conv2d):
63 | for p in m[1].parameters():
64 | if p.requires_grad:
65 | yield p
66 | else:
67 | if (
68 | isinstance(m[1], nn.Conv2d)
69 | or isinstance(m[1], SynchronizedBatchNorm2d)
70 | or isinstance(m[1], nn.BatchNorm2d)
71 | ):
72 | for p in m[1].parameters():
73 | if p.requires_grad:
74 | yield p
75 |
76 | def get_10x_lr_params(self):
77 | modules = [self.aspp, self.decoder]
78 | for i in range(len(modules)):
79 | for m in modules[i].named_modules():
80 | if self.freeze_bn:
81 | if isinstance(m[1], nn.Conv2d):
82 | for p in m[1].parameters():
83 | if p.requires_grad:
84 | yield p
85 | else:
86 | if (
87 | isinstance(m[1], nn.Conv2d)
88 | or isinstance(m[1], SynchronizedBatchNorm2d)
89 | or isinstance(m[1], nn.BatchNorm2d)
90 | ):
91 | for p in m[1].parameters():
92 | if p.requires_grad:
93 | yield p
94 |
95 |
96 | if __name__ == "__main__":
97 | model = DeepLab(backbone="drn", output_stride=8)
98 | model.eval()
99 | inputs = torch.rand(1, 3, 512, 512)
100 | output = model(inputs)
101 | print(output.size())
102 |
--------------------------------------------------------------------------------
/seg_lapa/networks/deeplab/readme.md:
--------------------------------------------------------------------------------
1 | The code for deeplabv3+ was adapted from: https://github.com/jfzhang95/pytorch-deeplab-xception
2 |
--------------------------------------------------------------------------------
/seg_lapa/networks/deeplab/sync_batchnorm/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : __init__.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d
12 | from .replicate import DataParallelWithCallback, patch_replication_callback
13 |
--------------------------------------------------------------------------------
/seg_lapa/networks/deeplab/sync_batchnorm/batchnorm.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : batchnorm.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import collections
12 |
13 | import torch
14 | import torch.nn.functional as F
15 | from torch.nn.modules.batchnorm import _BatchNorm
16 | from torch.nn.parallel._functions import Broadcast, ReduceAddCoalesced
17 |
18 | from .comm import SyncMaster
19 |
20 | __all__ = ["SynchronizedBatchNorm1d", "SynchronizedBatchNorm2d", "SynchronizedBatchNorm3d"]
21 |
22 |
23 | def _sum_ft(tensor):
24 | """sum over the first and last dimention"""
25 | return tensor.sum(dim=0).sum(dim=-1)
26 |
27 |
28 | def _unsqueeze_ft(tensor):
29 | """add new dementions at the front and the tail"""
30 | return tensor.unsqueeze(0).unsqueeze(-1)
31 |
32 |
33 | _ChildMessage = collections.namedtuple("_ChildMessage", ["sum", "ssum", "sum_size"])
34 | _MasterMessage = collections.namedtuple("_MasterMessage", ["sum", "inv_std"])
35 |
36 |
37 | class _SynchronizedBatchNorm(_BatchNorm):
38 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True):
39 | super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine)
40 |
41 | self._sync_master = SyncMaster(self._data_parallel_master)
42 |
43 | self._is_parallel = False
44 | self._parallel_id = None
45 | self._slave_pipe = None
46 |
47 | def forward(self, input):
48 | # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation.
49 | if not (self._is_parallel and self.training):
50 | return F.batch_norm(
51 | input,
52 | self.running_mean,
53 | self.running_var,
54 | self.weight,
55 | self.bias,
56 | self.training,
57 | self.momentum,
58 | self.eps,
59 | )
60 |
61 | # Resize the input to (B, C, -1).
62 | input_shape = input.size()
63 | input = input.view(input.size(0), self.num_features, -1)
64 |
65 | # Compute the sum and square-sum.
66 | sum_size = input.size(0) * input.size(2)
67 | input_sum = _sum_ft(input)
68 | input_ssum = _sum_ft(input ** 2)
69 |
70 | # Reduce-and-broadcast the statistics.
71 | if self._parallel_id == 0:
72 | mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size))
73 | else:
74 | mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size))
75 |
76 | # Compute the output.
77 | if self.affine:
78 | # MJY:: Fuse the multiplication for speed.
79 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias)
80 | else:
81 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std)
82 |
83 | # Reshape it.
84 | return output.view(input_shape)
85 |
86 | def __data_parallel_replicate__(self, ctx, copy_id):
87 | self._is_parallel = True
88 | self._parallel_id = copy_id
89 |
90 | # parallel_id == 0 means master device.
91 | if self._parallel_id == 0:
92 | ctx.sync_master = self._sync_master
93 | else:
94 | self._slave_pipe = ctx.sync_master.register_slave(copy_id)
95 |
96 | def _data_parallel_master(self, intermediates):
97 | """Reduce the sum and square-sum, compute the statistics, and broadcast it."""
98 |
99 | # Always using same "device order" makes the ReduceAdd operation faster.
100 | # Thanks to:: Tete Xiao (http://tetexiao.com/)
101 | intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device())
102 |
103 | to_reduce = [i[1][:2] for i in intermediates]
104 | to_reduce = [j for i in to_reduce for j in i] # flatten
105 | target_gpus = [i[1].sum.get_device() for i in intermediates]
106 |
107 | sum_size = sum([i[1].sum_size for i in intermediates])
108 | sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce)
109 | mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size)
110 |
111 | broadcasted = Broadcast.apply(target_gpus, mean, inv_std)
112 |
113 | outputs = []
114 | for i, rec in enumerate(intermediates):
115 | outputs.append((rec[0], _MasterMessage(*broadcasted[i * 2 : i * 2 + 2])))
116 |
117 | return outputs
118 |
119 | def _compute_mean_std(self, sum_, ssum, size):
120 | """Compute the mean and standard-deviation with sum and square-sum. This method
121 | also maintains the moving average on the master device."""
122 | assert size > 1, "BatchNorm computes unbiased standard-deviation, which requires size > 1."
123 | mean = sum_ / size
124 | sumvar = ssum - sum_ * mean
125 | unbias_var = sumvar / (size - 1)
126 | bias_var = sumvar / size
127 |
128 | self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data
129 | self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data
130 |
131 | return mean, bias_var.clamp(self.eps) ** -0.5
132 |
133 |
134 | class SynchronizedBatchNorm1d(_SynchronizedBatchNorm):
135 | r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a
136 | mini-batch.
137 | .. math::
138 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
139 | This module differs from the built-in PyTorch BatchNorm1d as the mean and
140 | standard-deviation are reduced across all devices during training.
141 | For example, when one uses `nn.DataParallel` to wrap the network during
142 | training, PyTorch's implementation normalize the tensor on each device using
143 | the statistics only on that device, which accelerated the computation and
144 | is also easy to implement, but the statistics might be inaccurate.
145 | Instead, in this synchronized version, the statistics will be computed
146 | over all training samples distributed on multiple devices.
147 |
148 | Note that, for one-GPU or CPU-only case, this module behaves exactly same
149 | as the built-in PyTorch implementation.
150 | The mean and standard-deviation are calculated per-dimension over
151 | the mini-batches and gamma and beta are learnable parameter vectors
152 | of size C (where C is the input size).
153 | During training, this layer keeps a running estimate of its computed mean
154 | and variance. The running sum is kept with a default momentum of 0.1.
155 | During evaluation, this running mean/variance is used for normalization.
156 | Because the BatchNorm is done over the `C` dimension, computing statistics
157 | on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm
158 | Args:
159 | num_features: num_features from an expected input of size
160 | `batch_size x num_features [x width]`
161 | eps: a value added to the denominator for numerical stability.
162 | Default: 1e-5
163 | momentum: the value used for the running_mean and running_var
164 | computation. Default: 0.1
165 | affine: a boolean value that when set to ``True``, gives the layer learnable
166 | affine parameters. Default: ``True``
167 | Shape:
168 | - Input: :math:`(N, C)` or :math:`(N, C, L)`
169 | - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)
170 | Examples:
171 | >>> # With Learnable Parameters
172 | >>> m = SynchronizedBatchNorm1d(100)
173 | >>> # Without Learnable Parameters
174 | >>> m = SynchronizedBatchNorm1d(100, affine=False)
175 | >>> input = torch.autograd.Variable(torch.randn(20, 100))
176 | >>> output = m(input)
177 | """
178 |
179 | def _check_input_dim(self, input):
180 | if input.dim() != 2 and input.dim() != 3:
181 | raise ValueError("expected 2D or 3D input (got {}D input)".format(input.dim()))
182 | super(SynchronizedBatchNorm1d, self)._check_input_dim(input)
183 |
184 |
185 | class SynchronizedBatchNorm2d(_SynchronizedBatchNorm):
186 | r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch
187 | of 3d inputs
188 | .. math::
189 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
190 | This module differs from the built-in PyTorch BatchNorm2d as the mean and
191 | standard-deviation are reduced across all devices during training.
192 | For example, when one uses `nn.DataParallel` to wrap the network during
193 | training, PyTorch's implementation normalize the tensor on each device using
194 | the statistics only on that device, which accelerated the computation and
195 | is also easy to implement, but the statistics might be inaccurate.
196 | Instead, in this synchronized version, the statistics will be computed
197 | over all training samples distributed on multiple devices.
198 |
199 | Note that, for one-GPU or CPU-only case, this module behaves exactly same
200 | as the built-in PyTorch implementation.
201 | The mean and standard-deviation are calculated per-dimension over
202 | the mini-batches and gamma and beta are learnable parameter vectors
203 | of size C (where C is the input size).
204 | During training, this layer keeps a running estimate of its computed mean
205 | and variance. The running sum is kept with a default momentum of 0.1.
206 | During evaluation, this running mean/variance is used for normalization.
207 | Because the BatchNorm is done over the `C` dimension, computing statistics
208 | on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm
209 | Args:
210 | num_features: num_features from an expected input of
211 | size batch_size x num_features x height x width
212 | eps: a value added to the denominator for numerical stability.
213 | Default: 1e-5
214 | momentum: the value used for the running_mean and running_var
215 | computation. Default: 0.1
216 | affine: a boolean value that when set to ``True``, gives the layer learnable
217 | affine parameters. Default: ``True``
218 | Shape:
219 | - Input: :math:`(N, C, H, W)`
220 | - Output: :math:`(N, C, H, W)` (same shape as input)
221 | Examples:
222 | >>> # With Learnable Parameters
223 | >>> m = SynchronizedBatchNorm2d(100)
224 | >>> # Without Learnable Parameters
225 | >>> m = SynchronizedBatchNorm2d(100, affine=False)
226 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45))
227 | >>> output = m(input)
228 | """
229 |
230 | def _check_input_dim(self, input):
231 | if input.dim() != 4:
232 | raise ValueError("expected 4D input (got {}D input)".format(input.dim()))
233 | super(SynchronizedBatchNorm2d, self)._check_input_dim(input)
234 |
235 |
236 | class SynchronizedBatchNorm3d(_SynchronizedBatchNorm):
237 | r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch
238 | of 4d inputs
239 | .. math::
240 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
241 | This module differs from the built-in PyTorch BatchNorm3d as the mean and
242 | standard-deviation are reduced across all devices during training.
243 | For example, when one uses `nn.DataParallel` to wrap the network during
244 | training, PyTorch's implementation normalize the tensor on each device using
245 | the statistics only on that device, which accelerated the computation and
246 | is also easy to implement, but the statistics might be inaccurate.
247 | Instead, in this synchronized version, the statistics will be computed
248 | over all training samples distributed on multiple devices.
249 |
250 | Note that, for one-GPU or CPU-only case, this module behaves exactly same
251 | as the built-in PyTorch implementation.
252 | The mean and standard-deviation are calculated per-dimension over
253 | the mini-batches and gamma and beta are learnable parameter vectors
254 | of size C (where C is the input size).
255 | During training, this layer keeps a running estimate of its computed mean
256 | and variance. The running sum is kept with a default momentum of 0.1.
257 | During evaluation, this running mean/variance is used for normalization.
258 | Because the BatchNorm is done over the `C` dimension, computing statistics
259 | on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm
260 | or Spatio-temporal BatchNorm
261 | Args:
262 | num_features: num_features from an expected input of
263 | size batch_size x num_features x depth x height x width
264 | eps: a value added to the denominator for numerical stability.
265 | Default: 1e-5
266 | momentum: the value used for the running_mean and running_var
267 | computation. Default: 0.1
268 | affine: a boolean value that when set to ``True``, gives the layer learnable
269 | affine parameters. Default: ``True``
270 | Shape:
271 | - Input: :math:`(N, C, D, H, W)`
272 | - Output: :math:`(N, C, D, H, W)` (same shape as input)
273 | Examples:
274 | >>> # With Learnable Parameters
275 | >>> m = SynchronizedBatchNorm3d(100)
276 | >>> # Without Learnable Parameters
277 | >>> m = SynchronizedBatchNorm3d(100, affine=False)
278 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10))
279 | >>> output = m(input)
280 | """
281 |
282 | def _check_input_dim(self, input):
283 | if input.dim() != 5:
284 | raise ValueError("expected 5D input (got {}D input)".format(input.dim()))
285 | super(SynchronizedBatchNorm3d, self)._check_input_dim(input)
286 |
--------------------------------------------------------------------------------
/seg_lapa/networks/deeplab/sync_batchnorm/comm.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : comm.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import collections
12 | import queue
13 | import threading
14 |
15 | __all__ = ["FutureResult", "SlavePipe", "SyncMaster"]
16 |
17 |
18 | class FutureResult(object):
19 | """A thread-safe future implementation. Used only as one-to-one pipe."""
20 |
21 | def __init__(self):
22 | self._result = None
23 | self._lock = threading.Lock()
24 | self._cond = threading.Condition(self._lock)
25 |
26 | def put(self, result):
27 | with self._lock:
28 | assert self._result is None, "Previous result has't been fetched."
29 | self._result = result
30 | self._cond.notify()
31 |
32 | def get(self):
33 | with self._lock:
34 | if self._result is None:
35 | self._cond.wait()
36 |
37 | res = self._result
38 | self._result = None
39 | return res
40 |
41 |
42 | _MasterRegistry = collections.namedtuple("MasterRegistry", ["result"])
43 | _SlavePipeBase = collections.namedtuple("_SlavePipeBase", ["identifier", "queue", "result"])
44 |
45 |
46 | class SlavePipe(_SlavePipeBase):
47 | """Pipe for master-slave communication."""
48 |
49 | def run_slave(self, msg):
50 | self.queue.put((self.identifier, msg))
51 | ret = self.result.get()
52 | self.queue.put(True)
53 | return ret
54 |
55 |
56 | class SyncMaster(object):
57 | """An abstract `SyncMaster` object.
58 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should
59 | call `register(id)` and obtain an `SlavePipe` to communicate with the master.
60 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected,
61 | and passed to a registered callback.
62 | - After receiving the messages, the master device should gather the information and determine to message passed
63 | back to each slave devices.
64 | """
65 |
66 | def __init__(self, master_callback):
67 | """
68 | Args:
69 | master_callback: a callback to be invoked after having collected messages from slave devices.
70 | """
71 | self._master_callback = master_callback
72 | self._queue = queue.Queue()
73 | self._registry = collections.OrderedDict()
74 | self._activated = False
75 |
76 | def __getstate__(self):
77 | return {"master_callback": self._master_callback}
78 |
79 | def __setstate__(self, state):
80 | self.__init__(state["master_callback"])
81 |
82 | def register_slave(self, identifier):
83 | """
84 | Register an slave device.
85 | Args:
86 | identifier: an identifier, usually is the device id.
87 | Returns: a `SlavePipe` object which can be used to communicate with the master device.
88 | """
89 | if self._activated:
90 | assert self._queue.empty(), "Queue is not clean before next initialization."
91 | self._activated = False
92 | self._registry.clear()
93 | future = FutureResult()
94 | self._registry[identifier] = _MasterRegistry(future)
95 | return SlavePipe(identifier, self._queue, future)
96 |
97 | def run_master(self, master_msg):
98 | """
99 | Main entry for the master device in each forward pass.
100 | The messages were first collected from each devices (including the master device), and then
101 | an callback will be invoked to compute the message to be sent back to each devices
102 | (including the master device).
103 | Args:
104 | master_msg: the message that the master want to send to itself. This will be placed as the first
105 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example.
106 | Returns: the message to be sent back to the master device.
107 | """
108 | self._activated = True
109 |
110 | intermediates = [(0, master_msg)]
111 | for i in range(self.nr_slaves):
112 | intermediates.append(self._queue.get())
113 |
114 | results = self._master_callback(intermediates)
115 | assert results[0][0] == 0, "The first result should belongs to the master."
116 |
117 | for i, res in results:
118 | if i == 0:
119 | continue
120 | self._registry[i].result.put(res)
121 |
122 | for i in range(self.nr_slaves):
123 | assert self._queue.get() is True
124 |
125 | return results[0][1]
126 |
127 | @property
128 | def nr_slaves(self):
129 | return len(self._registry)
130 |
--------------------------------------------------------------------------------
/seg_lapa/networks/deeplab/sync_batchnorm/replicate.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : replicate.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import functools
12 |
13 | from torch.nn.parallel.data_parallel import DataParallel
14 |
15 | __all__ = ["CallbackContext", "execute_replication_callbacks", "DataParallelWithCallback", "patch_replication_callback"]
16 |
17 |
18 | class CallbackContext(object):
19 | pass
20 |
21 |
22 | def execute_replication_callbacks(modules):
23 | """
24 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication.
25 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
26 | Note that, as all modules are isomorphism, we assign each sub-module with a context
27 | (shared among multiple copies of this module on different devices).
28 | Through this context, different copies can share some information.
29 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback
30 | of any slave copies.
31 | """
32 | master_copy = modules[0]
33 | nr_modules = len(list(master_copy.modules()))
34 | ctxs = [CallbackContext() for _ in range(nr_modules)]
35 |
36 | for i, module in enumerate(modules):
37 | for j, m in enumerate(module.modules()):
38 | if hasattr(m, "__data_parallel_replicate__"):
39 | m.__data_parallel_replicate__(ctxs[j], i)
40 |
41 |
42 | class DataParallelWithCallback(DataParallel):
43 | """
44 | Data Parallel with a replication callback.
45 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by
46 | original `replicate` function.
47 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
48 | Examples:
49 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
50 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
51 | # sync_bn.__data_parallel_replicate__ will be invoked.
52 | """
53 |
54 | def replicate(self, module, device_ids):
55 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids)
56 | execute_replication_callbacks(modules)
57 | return modules
58 |
59 |
60 | def patch_replication_callback(data_parallel):
61 | """
62 | Monkey-patch an existing `DataParallel` object. Add the replication callback.
63 | Useful when you have customized `DataParallel` implementation.
64 | Examples:
65 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
66 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
67 | > patch_replication_callback(sync_bn)
68 | # this is equivalent to
69 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
70 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
71 | """
72 |
73 | assert isinstance(data_parallel, DataParallel)
74 |
75 | old_replicate = data_parallel.replicate
76 |
77 | @functools.wraps(old_replicate)
78 | def new_replicate(module, device_ids):
79 | modules = old_replicate(module, device_ids)
80 | execute_replication_callbacks(modules)
81 | return modules
82 |
83 | data_parallel.replicate = new_replicate
84 |
--------------------------------------------------------------------------------
/seg_lapa/networks/deeplab/sync_batchnorm/unittest.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : unittest.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import unittest
12 |
13 | import numpy as np
14 | from torch.autograd import Variable
15 |
16 |
17 | def as_numpy(v):
18 | if isinstance(v, Variable):
19 | v = v.data
20 | return v.cpu().numpy()
21 |
22 |
23 | class TorchTestCase(unittest.TestCase):
24 | def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3):
25 | npa, npb = as_numpy(a), as_numpy(b)
26 | self.assertTrue(
27 | np.allclose(npa, npb, atol=atol),
28 | "Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}".format(
29 | a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max()
30 | ),
31 | )
32 |
--------------------------------------------------------------------------------
/seg_lapa/train.py:
--------------------------------------------------------------------------------
1 | from typing import Any, List
2 |
3 | import hydra
4 | import numpy as np
5 | import pytorch_lightning as pl
6 | import wandb
7 | from omegaconf import DictConfig, OmegaConf
8 | from pytorch_lightning import loggers as pl_loggers
9 |
10 | from seg_lapa import metrics
11 | from seg_lapa.callbacks.log_media import LogMediaQueue, Mode
12 | from seg_lapa.config_parse.train_conf import ParseConfig
13 | from seg_lapa.loss_func import CrossEntropy2D
14 | from seg_lapa.utils import utils
15 | from seg_lapa.utils.utils import is_rank_zero
16 |
17 |
18 | class DeeplabV3plus(pl.LightningModule, ParseConfig):
19 | def __init__(self, cfg: DictConfig, log_media_max_batches=1):
20 | super().__init__()
21 | self.save_hyperparameters() # Will save the config to wandb too
22 | # Accessing cfg via hparams allows value to be loaded from checkpoints
23 | self.config = self.parse_config(self.hparams.cfg)
24 |
25 | self.cross_entropy_loss = CrossEntropy2D(loss_per_image=True, ignore_index=255)
26 | self.model = self.config.model.get_model()
27 |
28 | self.iou_train = metrics.Iou(num_classes=self.config.model.num_classes)
29 | self.iou_val = metrics.Iou(num_classes=self.config.model.num_classes)
30 | self.iou_test = metrics.Iou(num_classes=self.config.model.num_classes)
31 |
32 | # Logging media such a images using `self.log()` is extremely memory-expensive.
33 | # Save predictions to be logged within a circular queue, to be consumed in the LogMedia callback.
34 | self.log_media: LogMediaQueue = LogMediaQueue(log_media_max_batches)
35 |
36 | def forward(self, x):
37 | """In lightning, forward defines the prediction/inference actions.
38 | This method can be called elsewhere in the LightningModule with: `outputs = self(inputs)`.
39 | """
40 | outputs = self.model(x)
41 | return outputs
42 |
43 | def training_step(self, batch, batch_idx):
44 | """Defines the train loop. It is independent of forward().
45 | Don’t use any cuda or .to(device) calls in the code. PL will move the tensors to the correct device.
46 | """
47 | inputs, labels = batch
48 | outputs = self.model(inputs)
49 | predictions = outputs.argmax(dim=1)
50 |
51 | # Calculate Loss
52 | loss = self.cross_entropy_loss(outputs, labels)
53 |
54 | """Log the value on GPU0 per step. Also log average of all steps at epoch_end."""
55 | self.log("Train/loss", loss, on_step=True, on_epoch=True)
56 | """Log the avg. value across all GPUs per step. Also log average of all steps at epoch_end.
57 | Alternately, you can use the ops 'sum' or 'avg'.
58 | Using sync_dist is efficient. It adds extremely minor overhead for scalar values.
59 | """
60 | # self.log("Train/loss", loss, on_step=True, on_epoch=True, sync_dist=True, sync_dist_op="avg")
61 |
62 | # Calculate Metrics
63 | self.iou_train(predictions, labels)
64 |
65 | # Returning images is expensive - All the batches are accumulated for _epoch_end().
66 | # Save the latst predictions to be logged in an attr. They will be consumed by the LogMedia callback.
67 | self.log_media.append({"inputs": inputs, "labels": labels, "preds": predictions}, Mode.TRAIN)
68 |
69 | return {"loss": loss}
70 |
71 | def validation_step(self, batch, batch_idx):
72 | inputs, labels = batch
73 | outputs = self.model(inputs)
74 | predictions = outputs.argmax(dim=1)
75 |
76 | # Calculate Loss
77 | loss = self.cross_entropy_loss(outputs, labels)
78 | self.log("Val/loss", loss)
79 |
80 | # Calculate Metrics
81 | self.iou_val(predictions, labels)
82 |
83 | # Save the latest predictions to be logged
84 | self.log_media.append({"inputs": inputs, "labels": labels, "preds": predictions}, Mode.VAL)
85 |
86 | return {"val_loss": loss}
87 |
88 | def test_step(self, batch, batch_idx):
89 | inputs, labels = batch
90 | outputs = self.model(inputs)
91 | predictions = outputs.argmax(dim=1)
92 |
93 | # Calculate Loss
94 | loss = self.cross_entropy_loss(outputs, labels)
95 | self.log("Test/loss", loss)
96 |
97 | # Calculate Metrics
98 | self.iou_test(predictions, labels)
99 |
100 | # Save the latest predictions to be logged
101 | self.log_media.append({"inputs": inputs, "labels": labels, "preds": predictions}, Mode.TEST)
102 |
103 | return {"test_loss": loss}
104 |
105 | def training_epoch_end(self, outputs: List[Any]):
106 | # Compute and log metrics across epoch
107 | metrics_avg = self.iou_train.compute()
108 | self.log("Train/mIoU", metrics_avg.miou)
109 | self.iou_train.reset()
110 |
111 | def validation_epoch_end(self, outputs: List[Any]):
112 | # Compute and log metrics across epoch
113 | metrics_avg = self.iou_val.compute()
114 | self.log("Val/mIoU", metrics_avg.miou)
115 | self.iou_val.reset()
116 |
117 | def test_epoch_end(self, outputs: List[Any]):
118 | # Compute and log metrics across epoch
119 | metrics_avg = self.iou_test.compute()
120 | self.log("Test/mIoU", metrics_avg.miou)
121 | self.log("Test/Accuracy", metrics_avg.accuracy.mean())
122 | self.log("Test/Precision", metrics_avg.precision.mean())
123 | self.log("Test/Recall", metrics_avg.recall.mean())
124 |
125 | # Save test results as a Table (WandB)
126 | self.log_results_table_wandb(metrics_avg)
127 | self.iou_test.reset()
128 |
129 | def log_results_table_wandb(self, metrics_avg: metrics.IouMetric):
130 | if not isinstance(self.logger, pl_loggers.WandbLogger):
131 | return
132 |
133 | results = metrics.IouMetric(
134 | iou_per_class=metrics_avg.iou_per_class.cpu().numpy(),
135 | miou=metrics_avg.miou.cpu().numpy(),
136 | accuracy=metrics_avg.accuracy.cpu().numpy().mean(),
137 | precision=metrics_avg.precision.cpu().numpy().mean(),
138 | recall=metrics_avg.recall.cpu().numpy().mean(),
139 | specificity=metrics_avg.specificity.cpu().numpy().mean(),
140 | )
141 |
142 | data = np.stack(
143 | [results.miou, results.accuracy, results.precision, results.recall, results.specificity], axis=0
144 | )
145 | data_l = [round(x.item(), 4) for x in data]
146 | table = wandb.Table(data=data_l, columns=["mIoU", "Accuracy", "Precision", "Recall", "Specificity"])
147 | self.logger.experiment.log({f"Test/Results": table}, commit=False)
148 |
149 | data = np.stack((np.arange(results.iou_per_class.shape[0]), results.iou_per_class)).T
150 | table = wandb.Table(data=data.round(decimals=4).tolist(), columns=["Class ID", "IoU"])
151 | self.logger.experiment.log({f"Test/IoU_per_class": table}, commit=False)
152 |
153 | def configure_optimizers(self):
154 | optimizer = self.config.optimizer.get_optimizer(self.parameters())
155 |
156 | ret_opt = {"optimizer": optimizer}
157 |
158 | sch = self.config.scheduler.get_scheduler(optimizer)
159 | if sch is not None:
160 | scheduler = {
161 | "scheduler": sch, # The LR scheduler instance (required)
162 | "interval": "epoch", # The unit of the scheduler's step size
163 | "frequency": 1, # The frequency of the scheduler
164 | "reduce_on_plateau": False, # For ReduceLROnPlateau scheduler
165 | "monitor": "Val/mIoU", # Metric for ReduceLROnPlateau to monitor
166 | "strict": True, # Whether to crash the training if `monitor` is not found
167 | "name": None, # Custom name for LearningRateMonitor to use
168 | }
169 |
170 | ret_opt.update({"lr_scheduler": scheduler})
171 |
172 | return ret_opt
173 |
174 |
175 | @hydra.main(config_path="config", config_name="train")
176 | def main(cfg: DictConfig):
177 | if is_rank_zero():
178 | print("\nGiven Config:\n", OmegaConf.to_yaml(cfg))
179 |
180 | config = ParseConfig.parse_config(cfg)
181 | if is_rank_zero():
182 | print("\nResolved Dataclass:\n", config, "\n")
183 |
184 | utils.fix_seeds(config.random_seed)
185 | exp_dir = utils.generate_log_dir_path(config)
186 |
187 | wb_logger = config.logger.get_logger(cfg, config.logs_root_dir)
188 | callbacks = config.callbacks.get_callbacks_list(exp_dir, cfg)
189 | dm = config.dataset.get_datamodule()
190 |
191 | # Load weights
192 | if config.load_weights.path is None:
193 | model = DeeplabV3plus(cfg)
194 | else:
195 | model = DeeplabV3plus.load_from_checkpoint(config.load_weights.path, cfg=cfg)
196 |
197 | trainer = config.trainer.get_trainer(wb_logger, callbacks, config.logs_root_dir)
198 |
199 | # Run Training
200 | trainer.fit(model, datamodule=dm)
201 |
202 | # Run Testing
203 | result = trainer.test(ckpt_path=None) # Prints the final result
204 |
205 | wandb.finish()
206 |
207 |
208 | if __name__ == "__main__":
209 | main()
210 |
--------------------------------------------------------------------------------
/seg_lapa/utils/path_check.py:
--------------------------------------------------------------------------------
1 | from enum import Enum
2 | from pathlib import Path
3 | from typing import Optional, Union
4 |
5 |
6 | class PathType(Enum):
7 | FILE = 0
8 | DIR = 1
9 | ANY = 2
10 |
11 |
12 | def get_project_root() -> Path:
13 | """Get the root dir of the project (one level above package).
14 | Used to get paths relative to the project, mainly for placing log files
15 |
16 | This function assumes that this module is at project_root/package_dir/utils_dir/module.py
17 | If this structure changes, the func must change.
18 | """
19 | return Path(__file__).parent.parent.parent
20 |
21 |
22 | def get_path(
23 | input_path: Union[str, Path],
24 | must_exist: Optional[bool] = None,
25 | path_type: PathType = PathType.DIR,
26 | force_relative_to_project: bool = False,
27 | ) -> Path:
28 | """Converts a str to a pathlib Path with added checks
29 |
30 | Args:
31 | input_path: The path to be converted
32 | must_exist: If given, ensure that the path either exists or does not exist.
33 | - None: Don't check
34 | - True: It must exist.
35 | - False: It must not exist.
36 | path_type: Whether the path is to a dir or file. Only used if must_exist is not None.
37 | force_relative_to_project: If a relative path is given, convert it to be relative to the project root dir.
38 | Required because the package can be called from anywhere, and the relative path will
39 | be relative to where it's called from, rather than relative to the script location.
40 |
41 | Useful to always place logs in the project root, regardless of where the package
42 | is called from.
43 | """
44 | if not isinstance(path_type, PathType):
45 | raise ValueError(f"Invalid path type '{path_type} (type={type(path_type)})'. Must be of type Enum {PathType}")
46 |
47 | input_path = Path(input_path)
48 |
49 | if not input_path.expanduser().is_absolute() and force_relative_to_project:
50 | # If the input path is relative, change it to become relative to the project root.
51 | proj_root = get_project_root()
52 | input_path = proj_root / input_path
53 |
54 | if must_exist is not None:
55 | if must_exist:
56 | # check that path exists
57 | if not input_path.exists():
58 | raise ValueError(f"Could not find {path_type.name.lower()}. Does not exist: {input_path}")
59 |
60 | # If required, check that it is the correct type (file vs dir)
61 | if path_type is not PathType.ANY:
62 | if not input_path.is_dir() and path_type == PathType.DIR:
63 | raise ValueError(f"Not a dir: {input_path}")
64 |
65 | if not input_path.is_file() and path_type == PathType.FILE:
66 | raise ValueError(f"Not a file: {input_path}")
67 | else:
68 | # Ensure path doesn't already exist
69 | if input_path.exists():
70 | raise ValueError(f"Path already exists: {input_path}")
71 |
72 | return input_path
73 |
--------------------------------------------------------------------------------
/seg_lapa/utils/segmentation_label2rgb.py:
--------------------------------------------------------------------------------
1 | import enum
2 |
3 | import numpy as np
4 | from PIL import Image
5 |
6 |
7 | class Palette(enum.Enum):
8 | LAPA = 0
9 |
10 |
11 | class LabelToRGB:
12 | # fmt: off
13 | palette_lapa = [
14 | # Color Palette for the LaPa dataset, which has 11 classes
15 | 0, 0, 0, # 0 background
16 | 0, 153, 255, # 1 skin
17 | 102, 255, 153, # 2 left eyebrow
18 | 0, 204, 153, # 3 right eyebrow
19 | 255, 255, 102, # 4 left eye
20 | 255, 255, 204, # 5 right eye
21 | 255, 153, 0, # 6 nose
22 | 255, 102, 255, # 7 upper lip
23 | 102, 0, 51, # 8 inner mouth
24 | 255, 204, 255, # 9 lower lip
25 | 255, 0, 10, # 10 hair
26 | ]
27 | # fmt: on
28 |
29 | def __init__(self):
30 | """Generates a color map with a unique hue for each class in label.
31 | The hues are uniformly sampled from the range [0, 1). If the num of classes is too high, then the difference
32 | between neighboring hues will become indistinguishable"""
33 | self.color_palettes = {Palette.LAPA: self.palette_lapa}
34 |
35 | def map_color_palette(self, label: np.ndarray, palette: Palette) -> np.ndarray:
36 | """Generates RGB visualization of label by applying a color palette
37 | Label should contain an uint class index per pixel.
38 |
39 | Args:
40 | Args:
41 | label (numpy.ndarray): Each pixel has uint value corresponding to it's class index
42 | Shape: (H, W), dtype: np.uint8, np.uint16
43 | palette (Palette): Which color palette to use.
44 |
45 | Returns:
46 | numpy.ndarray: RGB image, with each class mapped to a unique color.
47 | Shape: (H, W, 3), dtype: np.uint8
48 | """
49 | if len(label.shape) != 2:
50 | raise ValueError(f"Label must have shape: (H, W). Input: {label.shape}")
51 | if not (label.dtype == np.uint8):
52 | raise ValueError(f"Label must have dtype np.uint8. Input: {label.dtype}")
53 | if not isinstance(palette, Palette):
54 | raise ValueError(f"palette must be of type {Palette}. Input: {palette}")
55 |
56 | color_palette = self.color_palettes[palette]
57 |
58 | # Check that the pallete has enough colors
59 | if len(color_palette) < label.max():
60 | raise ValueError(
61 | f"The chosen color palette has only {len(color_palette)} values. It does not have"
62 | f" enough unique colors to represent all the values in the label ({label.max()})"
63 | )
64 |
65 | # Map grayscale image's pixel values to RGB color palette
66 | _im = Image.fromarray(label)
67 | _im.putpalette(color_palette)
68 | _im = _im.convert(mode="RGB")
69 | im = np.asarray(_im)
70 |
71 | return im
72 |
73 | def colorize_batch_numpy(self, batch_label: np.ndarray) -> np.ndarray:
74 | """Convert a batch of numpy labels to RGB
75 |
76 | Args:
77 | batch_label (numpy.ndarray): Shape: [N, H, W], dtype=np.uint8
78 |
79 | Returns:
80 | numpy.ndarray: Colorize labels. Shape: [N, H, W, 3], dtype=np.uint8
81 | """
82 | if not isinstance(batch_label, np.ndarray):
83 | raise TypeError(f"`batch_label` expected to be Numpy array. Got: {type(batch_label)}")
84 | if len(batch_label.shape) != 3:
85 | raise ValueError(f"`batch_label` expected shape [N, H, W]. Got: {batch_label.shape}")
86 | if batch_label.dtype != np.uint8:
87 | raise ValueError(f"`batch_label` must be of dtype np.uint8. Got: {batch_label.dtype}")
88 |
89 | batch_label_rgb = [self.map_color_palette(label, Palette.LAPA) for label in batch_label]
90 | batch_label_rgb = np.stack(batch_label_rgb, axis=0)
91 | return batch_label_rgb
92 |
93 |
94 | if __name__ == "__main__":
95 | # Create dummy mask
96 | dummy_mask = np.zeros((512, 512), dtype=np.uint8)
97 | num_classes = 11
98 | cols = dummy_mask.shape[1]
99 | for idx, class_id in enumerate(range(num_classes)):
100 | cols_slice = cols // num_classes
101 | dummy_mask[:, cols_slice * idx : cols_slice * (idx + 1)] = class_id
102 |
103 | # Colorize mask
104 | label2rgb = LabelToRGB()
105 | colorized_mask = label2rgb.map_color_palette(dummy_mask, Palette.LAPA)
106 |
107 | # View mask
108 | img = Image.fromarray(colorized_mask)
109 | img.show()
110 |
--------------------------------------------------------------------------------
/seg_lapa/utils/utils.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import os
3 | from pathlib import Path
4 | from typing import Optional
5 |
6 | import pytorch_lightning as pl
7 |
8 | from seg_lapa.config_parse.logger_conf import DisabledLoggerConf, WandbConf
9 | from seg_lapa.config_parse.train_conf import TrainConf
10 | from seg_lapa.utils import path_check
11 |
12 | LOGS_DIR = "log-media"
13 |
14 |
15 | def is_rank_zero():
16 | local_rank = int(os.environ.get("LOCAL_RANK", 0))
17 | node_rank = int(os.environ.get("NODE_RANK", 0))
18 | if local_rank == 0 and node_rank == 0:
19 | return True
20 |
21 | return False
22 |
23 |
24 | def generate_log_dir_path(config: TrainConf) -> Path:
25 | """Generate the path to the log dir for this run.
26 | The directory structure for logs depends on the logger used.
27 |
28 | wandb - Each run's log dir's name will contain the wandb runid for easy identification
29 |
30 | Args:
31 | config: The config dataclass.
32 | """
33 | logs_root_dir = path_check.get_path(config.logs_root_dir, force_relative_to_project=True)
34 |
35 | # Exp directory structure would depend on the logger used
36 | logs_root_dir = logs_root_dir / LOGS_DIR / f"{config.logger.name}-logger"
37 | timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
38 | if isinstance(config.logger, DisabledLoggerConf):
39 | exp_dir = logs_root_dir / f"{timestamp}"
40 | elif isinstance(config.logger, WandbConf):
41 | run_id = config.logger.get_run_id()
42 | exp_dir = logs_root_dir / f"{timestamp}_{run_id}"
43 | else:
44 | raise NotImplementedError(f"Generating log dir not implemented for logger: {config.logger}")
45 |
46 | return exp_dir
47 |
48 |
49 | def fix_seeds(random_seed: Optional[int]) -> None:
50 | """Fix seeds for reproducibility.
51 | Ref:
52 | https://pytorch.org/docs/stable/notes/randomness.html
53 |
54 | Args:
55 | random_seed: If None, seeds not set. If int, uses value to seed.
56 | """
57 | if random_seed is not None:
58 | pl.seed_everything(random_seed)
59 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [tool:pytest]
2 | norecursedirs =
3 | .git
4 | dist
5 | build
6 | addopts =
7 | --strict
8 | --doctest-modules
9 | --durations=0
10 |
11 | [coverage:report]
12 | exclude_lines =
13 | pragma: no-cover
14 | pass
15 |
16 | [flake8]
17 | max-line-length = 120
18 | exclude = .tox,*.egg,build,temp
19 | select = E,W,F
20 | doctests = True
21 | verbose = 2
22 | # https://pep8.readthedocs.io/en/latest/intro.html#error-codes
23 | format = pylint
24 | # see: https://www.flake8rules.com/
25 | ignore =
26 | E731 # Do not assign a lambda expression, use a def
27 | W504 # Line break occurred after a binary operator
28 | F401 # Module imported but unused
29 | F841 # Local variable name is assigned to but never used
30 | W605 # Invalid escape sequence 'x'
31 |
32 | # setup.cfg or tox.ini
33 | [check-manifest]
34 | ignore =
35 | *.yml
36 | .github
37 | .github/*
38 |
39 | [metadata]
40 | license_file = LICENSE
41 | description-file = README.md
42 | # long_description = file:README.md
43 | # long_description_content_type = text/markdown
44 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | from setuptools import setup, find_packages
4 |
5 | setup(
6 | name="seg_lapa",
7 | version="0.2.0",
8 | description="Semantic Segmentation on the LaPa dataset using Pytorch Lightning",
9 | author="john doe",
10 | author_email="example@gmail.com",
11 | # REPLACE WITH YOUR OWN GITHUB PROJECT LINK
12 | url="https://github.com/Shreeyak/pytorch-lightning-segmentation-template",
13 | python_requires=">=3.7.7",
14 | install_requires=[
15 | # Install torch first, depending on cuda version
16 | # "torch==1.7.1",
17 | # "torchvision==0.8.2",
18 | "pytorch-lightning==1.1.2",
19 | "gdown==3.12.2",
20 | "albumentations==0.5.2",
21 | "opencv-python==4.4.0.44",
22 | "hydra-core==1.0.4",
23 | "wandb==0.10.12",
24 | "pydantic==1.7.3",
25 | ],
26 | packages=find_packages(),
27 | )
28 |
--------------------------------------------------------------------------------