├── .autoenv.zsh ├── .github └── workflows │ └── ci-testing.yml ├── .gitignore ├── LICENSE ├── README.md ├── configs ├── data │ ├── cifar.yaml │ └── mnist.yaml ├── defaults.yaml ├── model │ ├── autoencoder.yaml │ └── classifier.yaml └── optim │ ├── adam.yaml │ └── sgd.yaml ├── environment.yaml ├── main.py ├── project ├── __init__.py ├── data │ ├── __init__.py │ ├── cifar.py │ └── mnist.py └── model │ ├── __init__.py │ ├── autoencoder.py │ ├── classifier.py │ └── lit_image_classifier.py ├── setup.cfg ├── setup.py └── tests ├── __init__.py ├── requirements.txt └── test_classifier.py /.autoenv.zsh: -------------------------------------------------------------------------------- 1 | # Install https://github.com/Tarrasch/zsh-autoenv 2 | # to automatically execute these command when cd into this project 3 | autostash HYDRA_FULL_ERROR=1 4 | conda activate project 5 | -------------------------------------------------------------------------------- /.github/workflows/ci-testing.yml: -------------------------------------------------------------------------------- 1 | name: CI testing 2 | 3 | # see: https://help.github.com/en/actions/reference/events-that-trigger-workflows 4 | on: 5 | # Trigger the workflow on push or pull request, but only for the master branch 6 | push: 7 | branches: 8 | - master 9 | pull_request: 10 | branches: 11 | - master 12 | 13 | jobs: 14 | pytest: 15 | 16 | runs-on: ${{ matrix.os }} 17 | strategy: 18 | fail-fast: false 19 | matrix: 20 | os: [ubuntu-20.04, macOS-10.15, windows-2019] 21 | python-version: [3.7] 22 | 23 | # Timeout: https://stackoverflow.com/a/59076067/4521646 24 | timeout-minutes: 35 25 | 26 | steps: 27 | - uses: actions/checkout@v2 28 | - name: Set up Python ${{ matrix.python-version }} 29 | uses: actions/setup-python@v2 30 | with: 31 | python-version: ${{ matrix.python-version }} 32 | 33 | # Github Actions: Run step on specific OS: https://stackoverflow.com/a/57948488/4521646 34 | - name: Setup macOS 35 | if: runner.os == 'macOS' 36 | run: | 37 | brew install libomp # https://github.com/pytorch/pytorch/issues/20030 38 | 39 | # Note: This uses an internal pip API and may not always work 40 | # https://github.com/actions/cache/blob/master/examples.md#multiple-oss-in-a-workflow 41 | - name: Get pip cache 42 | id: pip-cache 43 | run: | 44 | python -c "from pip._internal.locations import USER_CACHE_DIR; print('::set-output name=dir::' + USER_CACHE_DIR)" 45 | 46 | - name: Cache pip 47 | uses: actions/cache@v2 48 | with: 49 | path: ${{ steps.pip-cache.outputs.dir }} 50 | key: ${{ runner.os }}-py${{ matrix.python-version }}-${{ hashFiles('requirements.txt') }} 51 | restore-keys: | 52 | ${{ runner.os }}-py${{ matrix.python-version }}- 53 | 54 | - name: Install dependencies 55 | run: | 56 | pip install --requirement requirements.txt --upgrade --quiet --find-links https://download.pytorch.org/whl/cpu/torch_stable.html --use-feature=2020-resolver 57 | pip install --requirement tests/requirements.txt --quiet --use-feature=2020-resolver 58 | python --version 59 | pip --version 60 | pip list 61 | shell: bash 62 | 63 | - name: Tests 64 | run: | 65 | coverage run --source project -m py.test project tests -v --junitxml=junit/test-results-${{ runner.os }}-${{ matrix.python-version }}.xml 66 | 67 | - name: Statistics 68 | if: success() 69 | run: | 70 | coverage report 71 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | # IDEs 107 | .idea 108 | .vscode 109 | 110 | # Misc 111 | .DS_Store 112 | .tags 113 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deep Learning project template 2 | Use this template to rapidly bootstrap a DL project: 3 | 4 | - Write code in [Pytorch Lightning](https://www.pytorchlightning.ai/)'s `LightningModule` and `LightningDataModule`. 5 | - Run code from composable `yaml` configurations with [Hydra](https://hydra.cc/). 6 | - Manage packages in `environment.yaml` with [conda](https://docs.conda.io/projects/conda/en/latest/glossary.html#miniconda-glossary). 7 | - Log and visualize metrics + hyperparameters with [Tensorboard](https://tensorboard.dev/). 8 | - Sane default with best/good practices only where it makes sense for small-scale research-style project. 9 | 10 | Have an issue, found a bug, know a better practice? Feel free to open an issue, pull request or discussion thread. All contribution welcome. 11 | 12 | I hope to maintaining this repo with better deep learning engineering practices as they evolve. 13 | 14 | ## Quick start 15 | 16 |
Click to expand/collapse 17 |

