├── .github └── workflows │ └── run-pre-commit.yml ├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── README.md ├── pyproject.toml ├── requirements-cuda11.txt ├── requirements-dev.txt ├── scripts ├── kill_hanging_processes.sh └── sample-wandb-logging.png ├── seg_lapa ├── __init__.py ├── callbacks │ └── log_media.py ├── config │ ├── callbacks │ │ ├── disabled.yaml │ │ ├── standard.yaml │ │ └── with_lr_monitor.yaml │ ├── dataset │ │ └── lapa.yaml │ ├── load_weights │ │ ├── disabled.yaml │ │ └── pretrain.yaml │ ├── logger │ │ ├── disabled.yaml │ │ ├── wandb.yaml │ │ └── wandb_debug.yaml │ ├── model │ │ └── deeplabv3.yaml │ ├── optimizer │ │ ├── adam.yaml │ │ └── sgd.yaml │ ├── scheduler │ │ ├── cyclic.yaml │ │ ├── disabled.yaml │ │ ├── plateau.yaml │ │ ├── poly.yaml │ │ └── step.yaml │ ├── train.yaml │ └── trainer │ │ ├── debug.yaml │ │ └── standard.yaml ├── config_parse │ ├── callbacks_available.py │ ├── callbacks_conf.py │ ├── conf_utils.py │ ├── dataset_conf.py │ ├── load_weights_conf.py │ ├── logger_conf.py │ ├── model_conf.py │ ├── optimizer_conf.py │ ├── scheduler_conf.py │ ├── train_conf.py │ └── trainer_conf.py ├── datasets │ └── lapa.py ├── loss_func.py ├── metrics.py ├── networks │ └── deeplab │ │ ├── aspp.py │ │ ├── backbone │ │ ├── __init__.py │ │ ├── drn.py │ │ ├── mobilenet.py │ │ ├── resnet.py │ │ └── xception.py │ │ ├── decoder.py │ │ ├── decoder_masks.py │ │ ├── deeplab.py │ │ ├── readme.md │ │ └── sync_batchnorm │ │ ├── __init__.py │ │ ├── batchnorm.py │ │ ├── comm.py │ │ ├── replicate.py │ │ └── unittest.py ├── train.py └── utils │ ├── path_check.py │ ├── segmentation_label2rgb.py │ └── utils.py ├── setup.cfg └── setup.py /.github/workflows/run-pre-commit.yml: -------------------------------------------------------------------------------- 1 | name: run-pre-commit 2 | 3 | on: [pull_request] 4 | 5 | jobs: 6 | run-pre-commit: 7 | runs-on: ubuntu-latest 8 | steps: 9 | - uses: actions/checkout@v2 10 | with: 11 | fetch-depth: 0 12 | - uses: actions/setup-python@v2 13 | with: 14 | python-version: '3.8' 15 | - uses: pre-commit/action@v2.0.0 16 | with: 17 | token: ${{ secrets.GITHUB_TOKEN }} 18 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Custom 2 | lightning_logs/ 3 | wandb/ 4 | checkpoints/ 5 | log-media/ 6 | 7 | .DS_Store 8 | *.tar.gz 9 | 10 | # Byte-compiled / optimized / DLL files 11 | __pycache__/ 12 | *.py[cod] 13 | *$py.class 14 | *.pyc 15 | 16 | # C extensions 17 | *.so 18 | 19 | # Distribution / packaging 20 | .Python 21 | build/ 22 | develop-eggs/ 23 | dist/ 24 | downloads/ 25 | eggs/ 26 | .eggs/ 27 | lib/ 28 | lib64/ 29 | parts/ 30 | sdist/ 31 | var/ 32 | wheels/ 33 | *.egg-info/ 34 | .installed.cfg 35 | *.egg 36 | MANIFEST 37 | 38 | # Lightning /research 39 | test_tube_exp/ 40 | tests/tests_tt_dir/ 41 | tests/save_dir 42 | default/ 43 | data/ 44 | test_tube_logs/ 45 | test_tube_data/ 46 | #datasets/ # dir in code called datasets. 47 | model_weights/ 48 | tests/save_dir 49 | tests/tests_tt_dir/ 50 | processed/ 51 | raw/ 52 | 53 | # PyInstaller 54 | # Usually these files are written by a python script from a template 55 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 56 | *.manifest 57 | *.spec 58 | 59 | # Installer logs 60 | pip-log.txt 61 | pip-delete-this-directory.txt 62 | 63 | # Unit test / coverage reports 64 | htmlcov/ 65 | .tox/ 66 | .coverage 67 | .coverage.* 68 | .cache 69 | nosetests.xml 70 | coverage.xml 71 | *.cover 72 | .hypothesis/ 73 | .pytest_cache/ 74 | 75 | # Translations 76 | *.mo 77 | *.pot 78 | 79 | # Django stuff: 80 | *.log 81 | local_settings.py 82 | db.sqlite3 83 | 84 | # Flask stuff: 85 | instance/ 86 | .webassets-cache 87 | 88 | # Scrapy stuff: 89 | .scrapy 90 | 91 | # Sphinx documentation 92 | docs/_build/ 93 | 94 | # PyBuilder 95 | target/ 96 | 97 | # Jupyter Notebook 98 | .ipynb_checkpoints 99 | 100 | # pyenv 101 | .python-version 102 | 103 | # celery beat schedule file 104 | celerybeat-schedule 105 | 106 | # SageMath parsed files 107 | *.sage.py 108 | 109 | # Environments 110 | .env 111 | .venv 112 | env/ 113 | venv/ 114 | ENV/ 115 | env.bak/ 116 | venv.bak/ 117 | 118 | # Spyder project settings 119 | .spyderproject 120 | .spyproject 121 | 122 | # Rope project settings 123 | .ropeproject 124 | 125 | # mkdocs documentation 126 | /site 127 | 128 | # mypy 129 | .mypy_cache/ 130 | 131 | # IDEs 132 | .idea 133 | .vscode 134 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/ambv/black 3 | rev: stable 4 | hooks: 5 | - id: black 6 | language_version: python3.8 7 | - repo: https://github.com/pre-commit/mirrors-prettier 8 | rev: v2.0.0 9 | hooks: 10 | - id: prettier 11 | files: \.(yaml|md)$ # The filename extensions this formatter edits 12 | - repo: https://github.com/pre-commit/pre-commit-hooks 13 | rev: v3.4.0 14 | hooks: 15 | - id: end-of-file-fixer 16 | - id: requirements-txt-fixer 17 | - id: trailing-whitespace 18 | args: [--markdown-linebreak-ext=md] # preserve Markdown hard linebreaks 19 | - id: check-merge-conflict 20 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | --- 2 | 3 |
4 | 5 | # Sementation Lapa 6 | 7 | [![Paper](http://img.shields.io/badge/paper-arxiv.1001.2234-B31B1B.svg)](https://www.nature.com/articles/nature14539) 8 | [![Conference](http://img.shields.io/badge/AnyConference-year-4b44ce.svg)](https://papers.nips.cc/book/advances-in-neural-information-processing-systems-31-2018) 9 | [![code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) 10 | [![code style: prettier](https://img.shields.io/badge/code_style-prettier-ff69b4.svg?style=flat-square)](https://github.com/prettier/prettier) 11 | 12 |
13 | 14 | ## Description 15 | 16 | This an example project showcasing [Pytorch Lightning](https://www.pytorchlightning.ai/) for training, 17 | [hydra](https://hydra.cc/) for the configuration system and [wandb](https://wandb.ai/) (Weights and Biases) for logging. 18 | 19 | The project tackles a more realistic setting than MNIST by demonstrating segmentation of facial regions on the 20 | [LaPa dataset](https://github.com/JDAI-CV/lapa-dataset) with Deeplabv3+. 21 | 22 | ![logging](scripts/sample-wandb-logging.png) 23 | 24 | 25 | ## Install 26 | 27 | If using Ampere GPUs (RTX 3090), then CUDA 11.1 is required. Some libraries throw error on trying to 28 | install with CUDA 11.0 (unsupported gpu architecture 'compute_86'). This error is solved by moving to 29 | CUDA 11.1. 30 | 31 | ### Install Pytorch as per CUDA (Feb 2021) 32 | 33 | Install Pytorch (`torch` and `torchvision`) before installing the other dependencies. 34 | 35 | #### CUDA 11.1 36 | 37 | `pytorch` and `torchvision` need to be installed from source. Check: 38 | 39 | - https://github.com/pytorch/pytorch#installation 40 | - https://github.com/pytorch/vision 41 | 42 | For torchvision, install system dependencies: 43 | 44 | ```shell script 45 | sudo apt install libturbojpeg libpng-dev libjpeg-dev 46 | ``` 47 | 48 | #### CUDA 11.0 49 | 50 | Systems with Cuda 11.0 (such as those with Ampere GPUs): 51 | 52 | ```shell script 53 | pip install -r requirements-cuda11_0.txt 54 | ``` 55 | 56 | #### CUDA 10.x 57 | 58 | System with CUDA 10.x: 59 | 60 | ```shell script 61 | pip install -r requirements-cuda10.txt 62 | ``` 63 | 64 | ### Install project package and dependencies 65 | 66 | ```shell script 67 | # clone project 68 | git clone git@github.com:Shreeyak/pytorch-lightning-segmentation-lapa.git 69 | 70 | # install project in development mode 71 | cd pytorch-lightning-segmentation-lapa 72 | pip install -e . 73 | 74 | # Setup git precommits 75 | pip install -r requirements-dev.txt 76 | pre-commit install 77 | ``` 78 | 79 | #### Developer dependencies 80 | 81 | This repository uses git pre-commit hooks to auto-format code. 82 | These developer dependencies are in requirements-dev.txt. 83 | The other files describing pre-commit hooks are: `pyproject.toml`, `.pre-commit-config.yaml` 84 | 85 | ## Usage 86 | 87 | Download the Lapa dataset from https://github.com/JDAI-CV/lapa-dataset 88 | It can be placed at `seg_lapa/data`. 89 | 90 | Run training. See the [hydra documentation](https://hydra.cc/docs/advanced/override_grammar/basic) 91 | on how to override the config values from the CLI. 92 | 93 | ```bash 94 | # Run training 95 | python -m seg_lapa.train dataset.data_dir= 96 | 97 | # Run on multiple gpus 98 | python -m seg_lapa.train dataset.data_dir= train.gpus=\"0,1\" 99 | ``` 100 | 101 | ## Using this template for your own project 102 | 103 | To use this template for your own project: 104 | 105 | 1. Search and replace `seg_lapa` with your project name 106 | 2. Edit setup.py with new package name, requirements and other details 107 | 3. Replace the model, dataloaders, loss function, metric with your own! 108 | 4. Update the readme! Add your own links to your paper at the top, add citation info at bottom. 109 | 110 | This template was based on the Pytorch-Lightning 111 | [seed project](https://github.com/PyTorchLightning/deep-learning-project-template). 112 | 113 | ### Callbacks 114 | 115 | The callbacks can be configured from the config files or 116 | [command line overrides](https://hydra.cc/docs/next/advanced/override_grammar/basic/). 117 | To disable a config, simply remove them from the config. More callbacks can easily be added to the config system 118 | as needed. The following callbacks are added as of now: 119 | 120 | - [Early Stopping](https://pytorch-lightning.readthedocs.io/en/latest/generated/pytorch_lightning.callbacks.EarlyStopping.html#pytorch_lightning.callbacks.EarlyStopping) 121 | - [Model Checkpoint](https://pytorch-lightning.readthedocs.io/en/latest/generated/pytorch_lightning.callbacks.ModelCheckpoint.html#pytorch_lightning.callbacks.ModelCheckpoint) 122 | - [Log Media](#logmedia) 123 | 124 | CLI override Examples: 125 | 126 | ```shell script 127 | # Disable the LogMedia callback. 128 | python -m seg_lapa.train "~callbacks.log_media" 129 | 130 | # Set the EarlyStopping callback to wait for 20 epochs before terminating. 131 | python -m seg_lapa.train callbacks.early_stopping.patience=20 132 | ``` 133 | 134 | #### LogMedia 135 | 136 | The LogMedia callback is used to log media, such as images and point clouds, to the logger and to local disk. 137 | It is also used to save the config files for each run. The `LightningModule` adds data to a queue, which is 138 | fetched within the `LogMedia` callback and logged to the logger and/or disk. 139 | 140 | To customize this callback for your application, override or modify the following methods: 141 | 142 | - `LogMedia._get_preds_from_lightningmodule()` 143 | - `LogMedia.__save_media_to_disk()` 144 | - `LogMedia.__save_media_to_logger()` 145 | - The LightningModule should have an attribute of type `LogMediaQueue` called `self.log_media`. 146 | Change the data that you push into the queue in train/val/test steps as per requirement. 147 | 148 | ##### Notes: 149 | 150 | - LogMedia currently supports the Weights and Biases logger only. 151 | - By default, LogMedia only saves the latest samples to disk. To save the results from each step/epoch, pass 152 | `save_latest_only=False`. 153 | 154 | #### EarlyStopping 155 | 156 | This is Lightning's built-in callback. Here's some tips on how to configure early stopping: 157 | 158 | ``` 159 | Args: 160 | monitor: Monitor a key validation metric (eg: mIoU). Monitoring loss is not a good idea as it is an unreliable 161 | indicator of model performance. Two models might have the same loss but different performance 162 | or the loss might start increasing, even though performance does not decrease. 163 | 164 | min_delta: Project-dependent - choose a value for your metric below which you'd consider the improvement 165 | negligible. 166 | Example: For segmentation, I do not care for improvements less than 0.05% IoU in general. 167 | But in kaggle competitions, even 0.01% would matter. 168 | 169 | patience: Patience is the number of val epochs to wait for to see an improvement. It is affected by the 170 | ``check_val_every_n_epoch`` and ``val_check_interval`` params to the PL Trainer. 171 | 172 | Takes experimentation to figure out appropriate patience for your project. Train the model 173 | once without early stopping and see how long it takes to converge on a given dataset. 174 | Choose the number of epochs between when you feel it's started to converge and after you're 175 | sure the model has converged. Reduce the patience if you see the model continues to train for too long. 176 | ``` 177 | 178 | #### ModelCheckpoint 179 | 180 | This is also Lightning's built-in callback to save checkpoints. It can monitor a logged value and save best checkpoints, 181 | save the latest checkpoint or save checkpoints every N steps/epoch. 182 | We save checkpoints in our own logs directory structure, which is different from Lightning's default. 183 | 184 | ### Loggers 185 | 186 | At this point, this project only supports the WandB logger (Weights and Biases). Other loggers can easily be added. 187 | Modify these methods after adding your logger to the config system: 188 | 189 | - `utils.generate_log_dir_path()` - Generates dir structure to save logs 190 | - `LogMedia._log_images_to_wandb()` - If logging media such as images 191 | 192 | ### Notes 193 | 194 | #### Pre-commit workflow 195 | 196 | The project uses pre-commit hooks for `black` and `prettier`, which are auto-format tools for `.py` and 197 | `.yaml | .md` files respectively. After pushing code to a branch, the formatter will automatically run 198 | and apply changes to submitted code. 199 | 200 | #### Absolute imports 201 | 202 | This project is setup as a package. One of the advantages of setting it up as a 203 | package is that it is easy to import modules from anywhere. 204 | To avoid errors with pytorch-lightning, always use absolute imports: 205 | 206 | ```python 207 | from seg_lapa.loss_func import CrossEntropy2D 208 | from seg_lapa import metrics 209 | import seg_lapa.metrics as metrics 210 | ``` 211 | 212 | ### Citation 213 | 214 | ``` 215 | @article{YourName, 216 | title={Your Title}, 217 | author={Your team}, 218 | journal={Location}, 219 | year={Year} 220 | } 221 | ``` 222 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 120 3 | target-version = ["py38"] 4 | exclude = "(.eggs|.git|.hg|.mypy_cache|.nox|.tox|.venv|.svn|_build|buck-out|build|dist)" 5 | -------------------------------------------------------------------------------- /requirements-cuda11.txt: -------------------------------------------------------------------------------- 1 | # This file is for installing pytorch 1.7 for cuda 11, as it requires different syntax 2 | --find-links https://download.pytorch.org/whl/torch_stable.html 3 | torch==1.7.1+cu110 4 | torchvision==0.8.2+cu110 5 | -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | # Developer utilities 2 | pre-commit>=2.9.3 3 | -------------------------------------------------------------------------------- /scripts/kill_hanging_processes.sh: -------------------------------------------------------------------------------- 1 | # When using DDP, processes don't always die cleanly. This snippet will kill all the processes of seg_lapa package. 2 | kill $(ps aux | grep seg_lapa | grep -v grep | awk '{print $2}') 3 | -------------------------------------------------------------------------------- /scripts/sample-wandb-logging.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shreeyak/pytorch-lightning-segmentation-template/851df18c61a6b7354304e3e54b939136378aa142/scripts/sample-wandb-logging.png -------------------------------------------------------------------------------- /seg_lapa/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shreeyak/pytorch-lightning-segmentation-template/851df18c61a6b7354304e3e54b939136378aa142/seg_lapa/__init__.py -------------------------------------------------------------------------------- /seg_lapa/callbacks/log_media.py: -------------------------------------------------------------------------------- 1 | from collections import deque 2 | from dataclasses import dataclass 3 | from enum import Enum 4 | from typing import Any, List, Optional 5 | 6 | import numpy as np 7 | import torch 8 | import wandb 9 | from omegaconf import DictConfig, OmegaConf 10 | from pytorch_lightning import loggers as pl_loggers 11 | from pytorch_lightning.callbacks import Callback 12 | from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn 13 | 14 | # Project-specific imports for logging media to disk 15 | import cv2 16 | import math 17 | from pathlib import Path 18 | from seg_lapa.utils.segmentation_label2rgb import LabelToRGB, Palette 19 | 20 | CONFIG_FNAME = "config.yaml" 21 | 22 | 23 | class Mode(Enum): 24 | TRAIN = "Train" 25 | VAL = "Val" 26 | TEST = "Test" 27 | 28 | 29 | @dataclass 30 | class PredData: 31 | """Holds the data read and converted from the LightningModule's LogMediaQueue""" 32 | 33 | inputs: np.ndarray 34 | labels: np.ndarray 35 | preds: np.ndarray 36 | 37 | 38 | class LogMediaQueue: 39 | """Holds a circular queue for each of train/val/test modes, each of which contain the latest N batches of data""" 40 | 41 | def __init__(self, max_len: int = 3): 42 | if max_len < 1: 43 | raise ValueError(f"Queue must be length >= 1. Given: {max_len}") 44 | 45 | self.max_len = max_len 46 | self.log_media = { 47 | Mode.TRAIN: deque(maxlen=self.max_len), 48 | Mode.VAL: deque(maxlen=self.max_len), 49 | Mode.TEST: deque(maxlen=self.max_len), 50 | } 51 | 52 | def clear(self): 53 | """Clear all queues""" 54 | for mode, queue in self.log_media.items(): 55 | queue.clear() 56 | 57 | @rank_zero_only 58 | def append(self, data: Any, mode: Mode): 59 | """Add a batch of data to a queue. Mode selects train/val/test queue""" 60 | self.log_media[mode].append(data) 61 | 62 | @rank_zero_only 63 | def fetch(self, mode: Mode) -> List[Any]: 64 | """Fetch all the batches available in a queue. Empties the selected queue""" 65 | data_r = [] 66 | while len(self.log_media[mode]) > 0: 67 | data_r.append(self.log_media[mode].popleft()) 68 | 69 | return data_r 70 | 71 | def len(self, mode: Mode) -> int: 72 | """Get the number of elements in a queue""" 73 | return len(self.log_media[mode]) 74 | 75 | 76 | class LogMedia(Callback): 77 | r"""Logs model output images and other media to weights and biases 78 | 79 | This callback required adding an attribute to the LightningModule called ``self.log_media``. This is a circular 80 | queue that holds the latest N batches. This callback fetches the latest data from the queue for logging. 81 | 82 | Usage: 83 | import pytorch_lightning as pl 84 | 85 | class MyModel(pl.LightningModule): 86 | self.log_media: LogMediaQueue = LogMediaQueue(max_len) 87 | 88 | trainer = pl.Trainer(callbacks=[LogMedia()]) 89 | 90 | Args: 91 | cfg (omegaconf.DictConfig, optional): The hydra cfg file given to run. If passed, It will be saved to the 92 | logs dir. 93 | period_epoch (int): If > 0, log every N epochs 94 | period_step (int): If > 0, log every N steps (i.e. batches) 95 | max_samples (int): Max number of data samples to log 96 | save_to_disk (bool): If True, save results to disk 97 | save_latest_only (only): If True, will overwrite prev results at each period. 98 | exp_dir (str or Path): Path to directory where results will be saved 99 | verbose (bool): verbosity mode. Default: ``True``. 100 | """ 101 | 102 | SUPPORTED_LOGGERS = [pl_loggers.WandbLogger] 103 | 104 | def __init__( 105 | self, 106 | cfg: DictConfig = None, 107 | max_samples: int = 10, 108 | period_epoch: int = 1, 109 | period_step: int = 0, 110 | save_to_disk: bool = True, 111 | save_latest_only: bool = True, 112 | exp_dir: Optional[str] = None, 113 | verbose: bool = True, 114 | ): 115 | super().__init__() 116 | self.cfg = cfg 117 | self.max_samples = max_samples 118 | self.period_epoch = period_epoch 119 | self.period_step = period_step 120 | self.save_to_disk = save_to_disk 121 | self.save_latest_only = save_latest_only 122 | self.verbose = verbose 123 | self.valid_logger = False 124 | 125 | try: 126 | self.exp_dir = Path(exp_dir) if self.save_to_disk else None 127 | except TypeError as e: 128 | raise ValueError(f"Invalid exp_dir: {exp_dir}. \n{e}") 129 | 130 | if not OmegaConf.is_config(self.cfg): 131 | raise ValueError(f"Config file not of type {DictConfig}. Given: {type(cfg)}") 132 | 133 | # Project-specific fields 134 | self.class_labels_lapa = { 135 | 0: "background", 136 | 1: "skin", 137 | 2: "eyebrow_left", 138 | 3: "eyebrow_right", 139 | 4: "eye_left", 140 | 5: "eye_right", 141 | 6: "nose", 142 | 7: "lip_upper", 143 | 8: "inner_mouth", 144 | 9: "lip_lower", 145 | 10: "hair", 146 | } 147 | 148 | def setup(self, trainer, pl_module, stage: str): 149 | # This callback requires a ``.log_media`` attribute in LightningModule 150 | req_attr = "log_media" 151 | if not hasattr(pl_module, req_attr): 152 | raise AttributeError( 153 | f"{pl_module.__class__.__name__}.{req_attr} not found. The {LogMedia.__name__} " 154 | f"callback requires the LightningModule to have the {req_attr} attribute." 155 | ) 156 | if not isinstance(pl_module.log_media, LogMediaQueue): 157 | raise AttributeError(f"{pl_module.__class__.__name__}.{req_attr} must be of type {LogMediaQueue.__name__}") 158 | 159 | if self.verbose: 160 | pl_module.print( 161 | f"Initializing Callback {LogMedia.__name__}. " 162 | f"Logging to disk: {self.exp_dir if self.save_to_disk else False}" 163 | ) 164 | 165 | self._create_log_dir() 166 | self.valid_logger = True if self._logger_is_supported(trainer) else False 167 | 168 | # Save copy of config to logger 169 | if self.valid_logger: 170 | if isinstance(trainer.logger, pl_loggers.WandbLogger): 171 | OmegaConf.save(self.cfg, Path(trainer.logger.experiment.dir) / "train.yaml") 172 | trainer.logger.experiment.save("*.yaml") 173 | else: 174 | raise NotImplementedError 175 | 176 | def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): 177 | if self._should_log_step(trainer, batch_idx): 178 | self._log_results(trainer, pl_module, Mode.TRAIN, batch_idx) 179 | 180 | def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): 181 | if self._should_log_step(trainer, batch_idx): 182 | self._log_results(trainer, pl_module, Mode.VAL, batch_idx) 183 | 184 | def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): 185 | if self._should_log_step(trainer, batch_idx): 186 | self._log_results(trainer, pl_module, Mode.TEST, batch_idx) 187 | 188 | def on_train_epoch_end(self, trainer, pl_module, outputs): 189 | if self._should_log_epoch(trainer): 190 | self._log_results(trainer, pl_module, Mode.TRAIN) 191 | 192 | def on_validation_epoch_end(self, trainer, pl_module): 193 | if self._should_log_epoch(trainer): 194 | self._log_results(trainer, pl_module, Mode.VAL) 195 | 196 | def on_test_epoch_end(self, trainer, pl_module): 197 | if self._should_log_epoch(trainer): 198 | self._log_results(trainer, pl_module, Mode.TEST) 199 | 200 | def _log_results(self, trainer, pl_module, mode: Mode, batch_idx: Optional[int] = None): 201 | pred_data = self._get_preds_from_lightningmodule(pl_module, mode) 202 | self._save_media_to_logger(trainer, pred_data, mode) 203 | self._save_media_to_disk(trainer, pred_data, mode, batch_idx) 204 | 205 | def _should_log_epoch(self, trainer): 206 | if trainer.running_sanity_check: 207 | return False 208 | if self.period_epoch < 1 or ((trainer.current_epoch + 1) % self.period_epoch != 0): 209 | return False 210 | return True 211 | 212 | def _should_log_step(self, trainer, batch_idx): 213 | if trainer.running_sanity_check: 214 | return False 215 | if self.period_step < 1 or ((batch_idx + 1) % self.period_step != 0): 216 | return False 217 | return True 218 | 219 | @rank_zero_only 220 | def _create_log_dir(self): 221 | if not self.save_to_disk: 222 | return 223 | 224 | self.exp_dir.mkdir(parents=True, exist_ok=True) 225 | if self.cfg is not None: 226 | fname = self.exp_dir / CONFIG_FNAME 227 | OmegaConf.save(config=self.cfg, f=fname, resolve=True) 228 | 229 | @rank_zero_only 230 | def _logger_is_supported(self, trainer): 231 | """This callback only works with wandb logger""" 232 | for logger_type in self.SUPPORTED_LOGGERS: 233 | if isinstance(trainer.logger, logger_type): 234 | return True 235 | 236 | rank_zero_warn( 237 | f"Unsupported logger: '{trainer.logger}', will not log any media to logger this run." 238 | f" Supported loggers: {[sup_log.__name__ for sup_log in self.SUPPORTED_LOGGERS]}." 239 | ) 240 | return False 241 | 242 | @rank_zero_only 243 | def _get_preds_from_lightningmodule(self, pl_module, mode: Mode) -> Optional[PredData]: 244 | """Fetch latest N batches from the data queue in LightningModule. 245 | Process the tensors as required (example, convert to numpy arrays and scale) 246 | """ 247 | if pl_module.log_media.len(mode) == 0: # Empty queue 248 | rank_zero_warn(f"\nEmpty LogMediaQueue! Mode: {mode}. Epoch: {pl_module.trainer.current_epoch}") 249 | return None 250 | 251 | media_data = pl_module.log_media.fetch(mode) 252 | 253 | inputs = torch.cat([x["inputs"] for x in media_data], dim=0) 254 | labels = torch.cat([x["labels"] for x in media_data], dim=0) 255 | preds = torch.cat([x["preds"] for x in media_data], dim=0) 256 | 257 | # Limit the num of samples and convert to numpy 258 | inputs = inputs[: self.max_samples].detach().cpu().numpy().transpose((0, 2, 3, 1)) 259 | inputs = (inputs * 255).astype(np.uint8) 260 | labels = labels[: self.max_samples].detach().cpu().numpy().astype(np.uint8) 261 | preds = preds[: self.max_samples].detach().cpu().numpy().astype(np.uint8) 262 | 263 | out = PredData(inputs=inputs, labels=labels, preds=preds) 264 | 265 | return out 266 | 267 | @rank_zero_only 268 | def _save_media_to_disk(self, trainer, pred_data: Optional[PredData], mode: Mode, batch_idx: Optional[int] = None): 269 | """For a given mode (train/val/test), save the results to disk""" 270 | if not self.save_to_disk: 271 | return 272 | if pred_data is None: # Empty queue 273 | rank_zero_warn(f"Empty queue! Mode: {mode}") 274 | return 275 | 276 | # Create output filename 277 | if self.save_latest_only: 278 | output_filename = f"results.{mode.name.lower()}.png" 279 | else: 280 | if batch_idx is None: 281 | output_filename = f"results-epoch{trainer.current_epoch}.{mode.name.lower()}.png" 282 | else: 283 | output_filename = f"results-epoch{trainer.current_epoch}-step{batch_idx}.{mode.name.lower()}.png" 284 | output_filename = self.exp_dir / output_filename 285 | 286 | # Get the latest batches from the data queue in LightningModule 287 | inputs, labels, preds = pred_data.inputs, pred_data.labels, pred_data.preds 288 | 289 | # Colorize labels and predictions 290 | label2rgb = LabelToRGB() 291 | labels_rgb = [label2rgb.map_color_palette(lbl, Palette.LAPA) for lbl in labels] 292 | preds_rgb = [label2rgb.map_color_palette(pred, Palette.LAPA) for pred in preds] 293 | inputs_l = [ipt for ipt in inputs] 294 | 295 | # Create collage of results 296 | results_l = [] 297 | # Combine each pair of inp/lbl/pred into singe image 298 | for inp, lbl, pred in zip(inputs_l, labels_rgb, preds_rgb): 299 | res_combined = np.concatenate((inp, lbl, pred), axis=1) 300 | results_l.append(res_combined) 301 | # Create grid from combined imgs 302 | n_imgs = len(results_l) 303 | n_cols = 4 # Fix num of columns 304 | n_rows = int(math.ceil(n_imgs / n_cols)) 305 | img_h, img_w, _ = results_l[0].shape 306 | grid_results = np.zeros((img_h * n_rows, img_w * n_cols, 3), dtype=np.uint8) 307 | for idy in range(n_rows): 308 | for idx in range(n_cols): 309 | grid_results[idy * img_h : (idy + 1) * img_h, idx * img_w : (idx + 1) * img_w, :] = results_l[idx + idy] 310 | 311 | # Save collage 312 | if not cv2.imwrite(str(output_filename), cv2.cvtColor(grid_results, cv2.COLOR_RGB2BGR)): 313 | rank_zero_warn(f"Error in writing image: {output_filename}") 314 | 315 | @rank_zero_only 316 | def _save_media_to_logger(self, trainer, pred_data: Optional[PredData], mode: Mode): 317 | """Log images to wandb at the end of a batch. Steps are common for train/val/test""" 318 | if not self.valid_logger: 319 | return 320 | if pred_data is None: # Empty queue 321 | return 322 | 323 | if isinstance(trainer.logger, pl_loggers.WandbLogger): 324 | self._log_media_to_wandb(trainer, pred_data, mode) 325 | else: 326 | raise NotImplementedError(f"No method to log media to logger: {trainer.logger}") 327 | 328 | def _log_media_to_wandb(self, trainer, pred_data: Optional[PredData], mode: Mode): 329 | # Get the latest batches from the data queue in LightningModule 330 | inputs, labels, preds = pred_data.inputs, pred_data.labels, pred_data.preds 331 | 332 | # Create wandb Image for logging 333 | mask_list = [] 334 | for img, lbl, pred in zip(inputs, labels, preds): 335 | mask_img = wandb.Image( 336 | img, 337 | masks={ 338 | "predictions": {"mask_data": pred, "class_labels": self.class_labels_lapa}, 339 | "groud_truth": {"mask_data": lbl, "class_labels": self.class_labels_lapa}, 340 | }, 341 | ) 342 | mask_list.append(mask_img) 343 | 344 | wandb_log_label = f"{mode.name.title()}/Predictions" 345 | trainer.logger.experiment.log({wandb_log_label: mask_list}, commit=False) 346 | -------------------------------------------------------------------------------- /seg_lapa/config/callbacks/disabled.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | name: disabled 3 | -------------------------------------------------------------------------------- /seg_lapa/config/callbacks/standard.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | name: standard 3 | 4 | early_stopping: 5 | monitor: "Val/mIoU" 6 | min_delta: 0.0005 7 | patience: 10 8 | mode: "max" 9 | verbose: false 10 | 11 | checkpoints: 12 | filename: "best" # PL Default: "{epoch}-{step}". `=` in filename can cause errors when parsing cli overrides. 13 | save_last: true 14 | save_top_k: 1 15 | monitor: "Val/mIoU" 16 | mode: "max" 17 | period: 10 18 | verbose: false 19 | 20 | log_media: 21 | max_samples: 10 22 | period_epoch: 1 23 | period_step: 0 24 | save_to_disk: true 25 | save_latest_only: true 26 | verbose: true 27 | -------------------------------------------------------------------------------- /seg_lapa/config/callbacks/with_lr_monitor.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | name: standard 3 | 4 | early_stopping: 5 | monitor: "Val/mIoU" 6 | min_delta: 0.0005 7 | patience: 10 8 | mode: "max" 9 | verbose: false 10 | 11 | checkpoints: 12 | filename: "best" # PL Default: "{epoch}-{step}". `=` in filename can cause errors when parsing cli overrides. 13 | save_last: true 14 | save_top_k: 1 15 | monitor: "Val/mIoU" 16 | mode: "max" 17 | period: 10 18 | verbose: false 19 | 20 | log_media: 21 | max_samples: 10 22 | period_epoch: 1 23 | period_step: 0 24 | save_to_disk: true 25 | save_latest_only: true 26 | verbose: true 27 | 28 | lr_monitor: 29 | # set to epoch or step to log lr of all optimizers at the same interval, set to None to log at individual 30 | # interval according to the interval key of each scheduler. 31 | logging_interval: "step" 32 | log_momentum: false # If true, log the momentum values of the optimizer, if present 33 | -------------------------------------------------------------------------------- /seg_lapa/config/dataset/lapa.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | name: lapa 3 | data_dir: "/home/gaini/lightning/LaPa" 4 | batch_size: 16 5 | num_workers: 4 6 | resize_h: 256 7 | resize_w: 256 8 | -------------------------------------------------------------------------------- /seg_lapa/config/load_weights/disabled.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | name: disabled 3 | -------------------------------------------------------------------------------- /seg_lapa/config/load_weights/pretrain.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | name: load_weights 3 | 4 | # Use this to load pre-train weights (such as model pretrained on pascal dataset) 5 | path: null 6 | -------------------------------------------------------------------------------- /seg_lapa/config/logger/disabled.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | name: disabled 3 | -------------------------------------------------------------------------------- /seg_lapa/config/logger/wandb.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | name: wandb 3 | entity: cleargrasp2 4 | project: segmentation 5 | run_name: null # If None, will generate random name 6 | run_id: null # Pass run_id to resume logging to that run 7 | -------------------------------------------------------------------------------- /seg_lapa/config/logger/wandb_debug.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | name: wandb 3 | entity: cleargrasp2 4 | project: shrek_debug 5 | run_name: null # If None, will generate random name 6 | run_id: null # Pass run_id to resume logging to that run 7 | -------------------------------------------------------------------------------- /seg_lapa/config/model/deeplabv3.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | name: deeplabv3 3 | backbone: "drn" # Possible values for backbone: ['xception', 'resnet', 'drn'] 4 | num_classes: 11 5 | output_stride: 8 # Not used with 'drn' backbone 6 | sync_bn: False # This enables custom batchnorm code that syncs across gpus. 7 | #enable_amp: False # Should always be false, since PL takes case of 16bit training 8 | -------------------------------------------------------------------------------- /seg_lapa/config/optimizer/adam.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | name: adam 3 | lr: 1e-3 4 | weight_decay: 0.0005 5 | -------------------------------------------------------------------------------- /seg_lapa/config/optimizer/sgd.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | name: sgd 3 | lr: 1e-6 4 | momentum: 0.9 5 | weight_decay: 0.0 6 | nesterov: False 7 | -------------------------------------------------------------------------------- /seg_lapa/config/scheduler/cyclic.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | # Sets the learning rate of each parameter group according to cyclical learning rate policy (CLR). 3 | # The policy cycles the learning rate between two boundaries with a constant frequency 4 | name: cyclic 5 | 6 | # Initial learning rate which is the lower boundary in the cycle for each parameter group. 7 | base_lr: 0.0001 8 | 9 | # Upper learning rate boundaries in the cycle for each parameter group. 10 | max_lr: 0.01 11 | 12 | # Number of training iterations in the increasing half of a cycle. Should be set to 2-10 times the number of iterations in an epoch 13 | step_size_up: 20 14 | 15 | # Number of training iterations in the decreasing half of a cycle. If step_size_down is None, it is set to step_size_up. 16 | step_size_down: null 17 | -------------------------------------------------------------------------------- /seg_lapa/config/scheduler/disabled.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | # Disabled learning rate scheduler. 3 | name: disabled 4 | -------------------------------------------------------------------------------- /seg_lapa/config/scheduler/plateau.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | # Reduce learning rate when a metric has no improvement for a ‘patience’ number of epochs. 3 | # Models often benefit from reducing the learning rate by a factor of 2-10 once learning stagnates. 4 | name: plateau 5 | 6 | # Factor by which the learning rate will be reduced. new_lr = lr * factor. 7 | factor: 0.8 8 | 9 | # Number of epochs with no improvement after which learning rate will be reduced. 10 | patience: 25 11 | 12 | # A lower bound on the learning rate of all param groups or each group respectively 13 | min_lr: 1e-7 14 | 15 | # One of min, max. In min mode, lr change when metric stops decreasing; in max mode when metric stops increasing. 16 | mode: "max" 17 | 18 | # Threshold for measuring the new optimum, to only focus on significant changes. 19 | threshold: 1e-4 20 | 21 | # Number of epochs to wait before resuming normal operation after lr has been reduced. 22 | cooldown: 0 23 | 24 | # Minimal decay applied to lr. If the difference between new and old lr is smaller than eps, 25 | # the update is ignored 26 | eps: 1e-8 27 | 28 | # If True, prints a message to stdout for each update. 29 | verbose: False 30 | -------------------------------------------------------------------------------- /seg_lapa/config/scheduler/poly.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | # The base lr decrease smoothly upto max_iter as a polynomial func 3 | # lr = base_lr * (1 - current_iter / max_iter) ^ pow_factor 4 | name: poly 5 | 6 | # Maximum number of iterations for which training will run 7 | max_iter: 500 8 | 9 | # The polinomial factor per formula above 10 | pow_factor: 0.9 11 | -------------------------------------------------------------------------------- /seg_lapa/config/scheduler/step.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | # Decays the learning rate of each parameter group by gamma every step_size epochs. 3 | name: step 4 | 5 | # Period of learning rate decay. 6 | step_size: 30 7 | 8 | # Multiplicative factor of learning rate decay. 9 | gamma: 0.1 10 | -------------------------------------------------------------------------------- /seg_lapa/config/train.yaml: -------------------------------------------------------------------------------- 1 | random_seed: 0 # If None, seeds not set. If int, uses value to seed. 2 | logs_root_dir: "./" # Where to save logs and checkpoints. 3 | 4 | defaults: 5 | - dataset: lapa 6 | - model: deeplabv3 7 | - optimizer: adam 8 | - trainer: standard 9 | - scheduler: disabled 10 | - logger: wandb 11 | - callbacks: standard 12 | - load_weights: disabled 13 | 14 | # To disable any .log files 15 | - hydra/job_logging: disabled 16 | - hydra/hydra_logging: disabled 17 | 18 | hydra: 19 | output_subdir: null # Disable saving of config files. We'll do that ourselves. 20 | run: 21 | dir: . # Set working dir to current directory 22 | -------------------------------------------------------------------------------- /seg_lapa/config/trainer/debug.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | name: trainer 3 | 4 | gpus: 1 5 | accelerator: "ddp" 6 | precision: 16 7 | 8 | max_epochs: 3 9 | resume_from_checkpoint: null 10 | log_every_n_steps: 1 11 | 12 | # For deterministic runs 13 | benchmark: False 14 | deterministic: True 15 | 16 | # Limit batches for debugging 17 | fast_dev_run: False 18 | overfit_batches: 0.0 19 | limit_train_batches: 2 20 | limit_val_batches: 2 21 | limit_test_batches: 2 22 | -------------------------------------------------------------------------------- /seg_lapa/config/trainer/standard.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | name: trainer 3 | 4 | gpus: 1 # Denotes the number of gpus to use. Set CUDA_VISIBLE_DEVICES env var to control which gpus are used. 5 | accelerator: "ddp" 6 | precision: 16 7 | 8 | max_epochs: 100 9 | resume_from_checkpoint: null 10 | log_every_n_steps: 1 11 | 12 | # For deterministic runs 13 | benchmark: False # If true enables cudnn.benchmark. 14 | deterministic: True # If true enables cudnn.deterministic. 15 | 16 | # Limit batches for debugging 17 | fast_dev_run: False # If True, runs 1 batch of train, val and test to find any bugs (ie: a sort of unit test). 18 | overfit_batches: 0.0 # Overfit on subset of training data. Use the same as val/test set. (floats = percent, int = num_batches). Warn: 1 will be cast to 1.0. 19 | limit_train_batches: 1.0 # How much of training dataset to check (floats = percent, int = num_batches). Warn: 1 will be cast to 1.0. 20 | limit_val_batches: 1.0 21 | limit_test_batches: 1.0 22 | -------------------------------------------------------------------------------- /seg_lapa/config_parse/callbacks_available.py: -------------------------------------------------------------------------------- 1 | """Dataclasses just to initialize and return Callback objects""" 2 | from typing import Optional 3 | 4 | from omegaconf import DictConfig 5 | from pydantic.dataclasses import dataclass 6 | from pytorch_lightning.callbacks import Callback, EarlyStopping, LearningRateMonitor, ModelCheckpoint 7 | 8 | from seg_lapa.callbacks.log_media import LogMedia 9 | from seg_lapa.config_parse.conf_utils import asdict_filtered 10 | 11 | 12 | @dataclass(frozen=True) 13 | class EarlyStopConf: 14 | monitor: str 15 | min_delta: float 16 | patience: int 17 | mode: str 18 | verbose: bool = False 19 | 20 | def get_callback(self) -> Callback: 21 | return EarlyStopping(**asdict_filtered(self)) 22 | 23 | 24 | @dataclass(frozen=True) 25 | class CheckpointConf: 26 | filename: Optional[str] 27 | monitor: Optional[str] 28 | mode: str 29 | save_last: Optional[bool] 30 | period: int 31 | save_top_k: Optional[int] 32 | verbose: bool = False 33 | 34 | def get_callback(self, logs_dir) -> Callback: 35 | return ModelCheckpoint(dirpath=logs_dir, **asdict_filtered(self)) 36 | 37 | 38 | @dataclass(frozen=True) 39 | class LogMediaConf: 40 | max_samples: int 41 | period_epoch: int 42 | period_step: int 43 | save_to_disk: bool 44 | save_latest_only: bool 45 | verbose: bool = False 46 | 47 | def get_callback(self, exp_dir: str, cfg: DictConfig) -> Callback: 48 | return LogMedia(exp_dir=exp_dir, cfg=cfg, **asdict_filtered(self)) 49 | 50 | 51 | @dataclass(frozen=True) 52 | class LearningRateMonitorConf: 53 | logging_interval: Optional[str] 54 | log_momentum: bool = False 55 | 56 | def get_callback(self) -> Callback: 57 | return LearningRateMonitor(**asdict_filtered(self)) 58 | -------------------------------------------------------------------------------- /seg_lapa/config_parse/callbacks_conf.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Dict, List, Optional 3 | 4 | from omegaconf import DictConfig 5 | from pydantic.dataclasses import dataclass 6 | from pytorch_lightning.callbacks import Callback 7 | 8 | from seg_lapa.config_parse.callbacks_available import ( 9 | CheckpointConf, 10 | EarlyStopConf, 11 | LearningRateMonitorConf, 12 | LogMediaConf, 13 | ) 14 | from seg_lapa.config_parse.conf_utils import validate_config_group_generic 15 | 16 | 17 | # The Callbacks config cannot be directly initialized because it contains sub-entries for each callback, each 18 | # of which describes a separate class. 19 | # For each of the callbacks, we define a dataclass and use them to init the list of callbacks 20 | 21 | 22 | @dataclass(frozen=True) 23 | class CallbacksConf(ABC): 24 | name: str 25 | 26 | @abstractmethod 27 | def get_callbacks_list(self, *args) -> List: 28 | return [] 29 | 30 | 31 | @dataclass(frozen=True) 32 | class DisabledCallbacksConf(CallbacksConf): 33 | def get_callbacks_list(self, *args) -> List: 34 | return [] 35 | 36 | 37 | @dataclass(frozen=True) 38 | class StandardCallbacksConf(CallbacksConf): 39 | """Get a dictionary of all the callbacks.""" 40 | 41 | early_stopping: Optional[Dict] = None 42 | checkpoints: Optional[Dict] = None 43 | log_media: Optional[Dict] = None 44 | lr_monitor: Optional[Dict] = None 45 | 46 | def get_callbacks_list(self, exp_dir: str, cfg: DictConfig) -> List[Callback]: 47 | """Get all available callbacks and the Callback Objects in list 48 | If a callback's entry is not present in the config file, it'll not be output in the list 49 | """ 50 | callbacks_list = [] 51 | if self.early_stopping is not None: 52 | early_stop = EarlyStopConf(**self.early_stopping).get_callback() 53 | callbacks_list.append(early_stop) 54 | 55 | if self.checkpoints is not None: 56 | checkpoint = CheckpointConf(**self.checkpoints).get_callback(exp_dir) 57 | callbacks_list.append(checkpoint) 58 | 59 | if self.log_media is not None: 60 | log_media = LogMediaConf(**self.log_media).get_callback(exp_dir, cfg) 61 | callbacks_list.append(log_media) 62 | 63 | if self.lr_monitor is not None: 64 | lr_monitor = LearningRateMonitorConf(**self.lr_monitor).get_callback() 65 | callbacks_list.append(lr_monitor) 66 | 67 | return callbacks_list 68 | 69 | 70 | valid_names = { 71 | "disabled": DisabledCallbacksConf, 72 | "standard": StandardCallbacksConf, 73 | } 74 | 75 | 76 | def validate_config_group(cfg_subgroup: DictConfig) -> CallbacksConf: 77 | validated_dataclass = validate_config_group_generic( 78 | cfg_subgroup, dataclass_dict=valid_names, config_category="callback" 79 | ) 80 | return validated_dataclass 81 | -------------------------------------------------------------------------------- /seg_lapa/config_parse/conf_utils.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | from typing import Dict, Optional, Sequence 3 | 4 | from omegaconf import DictConfig, OmegaConf 5 | 6 | 7 | def asdict_filtered(obj, remove_keys: Optional[Sequence[str]] = None) -> Dict: 8 | """Returns the attributes of a dataclass in the form of a dict, with unwanted attributes removed. 9 | Each config group has the term 'name', which is helpful in identifying the node that was chosen 10 | in the config group (Eg. config group = optimizers, nodes = adam, sgd). 11 | However, the 'name' parameter is not required for initializing any dataclasses. Hence it needs to be removed. 12 | 13 | Args: 14 | obj: The dataclass whose atrributes will be converted to dict 15 | remove_keys: The keys to remove from the dict. The default is ['name']. 16 | """ 17 | if not dataclasses.is_dataclass(obj): 18 | raise ValueError(f"Not a dataclass/dataclass instance") 19 | 20 | if remove_keys is None: 21 | remove_keys = ["name"] 22 | 23 | # Clean the arguments 24 | args = dataclasses.asdict(obj) 25 | for key in remove_keys: 26 | if key in args: 27 | args.pop(key) 28 | 29 | return args 30 | 31 | 32 | def validate_config_group_generic(cfg_group: DictConfig, dataclass_dict: Dict, config_category: str = "option"): 33 | """Use a hydra config group to initialize a pydantic dataclass. Initializing it validates the data. 34 | Each of our config groups has a name parameter, which is used to map to valid dataclasses for validation. 35 | 36 | Pydantic will force the parameters to the desired datatype and will throw errors if the config 37 | cannot be cast to the dataclass members. 38 | 39 | Args: 40 | cfg_group: The config group extracted from the hydra config. 41 | dataclass_dict: A dict containing the mapping from 'name' entry in config files to matching 42 | pydantic dataclasses for validation. 43 | config_category: For pretty print statements. Configure the name of the config group when throwing error. 44 | 45 | Raises: 46 | ValueError: If the name parameter does not match to any of the valid options 47 | """ 48 | if not OmegaConf.is_config(cfg_group): 49 | raise ValueError(f"Given config not an OmegaConf config. Got: {type(cfg_group)}") 50 | 51 | # Get the "name" entry in config 52 | name = cfg_group.name 53 | if name is None: 54 | raise KeyError( 55 | f"The given config does not contain a 'name' entry. Cannot map to a dataclass.\n" 56 | f" Config:\n {OmegaConf.to_yaml(cfg_group)}" 57 | ) 58 | 59 | # Convert hydra config to dict - This dict contains the arguments to init dataclass 60 | cfg_asdict = OmegaConf.to_container(cfg_group, resolve=True) 61 | 62 | # Get the dataclass to init from the mapping. Init the dataclass using hydra config 63 | try: 64 | dataclass_obj = dataclass_dict[name](**cfg_asdict) 65 | except KeyError: 66 | raise ValueError( 67 | f"Invalid Config: '{cfg_group.name}' is not a valid {config_category}. " 68 | f"Valid Options: {list(dataclass_dict.keys())}" 69 | ) 70 | 71 | return dataclass_obj 72 | -------------------------------------------------------------------------------- /seg_lapa/config_parse/dataset_conf.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | import pytorch_lightning as pl 4 | from omegaconf import DictConfig 5 | from pydantic.dataclasses import dataclass 6 | 7 | from seg_lapa.config_parse.conf_utils import asdict_filtered, validate_config_group_generic 8 | from seg_lapa.datasets.lapa import LaPaDataModule 9 | 10 | 11 | @dataclass(frozen=True) 12 | class DatasetConf(ABC): 13 | name: str 14 | 15 | @abstractmethod 16 | def get_datamodule(self) -> pl.LightningDataModule: 17 | pass 18 | 19 | 20 | @dataclass(frozen=True) 21 | class LapaConf(DatasetConf): 22 | data_dir: str 23 | batch_size: int 24 | num_workers: int 25 | resize_h: int 26 | resize_w: int 27 | 28 | def get_datamodule(self) -> LaPaDataModule: 29 | return LaPaDataModule(**asdict_filtered(self)) 30 | 31 | 32 | valid_names = {"lapa": LapaConf} 33 | 34 | 35 | def validate_config_group(cfg_subgroup: DictConfig) -> DatasetConf: 36 | validated_dataclass = validate_config_group_generic( 37 | cfg_subgroup, dataclass_dict=valid_names, config_category="dataset" 38 | ) 39 | return validated_dataclass 40 | -------------------------------------------------------------------------------- /seg_lapa/config_parse/load_weights_conf.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | from typing import Optional 3 | 4 | from omegaconf import DictConfig 5 | from pydantic.dataclasses import dataclass 6 | 7 | from seg_lapa.config_parse.conf_utils import validate_config_group_generic 8 | 9 | 10 | @dataclass(frozen=True) 11 | class LoadWeightsConf(ABC): 12 | name: str 13 | path: Optional[str] = None 14 | 15 | 16 | valid_names = { 17 | "disabled": LoadWeightsConf, 18 | "load_weights": LoadWeightsConf, 19 | } 20 | 21 | 22 | def validate_config_group(cfg_subgroup: DictConfig) -> LoadWeightsConf: 23 | validated_dataclass = validate_config_group_generic( 24 | cfg_subgroup, dataclass_dict=valid_names, config_category="load_weights" 25 | ) 26 | return validated_dataclass 27 | -------------------------------------------------------------------------------- /seg_lapa/config_parse/logger_conf.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from pathlib import Path 3 | from typing import Optional 4 | 5 | import wandb 6 | from omegaconf import DictConfig, OmegaConf 7 | from pydantic.dataclasses import dataclass 8 | from pytorch_lightning import loggers as pl_loggers 9 | 10 | from seg_lapa.config_parse.conf_utils import asdict_filtered, validate_config_group_generic 11 | 12 | 13 | @dataclass 14 | class LoggerConf(ABC): 15 | name: str 16 | 17 | @abstractmethod 18 | def get_logger(self, *args): 19 | pass 20 | 21 | @abstractmethod 22 | def get_run_id(self, *args): 23 | """Loggers such as WandB generate a unique run id that can be used to resume runs""" 24 | pass 25 | 26 | 27 | @dataclass 28 | class DisabledLoggerConf(LoggerConf): 29 | @staticmethod 30 | def get_logger(*args): 31 | return False 32 | 33 | @staticmethod 34 | def get_run_id(): 35 | return None 36 | 37 | 38 | @dataclass 39 | class WandbConf(LoggerConf): 40 | """Weights and Biases. Ref: wandb.com""" 41 | 42 | entity: str 43 | project: str 44 | run_name: Optional[str] 45 | run_id: Optional[str] = None # Pass run_id to resume logging to that run. 46 | 47 | def get_logger(self, cfg: DictConfig, save_dir: Path) -> pl_loggers.WandbLogger: 48 | """Returns the Weights and Biases (wandb) logger object (really an wandb Run object) 49 | The run object corresponds to a single execution of the script and is returned from `wandb.init()`. 50 | 51 | Args: 52 | run_id: Unique run id. If run id exists, will continue logging to that run. 53 | cfg: The entire config got from hydra, for purposes of logging the config of each run in wandb. 54 | save_dir: Root dir to save wandb log files 55 | 56 | Returns: 57 | wandb.wandb_sdk.wandb_run.Run: wandb run object. Can be used for logging. 58 | """ 59 | # Some argument names to wandb are different from the attribute names of the class. 60 | # Pop the offending attributes before passing to init func. 61 | args_dict = asdict_filtered(self) 62 | run_name = args_dict.pop("run_name") 63 | run_id = args_dict.pop("run_id") 64 | 65 | # If `self.save_hyperparams()` is called in LightningModule, it will save the cfg passed as argument 66 | # cfg_dict = OmegaConf.to_container(cfg, resolve=True) 67 | 68 | wb_logger = pl_loggers.WandbLogger(name=run_name, id=run_id, save_dir=str(save_dir), **args_dict) 69 | 70 | return wb_logger 71 | 72 | def get_run_id(self): 73 | """If a run_id has been provided by the user, resume logging to that run. 74 | Otherwise a random run-id will be generated 75 | """ 76 | if self.run_id is None: 77 | self.run_id = wandb.util.generate_id() 78 | 79 | return self.run_id 80 | 81 | 82 | valid_names = { 83 | "wandb": WandbConf, 84 | "disabled": DisabledLoggerConf, 85 | } 86 | 87 | 88 | def validate_config_group(cfg_subgroup: DictConfig) -> LoggerConf: 89 | validated_dataclass = validate_config_group_generic( 90 | cfg_subgroup, dataclass_dict=valid_names, config_category="logger" 91 | ) 92 | return validated_dataclass 93 | -------------------------------------------------------------------------------- /seg_lapa/config_parse/model_conf.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | import torch 4 | from omegaconf import DictConfig 5 | from pydantic.dataclasses import dataclass 6 | 7 | from seg_lapa.config_parse.conf_utils import asdict_filtered, validate_config_group_generic 8 | from seg_lapa.networks.deeplab.deeplab import DeepLab 9 | 10 | 11 | @dataclass(frozen=True) 12 | class ModelConf(ABC): 13 | name: str 14 | num_classes: int 15 | 16 | @abstractmethod 17 | def get_model(self): 18 | pass 19 | 20 | 21 | @dataclass(frozen=True) 22 | class Deeplabv3Conf(ModelConf): 23 | backbone: str 24 | output_stride: int 25 | sync_bn: bool # Can use PL to sync batchnorm. This enables custom batchnorm code. 26 | enable_amp: bool = False # Should always be false, since PL takes case of 16bit training 27 | 28 | def get_model(self) -> torch.nn.Module: 29 | return DeepLab(**asdict_filtered(self)) 30 | 31 | 32 | valid_names = { 33 | "deeplabv3": Deeplabv3Conf, 34 | } 35 | 36 | 37 | def validate_config_group(cfg_subgroup: DictConfig) -> ModelConf: 38 | validated_dataclass = validate_config_group_generic( 39 | cfg_subgroup, dataclass_dict=valid_names, config_category="model" 40 | ) 41 | return validated_dataclass 42 | -------------------------------------------------------------------------------- /seg_lapa/config_parse/optimizer_conf.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | import torch 4 | from omegaconf import DictConfig 5 | from pydantic.dataclasses import dataclass 6 | 7 | from seg_lapa.config_parse.conf_utils import asdict_filtered, validate_config_group_generic 8 | 9 | 10 | @dataclass(frozen=True) 11 | class OptimConf(ABC): 12 | name: str 13 | 14 | @abstractmethod 15 | def get_optimizer(self, model_params) -> torch.optim.Optimizer: 16 | pass 17 | 18 | 19 | @dataclass(frozen=True) 20 | class AdamConf(OptimConf): 21 | lr: float 22 | weight_decay: float 23 | 24 | def get_optimizer(self, model_params) -> torch.optim.Optimizer: 25 | return torch.optim.Adam(params=model_params, **asdict_filtered(self)) 26 | 27 | 28 | @dataclass(frozen=True) 29 | class SgdConf(OptimConf): 30 | lr: float 31 | momentum: float 32 | weight_decay: float 33 | nesterov: bool 34 | 35 | def get_optimizer(self, model_params) -> torch.optim.Optimizer: 36 | return torch.optim.SGD(params=model_params, **asdict_filtered(self)) 37 | 38 | 39 | valid_names = {"adam": AdamConf, "sgd": SgdConf} 40 | 41 | 42 | def validate_config_group(cfg_subgroup: DictConfig) -> OptimConf: 43 | validated_dataclass = validate_config_group_generic( 44 | cfg_subgroup, dataclass_dict=valid_names, config_category="optimizer" 45 | ) 46 | return validated_dataclass 47 | -------------------------------------------------------------------------------- /seg_lapa/config_parse/scheduler_conf.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import List, Optional, Union 3 | 4 | import torch 5 | from omegaconf import DictConfig 6 | from pydantic.dataclasses import dataclass 7 | from torch.optim.optimizer import Optimizer 8 | 9 | from seg_lapa.config_parse.conf_utils import asdict_filtered, validate_config_group_generic 10 | 11 | 12 | @dataclass(frozen=True) 13 | class SchedulerConf(ABC): 14 | name: str 15 | 16 | @abstractmethod 17 | def get_scheduler(self, optimizer: Optimizer) -> Optional[torch.optim.lr_scheduler._LRScheduler]: 18 | pass 19 | 20 | 21 | @dataclass(frozen=True) 22 | class DisabledConfig(SchedulerConf): 23 | def get_scheduler(self, optimizer: Optimizer) -> None: 24 | return None 25 | 26 | 27 | @dataclass(frozen=True) 28 | class CyclicConfig(SchedulerConf): 29 | base_lr: float 30 | max_lr: float 31 | step_size_up: int 32 | step_size_down: Optional[int] 33 | 34 | def get_scheduler(self, optimizer: Optimizer) -> torch.optim.lr_scheduler.CyclicLR: 35 | return torch.optim.lr_scheduler.CyclicLR(optimizer, cycle_momentum=False, **asdict_filtered(self)) 36 | 37 | 38 | @dataclass(frozen=True) 39 | class PolyConfig(SchedulerConf): 40 | max_iter: int 41 | pow_factor: float 42 | 43 | def get_scheduler(self, optimizer: Optimizer) -> torch.optim.lr_scheduler.LambdaLR: 44 | max_iter = float(self.max_iter) 45 | pow_factor = float(self.pow_factor) 46 | 47 | def poly_schedule(n_iter: int) -> float: 48 | return (1 - n_iter / max_iter) ** pow_factor 49 | 50 | return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=poly_schedule) 51 | 52 | 53 | @dataclass(frozen=True) 54 | class StepConfig(SchedulerConf): 55 | step_size: int 56 | gamma: float 57 | 58 | def get_scheduler(self, optimizer: Optimizer) -> torch.optim.lr_scheduler.StepLR: 59 | return torch.optim.lr_scheduler.StepLR(optimizer, **asdict_filtered(self)) 60 | 61 | 62 | @dataclass(frozen=True) 63 | class PlateauConfig(SchedulerConf): 64 | factor: float 65 | patience: int 66 | min_lr: Union[float, List[float]] 67 | mode: str 68 | threshold: float 69 | cooldown: int 70 | eps: float 71 | verbose: bool 72 | 73 | def get_scheduler(self, optimizer: Optimizer) -> torch.optim.lr_scheduler.ReduceLROnPlateau: 74 | return torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, **asdict_filtered(self)) 75 | 76 | 77 | valid_names = { 78 | "disabled": DisabledConfig, 79 | "cyclic": CyclicConfig, 80 | "plateau": PlateauConfig, 81 | "poly": PolyConfig, 82 | "step": StepConfig, 83 | } 84 | 85 | 86 | def validate_config_group(cfg_subgroup: DictConfig) -> SchedulerConf: 87 | validated_dataclass = validate_config_group_generic( 88 | cfg_subgroup, dataclass_dict=valid_names, config_category="scheduler" 89 | ) 90 | return validated_dataclass 91 | -------------------------------------------------------------------------------- /seg_lapa/config_parse/train_conf.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from omegaconf import DictConfig 4 | from pydantic.dataclasses import dataclass 5 | 6 | from seg_lapa.config_parse import ( 7 | callbacks_conf, 8 | dataset_conf, 9 | load_weights_conf, 10 | logger_conf, 11 | model_conf, 12 | optimizer_conf, 13 | scheduler_conf, 14 | trainer_conf, 15 | ) 16 | from seg_lapa.config_parse.callbacks_conf import CallbacksConf 17 | from seg_lapa.config_parse.dataset_conf import DatasetConf 18 | from seg_lapa.config_parse.load_weights_conf import LoadWeightsConf 19 | from seg_lapa.config_parse.logger_conf import LoggerConf 20 | from seg_lapa.config_parse.model_conf import ModelConf 21 | from seg_lapa.config_parse.optimizer_conf import OptimConf 22 | from seg_lapa.config_parse.scheduler_conf import SchedulerConf 23 | from seg_lapa.config_parse.trainer_conf import TrainerConf 24 | 25 | 26 | @dataclass(frozen=True) 27 | class TrainConf: 28 | random_seed: Optional[int] 29 | logs_root_dir: str 30 | dataset: DatasetConf 31 | optimizer: OptimConf 32 | model: ModelConf 33 | trainer: TrainerConf 34 | scheduler: SchedulerConf 35 | logger: LoggerConf 36 | callbacks: CallbacksConf 37 | load_weights: LoadWeightsConf 38 | 39 | 40 | class ParseConfig: 41 | @classmethod 42 | def parse_config(cls, cfg: DictConfig) -> TrainConf: 43 | """Parses the config file read from hydra to populate the TrainConfig dataclass""" 44 | config = TrainConf( 45 | random_seed=cfg.random_seed, 46 | logs_root_dir=cfg.logs_root_dir, 47 | dataset=dataset_conf.validate_config_group(cfg.dataset), 48 | model=model_conf.validate_config_group(cfg.model), 49 | optimizer=optimizer_conf.validate_config_group(cfg.optimizer), 50 | trainer=trainer_conf.validate_config_group(cfg.trainer), 51 | scheduler=scheduler_conf.validate_config_group(cfg.scheduler), 52 | logger=logger_conf.validate_config_group(cfg.logger), 53 | callbacks=callbacks_conf.validate_config_group(cfg.callbacks), 54 | load_weights=load_weights_conf.validate_config_group(cfg.load_weights), 55 | ) 56 | 57 | return config 58 | -------------------------------------------------------------------------------- /seg_lapa/config_parse/trainer_conf.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import List, Optional 3 | 4 | import pytorch_lightning as pl 5 | from omegaconf import DictConfig 6 | from pydantic.dataclasses import dataclass 7 | from pytorch_lightning.callbacks import Callback 8 | from pytorch_lightning.loggers.base import LightningLoggerBase 9 | 10 | from seg_lapa.config_parse.conf_utils import asdict_filtered, validate_config_group_generic 11 | 12 | 13 | @dataclass(frozen=True) 14 | class TrainerConf(ABC): 15 | name: str 16 | 17 | @abstractmethod 18 | def get_trainer( 19 | self, pl_logger: LightningLoggerBase, callbacks: List[Callback], default_root_dir: str 20 | ) -> pl.Trainer: 21 | pass 22 | 23 | 24 | @dataclass(frozen=True) 25 | class TrainerConfig(TrainerConf): 26 | gpus: int 27 | accelerator: Optional[str] 28 | precision: int 29 | max_epochs: int 30 | resume_from_checkpoint: Optional[str] 31 | log_every_n_steps: int 32 | 33 | benchmark: bool = False 34 | deterministic: bool = False 35 | fast_dev_run: bool = False 36 | overfit_batches: float = 0.0 37 | limit_train_batches: float = 1.0 38 | limit_val_batches: float = 1.0 39 | limit_test_batches: float = 1.0 40 | 41 | def get_trainer( 42 | self, pl_logger: LightningLoggerBase, callbacks: List[Callback], default_root_dir: str 43 | ) -> pl.Trainer: 44 | trainer = pl.Trainer( 45 | logger=pl_logger, 46 | callbacks=callbacks, 47 | default_root_dir=default_root_dir, 48 | **asdict_filtered(self), 49 | ) 50 | return trainer 51 | 52 | 53 | valid_names = {"trainer": TrainerConfig} 54 | 55 | 56 | def validate_config_group(cfg_subgroup: DictConfig) -> TrainerConf: 57 | validated_dataclass = validate_config_group_generic( 58 | cfg_subgroup, dataclass_dict=valid_names, config_category="trainer" 59 | ) 60 | return validated_dataclass 61 | -------------------------------------------------------------------------------- /seg_lapa/datasets/lapa.py: -------------------------------------------------------------------------------- 1 | import enum 2 | import random 3 | from pathlib import Path 4 | from typing import List, Optional, Tuple, Union 5 | 6 | import albumentations as A 7 | import cv2 8 | import numpy as np 9 | import pytorch_lightning as pl 10 | import torch 11 | import torchvision 12 | from torch.utils.data import DataLoader, Dataset 13 | 14 | from seg_lapa.utils.path_check import get_path, PathType 15 | 16 | 17 | class DatasetSplit(enum.Enum): 18 | TRAIN = 0 19 | VAL = 1 20 | TEST = 2 21 | 22 | 23 | class LapaDataset(Dataset): 24 | """The Landmark guided face Parsing dataset (LaPa) 25 | Contains pixel-level annotations for face parsing. 26 | 27 | References: 28 | https://github.com/JDAI-CV/lapa-dataset 29 | """ 30 | 31 | @enum.unique 32 | class LapaClassId(enum.IntEnum): 33 | # Mapping of the classes within the lapa dataset 34 | BACKGROUND = 0 35 | SKIN = 1 36 | EYEBROW_LEFT = 2 37 | EYEBROW_RIGHT = 3 38 | EYE_LEFT = 4 39 | EYE_RIGHT = 5 40 | NOSE = 6 41 | LIP_UPPER = 7 42 | INNER_MOUTH = 8 43 | LIP_LOWER = 9 44 | HAIR = 10 45 | 46 | SUBDIR_IMAGES = "images" 47 | SUBDIR_LABELS = "labels" 48 | 49 | SUBDIR_SPLIT = {DatasetSplit.TRAIN: "train", DatasetSplit.VAL: "val", DatasetSplit.TEST: "test"} 50 | 51 | def __init__( 52 | self, 53 | root_dir: Union[str, Path], 54 | data_split: DatasetSplit, 55 | image_ext: Tuple[str] = ("*.jpg",), 56 | label_ext: Tuple[str] = ("*.png",), 57 | augmentations: Optional[A.Compose] = None, 58 | ): 59 | super().__init__() 60 | self.augmentations = augmentations 61 | self.image_ext = image_ext # The file extensions of input images to search for in input dir 62 | self.label_ext = label_ext # The file extensions of labels to search for in label dir 63 | self.root_dir = self._check_dir(root_dir) 64 | 65 | # Get subdirs for images and labels 66 | self.images_dir = self._check_dir(self.root_dir / self.SUBDIR_SPLIT[data_split] / self.SUBDIR_IMAGES) 67 | self.labels_dir = self._check_dir(self.root_dir / self.SUBDIR_SPLIT[data_split] / self.SUBDIR_LABELS) 68 | 69 | # Create list of filenames 70 | self._datalist_input = [] # Variable containing list of all input images filenames in dataset 71 | self._datalist_label = [] # Variable containing list of all ground truth filenames in dataset 72 | self._create_lists_filenames() 73 | 74 | def __len__(self): 75 | return len(self._datalist_input) 76 | 77 | def __getitem__(self, index): 78 | # Read input rgb imgs 79 | image_path = self._datalist_input[index] 80 | img = self._read_image(image_path) 81 | 82 | # Read ground truth labels 83 | label_path = self._datalist_label[index] 84 | label = self._read_label(label_path) 85 | 86 | # Apply image augmentations 87 | if self.augmentations is not None: 88 | augmented = self.augmentations(image=img, mask=label) 89 | img = augmented["image"] 90 | label = augmented["mask"] 91 | 92 | # Convert to Tensor. RGB images are normally numpy uint8 array with shape (H, W, 3). 93 | # RGB tensors should be (3, H, W) with dtype float32 in range [0, 1] (may change with normalization applied) 94 | img_tensor = torchvision.transforms.ToTensor()(img) 95 | label_tensor = torch.from_numpy(label) 96 | 97 | # TODO: Return dict 98 | # data = { 99 | # 'image': img_tensor, 100 | # 'label': label_tensor 101 | # } 102 | # return data 103 | 104 | return img_tensor, label_tensor.long() 105 | 106 | @staticmethod 107 | def _check_dir(dir_path: Union[str, Path]) -> Path: 108 | return get_path(dir_path, must_exist=True, path_type=PathType.DIR) 109 | 110 | @staticmethod 111 | def _read_label(label_path: Path) -> np.ndarray: 112 | mask = cv2.imread(str(label_path), cv2.IMREAD_GRAYSCALE | cv2.IMREAD_ANYDEPTH | cv2.IMREAD_ANYCOLOR) 113 | 114 | if len(mask.shape) != 2: 115 | raise RuntimeError(f"The shape of label must be (H, W). Got: {mask.shape}") 116 | 117 | return mask.astype(np.int32) 118 | 119 | @staticmethod 120 | def _read_image(image_path: Path) -> np.ndarray: 121 | mask = cv2.imread(str(image_path), cv2.IMREAD_COLOR) 122 | mask = cv2.cvtColor(mask, cv2.COLOR_BGR2RGB) 123 | 124 | if len(mask.shape) != 3: 125 | raise RuntimeError(f"The shape of image must be (H, W, C). Got: {mask.shape}") 126 | 127 | return mask 128 | 129 | def _create_lists_filenames(self): 130 | """Creates a list of filenames of images and labels in dataset""" 131 | self._datalist_input = self._get_matching_files_in_dir(self.images_dir, self.image_ext) 132 | self._datalist_label = self._get_matching_files_in_dir(self.labels_dir, self.label_ext) 133 | 134 | num_images = len(self._datalist_input) 135 | num_labels = len(self._datalist_label) 136 | if num_images != num_labels: 137 | raise ValueError( 138 | f"The number of images ({num_images}) and labels ({num_labels}) do not match." 139 | f"\n Images dir: {self.images_dir}\n Labels dir:{self.labels_dir}" 140 | ) 141 | 142 | def _get_matching_files_in_dir(self, data_dir: Union[str, Path], wildcard_patterns: Tuple[str]) -> List[Path]: 143 | """Get filenames within a dir that match a set of wildcard patterns 144 | Will not search within subdirectories. 145 | 146 | Args: 147 | data_dir: Directory to search within 148 | wildcard_patterns: Tuple of wildcard patterns matching required filenames. Eg: ('*.rgb.png', '*.rgb.jpg') 149 | 150 | Returns: 151 | list[Path]: List of paths to files found 152 | """ 153 | data_dir = self._check_dir(data_dir) 154 | 155 | list_matching_files = [] 156 | for ext in wildcard_patterns: 157 | list_matching_files += sorted(data_dir.glob(ext)) 158 | 159 | if len(list_matching_files) == 0: 160 | raise ValueError( 161 | "No matching files found in given directory." 162 | f"\n Directory: {data_dir}\n Search patterns: {wildcard_patterns}" 163 | ) 164 | 165 | return list_matching_files 166 | 167 | 168 | class LaPaDataModule(pl.LightningDataModule): 169 | def __init__(self, data_dir: str, batch_size: int, num_workers: int, resize_h: int, resize_w: int): 170 | super().__init__() 171 | self.batch_size = batch_size 172 | self.num_workers = num_workers 173 | self.data_dir = get_path(data_dir, must_exist=True, path_type=PathType.DIR) 174 | self.resize_h = resize_h 175 | self.resize_w = resize_w 176 | 177 | self.lapa_train = None 178 | self.lapa_val = None 179 | self.lapa_test = None 180 | 181 | def prepare_data(self): 182 | """download dataset, tokenize, etc""" 183 | """ 184 | Downloading original data from author's google drive link: 185 | >>> import gdown 186 | >>> url = "https://drive.google.com/uc?export=download&id=1EtyCtiQZt2Y5qrb-0YxRxaVLpVcgCOQV" 187 | >>> output = "lapa-downloaded.tar.gz" 188 | >>> gdown.download(url, output, quiet=False, proxy=False) 189 | """ 190 | pass 191 | 192 | def setup(self, stage=None): 193 | # count number of classes, perform train/val/test splits, apply transforms, etc 194 | augs_train = self.get_augs_train() 195 | augs_test = self.get_augs_test() 196 | 197 | self.lapa_train = LapaDataset(root_dir=self.data_dir, data_split=DatasetSplit.TRAIN, augmentations=augs_train) 198 | self.lapa_test = LapaDataset(root_dir=self.data_dir, data_split=DatasetSplit.VAL, augmentations=augs_test) 199 | self.lapa_val = LapaDataset(root_dir=self.data_dir, data_split=DatasetSplit.TEST, augmentations=augs_test) 200 | 201 | def train_dataloader(self): 202 | train_loader = DataLoader( 203 | self.lapa_train, 204 | batch_size=self.batch_size, 205 | num_workers=self.num_workers, 206 | pin_memory=True, 207 | drop_last=True, 208 | worker_init_fn=self._dataloader_worker_init, 209 | ) 210 | return train_loader 211 | 212 | def val_dataloader(self): 213 | val_loader = DataLoader( 214 | self.lapa_val, 215 | batch_size=self.batch_size, 216 | num_workers=self.num_workers, 217 | pin_memory=True, 218 | drop_last=False, 219 | worker_init_fn=self._dataloader_worker_init, 220 | ) 221 | return val_loader 222 | 223 | def test_dataloader(self): 224 | test_loader = DataLoader( 225 | self.lapa_test, 226 | batch_size=self.batch_size, 227 | num_workers=self.num_workers, 228 | pin_memory=True, 229 | drop_last=False, 230 | worker_init_fn=self._dataloader_worker_init, 231 | ) 232 | return test_loader 233 | 234 | def get_augs_test(self): 235 | augs_test = A.Compose( 236 | [ 237 | # Geometric Augs 238 | A.SmallestMaxSize(max_size=self.resize_h, interpolation=0, p=1.0), 239 | A.CenterCrop(height=self.resize_h, width=self.resize_w, p=1.0), 240 | ] 241 | ) 242 | return augs_test 243 | 244 | def get_augs_train(self): 245 | augs_train = A.Compose( 246 | [ 247 | # Geometric Augs 248 | A.SmallestMaxSize(max_size=self.resize_h, interpolation=0, p=1.0), 249 | A.CenterCrop(height=self.resize_h, width=self.resize_w, p=1.0), 250 | ] 251 | ) 252 | return augs_train 253 | 254 | @staticmethod 255 | def _dataloader_worker_init(*args): 256 | """Seeds the workers within the Dataloader""" 257 | worker_seed = torch.initial_seed() % 2 ** 32 258 | np.random.seed(worker_seed) 259 | random.seed(worker_seed) 260 | -------------------------------------------------------------------------------- /seg_lapa/loss_func.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class CrossEntropy2D(nn.CrossEntropyLoss): 7 | """Use the torch.nn.CrossEntropyLoss loss to calculate mean loss per image or per pixel. 8 | Deeplab models calculate mean loss per image. 9 | 10 | Inputs: 11 | - inputs (Tensor): Raw output of network (without softmax applied). 12 | - targets (Tensor): Ground truth, containing a int class index in the range :math:`[0, C-1]` as the 13 | `target` for each pixel. 14 | 15 | Shape and dtype: 16 | - Input: [B, C, H, W], where C = num_classes. dtype=float16/32 17 | - Target: [B, H, W]. dtype=int32/64 18 | - Output: scalar. 19 | 20 | Args: 21 | loss_per_image (bool, optional): 22 | ignore_index (int, optional): Defaults to 255. The pixels with this labels do not contribute to loss 23 | 24 | References: 25 | https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html#crossentropyloss 26 | """ 27 | 28 | def __init__(self, loss_per_image: bool = True, ignore_index: int = 255): 29 | if loss_per_image: 30 | reduction = "sum" 31 | else: 32 | reduction = "mean" 33 | 34 | super().__init__(reduction=reduction, ignore_index=ignore_index) 35 | self.loss_per_image = loss_per_image 36 | 37 | def forward(self, inputs: torch.Tensor, targets: torch.Tensor): 38 | loss = F.cross_entropy( 39 | inputs, targets, weight=self.weight, ignore_index=self.ignore_index, reduction=self.reduction 40 | ) 41 | 42 | if self.loss_per_image: 43 | batch_size = inputs.shape[0] 44 | loss = loss / batch_size 45 | 46 | return loss 47 | -------------------------------------------------------------------------------- /seg_lapa/metrics.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import torch 4 | from pytorch_lightning import metrics 5 | 6 | 7 | @dataclass 8 | class IouMetric: 9 | iou_per_class: torch.Tensor 10 | miou: torch.Tensor # Mean IoU across all classes 11 | accuracy: torch.Tensor 12 | precision: torch.Tensor 13 | recall: torch.Tensor 14 | specificity: torch.Tensor 15 | 16 | 17 | class Iou(metrics.Metric): 18 | def __init__(self, num_classes: int = 11, normalize: bool = False): 19 | """Calculates the metrics iou, true positives and false positives/negatives for multi-class classification 20 | problems such as semantic segmentation. 21 | Because this is an expensive operation, we do not compute or sync the values per step. 22 | 23 | Forward accepts: 24 | 25 | - ``prediction`` (float or long tensor): ``(N, H, W)`` 26 | - ``label`` (long tensor): ``(N, H, W)`` 27 | 28 | Note: 29 | This metric produces a dataclass as output, so it can not be directly logged. 30 | """ 31 | super().__init__(compute_on_step=False, dist_sync_on_step=False) 32 | 33 | self.num_classes = num_classes 34 | # Metric normally calculated on batch. If true, final metrics (tp, fn, etc) will reflect average values per image 35 | self.normalize = normalize 36 | 37 | self.acc_confusion_matrix = None # The accumulated confusion matrix 38 | self.count_samples = None # Number of samples seen 39 | # Use `add_state()` for attr to track their state and synchronize state across processes 40 | self.add_state( 41 | "acc_confusion_matrix", default=torch.zeros((self.num_classes, self.num_classes)), dist_reduce_fx="sum" 42 | ) 43 | self.add_state("count_samples", default=torch.tensor(0), dist_reduce_fx="sum") 44 | 45 | def update(self, prediction: torch.Tensor, label: torch.Tensor): 46 | """Calculate the confusion matrix and accumulate it 47 | 48 | Args: 49 | prediction: Predictions of network (after argmax). Shape: [N, H, W] 50 | label: Ground truth. Each pixel has int value denoting class. Shape: [N, H, W] 51 | """ 52 | assert prediction.shape == label.shape 53 | assert len(label.shape) == 3 54 | 55 | num_images = int(label.shape[0]) 56 | 57 | label = label.view(-1).long() 58 | prediction = prediction.view(-1).long() 59 | 60 | # Calculate confusion matrix 61 | conf_mat = torch.bincount(self.num_classes * label + prediction, minlength=self.num_classes ** 2) 62 | conf_mat = conf_mat.reshape((self.num_classes, self.num_classes)) 63 | 64 | # Accumulate values 65 | self.acc_confusion_matrix += conf_mat 66 | self.count_samples += num_images 67 | 68 | def compute(self): 69 | """Compute the final IoU and other metrics across all samples seen""" 70 | # Normalize the accumulated confusion matrix, if needed 71 | conf_mat = self.acc_confusion_matrix 72 | if self.normalize: 73 | conf_mat = conf_mat / self.count_samples # Get average per image 74 | 75 | # Calculate True Positive (TP), False Positive (FP), False Negative (FN) and True Negative (TN) 76 | tp = conf_mat.diagonal() 77 | fn = conf_mat.sum(dim=0) - tp 78 | fp = conf_mat.sum(dim=1) - tp 79 | total_px = conf_mat.sum() 80 | tn = total_px - (tp + fn + fp) 81 | 82 | # Calculate Intersection over Union (IoU) 83 | eps = 1e-6 84 | iou_per_class = (tp + eps) / (fn + fp + tp + eps) # Use epsilon to avoid zero division errors 85 | iou_per_class[torch.isnan(iou_per_class)] = 0 86 | mean_iou = iou_per_class.mean() 87 | 88 | # Accuracy (what proportion of predictions — both Positive and Negative — were correctly classified?) 89 | accuracy = (tp + tn) / (tp + fp + fn + tn) 90 | accuracy[torch.isnan(accuracy)] = 0 91 | 92 | # Precision (what proportion of predicted Positives is truly Positive?) 93 | precision = tp / (tp + fp) 94 | precision[torch.isnan(precision)] = 0 95 | 96 | # Recall or True Positive Rate (what proportion of actual Positives is correctly classified?) 97 | recall = tp / (tp + fn) 98 | recall[torch.isnan(recall)] = 0 99 | 100 | # Specificity or true negative rate 101 | specificity = tn / (tn + fp) 102 | specificity[torch.isnan(specificity)] = 0 103 | 104 | data_r = IouMetric( 105 | iou_per_class=iou_per_class, 106 | miou=mean_iou, 107 | accuracy=accuracy, 108 | precision=precision, 109 | recall=recall, 110 | specificity=specificity, 111 | ) 112 | 113 | return data_r 114 | 115 | 116 | # Tests 117 | def test_iou(): 118 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 119 | print(f"Using device: {device}") 120 | 121 | # Create Fake label and prediction 122 | label = torch.zeros((1, 4, 4), dtype=torch.float32, device=device) 123 | pred = torch.zeros((1, 4, 4), dtype=torch.float32, device=device) 124 | label[:, :3, :3] = 1 125 | pred[:, -3:, -3:] = 1 126 | expected_iou = torch.tensor([2.0 / 12, 4.0 / 14], device=device) 127 | 128 | print("Testing IoU metrics", end="") 129 | iou_train = Iou(num_classes=2) 130 | iou_train.to(device) 131 | iou_train(pred, label) 132 | metrics_r = iou_train.compute() 133 | iou_per_class = metrics_r.iou_per_class 134 | assert (iou_per_class - expected_iou).sum() < 1e-6 135 | print(" passed") 136 | 137 | 138 | if __name__ == "__main__": 139 | # Run tests 140 | print("Running tests on metrics module...\n") 141 | test_iou() 142 | -------------------------------------------------------------------------------- /seg_lapa/networks/deeplab/aspp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 6 | 7 | 8 | class _ASPPModule(nn.Module): 9 | def __init__(self, inplanes, planes, kernel_size, padding, dilation, BatchNorm): 10 | super(_ASPPModule, self).__init__() 11 | self.atrous_conv = nn.Conv2d( 12 | inplanes, planes, kernel_size=kernel_size, stride=1, padding=padding, dilation=dilation, bias=False 13 | ) 14 | self.bn = BatchNorm(planes) 15 | self.relu = nn.ReLU() 16 | 17 | self._init_weight() 18 | 19 | def forward(self, x): 20 | x = self.atrous_conv(x) 21 | x = self.bn(x) 22 | 23 | return self.relu(x) 24 | 25 | def _init_weight(self): 26 | for m in self.modules(): 27 | if isinstance(m, nn.Conv2d): 28 | torch.nn.init.kaiming_normal_(m.weight) 29 | elif isinstance(m, SynchronizedBatchNorm2d): 30 | m.weight.data.fill_(1) 31 | m.bias.data.zero_() 32 | elif isinstance(m, nn.BatchNorm2d): 33 | m.weight.data.fill_(1) 34 | m.bias.data.zero_() 35 | 36 | 37 | class ASPP(nn.Module): 38 | def __init__(self, backbone, output_stride, BatchNorm): 39 | super(ASPP, self).__init__() 40 | if backbone == "drn": 41 | inplanes = 512 42 | elif backbone == "mobilenet": 43 | inplanes = 320 44 | else: 45 | inplanes = 2048 46 | if output_stride == 16: 47 | dilations = [1, 6, 12, 18] 48 | elif output_stride == 8: 49 | dilations = [1, 12, 24, 36] 50 | else: 51 | raise NotImplementedError 52 | 53 | self.aspp1 = _ASPPModule(inplanes, 256, 1, padding=0, dilation=dilations[0], BatchNorm=BatchNorm) 54 | self.aspp2 = _ASPPModule(inplanes, 256, 3, padding=dilations[1], dilation=dilations[1], BatchNorm=BatchNorm) 55 | self.aspp3 = _ASPPModule(inplanes, 256, 3, padding=dilations[2], dilation=dilations[2], BatchNorm=BatchNorm) 56 | self.aspp4 = _ASPPModule(inplanes, 256, 3, padding=dilations[3], dilation=dilations[3], BatchNorm=BatchNorm) 57 | 58 | self.global_avg_pool = nn.Sequential( 59 | nn.AdaptiveAvgPool2d((1, 1)), nn.Conv2d(inplanes, 256, 1, stride=1, bias=False), BatchNorm(256), nn.ReLU() 60 | ) 61 | self.conv1 = nn.Conv2d(1280, 256, 1, bias=False) 62 | self.bn1 = BatchNorm(256) 63 | self.relu = nn.ReLU() 64 | self.dropout = nn.Dropout(0.5) 65 | self._init_weight() 66 | 67 | def forward(self, x): 68 | x1 = self.aspp1(x) 69 | x2 = self.aspp2(x) 70 | x3 = self.aspp3(x) 71 | x4 = self.aspp4(x) 72 | x5 = self.global_avg_pool(x) 73 | x5 = F.interpolate(x5, size=x4.size()[2:], mode="bilinear", align_corners=True) 74 | x = torch.cat((x1, x2, x3, x4, x5), dim=1) 75 | 76 | x = self.conv1(x) 77 | x = self.bn1(x) 78 | x = self.relu(x) 79 | 80 | return self.dropout(x) 81 | 82 | def _init_weight(self): 83 | for m in self.modules(): 84 | if isinstance(m, nn.Conv2d): 85 | torch.nn.init.kaiming_normal_(m.weight) 86 | elif isinstance(m, SynchronizedBatchNorm2d): 87 | m.weight.data.fill_(1) 88 | m.bias.data.zero_() 89 | elif isinstance(m, nn.BatchNorm2d): 90 | m.weight.data.fill_(1) 91 | m.bias.data.zero_() 92 | 93 | 94 | def build_aspp(backbone, output_stride, BatchNorm): 95 | return ASPP(backbone, output_stride, BatchNorm) 96 | -------------------------------------------------------------------------------- /seg_lapa/networks/deeplab/backbone/__init__.py: -------------------------------------------------------------------------------- 1 | from . import drn, mobilenet, resnet, xception 2 | 3 | 4 | def build_backbone(backbone, output_stride, BatchNorm): 5 | if backbone == "resnet": 6 | return resnet.ResNet101(output_stride, BatchNorm) 7 | elif backbone == "xception": 8 | return xception.AlignedXception(output_stride, BatchNorm) 9 | elif backbone == "drn": 10 | return drn.drn_d_54(BatchNorm) 11 | elif backbone == "mobilenet": 12 | return mobilenet.MobileNetV2(output_stride, BatchNorm) 13 | else: 14 | raise NotImplementedError 15 | -------------------------------------------------------------------------------- /seg_lapa/networks/deeplab/backbone/drn.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch.nn as nn 4 | import torch.utils.model_zoo as model_zoo 5 | 6 | from ..sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 7 | 8 | webroot = "http://dl.yf.io/drn/" 9 | 10 | model_urls = { 11 | "resnet50": "https://download.pytorch.org/models/resnet50-19c8e357.pth", 12 | "drn-c-26": webroot + "drn_c_26-ddedf421.pth", 13 | "drn-c-42": webroot + "drn_c_42-9d336e8c.pth", 14 | "drn-c-58": webroot + "drn_c_58-0a53a92c.pth", 15 | "drn-d-22": webroot + "drn_d_22-4bd2f8ea.pth", 16 | "drn-d-38": webroot + "drn_d_38-eebb45f0.pth", 17 | "drn-d-54": webroot + "drn_d_54-0e0534ff.pth", 18 | "drn-d-105": webroot + "drn_d_105-12b40979.pth", 19 | } 20 | 21 | 22 | def conv3x3(in_planes, out_planes, stride=1, padding=1, dilation=1): 23 | return nn.Conv2d( 24 | in_planes, out_planes, kernel_size=3, stride=stride, padding=padding, bias=False, dilation=dilation 25 | ) 26 | 27 | 28 | class BasicBlock(nn.Module): 29 | expansion = 1 30 | 31 | def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=(1, 1), residual=True, BatchNorm=None): 32 | super(BasicBlock, self).__init__() 33 | self.conv1 = conv3x3(inplanes, planes, stride, padding=dilation[0], dilation=dilation[0]) 34 | self.bn1 = BatchNorm(planes) 35 | self.relu = nn.ReLU(inplace=True) 36 | self.conv2 = conv3x3(planes, planes, padding=dilation[1], dilation=dilation[1]) 37 | self.bn2 = BatchNorm(planes) 38 | self.downsample = downsample 39 | self.stride = stride 40 | self.residual = residual 41 | 42 | def forward(self, x): 43 | residual = x 44 | 45 | out = self.conv1(x) 46 | out = self.bn1(out) 47 | out = self.relu(out) 48 | 49 | out = self.conv2(out) 50 | out = self.bn2(out) 51 | 52 | if self.downsample is not None: 53 | residual = self.downsample(x) 54 | if self.residual: 55 | out += residual 56 | out = self.relu(out) 57 | 58 | return out 59 | 60 | 61 | class Bottleneck(nn.Module): 62 | expansion = 4 63 | 64 | def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=(1, 1), residual=True, BatchNorm=None): 65 | super(Bottleneck, self).__init__() 66 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 67 | self.bn1 = BatchNorm(planes) 68 | self.conv2 = nn.Conv2d( 69 | planes, planes, kernel_size=3, stride=stride, padding=dilation[1], bias=False, dilation=dilation[1] 70 | ) 71 | self.bn2 = BatchNorm(planes) 72 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 73 | self.bn3 = BatchNorm(planes * 4) 74 | self.relu = nn.ReLU(inplace=True) 75 | self.downsample = downsample 76 | self.stride = stride 77 | 78 | def forward(self, x): 79 | residual = x 80 | 81 | out = self.conv1(x) 82 | out = self.bn1(out) 83 | out = self.relu(out) 84 | 85 | out = self.conv2(out) 86 | out = self.bn2(out) 87 | out = self.relu(out) 88 | 89 | out = self.conv3(out) 90 | out = self.bn3(out) 91 | 92 | if self.downsample is not None: 93 | residual = self.downsample(x) 94 | 95 | out += residual 96 | out = self.relu(out) 97 | 98 | return out 99 | 100 | 101 | class DRN(nn.Module): 102 | def __init__(self, block, layers, arch="D", channels=(16, 32, 64, 128, 256, 512, 512, 512), BatchNorm=None): 103 | super(DRN, self).__init__() 104 | self.inplanes = channels[0] 105 | self.out_dim = channels[-1] 106 | self.arch = arch 107 | 108 | if arch == "C": 109 | self.conv1 = nn.Conv2d(3, channels[0], kernel_size=7, stride=1, padding=3, bias=False) 110 | self.bn1 = BatchNorm(channels[0]) 111 | self.relu = nn.ReLU(inplace=True) 112 | 113 | self.layer1 = self._make_layer(BasicBlock, channels[0], layers[0], stride=1, BatchNorm=BatchNorm) 114 | self.layer2 = self._make_layer(BasicBlock, channels[1], layers[1], stride=2, BatchNorm=BatchNorm) 115 | 116 | elif arch == "D": 117 | self.layer0 = nn.Sequential( 118 | nn.Conv2d(3, channels[0], kernel_size=7, stride=1, padding=3, bias=False), 119 | BatchNorm(channels[0]), 120 | nn.ReLU(inplace=True), 121 | ) 122 | 123 | self.layer1 = self._make_conv_layers(channels[0], layers[0], stride=1, BatchNorm=BatchNorm) 124 | self.layer2 = self._make_conv_layers(channels[1], layers[1], stride=2, BatchNorm=BatchNorm) 125 | 126 | self.layer3 = self._make_layer(block, channels[2], layers[2], stride=2, BatchNorm=BatchNorm) 127 | self.layer4 = self._make_layer(block, channels[3], layers[3], stride=2, BatchNorm=BatchNorm) 128 | self.layer5 = self._make_layer(block, channels[4], layers[4], dilation=2, new_level=False, BatchNorm=BatchNorm) 129 | self.layer6 = ( 130 | None 131 | if layers[5] == 0 132 | else self._make_layer(block, channels[5], layers[5], dilation=4, new_level=False, BatchNorm=BatchNorm) 133 | ) 134 | 135 | if arch == "C": 136 | self.layer7 = ( 137 | None 138 | if layers[6] == 0 139 | else self._make_layer( 140 | BasicBlock, channels[6], layers[6], dilation=2, new_level=False, residual=False, BatchNorm=BatchNorm 141 | ) 142 | ) 143 | self.layer8 = ( 144 | None 145 | if layers[7] == 0 146 | else self._make_layer( 147 | BasicBlock, channels[7], layers[7], dilation=1, new_level=False, residual=False, BatchNorm=BatchNorm 148 | ) 149 | ) 150 | elif arch == "D": 151 | self.layer7 = ( 152 | None 153 | if layers[6] == 0 154 | else self._make_conv_layers(channels[6], layers[6], dilation=2, BatchNorm=BatchNorm) 155 | ) 156 | self.layer8 = ( 157 | None 158 | if layers[7] == 0 159 | else self._make_conv_layers(channels[7], layers[7], dilation=1, BatchNorm=BatchNorm) 160 | ) 161 | 162 | self._init_weight() 163 | 164 | def _init_weight(self): 165 | for m in self.modules(): 166 | if isinstance(m, nn.Conv2d): 167 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 168 | m.weight.data.normal_(0, math.sqrt(2.0 / n)) 169 | elif isinstance(m, SynchronizedBatchNorm2d): 170 | m.weight.data.fill_(1) 171 | m.bias.data.zero_() 172 | elif isinstance(m, nn.BatchNorm2d): 173 | m.weight.data.fill_(1) 174 | m.bias.data.zero_() 175 | 176 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, new_level=True, residual=True, BatchNorm=None): 177 | assert dilation == 1 or dilation % 2 == 0 178 | downsample = None 179 | if stride != 1 or self.inplanes != planes * block.expansion: 180 | downsample = nn.Sequential( 181 | nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), 182 | BatchNorm(planes * block.expansion), 183 | ) 184 | 185 | layers = list() 186 | layers.append( 187 | block( 188 | self.inplanes, 189 | planes, 190 | stride, 191 | downsample, 192 | dilation=(1, 1) if dilation == 1 else (dilation // 2 if new_level else dilation, dilation), 193 | residual=residual, 194 | BatchNorm=BatchNorm, 195 | ) 196 | ) 197 | self.inplanes = planes * block.expansion 198 | for i in range(1, blocks): 199 | layers.append( 200 | block(self.inplanes, planes, residual=residual, dilation=(dilation, dilation), BatchNorm=BatchNorm) 201 | ) 202 | 203 | return nn.Sequential(*layers) 204 | 205 | def _make_conv_layers(self, channels, convs, stride=1, dilation=1, BatchNorm=None): 206 | modules = [] 207 | for i in range(convs): 208 | modules.extend( 209 | [ 210 | nn.Conv2d( 211 | self.inplanes, 212 | channels, 213 | kernel_size=3, 214 | stride=stride if i == 0 else 1, 215 | padding=dilation, 216 | bias=False, 217 | dilation=dilation, 218 | ), 219 | BatchNorm(channels), 220 | nn.ReLU(inplace=True), 221 | ] 222 | ) 223 | self.inplanes = channels 224 | return nn.Sequential(*modules) 225 | 226 | def forward(self, x): 227 | if self.arch == "C": 228 | x = self.conv1(x) 229 | x = self.bn1(x) 230 | x = self.relu(x) 231 | elif self.arch == "D": 232 | x = self.layer0(x) 233 | 234 | x = self.layer1(x) 235 | x = self.layer2(x) 236 | 237 | x = self.layer3(x) 238 | low_level_feat = x 239 | 240 | x = self.layer4(x) 241 | x = self.layer5(x) 242 | 243 | if self.layer6 is not None: 244 | x = self.layer6(x) 245 | 246 | if self.layer7 is not None: 247 | x = self.layer7(x) 248 | 249 | if self.layer8 is not None: 250 | x = self.layer8(x) 251 | 252 | return x, low_level_feat 253 | 254 | 255 | class DRN_A(nn.Module): 256 | def __init__(self, block, layers, BatchNorm=None): 257 | self.inplanes = 64 258 | super(DRN_A, self).__init__() 259 | self.out_dim = 512 * block.expansion 260 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 261 | self.bn1 = BatchNorm(64) 262 | self.relu = nn.ReLU(inplace=True) 263 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 264 | self.layer1 = self._make_layer(block, 64, layers[0], BatchNorm=BatchNorm) 265 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, BatchNorm=BatchNorm) 266 | self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2, BatchNorm=BatchNorm) 267 | self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4, BatchNorm=BatchNorm) 268 | 269 | self._init_weight() 270 | 271 | def _init_weight(self): 272 | for m in self.modules(): 273 | if isinstance(m, nn.Conv2d): 274 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 275 | m.weight.data.normal_(0, math.sqrt(2.0 / n)) 276 | elif isinstance(m, SynchronizedBatchNorm2d): 277 | m.weight.data.fill_(1) 278 | m.bias.data.zero_() 279 | elif isinstance(m, nn.BatchNorm2d): 280 | m.weight.data.fill_(1) 281 | m.bias.data.zero_() 282 | 283 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None): 284 | downsample = None 285 | if stride != 1 or self.inplanes != planes * block.expansion: 286 | downsample = nn.Sequential( 287 | nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), 288 | BatchNorm(planes * block.expansion), 289 | ) 290 | 291 | layers = [] 292 | layers.append(block(self.inplanes, planes, stride, downsample, BatchNorm=BatchNorm)) 293 | self.inplanes = planes * block.expansion 294 | for i in range(1, blocks): 295 | layers.append( 296 | block( 297 | self.inplanes, 298 | planes, 299 | dilation=( 300 | dilation, 301 | dilation, 302 | ), 303 | BatchNorm=BatchNorm, 304 | ) 305 | ) 306 | 307 | return nn.Sequential(*layers) 308 | 309 | def forward(self, x): 310 | x = self.conv1(x) 311 | x = self.bn1(x) 312 | x = self.relu(x) 313 | x = self.maxpool(x) 314 | 315 | x = self.layer1(x) 316 | x = self.layer2(x) 317 | x = self.layer3(x) 318 | x = self.layer4(x) 319 | 320 | return x 321 | 322 | 323 | def drn_a_50(BatchNorm, pretrained=True): 324 | model = DRN_A(Bottleneck, [3, 4, 6, 3], BatchNorm=BatchNorm) 325 | if pretrained: 326 | model.load_state_dict(model_zoo.load_url(model_urls["resnet50"])) 327 | return model 328 | 329 | 330 | def drn_c_26(BatchNorm, pretrained=True): 331 | model = DRN(BasicBlock, [1, 1, 2, 2, 2, 2, 1, 1], arch="C", BatchNorm=BatchNorm) 332 | if pretrained: 333 | pretrained = model_zoo.load_url(model_urls["drn-c-26"]) 334 | del pretrained["fc.weight"] 335 | del pretrained["fc.bias"] 336 | model.load_state_dict(pretrained) 337 | return model 338 | 339 | 340 | def drn_c_42(BatchNorm, pretrained=True): 341 | model = DRN(BasicBlock, [1, 1, 3, 4, 6, 3, 1, 1], arch="C", BatchNorm=BatchNorm) 342 | if pretrained: 343 | pretrained = model_zoo.load_url(model_urls["drn-c-42"]) 344 | del pretrained["fc.weight"] 345 | del pretrained["fc.bias"] 346 | model.load_state_dict(pretrained) 347 | return model 348 | 349 | 350 | def drn_c_58(BatchNorm, pretrained=True): 351 | model = DRN(Bottleneck, [1, 1, 3, 4, 6, 3, 1, 1], arch="C", BatchNorm=BatchNorm) 352 | if pretrained: 353 | pretrained = model_zoo.load_url(model_urls["drn-c-58"]) 354 | del pretrained["fc.weight"] 355 | del pretrained["fc.bias"] 356 | model.load_state_dict(pretrained) 357 | return model 358 | 359 | 360 | def drn_d_22(BatchNorm, pretrained=True): 361 | model = DRN(BasicBlock, [1, 1, 2, 2, 2, 2, 1, 1], arch="D", BatchNorm=BatchNorm) 362 | if pretrained: 363 | pretrained = model_zoo.load_url(model_urls["drn-d-22"]) 364 | del pretrained["fc.weight"] 365 | del pretrained["fc.bias"] 366 | model.load_state_dict(pretrained) 367 | return model 368 | 369 | 370 | def drn_d_24(BatchNorm, pretrained=True): 371 | model = DRN(BasicBlock, [1, 1, 2, 2, 2, 2, 2, 2], arch="D", BatchNorm=BatchNorm) 372 | if pretrained: 373 | pretrained = model_zoo.load_url(model_urls["drn-d-24"]) 374 | del pretrained["fc.weight"] 375 | del pretrained["fc.bias"] 376 | model.load_state_dict(pretrained) 377 | return model 378 | 379 | 380 | def drn_d_38(BatchNorm, pretrained=True): 381 | model = DRN(BasicBlock, [1, 1, 3, 4, 6, 3, 1, 1], arch="D", BatchNorm=BatchNorm) 382 | if pretrained: 383 | pretrained = model_zoo.load_url(model_urls["drn-d-38"]) 384 | del pretrained["fc.weight"] 385 | del pretrained["fc.bias"] 386 | model.load_state_dict(pretrained) 387 | return model 388 | 389 | 390 | def drn_d_40(BatchNorm, pretrained=True): 391 | model = DRN(BasicBlock, [1, 1, 3, 4, 6, 3, 2, 2], arch="D", BatchNorm=BatchNorm) 392 | if pretrained: 393 | pretrained = model_zoo.load_url(model_urls["drn-d-40"]) 394 | del pretrained["fc.weight"] 395 | del pretrained["fc.bias"] 396 | model.load_state_dict(pretrained) 397 | return model 398 | 399 | 400 | def drn_d_54(BatchNorm, pretrained=True): 401 | model = DRN(Bottleneck, [1, 1, 3, 4, 6, 3, 1, 1], arch="D", BatchNorm=BatchNorm) 402 | if pretrained: 403 | pretrained = model_zoo.load_url(model_urls["drn-d-54"]) 404 | del pretrained["fc.weight"] 405 | del pretrained["fc.bias"] 406 | model.load_state_dict(pretrained) 407 | return model 408 | 409 | 410 | def drn_d_105(BatchNorm, pretrained=True): 411 | model = DRN(Bottleneck, [1, 1, 3, 4, 23, 3, 1, 1], arch="D", BatchNorm=BatchNorm) 412 | if pretrained: 413 | pretrained = model_zoo.load_url(model_urls["drn-d-105"]) 414 | del pretrained["fc.weight"] 415 | del pretrained["fc.bias"] 416 | model.load_state_dict(pretrained) 417 | return model 418 | 419 | 420 | if __name__ == "__main__": 421 | import torch 422 | 423 | model = drn_a_50(BatchNorm=nn.BatchNorm2d, pretrained=True) 424 | input = torch.rand(1, 3, 512, 512) 425 | output, low_level_feat = model(input) 426 | print(output.size()) 427 | print(low_level_feat.size()) 428 | -------------------------------------------------------------------------------- /seg_lapa/networks/deeplab/backbone/mobilenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.utils.model_zoo as model_zoo 5 | 6 | from ..sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 7 | 8 | 9 | def conv_bn(inp, oup, stride, BatchNorm): 10 | return nn.Sequential(nn.Conv2d(inp, oup, 3, stride, 1, bias=False), BatchNorm(oup), nn.ReLU6(inplace=True)) 11 | 12 | 13 | def fixed_padding(inputs, kernel_size, dilation): 14 | kernel_size_effective = kernel_size + (kernel_size - 1) * (dilation - 1) 15 | pad_total = kernel_size_effective - 1 16 | pad_beg = pad_total // 2 17 | pad_end = pad_total - pad_beg 18 | padded_inputs = F.pad(inputs, (pad_beg, pad_end, pad_beg, pad_end)) 19 | return padded_inputs 20 | 21 | 22 | class InvertedResidual(nn.Module): 23 | def __init__(self, inp, oup, stride, dilation, expand_ratio, BatchNorm): 24 | super(InvertedResidual, self).__init__() 25 | self.stride = stride 26 | assert stride in [1, 2] 27 | 28 | hidden_dim = round(inp * expand_ratio) 29 | self.use_res_connect = self.stride == 1 and inp == oup 30 | self.kernel_size = 3 31 | self.dilation = dilation 32 | 33 | if expand_ratio == 1: 34 | self.conv = nn.Sequential( 35 | # dw 36 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 0, dilation, groups=hidden_dim, bias=False), 37 | BatchNorm(hidden_dim), 38 | nn.ReLU6(inplace=True), 39 | # pw-linear 40 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, 1, 1, bias=False), 41 | BatchNorm(oup), 42 | ) 43 | else: 44 | self.conv = nn.Sequential( 45 | # pw 46 | nn.Conv2d(inp, hidden_dim, 1, 1, 0, 1, bias=False), 47 | BatchNorm(hidden_dim), 48 | nn.ReLU6(inplace=True), 49 | # dw 50 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 0, dilation, groups=hidden_dim, bias=False), 51 | BatchNorm(hidden_dim), 52 | nn.ReLU6(inplace=True), 53 | # pw-linear 54 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, 1, bias=False), 55 | BatchNorm(oup), 56 | ) 57 | 58 | def forward(self, x): 59 | x_pad = fixed_padding(x, self.kernel_size, dilation=self.dilation) 60 | if self.use_res_connect: 61 | x = x + self.conv(x_pad) 62 | else: 63 | x = self.conv(x_pad) 64 | return x 65 | 66 | 67 | class MobileNetV2(nn.Module): 68 | def __init__(self, output_stride=8, BatchNorm=None, width_mult=1.0, pretrained=True): 69 | super(MobileNetV2, self).__init__() 70 | block = InvertedResidual 71 | input_channel = 32 72 | current_stride = 1 73 | rate = 1 74 | interverted_residual_setting = [ 75 | # t, c, n, s 76 | [1, 16, 1, 1], 77 | [6, 24, 2, 2], 78 | [6, 32, 3, 2], 79 | [6, 64, 4, 2], 80 | [6, 96, 3, 1], 81 | [6, 160, 3, 2], 82 | [6, 320, 1, 1], 83 | ] 84 | 85 | # building first layer 86 | input_channel = int(input_channel * width_mult) 87 | self.features = [conv_bn(3, input_channel, 2, BatchNorm)] 88 | current_stride *= 2 89 | # building inverted residual blocks 90 | for t, c, n, s in interverted_residual_setting: 91 | if current_stride == output_stride: 92 | stride = 1 93 | dilation = rate 94 | rate *= s 95 | else: 96 | stride = s 97 | dilation = 1 98 | current_stride *= s 99 | output_channel = int(c * width_mult) 100 | for i in range(n): 101 | if i == 0: 102 | self.features.append(block(input_channel, output_channel, stride, dilation, t, BatchNorm)) 103 | else: 104 | self.features.append(block(input_channel, output_channel, 1, dilation, t, BatchNorm)) 105 | input_channel = output_channel 106 | self.features = nn.Sequential(*self.features) 107 | self._initialize_weights() 108 | 109 | if pretrained: 110 | self._load_pretrained_model() 111 | 112 | self.low_level_features = self.features[0:4] 113 | self.high_level_features = self.features[4:] 114 | 115 | def forward(self, x): 116 | low_level_feat = self.low_level_features(x) 117 | x = self.high_level_features(low_level_feat) 118 | return x, low_level_feat 119 | 120 | def _load_pretrained_model(self): 121 | pretrain_dict = model_zoo.load_url("http://jeff95.me/models/mobilenet_v2-6a65762b.pth") 122 | model_dict = {} 123 | state_dict = self.state_dict() 124 | for k, v in pretrain_dict.items(): 125 | if k in state_dict: 126 | model_dict[k] = v 127 | state_dict.update(model_dict) 128 | self.load_state_dict(state_dict) 129 | 130 | def _initialize_weights(self): 131 | for m in self.modules(): 132 | if isinstance(m, nn.Conv2d): 133 | # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 134 | # m.weight.data.normal_(0, math.sqrt(2. / n)) 135 | torch.nn.init.kaiming_normal_(m.weight) 136 | elif isinstance(m, SynchronizedBatchNorm2d): 137 | m.weight.data.fill_(1) 138 | m.bias.data.zero_() 139 | elif isinstance(m, nn.BatchNorm2d): 140 | m.weight.data.fill_(1) 141 | m.bias.data.zero_() 142 | 143 | 144 | if __name__ == "__main__": 145 | input = torch.rand(1, 3, 512, 512) 146 | model = MobileNetV2(output_stride=16, BatchNorm=nn.BatchNorm2d) 147 | output, low_level_feat = model(input) 148 | print(output.size()) 149 | print(low_level_feat.size()) 150 | -------------------------------------------------------------------------------- /seg_lapa/networks/deeplab/backbone/resnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch.nn as nn 4 | import torch.utils.model_zoo as model_zoo 5 | 6 | from ..sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 7 | 8 | 9 | class Bottleneck(nn.Module): 10 | expansion = 4 11 | 12 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, BatchNorm=None): 13 | super(Bottleneck, self).__init__() 14 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 15 | self.bn1 = BatchNorm(planes) 16 | self.conv2 = nn.Conv2d( 17 | planes, planes, kernel_size=3, stride=stride, dilation=dilation, padding=dilation, bias=False 18 | ) 19 | self.bn2 = BatchNorm(planes) 20 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 21 | self.bn3 = BatchNorm(planes * 4) 22 | self.relu = nn.ReLU(inplace=True) 23 | self.downsample = downsample 24 | self.stride = stride 25 | self.dilation = dilation 26 | 27 | def forward(self, x): 28 | residual = x 29 | 30 | out = self.conv1(x) 31 | out = self.bn1(out) 32 | out = self.relu(out) 33 | 34 | out = self.conv2(out) 35 | out = self.bn2(out) 36 | out = self.relu(out) 37 | 38 | out = self.conv3(out) 39 | out = self.bn3(out) 40 | 41 | if self.downsample is not None: 42 | residual = self.downsample(x) 43 | 44 | out += residual 45 | out = self.relu(out) 46 | 47 | return out 48 | 49 | 50 | class ResNet(nn.Module): 51 | def __init__(self, block, layers, output_stride, BatchNorm, pretrained=True): 52 | self.inplanes = 64 53 | super(ResNet, self).__init__() 54 | blocks = [1, 2, 4] 55 | if output_stride == 16: 56 | strides = [1, 2, 2, 1] 57 | dilations = [1, 1, 1, 2] 58 | elif output_stride == 8: 59 | strides = [1, 2, 1, 1] 60 | dilations = [1, 1, 2, 4] 61 | else: 62 | raise NotImplementedError 63 | 64 | # Modules 65 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 66 | self.bn1 = BatchNorm(64) 67 | self.relu = nn.ReLU(inplace=True) 68 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 69 | 70 | self.layer1 = self._make_layer( 71 | block, 64, layers[0], stride=strides[0], dilation=dilations[0], BatchNorm=BatchNorm 72 | ) 73 | self.layer2 = self._make_layer( 74 | block, 128, layers[1], stride=strides[1], dilation=dilations[1], BatchNorm=BatchNorm 75 | ) 76 | self.layer3 = self._make_layer( 77 | block, 256, layers[2], stride=strides[2], dilation=dilations[2], BatchNorm=BatchNorm 78 | ) 79 | self.layer4 = self._make_MG_unit( 80 | block, 512, blocks=blocks, stride=strides[3], dilation=dilations[3], BatchNorm=BatchNorm 81 | ) 82 | # self.layer4 = self._make_layer(block, 512, layers[3], stride=strides[3], dilation=dilations[3], BatchNorm=BatchNorm) 83 | self._init_weight() 84 | 85 | if pretrained: 86 | self._load_pretrained_model() 87 | 88 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None): 89 | downsample = None 90 | if stride != 1 or self.inplanes != planes * block.expansion: 91 | downsample = nn.Sequential( 92 | nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), 93 | BatchNorm(planes * block.expansion), 94 | ) 95 | 96 | layers = [] 97 | layers.append(block(self.inplanes, planes, stride, dilation, downsample, BatchNorm)) 98 | self.inplanes = planes * block.expansion 99 | for i in range(1, blocks): 100 | layers.append(block(self.inplanes, planes, dilation=dilation, BatchNorm=BatchNorm)) 101 | 102 | return nn.Sequential(*layers) 103 | 104 | def _make_MG_unit(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None): 105 | downsample = None 106 | if stride != 1 or self.inplanes != planes * block.expansion: 107 | downsample = nn.Sequential( 108 | nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), 109 | BatchNorm(planes * block.expansion), 110 | ) 111 | 112 | layers = [] 113 | layers.append( 114 | block( 115 | self.inplanes, planes, stride, dilation=blocks[0] * dilation, downsample=downsample, BatchNorm=BatchNorm 116 | ) 117 | ) 118 | self.inplanes = planes * block.expansion 119 | for i in range(1, len(blocks)): 120 | layers.append(block(self.inplanes, planes, stride=1, dilation=blocks[i] * dilation, BatchNorm=BatchNorm)) 121 | 122 | return nn.Sequential(*layers) 123 | 124 | def forward(self, input): 125 | x = self.conv1(input) 126 | x = self.bn1(x) 127 | x = self.relu(x) 128 | x = self.maxpool(x) 129 | 130 | x = self.layer1(x) 131 | low_level_feat = x 132 | x = self.layer2(x) 133 | x = self.layer3(x) 134 | x = self.layer4(x) 135 | return x, low_level_feat 136 | 137 | def _init_weight(self): 138 | for m in self.modules(): 139 | if isinstance(m, nn.Conv2d): 140 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 141 | m.weight.data.normal_(0, math.sqrt(2.0 / n)) 142 | elif isinstance(m, SynchronizedBatchNorm2d): 143 | m.weight.data.fill_(1) 144 | m.bias.data.zero_() 145 | elif isinstance(m, nn.BatchNorm2d): 146 | m.weight.data.fill_(1) 147 | m.bias.data.zero_() 148 | 149 | def _load_pretrained_model(self): 150 | pretrain_dict = model_zoo.load_url("https://download.pytorch.org/models/resnet101-5d3b4d8f.pth") 151 | model_dict = {} 152 | state_dict = self.state_dict() 153 | for k, v in pretrain_dict.items(): 154 | if k in state_dict: 155 | model_dict[k] = v 156 | state_dict.update(model_dict) 157 | self.load_state_dict(state_dict) 158 | 159 | 160 | def ResNet101(output_stride, BatchNorm, pretrained=True): 161 | """Constructs a ResNet-101 model. 162 | Args: 163 | pretrained (bool): If True, returns a model pre-trained on ImageNet 164 | """ 165 | model = ResNet(Bottleneck, [3, 4, 23, 3], output_stride, BatchNorm, pretrained=pretrained) 166 | return model 167 | 168 | 169 | if __name__ == "__main__": 170 | import torch 171 | 172 | model = ResNet101(BatchNorm=nn.BatchNorm2d, pretrained=True, output_stride=8) 173 | input = torch.rand(1, 3, 512, 512) 174 | output, low_level_feat = model(input) 175 | print(output.size()) 176 | print(low_level_feat.size()) 177 | -------------------------------------------------------------------------------- /seg_lapa/networks/deeplab/backbone/xception.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.utils.model_zoo as model_zoo 6 | 7 | from ..sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 8 | 9 | 10 | def fixed_padding(inputs, kernel_size, dilation): 11 | kernel_size_effective = kernel_size + (kernel_size - 1) * (dilation - 1) 12 | pad_total = kernel_size_effective - 1 13 | pad_beg = pad_total // 2 14 | pad_end = pad_total - pad_beg 15 | padded_inputs = F.pad(inputs, (pad_beg, pad_end, pad_beg, pad_end)) 16 | return padded_inputs 17 | 18 | 19 | class SeparableConv2d(nn.Module): 20 | def __init__(self, inplanes, planes, kernel_size=3, stride=1, dilation=1, bias=False, BatchNorm=None): 21 | super(SeparableConv2d, self).__init__() 22 | 23 | self.conv1 = nn.Conv2d(inplanes, inplanes, kernel_size, stride, 0, dilation, groups=inplanes, bias=bias) 24 | self.bn = BatchNorm(inplanes) 25 | self.pointwise = nn.Conv2d(inplanes, planes, 1, 1, 0, 1, 1, bias=bias) 26 | 27 | def forward(self, x): 28 | x = fixed_padding(x, self.conv1.kernel_size[0], dilation=self.conv1.dilation[0]) 29 | x = self.conv1(x) 30 | x = self.bn(x) 31 | x = self.pointwise(x) 32 | return x 33 | 34 | 35 | class Block(nn.Module): 36 | def __init__( 37 | self, 38 | inplanes, 39 | planes, 40 | reps, 41 | stride=1, 42 | dilation=1, 43 | BatchNorm=None, 44 | start_with_relu=True, 45 | grow_first=True, 46 | is_last=False, 47 | ): 48 | super(Block, self).__init__() 49 | 50 | if planes != inplanes or stride != 1: 51 | self.skip = nn.Conv2d(inplanes, planes, 1, stride=stride, bias=False) 52 | self.skipbn = BatchNorm(planes) 53 | else: 54 | self.skip = None 55 | 56 | self.relu = nn.ReLU(inplace=True) 57 | rep = [] 58 | 59 | filters = inplanes 60 | if grow_first: 61 | rep.append(self.relu) 62 | rep.append(SeparableConv2d(inplanes, planes, 3, 1, dilation, BatchNorm=BatchNorm)) 63 | rep.append(BatchNorm(planes)) 64 | filters = planes 65 | 66 | for i in range(reps - 1): 67 | rep.append(self.relu) 68 | rep.append(SeparableConv2d(filters, filters, 3, 1, dilation, BatchNorm=BatchNorm)) 69 | rep.append(BatchNorm(filters)) 70 | 71 | if not grow_first: 72 | rep.append(self.relu) 73 | rep.append(SeparableConv2d(inplanes, planes, 3, 1, dilation, BatchNorm=BatchNorm)) 74 | rep.append(BatchNorm(planes)) 75 | 76 | if stride != 1: 77 | rep.append(self.relu) 78 | rep.append(SeparableConv2d(planes, planes, 3, 2, BatchNorm=BatchNorm)) 79 | rep.append(BatchNorm(planes)) 80 | 81 | if stride == 1 and is_last: 82 | rep.append(self.relu) 83 | rep.append(SeparableConv2d(planes, planes, 3, 1, BatchNorm=BatchNorm)) 84 | rep.append(BatchNorm(planes)) 85 | 86 | if not start_with_relu: 87 | rep = rep[1:] 88 | 89 | self.rep = nn.Sequential(*rep) 90 | 91 | def forward(self, inp): 92 | x = self.rep(inp) 93 | 94 | if self.skip is not None: 95 | skip = self.skip(inp) 96 | skip = self.skipbn(skip) 97 | else: 98 | skip = inp 99 | 100 | x = x + skip 101 | 102 | return x 103 | 104 | 105 | class AlignedXception(nn.Module): 106 | """ 107 | Modified Alighed Xception 108 | """ 109 | 110 | def __init__(self, output_stride, BatchNorm, pretrained=True): 111 | super(AlignedXception, self).__init__() 112 | 113 | if output_stride == 16: 114 | entry_block3_stride = 2 115 | middle_block_dilation = 1 116 | exit_block_dilations = (1, 2) 117 | elif output_stride == 8: 118 | entry_block3_stride = 1 119 | middle_block_dilation = 2 120 | exit_block_dilations = (2, 4) 121 | else: 122 | raise NotImplementedError 123 | 124 | # Entry flow 125 | self.conv1 = nn.Conv2d(3, 32, 3, stride=2, padding=1, bias=False) 126 | self.bn1 = BatchNorm(32) 127 | self.relu = nn.ReLU(inplace=True) 128 | 129 | self.conv2 = nn.Conv2d(32, 64, 3, stride=1, padding=1, bias=False) 130 | self.bn2 = BatchNorm(64) 131 | 132 | self.block1 = Block(64, 128, reps=2, stride=2, BatchNorm=BatchNorm, start_with_relu=False) 133 | self.block2 = Block(128, 256, reps=2, stride=2, BatchNorm=BatchNorm, start_with_relu=False, grow_first=True) 134 | self.block3 = Block( 135 | 256, 136 | 728, 137 | reps=2, 138 | stride=entry_block3_stride, 139 | BatchNorm=BatchNorm, 140 | start_with_relu=True, 141 | grow_first=True, 142 | is_last=True, 143 | ) 144 | 145 | # Middle flow 146 | self.block4 = Block( 147 | 728, 148 | 728, 149 | reps=3, 150 | stride=1, 151 | dilation=middle_block_dilation, 152 | BatchNorm=BatchNorm, 153 | start_with_relu=True, 154 | grow_first=True, 155 | ) 156 | self.block5 = Block( 157 | 728, 158 | 728, 159 | reps=3, 160 | stride=1, 161 | dilation=middle_block_dilation, 162 | BatchNorm=BatchNorm, 163 | start_with_relu=True, 164 | grow_first=True, 165 | ) 166 | self.block6 = Block( 167 | 728, 168 | 728, 169 | reps=3, 170 | stride=1, 171 | dilation=middle_block_dilation, 172 | BatchNorm=BatchNorm, 173 | start_with_relu=True, 174 | grow_first=True, 175 | ) 176 | self.block7 = Block( 177 | 728, 178 | 728, 179 | reps=3, 180 | stride=1, 181 | dilation=middle_block_dilation, 182 | BatchNorm=BatchNorm, 183 | start_with_relu=True, 184 | grow_first=True, 185 | ) 186 | self.block8 = Block( 187 | 728, 188 | 728, 189 | reps=3, 190 | stride=1, 191 | dilation=middle_block_dilation, 192 | BatchNorm=BatchNorm, 193 | start_with_relu=True, 194 | grow_first=True, 195 | ) 196 | self.block9 = Block( 197 | 728, 198 | 728, 199 | reps=3, 200 | stride=1, 201 | dilation=middle_block_dilation, 202 | BatchNorm=BatchNorm, 203 | start_with_relu=True, 204 | grow_first=True, 205 | ) 206 | self.block10 = Block( 207 | 728, 208 | 728, 209 | reps=3, 210 | stride=1, 211 | dilation=middle_block_dilation, 212 | BatchNorm=BatchNorm, 213 | start_with_relu=True, 214 | grow_first=True, 215 | ) 216 | self.block11 = Block( 217 | 728, 218 | 728, 219 | reps=3, 220 | stride=1, 221 | dilation=middle_block_dilation, 222 | BatchNorm=BatchNorm, 223 | start_with_relu=True, 224 | grow_first=True, 225 | ) 226 | self.block12 = Block( 227 | 728, 228 | 728, 229 | reps=3, 230 | stride=1, 231 | dilation=middle_block_dilation, 232 | BatchNorm=BatchNorm, 233 | start_with_relu=True, 234 | grow_first=True, 235 | ) 236 | self.block13 = Block( 237 | 728, 238 | 728, 239 | reps=3, 240 | stride=1, 241 | dilation=middle_block_dilation, 242 | BatchNorm=BatchNorm, 243 | start_with_relu=True, 244 | grow_first=True, 245 | ) 246 | self.block14 = Block( 247 | 728, 248 | 728, 249 | reps=3, 250 | stride=1, 251 | dilation=middle_block_dilation, 252 | BatchNorm=BatchNorm, 253 | start_with_relu=True, 254 | grow_first=True, 255 | ) 256 | self.block15 = Block( 257 | 728, 258 | 728, 259 | reps=3, 260 | stride=1, 261 | dilation=middle_block_dilation, 262 | BatchNorm=BatchNorm, 263 | start_with_relu=True, 264 | grow_first=True, 265 | ) 266 | self.block16 = Block( 267 | 728, 268 | 728, 269 | reps=3, 270 | stride=1, 271 | dilation=middle_block_dilation, 272 | BatchNorm=BatchNorm, 273 | start_with_relu=True, 274 | grow_first=True, 275 | ) 276 | self.block17 = Block( 277 | 728, 278 | 728, 279 | reps=3, 280 | stride=1, 281 | dilation=middle_block_dilation, 282 | BatchNorm=BatchNorm, 283 | start_with_relu=True, 284 | grow_first=True, 285 | ) 286 | self.block18 = Block( 287 | 728, 288 | 728, 289 | reps=3, 290 | stride=1, 291 | dilation=middle_block_dilation, 292 | BatchNorm=BatchNorm, 293 | start_with_relu=True, 294 | grow_first=True, 295 | ) 296 | self.block19 = Block( 297 | 728, 298 | 728, 299 | reps=3, 300 | stride=1, 301 | dilation=middle_block_dilation, 302 | BatchNorm=BatchNorm, 303 | start_with_relu=True, 304 | grow_first=True, 305 | ) 306 | 307 | # Exit flow 308 | self.block20 = Block( 309 | 728, 310 | 1024, 311 | reps=2, 312 | stride=1, 313 | dilation=exit_block_dilations[0], 314 | BatchNorm=BatchNorm, 315 | start_with_relu=True, 316 | grow_first=False, 317 | is_last=True, 318 | ) 319 | 320 | self.conv3 = SeparableConv2d(1024, 1536, 3, stride=1, dilation=exit_block_dilations[1], BatchNorm=BatchNorm) 321 | self.bn3 = BatchNorm(1536) 322 | 323 | self.conv4 = SeparableConv2d(1536, 1536, 3, stride=1, dilation=exit_block_dilations[1], BatchNorm=BatchNorm) 324 | self.bn4 = BatchNorm(1536) 325 | 326 | self.conv5 = SeparableConv2d(1536, 2048, 3, stride=1, dilation=exit_block_dilations[1], BatchNorm=BatchNorm) 327 | self.bn5 = BatchNorm(2048) 328 | 329 | # Init weights 330 | self._init_weight() 331 | 332 | # Load pretrained model 333 | if pretrained: 334 | self._load_pretrained_model() 335 | 336 | def forward(self, x): 337 | # Entry flow 338 | x = self.conv1(x) 339 | x = self.bn1(x) 340 | x = self.relu(x) 341 | 342 | x = self.conv2(x) 343 | x = self.bn2(x) 344 | x = self.relu(x) 345 | 346 | x = self.block1(x) 347 | # add relu here 348 | x = self.relu(x) 349 | low_level_feat = x 350 | x = self.block2(x) 351 | x = self.block3(x) 352 | 353 | # Middle flow 354 | x = self.block4(x) 355 | x = self.block5(x) 356 | x = self.block6(x) 357 | x = self.block7(x) 358 | x = self.block8(x) 359 | x = self.block9(x) 360 | x = self.block10(x) 361 | x = self.block11(x) 362 | x = self.block12(x) 363 | x = self.block13(x) 364 | x = self.block14(x) 365 | x = self.block15(x) 366 | x = self.block16(x) 367 | x = self.block17(x) 368 | x = self.block18(x) 369 | x = self.block19(x) 370 | 371 | # Exit flow 372 | x = self.block20(x) 373 | x = self.relu(x) 374 | x = self.conv3(x) 375 | x = self.bn3(x) 376 | x = self.relu(x) 377 | 378 | x = self.conv4(x) 379 | x = self.bn4(x) 380 | x = self.relu(x) 381 | 382 | x = self.conv5(x) 383 | x = self.bn5(x) 384 | x = self.relu(x) 385 | 386 | return x, low_level_feat 387 | 388 | def _init_weight(self): 389 | for m in self.modules(): 390 | if isinstance(m, nn.Conv2d): 391 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 392 | m.weight.data.normal_(0, math.sqrt(2.0 / n)) 393 | elif isinstance(m, SynchronizedBatchNorm2d): 394 | m.weight.data.fill_(1) 395 | m.bias.data.zero_() 396 | elif isinstance(m, nn.BatchNorm2d): 397 | m.weight.data.fill_(1) 398 | m.bias.data.zero_() 399 | 400 | def _load_pretrained_model(self): 401 | pretrain_dict = model_zoo.load_url("http://data.lip6.fr/cadene/pretrainedmodels/xception-b5690688.pth") 402 | model_dict = {} 403 | state_dict = self.state_dict() 404 | 405 | for k, v in pretrain_dict.items(): 406 | if k in state_dict: 407 | if "pointwise" in k: 408 | v = v.unsqueeze(-1).unsqueeze(-1) 409 | if k.startswith("block11"): 410 | model_dict[k] = v 411 | model_dict[k.replace("block11", "block12")] = v 412 | model_dict[k.replace("block11", "block13")] = v 413 | model_dict[k.replace("block11", "block14")] = v 414 | model_dict[k.replace("block11", "block15")] = v 415 | model_dict[k.replace("block11", "block16")] = v 416 | model_dict[k.replace("block11", "block17")] = v 417 | model_dict[k.replace("block11", "block18")] = v 418 | model_dict[k.replace("block11", "block19")] = v 419 | elif k.startswith("block12"): 420 | model_dict[k.replace("block12", "block20")] = v 421 | elif k.startswith("bn3"): 422 | model_dict[k] = v 423 | model_dict[k.replace("bn3", "bn4")] = v 424 | elif k.startswith("conv4"): 425 | model_dict[k.replace("conv4", "conv5")] = v 426 | elif k.startswith("bn4"): 427 | model_dict[k.replace("bn4", "bn5")] = v 428 | else: 429 | model_dict[k] = v 430 | state_dict.update(model_dict) 431 | self.load_state_dict(state_dict) 432 | 433 | 434 | if __name__ == "__main__": 435 | import torch 436 | 437 | model = AlignedXception(BatchNorm=nn.BatchNorm2d, pretrained=True, output_stride=16) 438 | input = torch.rand(1, 3, 512, 512) 439 | output, low_level_feat = model(input) 440 | print(output.size()) 441 | print(low_level_feat.size()) 442 | -------------------------------------------------------------------------------- /seg_lapa/networks/deeplab/decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 6 | 7 | 8 | class Decoder(nn.Module): 9 | def __init__(self, num_classes, backbone, BatchNorm): 10 | super(Decoder, self).__init__() 11 | if backbone == "resnet" or backbone == "drn": 12 | low_level_inplanes = 256 13 | elif backbone == "xception": 14 | low_level_inplanes = 128 15 | elif backbone == "mobilenet": 16 | low_level_inplanes = 24 17 | else: 18 | raise NotImplementedError 19 | 20 | self.conv1 = nn.Conv2d(low_level_inplanes, 48, 1, bias=False) 21 | self.bn1 = BatchNorm(48) 22 | self.relu = nn.ReLU() 23 | self.last_conv = nn.Sequential( 24 | nn.Conv2d(304, 256, kernel_size=3, stride=1, padding=1, bias=False), 25 | BatchNorm(256), 26 | nn.ReLU(), 27 | nn.Dropout(0.5), 28 | nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False), 29 | BatchNorm(256), 30 | nn.ReLU(), 31 | nn.Dropout(0.1), 32 | nn.Conv2d(256, num_classes, kernel_size=1, stride=1), 33 | ) 34 | self._init_weight() 35 | 36 | def forward(self, x, low_level_feat): 37 | low_level_feat = self.conv1(low_level_feat) 38 | low_level_feat = self.bn1(low_level_feat) 39 | low_level_feat = self.relu(low_level_feat) 40 | 41 | x = F.interpolate(x, size=low_level_feat.size()[2:], mode="bilinear", align_corners=True) 42 | x = torch.cat((x, low_level_feat), dim=1) 43 | x = self.last_conv(x) 44 | 45 | return x 46 | 47 | def _init_weight(self): 48 | for m in self.modules(): 49 | if isinstance(m, nn.Conv2d): 50 | torch.nn.init.kaiming_normal_(m.weight) 51 | elif isinstance(m, SynchronizedBatchNorm2d): 52 | m.weight.data.fill_(1) 53 | m.bias.data.zero_() 54 | elif isinstance(m, nn.BatchNorm2d): 55 | m.weight.data.fill_(1) 56 | m.bias.data.zero_() 57 | 58 | 59 | def build_decoder(num_classes, backbone, BatchNorm): 60 | return Decoder(num_classes, backbone, BatchNorm) 61 | -------------------------------------------------------------------------------- /seg_lapa/networks/deeplab/decoder_masks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 6 | 7 | 8 | class Decoder(nn.Module): 9 | def __init__(self, num_classes, backbone, BatchNorm): 10 | super(Decoder, self).__init__() 11 | if backbone == "resnet" or backbone == "drn": 12 | low_level_inplanes = 256 13 | elif backbone == "xception": 14 | low_level_inplanes = 128 15 | elif backbone == "mobilenet": 16 | low_level_inplanes = 24 17 | else: 18 | raise NotImplementedError 19 | 20 | self.conv1 = nn.Conv2d(low_level_inplanes, 48, 1, bias=False) 21 | self.bn1 = BatchNorm(48) 22 | self.relu = nn.ReLU() 23 | self.last_conv = nn.Sequential( 24 | nn.Conv2d(304, 256, kernel_size=3, stride=1, padding=1, bias=False), 25 | BatchNorm(256), 26 | nn.ReLU(), 27 | nn.Dropout(0.5), 28 | nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False), 29 | BatchNorm(256), 30 | nn.ReLU(), 31 | nn.Dropout(0.1), 32 | nn.Conv2d(256, num_classes, kernel_size=1, stride=1), 33 | ) 34 | self._init_weight() 35 | 36 | def forward(self, x, low_level_feat): 37 | low_level_feat = self.conv1(low_level_feat) 38 | low_level_feat = self.bn1(low_level_feat) 39 | low_level_feat = self.relu(low_level_feat) 40 | 41 | x = F.interpolate(x, size=low_level_feat.size()[2:], mode="bilinear", align_corners=True) 42 | x = torch.cat((x, low_level_feat), dim=1) 43 | x = self.last_conv(x) 44 | 45 | return x 46 | 47 | def _init_weight(self): 48 | for m in self.modules(): 49 | if isinstance(m, nn.Conv2d): 50 | torch.nn.init.kaiming_normal_(m.weight) 51 | elif isinstance(m, SynchronizedBatchNorm2d): 52 | m.weight.data.fill_(1) 53 | m.bias.data.zero_() 54 | elif isinstance(m, nn.BatchNorm2d): 55 | m.weight.data.fill_(1) 56 | m.bias.data.zero_() 57 | 58 | 59 | def build_decoder(num_classes, backbone, BatchNorm): 60 | return Decoder(num_classes, backbone, BatchNorm) 61 | -------------------------------------------------------------------------------- /seg_lapa/networks/deeplab/deeplab.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.cuda.amp import autocast 5 | 6 | from .aspp import build_aspp 7 | from .backbone import build_backbone 8 | from .decoder import build_decoder 9 | from .sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 10 | 11 | 12 | class DeepLab(nn.Module): 13 | def __init__( 14 | self, backbone="drn", output_stride=8, num_classes=11, sync_bn=False, freeze_bn=False, enable_amp=False 15 | ): 16 | super(DeepLab, self).__init__() 17 | 18 | if backbone == "drn" and output_stride != 8: 19 | raise ValueError(f'The "drn" backbone only supports output stride = 8. Input: {output_stride}') 20 | 21 | # Ref for sync_bn: https://hangzhang.org/PyTorch-Encoding/tutorials/syncbn.html 22 | # Sync batchnorm is required when the effective batchsize per GPU is small (~4) 23 | if sync_bn is True: 24 | BatchNorm = SynchronizedBatchNorm2d 25 | else: 26 | BatchNorm = nn.BatchNorm2d 27 | 28 | self.backbone = build_backbone(backbone, output_stride, BatchNorm) 29 | self.aspp = build_aspp(backbone, output_stride, BatchNorm) 30 | self.decoder = build_decoder(num_classes, backbone, BatchNorm) 31 | 32 | self.freeze_bn = freeze_bn 33 | 34 | self.enable_amp = enable_amp 35 | 36 | def forward(self, inputs): 37 | with autocast(enabled=self.enable_amp): 38 | """Pytorch Automatic Mixed Precision (AMP) Training 39 | Ref: https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast 40 | 41 | - For use with DataParallel, we must add autocast within model definition. 42 | """ 43 | x, low_level_feat = self.backbone(inputs) 44 | x = self.aspp(x) 45 | x = self.decoder(x, low_level_feat) 46 | x = F.interpolate(x, size=inputs.size()[2:], mode="bilinear", align_corners=True) 47 | 48 | return x 49 | 50 | def freeze_bn(self): 51 | for m in self.modules(): 52 | if isinstance(m, SynchronizedBatchNorm2d): 53 | m.eval() 54 | elif isinstance(m, nn.BatchNorm2d): 55 | m.eval() 56 | 57 | def get_1x_lr_params(self): 58 | modules = [self.backbone] 59 | for i in range(len(modules)): 60 | for m in modules[i].named_modules(): 61 | if self.freeze_bn: 62 | if isinstance(m[1], nn.Conv2d): 63 | for p in m[1].parameters(): 64 | if p.requires_grad: 65 | yield p 66 | else: 67 | if ( 68 | isinstance(m[1], nn.Conv2d) 69 | or isinstance(m[1], SynchronizedBatchNorm2d) 70 | or isinstance(m[1], nn.BatchNorm2d) 71 | ): 72 | for p in m[1].parameters(): 73 | if p.requires_grad: 74 | yield p 75 | 76 | def get_10x_lr_params(self): 77 | modules = [self.aspp, self.decoder] 78 | for i in range(len(modules)): 79 | for m in modules[i].named_modules(): 80 | if self.freeze_bn: 81 | if isinstance(m[1], nn.Conv2d): 82 | for p in m[1].parameters(): 83 | if p.requires_grad: 84 | yield p 85 | else: 86 | if ( 87 | isinstance(m[1], nn.Conv2d) 88 | or isinstance(m[1], SynchronizedBatchNorm2d) 89 | or isinstance(m[1], nn.BatchNorm2d) 90 | ): 91 | for p in m[1].parameters(): 92 | if p.requires_grad: 93 | yield p 94 | 95 | 96 | if __name__ == "__main__": 97 | model = DeepLab(backbone="drn", output_stride=8) 98 | model.eval() 99 | inputs = torch.rand(1, 3, 512, 512) 100 | output = model(inputs) 101 | print(output.size()) 102 | -------------------------------------------------------------------------------- /seg_lapa/networks/deeplab/readme.md: -------------------------------------------------------------------------------- 1 | The code for deeplabv3+ was adapted from: https://github.com/jfzhang95/pytorch-deeplab-xception 2 | -------------------------------------------------------------------------------- /seg_lapa/networks/deeplab/sync_batchnorm/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : __init__.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d 12 | from .replicate import DataParallelWithCallback, patch_replication_callback 13 | -------------------------------------------------------------------------------- /seg_lapa/networks/deeplab/sync_batchnorm/batchnorm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : batchnorm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import collections 12 | 13 | import torch 14 | import torch.nn.functional as F 15 | from torch.nn.modules.batchnorm import _BatchNorm 16 | from torch.nn.parallel._functions import Broadcast, ReduceAddCoalesced 17 | 18 | from .comm import SyncMaster 19 | 20 | __all__ = ["SynchronizedBatchNorm1d", "SynchronizedBatchNorm2d", "SynchronizedBatchNorm3d"] 21 | 22 | 23 | def _sum_ft(tensor): 24 | """sum over the first and last dimention""" 25 | return tensor.sum(dim=0).sum(dim=-1) 26 | 27 | 28 | def _unsqueeze_ft(tensor): 29 | """add new dementions at the front and the tail""" 30 | return tensor.unsqueeze(0).unsqueeze(-1) 31 | 32 | 33 | _ChildMessage = collections.namedtuple("_ChildMessage", ["sum", "ssum", "sum_size"]) 34 | _MasterMessage = collections.namedtuple("_MasterMessage", ["sum", "inv_std"]) 35 | 36 | 37 | class _SynchronizedBatchNorm(_BatchNorm): 38 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True): 39 | super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine) 40 | 41 | self._sync_master = SyncMaster(self._data_parallel_master) 42 | 43 | self._is_parallel = False 44 | self._parallel_id = None 45 | self._slave_pipe = None 46 | 47 | def forward(self, input): 48 | # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation. 49 | if not (self._is_parallel and self.training): 50 | return F.batch_norm( 51 | input, 52 | self.running_mean, 53 | self.running_var, 54 | self.weight, 55 | self.bias, 56 | self.training, 57 | self.momentum, 58 | self.eps, 59 | ) 60 | 61 | # Resize the input to (B, C, -1). 62 | input_shape = input.size() 63 | input = input.view(input.size(0), self.num_features, -1) 64 | 65 | # Compute the sum and square-sum. 66 | sum_size = input.size(0) * input.size(2) 67 | input_sum = _sum_ft(input) 68 | input_ssum = _sum_ft(input ** 2) 69 | 70 | # Reduce-and-broadcast the statistics. 71 | if self._parallel_id == 0: 72 | mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size)) 73 | else: 74 | mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size)) 75 | 76 | # Compute the output. 77 | if self.affine: 78 | # MJY:: Fuse the multiplication for speed. 79 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias) 80 | else: 81 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std) 82 | 83 | # Reshape it. 84 | return output.view(input_shape) 85 | 86 | def __data_parallel_replicate__(self, ctx, copy_id): 87 | self._is_parallel = True 88 | self._parallel_id = copy_id 89 | 90 | # parallel_id == 0 means master device. 91 | if self._parallel_id == 0: 92 | ctx.sync_master = self._sync_master 93 | else: 94 | self._slave_pipe = ctx.sync_master.register_slave(copy_id) 95 | 96 | def _data_parallel_master(self, intermediates): 97 | """Reduce the sum and square-sum, compute the statistics, and broadcast it.""" 98 | 99 | # Always using same "device order" makes the ReduceAdd operation faster. 100 | # Thanks to:: Tete Xiao (http://tetexiao.com/) 101 | intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device()) 102 | 103 | to_reduce = [i[1][:2] for i in intermediates] 104 | to_reduce = [j for i in to_reduce for j in i] # flatten 105 | target_gpus = [i[1].sum.get_device() for i in intermediates] 106 | 107 | sum_size = sum([i[1].sum_size for i in intermediates]) 108 | sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce) 109 | mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size) 110 | 111 | broadcasted = Broadcast.apply(target_gpus, mean, inv_std) 112 | 113 | outputs = [] 114 | for i, rec in enumerate(intermediates): 115 | outputs.append((rec[0], _MasterMessage(*broadcasted[i * 2 : i * 2 + 2]))) 116 | 117 | return outputs 118 | 119 | def _compute_mean_std(self, sum_, ssum, size): 120 | """Compute the mean and standard-deviation with sum and square-sum. This method 121 | also maintains the moving average on the master device.""" 122 | assert size > 1, "BatchNorm computes unbiased standard-deviation, which requires size > 1." 123 | mean = sum_ / size 124 | sumvar = ssum - sum_ * mean 125 | unbias_var = sumvar / (size - 1) 126 | bias_var = sumvar / size 127 | 128 | self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data 129 | self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data 130 | 131 | return mean, bias_var.clamp(self.eps) ** -0.5 132 | 133 | 134 | class SynchronizedBatchNorm1d(_SynchronizedBatchNorm): 135 | r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a 136 | mini-batch. 137 | .. math:: 138 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 139 | This module differs from the built-in PyTorch BatchNorm1d as the mean and 140 | standard-deviation are reduced across all devices during training. 141 | For example, when one uses `nn.DataParallel` to wrap the network during 142 | training, PyTorch's implementation normalize the tensor on each device using 143 | the statistics only on that device, which accelerated the computation and 144 | is also easy to implement, but the statistics might be inaccurate. 145 | Instead, in this synchronized version, the statistics will be computed 146 | over all training samples distributed on multiple devices. 147 | 148 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 149 | as the built-in PyTorch implementation. 150 | The mean and standard-deviation are calculated per-dimension over 151 | the mini-batches and gamma and beta are learnable parameter vectors 152 | of size C (where C is the input size). 153 | During training, this layer keeps a running estimate of its computed mean 154 | and variance. The running sum is kept with a default momentum of 0.1. 155 | During evaluation, this running mean/variance is used for normalization. 156 | Because the BatchNorm is done over the `C` dimension, computing statistics 157 | on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm 158 | Args: 159 | num_features: num_features from an expected input of size 160 | `batch_size x num_features [x width]` 161 | eps: a value added to the denominator for numerical stability. 162 | Default: 1e-5 163 | momentum: the value used for the running_mean and running_var 164 | computation. Default: 0.1 165 | affine: a boolean value that when set to ``True``, gives the layer learnable 166 | affine parameters. Default: ``True`` 167 | Shape: 168 | - Input: :math:`(N, C)` or :math:`(N, C, L)` 169 | - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) 170 | Examples: 171 | >>> # With Learnable Parameters 172 | >>> m = SynchronizedBatchNorm1d(100) 173 | >>> # Without Learnable Parameters 174 | >>> m = SynchronizedBatchNorm1d(100, affine=False) 175 | >>> input = torch.autograd.Variable(torch.randn(20, 100)) 176 | >>> output = m(input) 177 | """ 178 | 179 | def _check_input_dim(self, input): 180 | if input.dim() != 2 and input.dim() != 3: 181 | raise ValueError("expected 2D or 3D input (got {}D input)".format(input.dim())) 182 | super(SynchronizedBatchNorm1d, self)._check_input_dim(input) 183 | 184 | 185 | class SynchronizedBatchNorm2d(_SynchronizedBatchNorm): 186 | r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch 187 | of 3d inputs 188 | .. math:: 189 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 190 | This module differs from the built-in PyTorch BatchNorm2d as the mean and 191 | standard-deviation are reduced across all devices during training. 192 | For example, when one uses `nn.DataParallel` to wrap the network during 193 | training, PyTorch's implementation normalize the tensor on each device using 194 | the statistics only on that device, which accelerated the computation and 195 | is also easy to implement, but the statistics might be inaccurate. 196 | Instead, in this synchronized version, the statistics will be computed 197 | over all training samples distributed on multiple devices. 198 | 199 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 200 | as the built-in PyTorch implementation. 201 | The mean and standard-deviation are calculated per-dimension over 202 | the mini-batches and gamma and beta are learnable parameter vectors 203 | of size C (where C is the input size). 204 | During training, this layer keeps a running estimate of its computed mean 205 | and variance. The running sum is kept with a default momentum of 0.1. 206 | During evaluation, this running mean/variance is used for normalization. 207 | Because the BatchNorm is done over the `C` dimension, computing statistics 208 | on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm 209 | Args: 210 | num_features: num_features from an expected input of 211 | size batch_size x num_features x height x width 212 | eps: a value added to the denominator for numerical stability. 213 | Default: 1e-5 214 | momentum: the value used for the running_mean and running_var 215 | computation. Default: 0.1 216 | affine: a boolean value that when set to ``True``, gives the layer learnable 217 | affine parameters. Default: ``True`` 218 | Shape: 219 | - Input: :math:`(N, C, H, W)` 220 | - Output: :math:`(N, C, H, W)` (same shape as input) 221 | Examples: 222 | >>> # With Learnable Parameters 223 | >>> m = SynchronizedBatchNorm2d(100) 224 | >>> # Without Learnable Parameters 225 | >>> m = SynchronizedBatchNorm2d(100, affine=False) 226 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45)) 227 | >>> output = m(input) 228 | """ 229 | 230 | def _check_input_dim(self, input): 231 | if input.dim() != 4: 232 | raise ValueError("expected 4D input (got {}D input)".format(input.dim())) 233 | super(SynchronizedBatchNorm2d, self)._check_input_dim(input) 234 | 235 | 236 | class SynchronizedBatchNorm3d(_SynchronizedBatchNorm): 237 | r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch 238 | of 4d inputs 239 | .. math:: 240 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 241 | This module differs from the built-in PyTorch BatchNorm3d as the mean and 242 | standard-deviation are reduced across all devices during training. 243 | For example, when one uses `nn.DataParallel` to wrap the network during 244 | training, PyTorch's implementation normalize the tensor on each device using 245 | the statistics only on that device, which accelerated the computation and 246 | is also easy to implement, but the statistics might be inaccurate. 247 | Instead, in this synchronized version, the statistics will be computed 248 | over all training samples distributed on multiple devices. 249 | 250 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 251 | as the built-in PyTorch implementation. 252 | The mean and standard-deviation are calculated per-dimension over 253 | the mini-batches and gamma and beta are learnable parameter vectors 254 | of size C (where C is the input size). 255 | During training, this layer keeps a running estimate of its computed mean 256 | and variance. The running sum is kept with a default momentum of 0.1. 257 | During evaluation, this running mean/variance is used for normalization. 258 | Because the BatchNorm is done over the `C` dimension, computing statistics 259 | on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm 260 | or Spatio-temporal BatchNorm 261 | Args: 262 | num_features: num_features from an expected input of 263 | size batch_size x num_features x depth x height x width 264 | eps: a value added to the denominator for numerical stability. 265 | Default: 1e-5 266 | momentum: the value used for the running_mean and running_var 267 | computation. Default: 0.1 268 | affine: a boolean value that when set to ``True``, gives the layer learnable 269 | affine parameters. Default: ``True`` 270 | Shape: 271 | - Input: :math:`(N, C, D, H, W)` 272 | - Output: :math:`(N, C, D, H, W)` (same shape as input) 273 | Examples: 274 | >>> # With Learnable Parameters 275 | >>> m = SynchronizedBatchNorm3d(100) 276 | >>> # Without Learnable Parameters 277 | >>> m = SynchronizedBatchNorm3d(100, affine=False) 278 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10)) 279 | >>> output = m(input) 280 | """ 281 | 282 | def _check_input_dim(self, input): 283 | if input.dim() != 5: 284 | raise ValueError("expected 5D input (got {}D input)".format(input.dim())) 285 | super(SynchronizedBatchNorm3d, self)._check_input_dim(input) 286 | -------------------------------------------------------------------------------- /seg_lapa/networks/deeplab/sync_batchnorm/comm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : comm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import collections 12 | import queue 13 | import threading 14 | 15 | __all__ = ["FutureResult", "SlavePipe", "SyncMaster"] 16 | 17 | 18 | class FutureResult(object): 19 | """A thread-safe future implementation. Used only as one-to-one pipe.""" 20 | 21 | def __init__(self): 22 | self._result = None 23 | self._lock = threading.Lock() 24 | self._cond = threading.Condition(self._lock) 25 | 26 | def put(self, result): 27 | with self._lock: 28 | assert self._result is None, "Previous result has't been fetched." 29 | self._result = result 30 | self._cond.notify() 31 | 32 | def get(self): 33 | with self._lock: 34 | if self._result is None: 35 | self._cond.wait() 36 | 37 | res = self._result 38 | self._result = None 39 | return res 40 | 41 | 42 | _MasterRegistry = collections.namedtuple("MasterRegistry", ["result"]) 43 | _SlavePipeBase = collections.namedtuple("_SlavePipeBase", ["identifier", "queue", "result"]) 44 | 45 | 46 | class SlavePipe(_SlavePipeBase): 47 | """Pipe for master-slave communication.""" 48 | 49 | def run_slave(self, msg): 50 | self.queue.put((self.identifier, msg)) 51 | ret = self.result.get() 52 | self.queue.put(True) 53 | return ret 54 | 55 | 56 | class SyncMaster(object): 57 | """An abstract `SyncMaster` object. 58 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should 59 | call `register(id)` and obtain an `SlavePipe` to communicate with the master. 60 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, 61 | and passed to a registered callback. 62 | - After receiving the messages, the master device should gather the information and determine to message passed 63 | back to each slave devices. 64 | """ 65 | 66 | def __init__(self, master_callback): 67 | """ 68 | Args: 69 | master_callback: a callback to be invoked after having collected messages from slave devices. 70 | """ 71 | self._master_callback = master_callback 72 | self._queue = queue.Queue() 73 | self._registry = collections.OrderedDict() 74 | self._activated = False 75 | 76 | def __getstate__(self): 77 | return {"master_callback": self._master_callback} 78 | 79 | def __setstate__(self, state): 80 | self.__init__(state["master_callback"]) 81 | 82 | def register_slave(self, identifier): 83 | """ 84 | Register an slave device. 85 | Args: 86 | identifier: an identifier, usually is the device id. 87 | Returns: a `SlavePipe` object which can be used to communicate with the master device. 88 | """ 89 | if self._activated: 90 | assert self._queue.empty(), "Queue is not clean before next initialization." 91 | self._activated = False 92 | self._registry.clear() 93 | future = FutureResult() 94 | self._registry[identifier] = _MasterRegistry(future) 95 | return SlavePipe(identifier, self._queue, future) 96 | 97 | def run_master(self, master_msg): 98 | """ 99 | Main entry for the master device in each forward pass. 100 | The messages were first collected from each devices (including the master device), and then 101 | an callback will be invoked to compute the message to be sent back to each devices 102 | (including the master device). 103 | Args: 104 | master_msg: the message that the master want to send to itself. This will be placed as the first 105 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. 106 | Returns: the message to be sent back to the master device. 107 | """ 108 | self._activated = True 109 | 110 | intermediates = [(0, master_msg)] 111 | for i in range(self.nr_slaves): 112 | intermediates.append(self._queue.get()) 113 | 114 | results = self._master_callback(intermediates) 115 | assert results[0][0] == 0, "The first result should belongs to the master." 116 | 117 | for i, res in results: 118 | if i == 0: 119 | continue 120 | self._registry[i].result.put(res) 121 | 122 | for i in range(self.nr_slaves): 123 | assert self._queue.get() is True 124 | 125 | return results[0][1] 126 | 127 | @property 128 | def nr_slaves(self): 129 | return len(self._registry) 130 | -------------------------------------------------------------------------------- /seg_lapa/networks/deeplab/sync_batchnorm/replicate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : replicate.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import functools 12 | 13 | from torch.nn.parallel.data_parallel import DataParallel 14 | 15 | __all__ = ["CallbackContext", "execute_replication_callbacks", "DataParallelWithCallback", "patch_replication_callback"] 16 | 17 | 18 | class CallbackContext(object): 19 | pass 20 | 21 | 22 | def execute_replication_callbacks(modules): 23 | """ 24 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. 25 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 26 | Note that, as all modules are isomorphism, we assign each sub-module with a context 27 | (shared among multiple copies of this module on different devices). 28 | Through this context, different copies can share some information. 29 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback 30 | of any slave copies. 31 | """ 32 | master_copy = modules[0] 33 | nr_modules = len(list(master_copy.modules())) 34 | ctxs = [CallbackContext() for _ in range(nr_modules)] 35 | 36 | for i, module in enumerate(modules): 37 | for j, m in enumerate(module.modules()): 38 | if hasattr(m, "__data_parallel_replicate__"): 39 | m.__data_parallel_replicate__(ctxs[j], i) 40 | 41 | 42 | class DataParallelWithCallback(DataParallel): 43 | """ 44 | Data Parallel with a replication callback. 45 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by 46 | original `replicate` function. 47 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 48 | Examples: 49 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 50 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 51 | # sync_bn.__data_parallel_replicate__ will be invoked. 52 | """ 53 | 54 | def replicate(self, module, device_ids): 55 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids) 56 | execute_replication_callbacks(modules) 57 | return modules 58 | 59 | 60 | def patch_replication_callback(data_parallel): 61 | """ 62 | Monkey-patch an existing `DataParallel` object. Add the replication callback. 63 | Useful when you have customized `DataParallel` implementation. 64 | Examples: 65 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 66 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) 67 | > patch_replication_callback(sync_bn) 68 | # this is equivalent to 69 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 70 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 71 | """ 72 | 73 | assert isinstance(data_parallel, DataParallel) 74 | 75 | old_replicate = data_parallel.replicate 76 | 77 | @functools.wraps(old_replicate) 78 | def new_replicate(module, device_ids): 79 | modules = old_replicate(module, device_ids) 80 | execute_replication_callbacks(modules) 81 | return modules 82 | 83 | data_parallel.replicate = new_replicate 84 | -------------------------------------------------------------------------------- /seg_lapa/networks/deeplab/sync_batchnorm/unittest.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : unittest.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import unittest 12 | 13 | import numpy as np 14 | from torch.autograd import Variable 15 | 16 | 17 | def as_numpy(v): 18 | if isinstance(v, Variable): 19 | v = v.data 20 | return v.cpu().numpy() 21 | 22 | 23 | class TorchTestCase(unittest.TestCase): 24 | def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3): 25 | npa, npb = as_numpy(a), as_numpy(b) 26 | self.assertTrue( 27 | np.allclose(npa, npb, atol=atol), 28 | "Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}".format( 29 | a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max() 30 | ), 31 | ) 32 | -------------------------------------------------------------------------------- /seg_lapa/train.py: -------------------------------------------------------------------------------- 1 | from typing import Any, List 2 | 3 | import hydra 4 | import numpy as np 5 | import pytorch_lightning as pl 6 | import wandb 7 | from omegaconf import DictConfig, OmegaConf 8 | from pytorch_lightning import loggers as pl_loggers 9 | 10 | from seg_lapa import metrics 11 | from seg_lapa.callbacks.log_media import LogMediaQueue, Mode 12 | from seg_lapa.config_parse.train_conf import ParseConfig 13 | from seg_lapa.loss_func import CrossEntropy2D 14 | from seg_lapa.utils import utils 15 | from seg_lapa.utils.utils import is_rank_zero 16 | 17 | 18 | class DeeplabV3plus(pl.LightningModule, ParseConfig): 19 | def __init__(self, cfg: DictConfig, log_media_max_batches=1): 20 | super().__init__() 21 | self.save_hyperparameters() # Will save the config to wandb too 22 | # Accessing cfg via hparams allows value to be loaded from checkpoints 23 | self.config = self.parse_config(self.hparams.cfg) 24 | 25 | self.cross_entropy_loss = CrossEntropy2D(loss_per_image=True, ignore_index=255) 26 | self.model = self.config.model.get_model() 27 | 28 | self.iou_train = metrics.Iou(num_classes=self.config.model.num_classes) 29 | self.iou_val = metrics.Iou(num_classes=self.config.model.num_classes) 30 | self.iou_test = metrics.Iou(num_classes=self.config.model.num_classes) 31 | 32 | # Logging media such a images using `self.log()` is extremely memory-expensive. 33 | # Save predictions to be logged within a circular queue, to be consumed in the LogMedia callback. 34 | self.log_media: LogMediaQueue = LogMediaQueue(log_media_max_batches) 35 | 36 | def forward(self, x): 37 | """In lightning, forward defines the prediction/inference actions. 38 | This method can be called elsewhere in the LightningModule with: `outputs = self(inputs)`. 39 | """ 40 | outputs = self.model(x) 41 | return outputs 42 | 43 | def training_step(self, batch, batch_idx): 44 | """Defines the train loop. It is independent of forward(). 45 | Don’t use any cuda or .to(device) calls in the code. PL will move the tensors to the correct device. 46 | """ 47 | inputs, labels = batch 48 | outputs = self.model(inputs) 49 | predictions = outputs.argmax(dim=1) 50 | 51 | # Calculate Loss 52 | loss = self.cross_entropy_loss(outputs, labels) 53 | 54 | """Log the value on GPU0 per step. Also log average of all steps at epoch_end.""" 55 | self.log("Train/loss", loss, on_step=True, on_epoch=True) 56 | """Log the avg. value across all GPUs per step. Also log average of all steps at epoch_end. 57 | Alternately, you can use the ops 'sum' or 'avg'. 58 | Using sync_dist is efficient. It adds extremely minor overhead for scalar values. 59 | """ 60 | # self.log("Train/loss", loss, on_step=True, on_epoch=True, sync_dist=True, sync_dist_op="avg") 61 | 62 | # Calculate Metrics 63 | self.iou_train(predictions, labels) 64 | 65 | # Returning images is expensive - All the batches are accumulated for _epoch_end(). 66 | # Save the latst predictions to be logged in an attr. They will be consumed by the LogMedia callback. 67 | self.log_media.append({"inputs": inputs, "labels": labels, "preds": predictions}, Mode.TRAIN) 68 | 69 | return {"loss": loss} 70 | 71 | def validation_step(self, batch, batch_idx): 72 | inputs, labels = batch 73 | outputs = self.model(inputs) 74 | predictions = outputs.argmax(dim=1) 75 | 76 | # Calculate Loss 77 | loss = self.cross_entropy_loss(outputs, labels) 78 | self.log("Val/loss", loss) 79 | 80 | # Calculate Metrics 81 | self.iou_val(predictions, labels) 82 | 83 | # Save the latest predictions to be logged 84 | self.log_media.append({"inputs": inputs, "labels": labels, "preds": predictions}, Mode.VAL) 85 | 86 | return {"val_loss": loss} 87 | 88 | def test_step(self, batch, batch_idx): 89 | inputs, labels = batch 90 | outputs = self.model(inputs) 91 | predictions = outputs.argmax(dim=1) 92 | 93 | # Calculate Loss 94 | loss = self.cross_entropy_loss(outputs, labels) 95 | self.log("Test/loss", loss) 96 | 97 | # Calculate Metrics 98 | self.iou_test(predictions, labels) 99 | 100 | # Save the latest predictions to be logged 101 | self.log_media.append({"inputs": inputs, "labels": labels, "preds": predictions}, Mode.TEST) 102 | 103 | return {"test_loss": loss} 104 | 105 | def training_epoch_end(self, outputs: List[Any]): 106 | # Compute and log metrics across epoch 107 | metrics_avg = self.iou_train.compute() 108 | self.log("Train/mIoU", metrics_avg.miou) 109 | self.iou_train.reset() 110 | 111 | def validation_epoch_end(self, outputs: List[Any]): 112 | # Compute and log metrics across epoch 113 | metrics_avg = self.iou_val.compute() 114 | self.log("Val/mIoU", metrics_avg.miou) 115 | self.iou_val.reset() 116 | 117 | def test_epoch_end(self, outputs: List[Any]): 118 | # Compute and log metrics across epoch 119 | metrics_avg = self.iou_test.compute() 120 | self.log("Test/mIoU", metrics_avg.miou) 121 | self.log("Test/Accuracy", metrics_avg.accuracy.mean()) 122 | self.log("Test/Precision", metrics_avg.precision.mean()) 123 | self.log("Test/Recall", metrics_avg.recall.mean()) 124 | 125 | # Save test results as a Table (WandB) 126 | self.log_results_table_wandb(metrics_avg) 127 | self.iou_test.reset() 128 | 129 | def log_results_table_wandb(self, metrics_avg: metrics.IouMetric): 130 | if not isinstance(self.logger, pl_loggers.WandbLogger): 131 | return 132 | 133 | results = metrics.IouMetric( 134 | iou_per_class=metrics_avg.iou_per_class.cpu().numpy(), 135 | miou=metrics_avg.miou.cpu().numpy(), 136 | accuracy=metrics_avg.accuracy.cpu().numpy().mean(), 137 | precision=metrics_avg.precision.cpu().numpy().mean(), 138 | recall=metrics_avg.recall.cpu().numpy().mean(), 139 | specificity=metrics_avg.specificity.cpu().numpy().mean(), 140 | ) 141 | 142 | data = np.stack( 143 | [results.miou, results.accuracy, results.precision, results.recall, results.specificity], axis=0 144 | ) 145 | data_l = [round(x.item(), 4) for x in data] 146 | table = wandb.Table(data=data_l, columns=["mIoU", "Accuracy", "Precision", "Recall", "Specificity"]) 147 | self.logger.experiment.log({f"Test/Results": table}, commit=False) 148 | 149 | data = np.stack((np.arange(results.iou_per_class.shape[0]), results.iou_per_class)).T 150 | table = wandb.Table(data=data.round(decimals=4).tolist(), columns=["Class ID", "IoU"]) 151 | self.logger.experiment.log({f"Test/IoU_per_class": table}, commit=False) 152 | 153 | def configure_optimizers(self): 154 | optimizer = self.config.optimizer.get_optimizer(self.parameters()) 155 | 156 | ret_opt = {"optimizer": optimizer} 157 | 158 | sch = self.config.scheduler.get_scheduler(optimizer) 159 | if sch is not None: 160 | scheduler = { 161 | "scheduler": sch, # The LR scheduler instance (required) 162 | "interval": "epoch", # The unit of the scheduler's step size 163 | "frequency": 1, # The frequency of the scheduler 164 | "reduce_on_plateau": False, # For ReduceLROnPlateau scheduler 165 | "monitor": "Val/mIoU", # Metric for ReduceLROnPlateau to monitor 166 | "strict": True, # Whether to crash the training if `monitor` is not found 167 | "name": None, # Custom name for LearningRateMonitor to use 168 | } 169 | 170 | ret_opt.update({"lr_scheduler": scheduler}) 171 | 172 | return ret_opt 173 | 174 | 175 | @hydra.main(config_path="config", config_name="train") 176 | def main(cfg: DictConfig): 177 | if is_rank_zero(): 178 | print("\nGiven Config:\n", OmegaConf.to_yaml(cfg)) 179 | 180 | config = ParseConfig.parse_config(cfg) 181 | if is_rank_zero(): 182 | print("\nResolved Dataclass:\n", config, "\n") 183 | 184 | utils.fix_seeds(config.random_seed) 185 | exp_dir = utils.generate_log_dir_path(config) 186 | 187 | wb_logger = config.logger.get_logger(cfg, config.logs_root_dir) 188 | callbacks = config.callbacks.get_callbacks_list(exp_dir, cfg) 189 | dm = config.dataset.get_datamodule() 190 | 191 | # Load weights 192 | if config.load_weights.path is None: 193 | model = DeeplabV3plus(cfg) 194 | else: 195 | model = DeeplabV3plus.load_from_checkpoint(config.load_weights.path, cfg=cfg) 196 | 197 | trainer = config.trainer.get_trainer(wb_logger, callbacks, config.logs_root_dir) 198 | 199 | # Run Training 200 | trainer.fit(model, datamodule=dm) 201 | 202 | # Run Testing 203 | result = trainer.test(ckpt_path=None) # Prints the final result 204 | 205 | wandb.finish() 206 | 207 | 208 | if __name__ == "__main__": 209 | main() 210 | -------------------------------------------------------------------------------- /seg_lapa/utils/path_check.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | from pathlib import Path 3 | from typing import Optional, Union 4 | 5 | 6 | class PathType(Enum): 7 | FILE = 0 8 | DIR = 1 9 | ANY = 2 10 | 11 | 12 | def get_project_root() -> Path: 13 | """Get the root dir of the project (one level above package). 14 | Used to get paths relative to the project, mainly for placing log files 15 | 16 | This function assumes that this module is at project_root/package_dir/utils_dir/module.py 17 | If this structure changes, the func must change. 18 | """ 19 | return Path(__file__).parent.parent.parent 20 | 21 | 22 | def get_path( 23 | input_path: Union[str, Path], 24 | must_exist: Optional[bool] = None, 25 | path_type: PathType = PathType.DIR, 26 | force_relative_to_project: bool = False, 27 | ) -> Path: 28 | """Converts a str to a pathlib Path with added checks 29 | 30 | Args: 31 | input_path: The path to be converted 32 | must_exist: If given, ensure that the path either exists or does not exist. 33 | - None: Don't check 34 | - True: It must exist. 35 | - False: It must not exist. 36 | path_type: Whether the path is to a dir or file. Only used if must_exist is not None. 37 | force_relative_to_project: If a relative path is given, convert it to be relative to the project root dir. 38 | Required because the package can be called from anywhere, and the relative path will 39 | be relative to where it's called from, rather than relative to the script location. 40 | 41 | Useful to always place logs in the project root, regardless of where the package 42 | is called from. 43 | """ 44 | if not isinstance(path_type, PathType): 45 | raise ValueError(f"Invalid path type '{path_type} (type={type(path_type)})'. Must be of type Enum {PathType}") 46 | 47 | input_path = Path(input_path) 48 | 49 | if not input_path.expanduser().is_absolute() and force_relative_to_project: 50 | # If the input path is relative, change it to become relative to the project root. 51 | proj_root = get_project_root() 52 | input_path = proj_root / input_path 53 | 54 | if must_exist is not None: 55 | if must_exist: 56 | # check that path exists 57 | if not input_path.exists(): 58 | raise ValueError(f"Could not find {path_type.name.lower()}. Does not exist: {input_path}") 59 | 60 | # If required, check that it is the correct type (file vs dir) 61 | if path_type is not PathType.ANY: 62 | if not input_path.is_dir() and path_type == PathType.DIR: 63 | raise ValueError(f"Not a dir: {input_path}") 64 | 65 | if not input_path.is_file() and path_type == PathType.FILE: 66 | raise ValueError(f"Not a file: {input_path}") 67 | else: 68 | # Ensure path doesn't already exist 69 | if input_path.exists(): 70 | raise ValueError(f"Path already exists: {input_path}") 71 | 72 | return input_path 73 | -------------------------------------------------------------------------------- /seg_lapa/utils/segmentation_label2rgb.py: -------------------------------------------------------------------------------- 1 | import enum 2 | 3 | import numpy as np 4 | from PIL import Image 5 | 6 | 7 | class Palette(enum.Enum): 8 | LAPA = 0 9 | 10 | 11 | class LabelToRGB: 12 | # fmt: off 13 | palette_lapa = [ 14 | # Color Palette for the LaPa dataset, which has 11 classes 15 | 0, 0, 0, # 0 background 16 | 0, 153, 255, # 1 skin 17 | 102, 255, 153, # 2 left eyebrow 18 | 0, 204, 153, # 3 right eyebrow 19 | 255, 255, 102, # 4 left eye 20 | 255, 255, 204, # 5 right eye 21 | 255, 153, 0, # 6 nose 22 | 255, 102, 255, # 7 upper lip 23 | 102, 0, 51, # 8 inner mouth 24 | 255, 204, 255, # 9 lower lip 25 | 255, 0, 10, # 10 hair 26 | ] 27 | # fmt: on 28 | 29 | def __init__(self): 30 | """Generates a color map with a unique hue for each class in label. 31 | The hues are uniformly sampled from the range [0, 1). If the num of classes is too high, then the difference 32 | between neighboring hues will become indistinguishable""" 33 | self.color_palettes = {Palette.LAPA: self.palette_lapa} 34 | 35 | def map_color_palette(self, label: np.ndarray, palette: Palette) -> np.ndarray: 36 | """Generates RGB visualization of label by applying a color palette 37 | Label should contain an uint class index per pixel. 38 | 39 | Args: 40 | Args: 41 | label (numpy.ndarray): Each pixel has uint value corresponding to it's class index 42 | Shape: (H, W), dtype: np.uint8, np.uint16 43 | palette (Palette): Which color palette to use. 44 | 45 | Returns: 46 | numpy.ndarray: RGB image, with each class mapped to a unique color. 47 | Shape: (H, W, 3), dtype: np.uint8 48 | """ 49 | if len(label.shape) != 2: 50 | raise ValueError(f"Label must have shape: (H, W). Input: {label.shape}") 51 | if not (label.dtype == np.uint8): 52 | raise ValueError(f"Label must have dtype np.uint8. Input: {label.dtype}") 53 | if not isinstance(palette, Palette): 54 | raise ValueError(f"palette must be of type {Palette}. Input: {palette}") 55 | 56 | color_palette = self.color_palettes[palette] 57 | 58 | # Check that the pallete has enough colors 59 | if len(color_palette) < label.max(): 60 | raise ValueError( 61 | f"The chosen color palette has only {len(color_palette)} values. It does not have" 62 | f" enough unique colors to represent all the values in the label ({label.max()})" 63 | ) 64 | 65 | # Map grayscale image's pixel values to RGB color palette 66 | _im = Image.fromarray(label) 67 | _im.putpalette(color_palette) 68 | _im = _im.convert(mode="RGB") 69 | im = np.asarray(_im) 70 | 71 | return im 72 | 73 | def colorize_batch_numpy(self, batch_label: np.ndarray) -> np.ndarray: 74 | """Convert a batch of numpy labels to RGB 75 | 76 | Args: 77 | batch_label (numpy.ndarray): Shape: [N, H, W], dtype=np.uint8 78 | 79 | Returns: 80 | numpy.ndarray: Colorize labels. Shape: [N, H, W, 3], dtype=np.uint8 81 | """ 82 | if not isinstance(batch_label, np.ndarray): 83 | raise TypeError(f"`batch_label` expected to be Numpy array. Got: {type(batch_label)}") 84 | if len(batch_label.shape) != 3: 85 | raise ValueError(f"`batch_label` expected shape [N, H, W]. Got: {batch_label.shape}") 86 | if batch_label.dtype != np.uint8: 87 | raise ValueError(f"`batch_label` must be of dtype np.uint8. Got: {batch_label.dtype}") 88 | 89 | batch_label_rgb = [self.map_color_palette(label, Palette.LAPA) for label in batch_label] 90 | batch_label_rgb = np.stack(batch_label_rgb, axis=0) 91 | return batch_label_rgb 92 | 93 | 94 | if __name__ == "__main__": 95 | # Create dummy mask 96 | dummy_mask = np.zeros((512, 512), dtype=np.uint8) 97 | num_classes = 11 98 | cols = dummy_mask.shape[1] 99 | for idx, class_id in enumerate(range(num_classes)): 100 | cols_slice = cols // num_classes 101 | dummy_mask[:, cols_slice * idx : cols_slice * (idx + 1)] = class_id 102 | 103 | # Colorize mask 104 | label2rgb = LabelToRGB() 105 | colorized_mask = label2rgb.map_color_palette(dummy_mask, Palette.LAPA) 106 | 107 | # View mask 108 | img = Image.fromarray(colorized_mask) 109 | img.show() 110 | -------------------------------------------------------------------------------- /seg_lapa/utils/utils.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import os 3 | from pathlib import Path 4 | from typing import Optional 5 | 6 | import pytorch_lightning as pl 7 | 8 | from seg_lapa.config_parse.logger_conf import DisabledLoggerConf, WandbConf 9 | from seg_lapa.config_parse.train_conf import TrainConf 10 | from seg_lapa.utils import path_check 11 | 12 | LOGS_DIR = "log-media" 13 | 14 | 15 | def is_rank_zero(): 16 | local_rank = int(os.environ.get("LOCAL_RANK", 0)) 17 | node_rank = int(os.environ.get("NODE_RANK", 0)) 18 | if local_rank == 0 and node_rank == 0: 19 | return True 20 | 21 | return False 22 | 23 | 24 | def generate_log_dir_path(config: TrainConf) -> Path: 25 | """Generate the path to the log dir for this run. 26 | The directory structure for logs depends on the logger used. 27 | 28 | wandb - Each run's log dir's name will contain the wandb runid for easy identification 29 | 30 | Args: 31 | config: The config dataclass. 32 | """ 33 | logs_root_dir = path_check.get_path(config.logs_root_dir, force_relative_to_project=True) 34 | 35 | # Exp directory structure would depend on the logger used 36 | logs_root_dir = logs_root_dir / LOGS_DIR / f"{config.logger.name}-logger" 37 | timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") 38 | if isinstance(config.logger, DisabledLoggerConf): 39 | exp_dir = logs_root_dir / f"{timestamp}" 40 | elif isinstance(config.logger, WandbConf): 41 | run_id = config.logger.get_run_id() 42 | exp_dir = logs_root_dir / f"{timestamp}_{run_id}" 43 | else: 44 | raise NotImplementedError(f"Generating log dir not implemented for logger: {config.logger}") 45 | 46 | return exp_dir 47 | 48 | 49 | def fix_seeds(random_seed: Optional[int]) -> None: 50 | """Fix seeds for reproducibility. 51 | Ref: 52 | https://pytorch.org/docs/stable/notes/randomness.html 53 | 54 | Args: 55 | random_seed: If None, seeds not set. If int, uses value to seed. 56 | """ 57 | if random_seed is not None: 58 | pl.seed_everything(random_seed) 59 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [tool:pytest] 2 | norecursedirs = 3 | .git 4 | dist 5 | build 6 | addopts = 7 | --strict 8 | --doctest-modules 9 | --durations=0 10 | 11 | [coverage:report] 12 | exclude_lines = 13 | pragma: no-cover 14 | pass 15 | 16 | [flake8] 17 | max-line-length = 120 18 | exclude = .tox,*.egg,build,temp 19 | select = E,W,F 20 | doctests = True 21 | verbose = 2 22 | # https://pep8.readthedocs.io/en/latest/intro.html#error-codes 23 | format = pylint 24 | # see: https://www.flake8rules.com/ 25 | ignore = 26 | E731 # Do not assign a lambda expression, use a def 27 | W504 # Line break occurred after a binary operator 28 | F401 # Module imported but unused 29 | F841 # Local variable name is assigned to but never used 30 | W605 # Invalid escape sequence 'x' 31 | 32 | # setup.cfg or tox.ini 33 | [check-manifest] 34 | ignore = 35 | *.yml 36 | .github 37 | .github/* 38 | 39 | [metadata] 40 | license_file = LICENSE 41 | description-file = README.md 42 | # long_description = file:README.md 43 | # long_description_content_type = text/markdown 44 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from setuptools import setup, find_packages 4 | 5 | setup( 6 | name="seg_lapa", 7 | version="0.2.0", 8 | description="Semantic Segmentation on the LaPa dataset using Pytorch Lightning", 9 | author="john doe", 10 | author_email="example@gmail.com", 11 | # REPLACE WITH YOUR OWN GITHUB PROJECT LINK 12 | url="https://github.com/Shreeyak/pytorch-lightning-segmentation-template", 13 | python_requires=">=3.7.7", 14 | install_requires=[ 15 | # Install torch first, depending on cuda version 16 | # "torch==1.7.1", 17 | # "torchvision==0.8.2", 18 | "pytorch-lightning==1.1.2", 19 | "gdown==3.12.2", 20 | "albumentations==0.5.2", 21 | "opencv-python==4.4.0.44", 22 | "hydra-core==1.0.4", 23 | "wandb==0.10.12", 24 | "pydantic==1.7.3", 25 | ], 26 | packages=find_packages(), 27 | ) 28 | --------------------------------------------------------------------------------