18 | 19 | ### 0. Clone this template 20 | ```bash 21 | # clone project or create a new one from GitHub's template 22 | git clone https://github.com/lkhphuc/lightning-hydra-template new-project 23 | cd new-project 24 | rm -rf .git 25 | git init # Start of a new git history 26 | ``` 27 | 28 | ### 1. Add project's info 29 | - Edit [`setup.py`](setup.py) and add relevant information. 30 | - Rename the directory `project/` to the your project name. 31 | 32 | ### 2. Create environment and install dependencies 33 | - Name your environment and add packages in [`environment.yaml`](environment.yaml), then create/update the environment with: 34 | ```bash 35 | # Run this command every time you update environment.yaml 36 | conda env update -f environment.yaml 37 | ``` 38 | 39 | ### 3. Create Pytorch Lightning modules 40 | - `LightningModule`s are organized under [`project/model/`](project/model/). 41 | - `LightningDataModule`s are organized under [`project/data/`](project/data/). 42 | 43 | Each Lightning module should be in one separate file, while each file can contain all the relevant `nn.Module`s for that model. 44 | 45 | ### 4. Create Hydra configs 46 | Each `.py` file has its own corresponding `.yaml` file, such as `project/model/autoencoder.py` and `configs/model/autoencoder.yaml`. 47 | 48 | All `yaml` files are stored under `configs/` and the structure of this folder should be identical to the structure of the `project/`. 49 | ```bash 50 | $ tree project $ tree configs 51 | project configs 52 | ├── __init__.py ├── defaults.yaml 53 | ├── data ├── data 54 | │   ├── cifar.py │   ├── cifar.yaml 55 | │   └── mnist.py │   └── mnist.yaml 56 | └── model ├── model 57 | ├── autoencoder.py │   ├── autoencoder.yaml 58 | ├── classifier.py │   └── classifier.yaml 59 | └── optim 60 | ├── adam.yaml 61 | └── sgd.yaml 62 | ``` 63 | [`configs/defaults.yaml`](configs/defaults.yaml) contains all the defaults modules and arguments, including that for the `Trainer()`. 64 | 65 | 66 | ### 5. Run 67 | ```bash 68 | # This will run with all the default arguments 69 | python main.py 70 | # Override defaults from command line 71 | python main.py model=autoencoder data=cifar trainer.gpus=8 72 | ``` 73 |

74 |
75 | 76 | ## How it works 77 | This section will provide a brief introduction on how these components all come together. 78 | Please refer to the original documents of [Pytorch Lightning](pytorchlightning.ai/), [Hydra](hydra.cc/) and [TensorBoard](tensorboard.dev) for details. 79 | 80 |
Click to expand/collapse 81 |

82 | 83 | ### Entry points 84 | The launching point of the project is [`main.py`](main.py) located in the root directory. 85 | The `main()` function takes in a `DictConfig` object, which is prepared by `hydra` based on the `yaml` files and command line arguments provided at runtime. 86 | 87 | This is achieved by decorating the script `main()` function with `hydra.main()`, which requires a path to all the configs and a default `.yaml` file as follow: 88 | ```python 89 | @hydra.main(config_path="configs", config_name="defaults") 90 | def main(cfg: DictConfig) -> None: ... 91 | ``` 92 | This allow us to define multiple entry points for different functionalities with different defaults, such as `train.py`, `ensemble.py`, `test.py`, etc. 93 | 94 | 95 | ### Dynamically instantiate modules 96 | We will [use Hydra to instantiate objects](https://hydra.cc/docs/patterns/instantiate_objects/overview). 97 | This allow us to use the same entry point (`main.py` above) to dynamically combine different models and data modules. 98 | Given a [`configs/defaults.yaml`](configs/defaults.yaml) file contains: 99 | ```yaml 100 | defaults: 101 | - data: mnist # Path to sub-config, can also omit the .yaml extension 102 | - model: classifier.yaml # full path for ease of navigation (e.g vim cursor in path, press gf) 103 | ``` 104 | 105 | Different modules can be instantiated for each run by supplying a different set of configuration: 106 | ```bash 107 | # Using default 108 | $ python main.py 109 | 110 | # The default is equivalent to 111 | $ python main.py model=classifier data=mnist 112 | 113 | # Override a default module 114 | $ python main.py model=autoencoder 115 | $ python main.py data=cifar 116 | 117 | # Override multiple default modules and arguments 118 | $ python main.py model=autoencoder data=cifar trainer.gpus=4 119 | ``` 120 | 121 | In python, the module will be instantiated by a line, for example `data_module = hydra.utils.instantiate(cfg.data)`. 122 | 123 | `cfg.data` is a `DictConfig` object created by `hydra` at runtime, and is stored in a config file, for example [`configs/data/mnist.yaml`](configs/data/mnist.yaml): 124 | ```yaml 125 | name: mnist 126 | 127 | # _target_ class to instantiate 128 | _target_: project.data.MNISTDataModule 129 | # Argument to feed into __init__() of target module 130 | data_dir: ~/datasets/MNIST/ # Use absolute path 131 | batch_size: 4 132 | num_workers: 2 133 | 134 | # Can also define arbitrary info specific to this module 135 | input_dim: 784 136 | output_dim: 10 137 | ``` 138 | and the _target_: `project.data.MNISTDataModule` to be instantiated is: 139 | ```python 140 | class MNISTDataModule(pl.LightningDataModule): 141 | def __init__(self, data_dir: str = "", 142 | batch_size: int = 32, 143 | num_workers: int = 8, 144 | **kwargs): ... 145 | # kwargs is used to handle arguments in the DictConfig but not used for init 146 | ``` 147 | 148 | ### Directory management 149 | Since `hydra` manages our entry point and command line arguments, it also manages the output directory of each run. 150 | We can easily customize the output directory to suit our project via [`defaults.yaml`](configs/defaults.yaml) 151 | ```yaml 152 | hydra: 153 | run: 154 | # Configure output dir of each experiment programmatically from the arguments 155 | # Example "outputs/mnist/classifier/baseline/2021-03-10-141516" 156 | dir: outputs/${data.name}/${model.name}/${experiment}/${now:%Y-%m-%d_%H%M%S} 157 | ``` 158 | and tell `TensorBoardLogger()` to use the current working directory without adding anything: 159 | ```python 160 | tensorboard = pl.loggers.TensorBoardLogger(".", "", "") 161 | ``` 162 | 163 |

164 |
165 | 166 | ## Best practices 167 | 168 |
Click to expand/collapse 169 |

170 | 171 | ### `LightningModule` and `LightningDataModule` 172 | #### Be explicit about input arguments 173 | Each modules should be self-contained and self-explanatory, to maximize reusability, even across projects. 174 | - **Don't** do this: 175 | ```python 176 | class LitAutoEncoder(pl.LightningModule): 177 | def __init__(self, cfg, **kwargs): 178 | super().__init__() 179 | self.cfg = cfg 180 | ``` 181 | You will not like it when having to track down the config file every time just to remember what are the input arguments, their types and default values. 182 | 183 | - Do this instead: 184 | ```python 185 | class LitAutoEncoder(pl.LightningModule): 186 | def __init__(self, 187 | input_dim: int, output_dim: int, hidden_dim: int = 64, 188 | optim_encoder=None, optim_decoder=None, 189 | **kwargs): 190 | super().__init__() 191 | self.save_hyperparameters() 192 | # Later all input arguments can be accessed anywhere by 193 | self.hparams.input_dim 194 | # Use this to avoid boilderplate code such as 195 | self.input_dim = input_dim 196 | self.output_dim = output_dim 197 | ``` 198 | 199 | 200 | Also see Pytorch Lightning's [official style guide](https://pytorch-lightning.readthedocs.io/en/latest/starter/style_guide.html). 201 | 202 | ### Tensorboard 203 | - Use forward slash `/` in naming metrics to group it together. 204 | - Don't: `loss_val`, `loss_train` 205 | - Do: `loss/val`, `loss_train` 206 | - Group metrics by type, not on what data it was evaluate with: 207 | - Don't: `val/loss`, `val/accuracy`, `train/loss`, `train/acc` 208 | - Do: `loss/val`, `loss/train`, `accuracy/val`, `accuracy/train` 209 | ![Metric grouping](https://pytorch.org/docs/stable/_images/hier_tags.png) 210 | - Log computation graph of `LightningModule` by: 211 | - Define `self.example_input_array` in your module's `__init__()` 212 | - Enable in TensorBoard with `TensorBoard(log_graph=True)` 213 | ![Compute Graph](https://raw.githubusercontent.com/tensorflow/tensorboard/master/docs/images/graphs_conceptual.png) 214 | - [Proper loggin](https://pytorch-lightning.readthedocs.io/en/latest/extensions/logging.html#logging-hyperparameters) of hyper-parameters and metrics 215 | ![Tensorboard Parallel Coordinate](https://www.tensorflow.org/tensorboard/images/hparams_parallel_coordinates.png) 216 | 217 | 218 | ### Hydra 219 | 220 | #### Script is for one run, launcher is for multiple run 221 | Hydra serves two intertwined purposes, configuration management and script launcher. 222 | These two purposes are dealt with jointly because each run can potentially has a different set of configs. 223 | 224 | This provides a nice separation of concerns, in which the python scripts only focus on the functionalities of individual run, while the `hydra` command line will orchestrate multiple runs. 225 | With this separation, it's easy to use Hydra's [sweeper](https://hydra.cc/docs/plugins/ax_sweeper) to do hyperparameters search, or [launcher](https://hydra.cc/docs/plugins/submitit_launcher) to run experiments on SLURM cluster or cloud. 226 | 227 | #### Provide absolute path in config 228 | To provide path into program, it's best to provide an absolute path for both local or cloud storage (start with `~`, `/`, `s3://`). 229 | 230 | That way you don't have litter your code with `hydra.utils.get_original_cwd()` to convert relative path, and therefore retaining the flexibility to use your module outside of `hydra`-managed entry points. 231 | 232 | #### Naming experiments 233 | Use `hydra` to created a hierarchical structure for experiments output based on configurations of each run, by setting the `configs/defaults.yaml` with 234 | ``` 235 | dir: outputs/${data.name}/${model.name}/${experiment}/${now:%Y-%m-%d_%H%M%S} 236 | ``` 237 | 238 | - `${data.name}/${model.name}` will be dynamically determined from config object. They are preferably nested by the order of least frequently changed. 239 | - `${experiment}` is a string briefly describe the purpose of the experiment 240 | - `${now:%Y-%m-%d_%H%M%S}` will insert the time of run, serves as a unique identifier for runs differ only in minor hyperparameters such as learning rate. 241 | 242 | Example output:`outputs/mnist/classifier/baseline/2021-03-10-141516`. 243 | 244 | 245 |

246 |
247 | 248 | ## Tips and tricks 249 | 250 |
Click to expand/collapse 251 |

252 | 253 | ### Debug 254 | 255 | - Drop into a debugger anywhere in your code with a single line `import pdb; pdb.set_trace()`. 256 | - Use `ipdb` or [pudb](github.com/inducer/pudb) for nicer debugging experience, for example `import pudb; pudb.set_trace()` 257 | - Or just use `breakpoint()` for Python 3.7 or above. Set `PYTHONBREAKPOINT` environment variable to make `breakpoint()` use `ipdb` or `pudb`, for example `PYTHONBREAKPOINT=pudb.set_trace`. 258 | - Post mortem debugging by running script with `ipython --pdb`. It opens a debugger and drop you right into when and where an Exception is raised. 259 | ```bash 260 | $ ipython --pdb main.py -- model=autoencoder 261 | ``` 262 | This is super helpful to inspect the variables values when it fails, without having to put a breakpoint and then run the script again, which can takes a long time to start for deep learning model. 263 | - Use `fast_dev_run` of PytorchLightning, and checkout the entire [debugging tutorial](https://pytorch-lightning.readthedocs.io/en/stable/common/debugging.html). 264 | 265 | ### Colored Logs 266 | 267 | It's 2021 already, don't squint at your 4K HDR Quantum dot monitor to find a line from the black & white log. 268 | `pip install hydra-colorlog` and edit `defaults.yaml` to colorize your log file: 269 | ```yaml 270 | defaults: 271 | - override hydra/job_logging: colorlog 272 | - override hydra/hydra_logging: colorlog 273 | ``` 274 | This will colorize any python logger you created anywhere with: 275 | ```python 276 | import logging 277 | logger = logging.getLogger(__name__) 278 | logger.info("My log") 279 | ``` 280 | 281 | Alternative: [loguru](https://github.com/Delgan/loguru), [coloredlogs](https://github.com/xolox/python-coloredlogs). 282 | 283 | ### Auto activate conda environment and export variables 284 | 285 | [Zsh-autoenv](https://github.com/Tarrasch/zsh-autoenv) will auto source the content of `.autoenv.zsh` when you `cd` into a folder contains that file. 286 | Say goodbye to activate conda or export a bunch of variables for every new terminal: 287 | ```bash 288 | conda activate project 289 | HYDRA_FULL_ERROR=1 290 | PYTHON_BREAKPOINT=pudb.set_trace 291 | ``` 292 | 293 | Alternative: https://github.com/direnv/direnv, https://github.com/cxreg/smartcd, https://github.com/kennethreitz/autoenv 294 | 295 |

296 |
297 | 298 | 299 | ## TODO 300 | - [ ] Pre-commit hook for python `black`, `isort`. 301 | - [ ] [Experiments](https://hydra.cc/docs/next/patterns/configuring_experiments) 302 | - [ ] Configure trainer's callbacks from configs as well. 303 | - [ ] [Structured Configs](https://hydra.cc/docs/next/tutorials/structured_config/intro/#internaldocs-banner) 304 | - [ ] [Hydra Torch](https://github.com/pytorch/hydra-torch) and [Hydra Lightning](https://github.com/romesco/hydra-lightning) 305 | - [ ] [Keepsake](https://keepsake.ai/) version control 306 | - [ ] (Maybe) Unit test (only where it makes sense). 307 | 308 | 309 | # DELETE EVERYTHING ABOVE FOR YOUR PROJECT 310 | 311 | --- 312 | 313 |
314 | 315 | # ConSelfSTransDRLIB: 316 | ## Contrastive Self-supervised Transformers for Disentangled Representation Learning with Inductive Biases is All you need, and where to find them. 317 | 318 | [![Paper](http://img.shields.io/badge/paper-arxiv.1001.2234-B31B1B.svg)](https://www.nature.com/articles/nature14539) 319 | [![Conference](http://img.shields.io/badge/NeurIPS-2019-4b44ce.svg)](https://papers.nips.cc/book/advances-in-neural-information-processing-systems-31-2018) 320 | [![Conference](http://img.shields.io/badge/ICLR-2019-4b44ce.svg)](https://papers.nips.cc/book/advances-in-neural-information-processing-systems-31-2018) 321 | [![Conference](http://img.shields.io/badge/AnyConference-year-4b44ce.svg)](https://papers.nips.cc/book/advances-in-neural-information-processing-systems-31-2018) 322 | ![CI testing](https://github.com/PyTorchLightning/deep-learning-project-template/workflows/CI%20testing/badge.svg?branch=master&event=push) 323 | 324 |
325 | 326 | ## Description 327 | Code for paper paper. 328 | 329 | ## How to run 330 | ```bash 331 | python main.py 332 | ``` 333 | 334 | 335 | ### Citation 336 | ``` 337 | @article{YourName, 338 | title={Your Title}, 339 | author={Your team}, 340 | journal={Location}, 341 | year={Year} 342 | } 343 | ``` 344 | -------------------------------------------------------------------------------- /configs/data/cifar.yaml: -------------------------------------------------------------------------------- 1 | name: cifar10 2 | 3 | _target_: project.data.CIFARDataModule 4 | # Use absolute path for input to avoid messing with hydra output paths 5 | data_dir: ~/datasets/CIFAR/ 6 | batch_size: 4 7 | num_workers: 2 8 | 9 | input_dim: 3072 # 3*32*32 10 | output_dim: 10 11 | -------------------------------------------------------------------------------- /configs/data/mnist.yaml: -------------------------------------------------------------------------------- 1 | name: mnist 2 | 3 | _target_: project.data.MNISTDataModule 4 | data_dir: ~/datasets/MNIST/ # Use absolute path 5 | batch_size: 4 6 | num_workers: 2 7 | 8 | input_dim: 784 9 | output_dim: 10 10 | -------------------------------------------------------------------------------- /configs/defaults.yaml: -------------------------------------------------------------------------------- 1 | hydra: 2 | run: 3 | # Configure output dir of each experiment programmatically from the arguments 4 | # Example "outputs/mnist/classifier/baseline/2021-03-10-141516" 5 | dir: outputs/${data.name}/${model.name}/${experiment}/${now:%Y-%m-%d_%H%M%S} 6 | 7 | # Global configurations shared between different modules 8 | experiment: baseline 9 | 10 | # Composing nested config with default 11 | defaults: 12 | - data: mnist # Path to sub-config, can also omit the .yaml extension 13 | - model: classifier.yaml # I add full path for easy navigation in vim (cursor in path, press gf) 14 | 15 | - override hydra/job_logging: colorlog 16 | - override hydra/hydra_logging: colorlog 17 | 18 | # Pytorch lightning trainer's argument 19 | # default flags are commented to avoid clustering the hyperparameters 20 | trainer: 21 | # accelerator: None 22 | # accumulate_grad_batches: 1 23 | # amp_backend: native 24 | # amp_level: O2 25 | # auto_lr_find: False 26 | # auto_scale_batch_size: False 27 | # auto_select_gpus: False 28 | benchmark: True 29 | # check_val_every_n_epoch: 1 30 | # checkpoint_callback: True 31 | # default_root_dir: 32 | # deterministic: False 33 | # fast_dev_run: False 34 | # flush_logs_every_n_steps: 100 35 | # gpus: 36 | # gradient_clip_val: 0 37 | # limit_predict_batches: 1.0 38 | # limit_test_batches: 1.0 39 | # limit_train_batches: 1.0 40 | # limit_val_batches: 1.0 41 | # log_every_n_steps: 50 42 | # log_gpu_memory: False 43 | # logger: True 44 | # max_epochs: None 45 | # max_steps: None 46 | # min_epochs: None 47 | # min_steps: None 48 | # move_metrics_to_cpu: False 49 | # multiple_trainloader_mode: max_size_cycle 50 | # num_nodes: 1 51 | # num_processes: 1 52 | # num_sanity_val_steps: 2 53 | # overfit_batches: 0.0 54 | # plugins: None 55 | # precision: 16 56 | # prepare_data_per_node: True 57 | # process_position: 0 58 | # profiler: None 59 | # progress_bar_refresh_rate: None 60 | # reload_dataloaders_every_epoch: False 61 | # replace_sampler_ddp: True 62 | # resume_from_checkpoint: None 63 | # stochastic_weight_avg: False 64 | # sync_batchnorm: False 65 | terminate_on_nan: True 66 | # track_grad_norm: -1 67 | # truncated_bptt_steps: None 68 | # val_check_interval: 1.0 69 | # weights_save_path: None 70 | # weights_summary: top 71 | -------------------------------------------------------------------------------- /configs/model/autoencoder.yaml: -------------------------------------------------------------------------------- 1 | name: LitAutoEncoder 2 | 3 | _target_: project.model.autoencoder.LitAutoEncoder 4 | hidden_dim: 128 5 | 6 | defaults: 7 | - /optim/sgd.yaml@optim_encoder 8 | - /optim/adam.yaml@optim_decoder 9 | -------------------------------------------------------------------------------- /configs/model/classifier.yaml: -------------------------------------------------------------------------------- 1 | name: LitClassifier 2 | 3 | _target_: project.model.classifier.LitClassifier 4 | hidden_dim: 128 5 | 6 | defaults: 7 | - /optim/adam 8 | 9 | -------------------------------------------------------------------------------- /configs/optim/adam.yaml: -------------------------------------------------------------------------------- 1 | _target_: torch.optim.Adam 2 | lr: 1e-3 3 | 4 | -------------------------------------------------------------------------------- /configs/optim/sgd.yaml: -------------------------------------------------------------------------------- 1 | _target_: torch.optim.SGD 2 | lr: 1e-3 3 | momentum: 0.99 4 | 5 | -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: project 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - python=3.8 7 | - ipython 8 | - numpy 9 | - cudatoolkit=11.0 10 | - pytorch>=1.8 11 | - torchvision>=0.8 12 | - pip 13 | - pip: 14 | - -e . # Install project package in edit mode, so you don't have to run this everytime you edit your python files 15 | - -r tests/requirements.txt 16 | - pytorch-lightning>=1.2.0 17 | - hydra-core --pre # Using 1.1 pre-release for the moment 18 | - hydra-colorlog --pre 19 | - pudb 20 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import hydra 4 | import pytorch_lightning as pl 5 | from omegaconf import DictConfig, OmegaConf 6 | 7 | logger = logging.getLogger(__name__) 8 | 9 | @hydra.main(config_path="configs", config_name="defaults") 10 | def main(cfg: DictConfig) -> None: 11 | pl.seed_everything(1234) 12 | logger.info("\n" + OmegaConf.to_yaml(cfg)) 13 | 14 | # Instantiate all modules specified in the configs 15 | model = hydra.utils.instantiate( 16 | cfg.model, # Object to instantiate 17 | # Overwrite arguments at runtime that depends on other modules 18 | input_dim=cfg.data.input_dim, 19 | output_dim=cfg.data.output_dim, 20 | # Don't instantiate optimizer submodules with hydra, let `configure_optimizers()` do it 21 | _recursive_=False, 22 | ) 23 | 24 | data_module = hydra.utils.instantiate(cfg.data) 25 | 26 | # Let hydra manage direcotry outputs 27 | tensorboard = pl.loggers.TensorBoardLogger(".", "", "", log_graph=True, default_hp_metric=False) 28 | callbacks = [ 29 | pl.callbacks.ModelCheckpoint(monitor='loss/val'), 30 | pl.callbacks.EarlyStopping(monitor='loss/val', patience=50), 31 | ] 32 | 33 | trainer = pl.Trainer( 34 | **OmegaConf.to_container(cfg.trainer), 35 | logger=tensorboard, 36 | callbacks=callbacks, 37 | ) 38 | 39 | trainer.fit(model, datamodule=data_module) 40 | trainer.test(model, datamodule=data_module) # Optional 41 | 42 | 43 | if __name__ == '__main__': 44 | main() 45 | -------------------------------------------------------------------------------- /project/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lkhphuc/lightning-hydra-template/9dde822f32612fec421b76c59eeaf530ad3d324f/project/__init__.py -------------------------------------------------------------------------------- /project/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .mnist import MNISTDataModule 2 | from .cifar import CIFARDataModule 3 | -------------------------------------------------------------------------------- /project/data/cifar.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import pytorch_lightning as pl 4 | from torch.utils.data import DataLoader, random_split 5 | from torchvision.datasets import CIFAR10 6 | from torchvision.transforms import ToTensor 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | class CIFARDataModule(pl.LightningDataModule): 11 | def __init__(self, data_dir: str = "", batch_size: int = 32, num_workers: int = 8, **kwargs): 12 | super().__init__() 13 | self.data_dir = data_dir 14 | self.batch_size = batch_size 15 | self.num_workers = num_workers 16 | logger.info("Initialize CIFAR DataModule") 17 | 18 | def setup(self, stage=None): 19 | if stage == 'fit' or stage is None: 20 | cifar_full = CIFAR10(self.data_dir, train=True, download=True, transform=ToTensor()) 21 | split_size = [int(0.8 * len(cifar_full)), int(0.2 * len(cifar_full))] 22 | self.mnist_train, self.mnist_val = random_split(cifar_full, split_size) 23 | self.dims = tuple(self.mnist_train[0][0].shape) 24 | elif stage == 'test' or stage is None: 25 | self.mnist_test = CIFAR10(self.data_dir, train=False, download=True, transform=ToTensor()) 26 | 27 | def train_dataloader(self): 28 | return DataLoader(self.mnist_train, batch_size=self.batch_size, num_workers=self.num_workers) 29 | 30 | # Double workers for val and test loaders since there is no backward pass and GPU computation is faster 31 | def val_dataloader(self): 32 | return DataLoader(self.mnist_val, batch_size=self.batch_size, num_workers=self.num_workers*2) 33 | 34 | def test_dataloader(self): 35 | return DataLoader(self.mnist_test, batch_size=self.batch_size, num_workers=self.num_workers*2) 36 | -------------------------------------------------------------------------------- /project/data/mnist.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | from torch.utils.data import DataLoader, random_split 3 | from torchvision.datasets import MNIST 4 | from torchvision.transforms import ToTensor 5 | 6 | 7 | class MNISTDataModule(pl.LightningDataModule): 8 | def __init__(self, data_dir: str = "", batch_size: int = 32, num_workers: int = 8, **kwargs): 9 | super().__init__() 10 | #TODO after merge https://github.com/PyTorchLightning/pytorch-lightning/pull/3792 11 | # self.save_hyperparameters() 12 | self.data_dir = data_dir 13 | self.batch_size = batch_size 14 | self.num_workers = num_workers 15 | 16 | def setup(self, stage=None): 17 | if stage == 'fit' or stage is None: 18 | mnist_full = MNIST(self.data_dir, train=True, download=True, transform=ToTensor()) 19 | self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000]) 20 | self.dims = tuple(self.mnist_train[0][0].shape) 21 | elif stage == 'test' or stage is None: 22 | self.mnist_test = MNIST(self.data_dir, train=False, download=True, transform=ToTensor()) 23 | 24 | def train_dataloader(self): 25 | return DataLoader(self.mnist_train, batch_size=self.batch_size, num_workers=self.num_workers) 26 | 27 | # Double workers for val and test loaders since there is no backward pass and GPU computation is faster 28 | def val_dataloader(self): 29 | return DataLoader(self.mnist_val, batch_size=self.batch_size, num_workers=self.num_workers*2) 30 | 31 | def test_dataloader(self): 32 | return DataLoader(self.mnist_test, batch_size=self.batch_size, num_workers=self.num_workers*2) 33 | -------------------------------------------------------------------------------- /project/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lkhphuc/lightning-hydra-template/9dde822f32612fec421b76c59eeaf530ad3d324f/project/model/__init__.py -------------------------------------------------------------------------------- /project/model/autoencoder.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | import pytorch_lightning as pl 3 | import torch.nn.functional as F 4 | from torch import nn 5 | 6 | 7 | class LitAutoEncoder(pl.LightningModule): 8 | def __init__(self, input_dim, output_dim, hidden_dim=64, optim_encoder=None, optim_decoder=None, **kwargs): 9 | super().__init__() 10 | self.save_hyperparameters() 11 | 12 | self.encoder = nn.Sequential( 13 | nn.Linear(input_dim, hidden_dim), 14 | nn.ReLU(), 15 | nn.Linear(hidden_dim, output_dim) 16 | ) 17 | self.decoder = nn.Sequential( 18 | nn.Linear(output_dim, hidden_dim), 19 | nn.ReLU(), 20 | nn.Linear(hidden_dim, input_dim) 21 | ) 22 | 23 | def forward(self, x): 24 | # in lightning, forward defines the prediction/inference actions 25 | embedding = self.encoder(x) 26 | return embedding 27 | 28 | def training_step(self, batch, batch_idx, optimizer_idx): 29 | x, y = batch 30 | x = x.view(x.size(0), -1) 31 | z = self.encoder(x) 32 | x_hat = self.decoder(z) 33 | loss = F.mse_loss(x_hat, x) 34 | return loss 35 | 36 | def configure_optimizers(self): 37 | encoder_optim = hydra.utils.instantiate(self.hparams.optim_encoder, params=self.encoder.parameters()) 38 | decoder_optim = hydra.utils.instantiate(self.hparams.optim_decoder, params=self.decoder.parameters()) 39 | return [encoder_optim, decoder_optim], [] 40 | 41 | def on_train_start(self): 42 | # Proper logging of hyperparams and metrics in TB 43 | self.logger.log_hyperparams(self.hparams, {"loss/val": 0, "accuracy/val": 0, "accuracy/test": 0}) 44 | 45 | -------------------------------------------------------------------------------- /project/model/classifier.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | import pytorch_lightning as pl 3 | import torch 4 | from torch.nn import functional as F 5 | 6 | 7 | class LitClassifier(pl.LightningModule): 8 | def __init__(self, optim, input_dim, output_dim, hidden_dim=128, **kwargs): 9 | super().__init__() 10 | self.save_hyperparameters() 11 | 12 | self.l1 = torch.nn.Linear(input_dim, hidden_dim) 13 | self.l2 = torch.nn.Linear(hidden_dim, output_dim) 14 | 15 | self.val_accuracy = pl.metrics.Accuracy() 16 | self.test_accuracy = pl.metrics.Accuracy() 17 | 18 | self.example_input_array = torch.randn([input_dim, input_dim]) 19 | 20 | def forward(self, x): 21 | x = x.view(x.size(0), -1) 22 | x = torch.relu(self.l1(x)) 23 | x = torch.relu(self.l2(x)) 24 | return x 25 | 26 | def training_step(self, batch, batch_idx): 27 | x, y = batch 28 | y_hat = self(x) 29 | loss = F.cross_entropy(y_hat, y) 30 | self.log("loss/train", loss) 31 | return loss 32 | 33 | def validation_step(self, batch, batch_idx): 34 | x, y = batch 35 | y_hat = self(x) 36 | loss = F.cross_entropy(y_hat, y) 37 | self.log('loss/val', loss) 38 | accuracy = self.val_accuracy(torch.softmax(y_hat, dim=1), y) 39 | self.log('accuracy/val', accuracy) 40 | 41 | def validation_epoch_end(self, outputs): 42 | self.log("accuracy/val", self.val_accuracy.compute()) 43 | 44 | def test_step(self, batch, batch_idx): 45 | x, y = batch 46 | y_hat = self(x) 47 | loss = F.cross_entropy(y_hat, y) 48 | self.log('loss/test', loss) 49 | test_accuracy = self.test_accuracy(torch.softmax(y_hat, dim=1), y) 50 | return test_accuracy 51 | 52 | def test_epoch_end(self, outputs): 53 | self.log('accuracy/test', self.test_accuracy.compute()) 54 | 55 | 56 | def configure_optimizers(self): 57 | return hydra.utils.instantiate(self.hparams.optim, params=self.parameters()) 58 | 59 | def on_train_start(self): 60 | # Proper logging of hyperparams and metrics in TB 61 | self.logger.log_hyperparams(self.hparams, {"loss/val": 0, "accuracy/val": 0, "accuracy/test": 0}) 62 | 63 | -------------------------------------------------------------------------------- /project/model/lit_image_classifier.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | import pytorch_lightning as pl 3 | import torch 4 | from torch.nn import functional as F 5 | 6 | 7 | class Backbone(torch.nn.Module): 8 | def __init__(self, hidden_dim=128): 9 | super().__init__() 10 | self.l1 = torch.nn.Linear(28 * 28, hidden_dim) 11 | self.l2 = torch.nn.Linear(hidden_dim, 10) 12 | 13 | def forward(self, x): 14 | x = x.view(x.size(0), -1) 15 | x = torch.relu(self.l1(x)) 16 | x = torch.relu(self.l2(x)) 17 | return x 18 | 19 | 20 | class LitClassifier(pl.LightningModule): 21 | def __init__(self, backbone, optim, **kwargs): 22 | super().__init__() 23 | self.save_hyperparameters() 24 | self.backbone = backbone 25 | 26 | def forward(self, x): 27 | # use forward for inference/predictions 28 | embedding = self.backbone(x) 29 | return embedding 30 | 31 | def training_step(self, batch, batch_idx): 32 | x, y = batch 33 | y_hat = self.backbone(x) 34 | loss = F.cross_entropy(y_hat, y) 35 | self.log('train_loss', loss, on_epoch=True) 36 | return loss 37 | 38 | def validation_step(self, batch, batch_idx): 39 | x, y = batch 40 | y_hat = self.backbone(x) 41 | loss = F.cross_entropy(y_hat, y) 42 | self.log('valid_loss', loss, on_step=True) 43 | 44 | def test_step(self, batch, batch_idx): 45 | x, y = batch 46 | y_hat = self.backbone(x) 47 | loss = F.cross_entropy(y_hat, y) 48 | self.log('test_loss', loss) 49 | 50 | def configure_optimizers(self): 51 | # self.hparams available because we called self.save_hyperparameters() 52 | return hydra.utils.initiate(self.hparams.optim, params=self.parameters) 53 | 54 | def on_train_start(self): 55 | # Proper logging of hyperparams and metrics in TB 56 | self.logger.log_hyperparams(self.hparams, {"loss/val": 0, "accuracy/val": 0, "accuracy/test": 0}) 57 | 58 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [tool:pytest] 2 | norecursedirs = 3 | .git 4 | dist 5 | build 6 | addopts = 7 | --strict 8 | --doctest-modules 9 | --durations=0 10 | 11 | [coverage:report] 12 | exclude_lines = 13 | pragma: no-cover 14 | pass 15 | 16 | [flake8] 17 | max-line-length = 120 18 | exclude = .tox,*.egg,build,temp 19 | select = E,W,F 20 | doctests = True 21 | verbose = 2 22 | # https://pep8.readthedocs.io/en/latest/intro.html#error-codes 23 | format = pylint 24 | # see: https://www.flake8rules.com/ 25 | ignore = 26 | E731 # Do not assign a lambda expression, use a def 27 | W504 # Line break occurred after a binary operator 28 | F401 # Module imported but unused 29 | F841 # Local variable name is assigned to but never used 30 | W605 # Invalid escape sequence 'x' 31 | 32 | # setup.cfg or tox.ini 33 | [check-manifest] 34 | ignore = 35 | *.yml 36 | .github 37 | .github/* 38 | 39 | [metadata] 40 | license_file = LICENSE 41 | description-file = README.md 42 | # long_description = file:README.md 43 | # long_description_content_type = text/markdown 44 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='project', 5 | version='0.0.1', 6 | description='Describe Your Cool Project', 7 | author='', 8 | author_email='', 9 | # REPLACE WITH YOUR OWN GITHUB PROJECT LINK 10 | url='https://github.com/PyTorchLightning/pytorch-lightning-conference-seed', 11 | install_requires=['pytorch-lightning'], 12 | packages=find_packages(), 13 | ) 14 | 15 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lkhphuc/lightning-hydra-template/9dde822f32612fec421b76c59eeaf530ad3d324f/tests/__init__.py -------------------------------------------------------------------------------- /tests/requirements.txt: -------------------------------------------------------------------------------- 1 | coverage 2 | codecov>=2.1 3 | pytest>=3.0.5 4 | pytest-cov 5 | pytest-flake8 6 | flake8 7 | check-manifest 8 | twine==1.13.0 -------------------------------------------------------------------------------- /tests/test_classifier.py: -------------------------------------------------------------------------------- 1 | from pytorch_lightning import Trainer, seed_everything 2 | from project.model.classifier import LitClassifier 3 | from project.data import MNISTDataModule 4 | 5 | 6 | def test_lit_classifier(): 7 | seed_everything(1234) 8 | 9 | model = LitClassifier() 10 | dm = MNISTDataModule() 11 | trainer = Trainer(limit_train_batches=50, limit_val_batches=20, max_epochs=2) 12 | trainer.fit(model, dm) 13 | 14 | results = trainer.test(datamodule=dm) 15 | assert results[0]['test_acc'] > 0.7 16 | --------------------------------------------------------------------------------