├── .github └── workflows │ ├── publish_docs.yml │ └── test_suite.yml ├── .gitignore ├── LICENSE ├── README.md ├── cookiecutter.json ├── docs ├── changelog │ ├── index.md │ └── upgrade.md ├── features │ ├── bestpractices.md │ ├── cicd.md │ ├── conda.md │ ├── determinism.md │ ├── docs.md │ ├── envvars.md │ ├── fastdevrun.md │ ├── metadata.md │ ├── nncore.md │ ├── restore.md │ ├── storage.md │ ├── tags.md │ └── tests.md ├── getting-started │ ├── generation.md │ └── index.md ├── index.md ├── integrations │ ├── dvc.md │ ├── githubactions.md │ ├── hydra.md │ ├── lightning.md │ ├── mkdocs.md │ ├── streamlit.md │ └── wandb.md ├── overrides │ └── main.html ├── papers.md └── project-structure │ ├── conf.md │ ├── index.md │ └── structure.md ├── hooks ├── post_gen_project.py └── pre_gen_project.py ├── mkdocs.yml └── {{ cookiecutter.repository_name }} ├── .editorconfig ├── .env.template ├── .flake8 ├── .github └── workflows │ ├── publish.yml │ └── test_suite.yml ├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── README.md ├── conf ├── default.yaml ├── hydra │ └── default.yaml ├── nn │ ├── data │ │ ├── dataset │ │ │ └── vision │ │ │ │ └── mnist.yaml │ │ └── default.yaml │ ├── default.yaml │ └── module │ │ ├── default.yaml │ │ └── model │ │ └── cnn.yaml └── train │ └── default.yaml ├── data └── .gitignore ├── docs ├── index.md └── overrides │ └── main.html ├── env.yaml ├── mkdocs.yml ├── pyproject.toml ├── setup.cfg ├── setup.py ├── src └── {{ cookiecutter.package_name }} │ ├── __init__.py │ ├── data │ ├── __init__.py │ ├── datamodule.py │ └── dataset.py │ ├── modules │ ├── __init__.py │ └── module.py │ ├── pl_modules │ ├── __init__.py │ └── pl_module.py │ ├── run.py │ ├── ui │ ├── __init__.py │ └── run.py │ └── utils │ └── hf_io.py └── tests ├── __init__.py ├── conftest.py ├── test_checkpoint.py ├── test_configuration.py ├── test_nn_core_integration.py ├── test_resume.py ├── test_seeding.py ├── test_storage.py └── test_training.py /.github/workflows/publish_docs.yml: -------------------------------------------------------------------------------- 1 | name: Publish docs 2 | 3 | on: 4 | release: 5 | types: 6 | - created 7 | 8 | jobs: 9 | build: 10 | strategy: 11 | fail-fast: false 12 | matrix: 13 | python-version: ['3.9'] 14 | include: 15 | - os: ubuntu-20.04 16 | label: linux-64 17 | prefix: /usr/share/miniconda3/envs/ 18 | 19 | name: ${{ matrix.label }}-py${{ matrix.python-version }} 20 | runs-on: ${{ matrix.os }} 21 | 22 | steps: 23 | - uses: actions/checkout@v2 24 | with: 25 | fetch-depth: 0 26 | 27 | - name: Set up Python 28 | uses: actions/setup-python@v1 29 | 30 | - name: Install Dependencies 31 | shell: bash -l {0} 32 | run: | 33 | pip install cookiecutter mkdocs mkdocs-material mike 34 | 35 | # extract the first two digits from the release note 36 | - name: Set release notes tag 37 | run: | 38 | export RELEASE_TAG_VERSION=${{ github.event.release.tag_name }} 39 | echo "RELEASE_TAG_VERSION=${RELEASE_TAG_VERSION%.*}">> $GITHUB_ENV 40 | 41 | - name: Echo release notes tag 42 | run: | 43 | echo "${RELEASE_TAG_VERSION}" 44 | 45 | - name: Build docs website 46 | shell: bash -l {0} 47 | run: | 48 | git config user.name ci-bot 49 | git config user.email ci-bot@ci.com 50 | mike deploy --push --rebase --update-aliases ${RELEASE_TAG_VERSION} latest 51 | -------------------------------------------------------------------------------- /.github/workflows/test_suite.yml: -------------------------------------------------------------------------------- 1 | name: Test Suite 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | - develop 8 | 9 | pull_request: 10 | types: 11 | - opened 12 | - reopened 13 | - synchronize 14 | 15 | env: 16 | CACHE_NUMBER: 4 # increase to reset cache manually 17 | CONDA_ENV_FILE: 'env.yaml' 18 | CONDA_ENV_NAME: 'project-test' 19 | COOKIECUTTER_PROJECT_NAME: 'project-test' 20 | HUGGING_FACE_HUB_TOKEN: ${{secrets.HUGGING_FACE_HUB_TOKEN}} 21 | 22 | jobs: 23 | build: 24 | 25 | strategy: 26 | fail-fast: false 27 | matrix: 28 | python-version: ['3.11'] 29 | include: 30 | - os: ubuntu-20.04 31 | label: linux-64 32 | prefix: /usr/share/miniconda3/envs/ 33 | 34 | # - os: macos-latest 35 | # label: osx-64 36 | # prefix: /Users/runner/miniconda3/envs/$CONDA_ENV_NAME 37 | 38 | # - os: windows-latest 39 | # label: win-64 40 | # prefix: C:\Miniconda3\envs\$CONDA_ENV_NAME 41 | 42 | name: ${{ matrix.label }}-py${{ matrix.python-version }} 43 | runs-on: ${{ matrix.os }} 44 | 45 | steps: 46 | - name: Parametrize conda env name 47 | run: echo "PY_CONDA_ENV_NAME=${{ env.CONDA_ENV_NAME }}-${{ matrix.python-version }}" >> $GITHUB_ENV 48 | - name: echo conda env name 49 | run: echo ${{ env.PY_CONDA_ENV_NAME }} 50 | 51 | - name: Parametrize conda prefix 52 | run: echo "PY_PREFIX=${{ matrix.prefix }}${{ env.PY_CONDA_ENV_NAME }}" >> $GITHUB_ENV 53 | - name: echo conda prefix 54 | run: echo ${{ env.PY_PREFIX }} 55 | 56 | - name: Define generated project files paths 57 | run: | 58 | echo "PROJECT_SETUPCFG_FILE=${{ env.COOKIECUTTER_PROJECT_NAME }}/setup.cfg" >> $GITHUB_ENV 59 | echo "PROJECT_CONDAENV_FILE=${{ env.COOKIECUTTER_PROJECT_NAME }}/${{ env.CONDA_ENV_FILE }}" >> $GITHUB_ENV 60 | echo "PROJECT_PRECOMMIT_FILE=${{ env.COOKIECUTTER_PROJECT_NAME }}/.pre-commit-config.yaml" >> $GITHUB_ENV 61 | 62 | - uses: actions/checkout@v2 63 | 64 | # COOKIECUTTER GENERATION 65 | - name: Set up Python 66 | uses: actions/setup-python@v1 67 | 68 | - name: Install Dependencies 69 | shell: bash -l {0} 70 | run: | 71 | pip install cookiecutter 72 | 73 | - name: Generate Repo 74 | shell: bash -l {0} 75 | run: | 76 | echo -e 'n\nn\nn\n' | cookiecutter . --no-input project_name=${{ env.COOKIECUTTER_PROJECT_NAME }} 77 | 78 | - name: Init git into generated repo 79 | shell: bash -l {0} 80 | run: | 81 | git config --global user.name ci-bot 82 | git config --global user.email ci-bot@ci.com 83 | git init 84 | git add --all 85 | git commit -m "Initial commit" 86 | working-directory: ${{ env.COOKIECUTTER_PROJECT_NAME }} 87 | 88 | - name: Define cache key postfix 89 | run: | 90 | echo "CACHE_KEY_POSTFIX=${{ matrix.label }}-${{ matrix.python-version }}-${{ env.CACHE_NUMBER }}-${{ env.PY_CONDA_ENV_NAME }}-${{ hashFiles(env.PROJECT_CONDAENV_FILE) }}-${{ hashFiles(env.PROJECT_SETUPCFG_FILE) }}" >> $GITHUB_ENV 91 | 92 | - name: Echo cache keys 93 | run: | 94 | echo ${{ env.PROJECT_SETUPCFG_FILE }} 95 | echo ${{ env.PROJECT_CONDAENV_FILE }} 96 | echo ${{ env.PROJECT_PRECOMMIT_FILE }} 97 | echo ${{ env.CACHE_KEY_POSTFIX }} 98 | 99 | # GENERATED PROJECT CI/CD 100 | 101 | # Remove the python version pin from the env.yml which could be inconsistent 102 | - name: Remove explicit python version from the environment 103 | shell: bash -l {0} 104 | run: | 105 | sed -Ei '/^\s*-?\s*python\s*([#=].*)?$/d' ${{ env.CONDA_ENV_FILE }} 106 | cat ${{ env.CONDA_ENV_FILE }} 107 | working-directory: ${{ env.COOKIECUTTER_PROJECT_NAME }} 108 | 109 | # Install torch cpu-only 110 | - name: Install torch cpu only 111 | shell: bash -l {0} 112 | run: | 113 | sed -i '/nvidia\|cuda/d' ${{ env.CONDA_ENV_FILE }} 114 | cat ${{ env.CONDA_ENV_FILE }} 115 | working-directory: ${{ env.COOKIECUTTER_PROJECT_NAME }} 116 | 117 | - name: Setup Mambaforge 118 | uses: conda-incubator/setup-miniconda@v2 119 | with: 120 | miniforge-variant: Mambaforge 121 | miniforge-version: latest 122 | activate-environment: ${{ env.PY_CONDA_ENV_NAME }} 123 | python-version: ${{ matrix.python-version }} 124 | use-mamba: true 125 | 126 | - uses: actions/cache@v2 127 | name: Conda cache 128 | with: 129 | path: ${{ env.PY_PREFIX }} 130 | key: conda-${{ env.CACHE_KEY_POSTFIX }} 131 | id: conda_cache 132 | 133 | - uses: actions/cache@v2 134 | name: Pip cache 135 | with: 136 | path: ~/.cache/pip 137 | key: pip-${{ env.CACHE_KEY_POSTFIX }} 138 | 139 | - uses: actions/cache@v2 140 | name: Pre-commit cache 141 | with: 142 | path: ~/.cache/pre-commit 143 | key: pre-commit-${{ hashFiles(env.PROJECT_PRECOMMIT_FILE) }}-conda-${{ env.CACHE_KEY_POSTFIX }} 144 | 145 | # Ensure the hack for the python version worked 146 | - name: Ensure we have the right Python 147 | shell: bash -l {0} 148 | run: | 149 | echo "Installed Python: $(python --version)" 150 | echo "Expected: ${{ matrix.python-version }}" 151 | python --version | grep "Python ${{ matrix.python-version }}" 152 | working-directory: ${{ env.COOKIECUTTER_PROJECT_NAME }} 153 | 154 | # https://stackoverflow.com/questions/70520120/attributeerror-module-setuptools-distutils-has-no-attribute-version 155 | # https://github.com/pytorch/pytorch/pull/69904 156 | - name: Downgrade setuptools due to a but in PyTorch 1.10.1 157 | shell: bash -l {0} 158 | run: | 159 | pip install setuptools==59.5.0 --upgrade 160 | working-directory: ${{ env.COOKIECUTTER_PROJECT_NAME }} 161 | 162 | - name: Update conda environment 163 | run: mamba env update -n ${{ env.PY_CONDA_ENV_NAME }} -f ${{ env.CONDA_ENV_FILE }} 164 | if: steps.conda_cache.outputs.cache-hit != 'true' 165 | working-directory: ${{ env.COOKIECUTTER_PROJECT_NAME }} 166 | 167 | # Update pip env whether or not there was a conda cache hit 168 | - name: Update pip environment 169 | shell: bash -l {0} 170 | run: pip install -e ".[dev]" 171 | if: steps.conda_cache.outputs.cache-hit == 'true' 172 | working-directory: ${{ env.COOKIECUTTER_PROJECT_NAME }} 173 | 174 | - run: pip3 list 175 | shell: bash -l {0} 176 | working-directory: ${{ env.COOKIECUTTER_PROJECT_NAME }} 177 | 178 | - run: mamba info 179 | working-directory: ${{ env.COOKIECUTTER_PROJECT_NAME }} 180 | 181 | - run: mamba list 182 | working-directory: ${{ env.COOKIECUTTER_PROJECT_NAME }} 183 | 184 | # Ensure the hack for the python version worked 185 | - name: Ensure we have the right Python 186 | shell: bash -l {0} 187 | run: | 188 | echo "Installed Python: $(python --version)" 189 | echo "Expected: ${{ matrix.python-version }}" 190 | python --version | grep "Python ${{ matrix.python-version }}" 191 | working-directory: ${{ env.COOKIECUTTER_PROJECT_NAME }} 192 | 193 | - name: Run pre-commits 194 | shell: bash -l {0} 195 | run: | 196 | pre-commit install 197 | pre-commit run -v --all-files --show-diff-on-failure 198 | working-directory: ${{ env.COOKIECUTTER_PROJECT_NAME }} 199 | 200 | - name: Test with pytest 201 | shell: bash -l {0} 202 | run: | 203 | pytest -v 204 | working-directory: ${{ env.COOKIECUTTER_PROJECT_NAME }} 205 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | wandb 2 | multirun.yaml 3 | storage 4 | 5 | # ignore the _version.py file 6 | _version.py 7 | 8 | # .gitignore defaults for python and pycharm 9 | .idea 10 | 11 | # Byte-compiled / optimized / DLL files 12 | __pycache__/ 13 | *.py[cod] 14 | *$py.class 15 | 16 | # C extensions 17 | *.so 18 | 19 | # Distribution / packaging 20 | .Python 21 | build/ 22 | develop-eggs/ 23 | dist/ 24 | downloads/ 25 | eggs/ 26 | .eggs/ 27 | lib/ 28 | lib64/ 29 | parts/ 30 | sdist/ 31 | var/ 32 | wheels/ 33 | share/python-wheels/ 34 | *.egg-info/ 35 | .installed.cfg 36 | *.egg 37 | MANIFEST 38 | 39 | # PyInstaller 40 | # Usually these files are written by a python script from a template 41 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 42 | *.manifest 43 | *.spec 44 | 45 | # Installer logs 46 | pip-log.txt 47 | pip-delete-this-directory.txt 48 | 49 | # Unit test / coverage reports 50 | htmlcov/ 51 | .tox/ 52 | .nox/ 53 | .coverage 54 | .coverage.* 55 | .cache 56 | nosetests.xml 57 | coverage.xml 58 | *.cover 59 | *.py,cover 60 | .hypothesis/ 61 | .pytest_cache/ 62 | cover/ 63 | 64 | # Translations 65 | *.mo 66 | *.pot 67 | 68 | # Django stuff: 69 | *.log 70 | local_settings.py 71 | db.sqlite3 72 | db.sqlite3-journal 73 | 74 | # Flask stuff: 75 | instance/ 76 | .webassets-cache 77 | 78 | # Scrapy stuff: 79 | .scrapy 80 | 81 | # Sphinx documentation 82 | docs/_build/ 83 | 84 | # PyBuilder 85 | .pybuilder/ 86 | target/ 87 | 88 | # Jupyter Notebook 89 | .ipynb_checkpoints 90 | 91 | # IPython 92 | profile_default/ 93 | ipython_config.py 94 | 95 | # pyenv 96 | # For a library or package, you might want to ignore these files since the code is 97 | # intended to run in multiple environments; otherwise, check them in: 98 | # .python-version 99 | 100 | # pipenv 101 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 102 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 103 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 104 | # install all needed dependencies. 105 | #Pipfile.lock 106 | 107 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 108 | __pypackages__/ 109 | 110 | # Celery stuff 111 | celerybeat-schedule 112 | celerybeat.pid 113 | 114 | # SageMath parsed files 115 | *.sage.py 116 | 117 | # Environments 118 | .env 119 | .venv 120 | env/ 121 | venv/ 122 | ENV/ 123 | env.bak/ 124 | venv.bak/ 125 | 126 | # Spyder project settings 127 | .spyderproject 128 | .spyproject 129 | 130 | # Rope project settings 131 | .ropeproject 132 | 133 | # mkdocs documentation 134 | /site 135 | 136 | # mypy 137 | .mypy_cache/ 138 | .dmypy.json 139 | dmypy.json 140 | 141 | # Pyre type checker 142 | .pyre/ 143 | 144 | # pytype static type analyzer 145 | .pytype/ 146 | 147 | # Cython debug symbols 148 | cython_debug/ 149 | 150 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider 151 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 152 | 153 | # User-specific stuff 154 | .idea/**/workspace.xml 155 | .idea/**/tasks.xml 156 | .idea/**/usage.statistics.xml 157 | .idea/**/dictionaries 158 | .idea/**/shelf 159 | 160 | # Generated files 161 | .idea/**/contentModel.xml 162 | 163 | # Sensitive or high-churn files 164 | .idea/**/dataSources/ 165 | .idea/**/dataSources.ids 166 | .idea/**/dataSources.local.xml 167 | .idea/**/sqlDataSources.xml 168 | .idea/**/dynamic.xml 169 | .idea/**/uiDesigner.xml 170 | .idea/**/dbnavigator.xml 171 | 172 | # Gradle 173 | .idea/**/gradle.xml 174 | .idea/**/libraries 175 | 176 | # Gradle and Maven with auto-import 177 | # When using Gradle or Maven with auto-import, you should exclude module files, 178 | # since they will be recreated, and may cause churn. Uncomment if using 179 | # auto-import. 180 | # .idea/artifacts 181 | # .idea/compiler.xml 182 | # .idea/jarRepositories.xml 183 | # .idea/modules.xml 184 | # .idea/*.iml 185 | # .idea/modules 186 | # *.iml 187 | # *.ipr 188 | 189 | # CMake 190 | cmake-build-*/ 191 | 192 | # Mongo Explorer plugin 193 | .idea/**/mongoSettings.xml 194 | 195 | # File-based project format 196 | *.iws 197 | 198 | # IntelliJ 199 | out/ 200 | 201 | # mpeltonen/sbt-idea plugin 202 | .idea_modules/ 203 | 204 | # JIRA plugin 205 | atlassian-ide-plugin.xml 206 | 207 | # Cursive Clojure plugin 208 | .idea/replstate.xml 209 | 210 | # Crashlytics plugin (for Android Studio and IntelliJ) 211 | com_crashlytics_export_strings.xml 212 | crashlytics.properties 213 | crashlytics-build.properties 214 | fabric.properties 215 | 216 | # Editor-based Rest Client 217 | .idea/httpRequests 218 | 219 | # Android studio 3.1+ serialized cache file 220 | .idea/caches/build_file_checksums.ser 221 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Valentino Maiorca, Luca Moschella 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # NN Template 2 | 3 |

4 | CI 5 | CI 6 | Docs 7 | Release 8 | Code style: black 9 |

10 | 11 | [comment]: <> (

) 12 | 13 | [comment]: <> ( PyTorch) 14 | 15 | [comment]: <> ( Lightning) 16 | 17 | [comment]: <> ( Conf: hydra) 18 | 19 | [comment]: <> ( Logging: wandb) 20 | 21 | [comment]: <> ( Conf: hydra) 22 | 23 | [comment]: <> ( UI: streamlit) 24 | 25 | [comment]: <> (

) 26 | 27 |

28 | 29 | "We demand rigidly defined areas of doubt and uncertainty." 30 | 31 |

32 | 33 | 34 | Generic template to bootstrap your [PyTorch](https://pytorch.org/get-started/locally/) project, 35 | read more in the [documentation](https://grok-ai.github.io/nn-template). 36 | 37 | 38 | [![asciicast](https://asciinema.org/a/475623.svg)](https://asciinema.org/a/475623) 39 | 40 | ## Get started 41 | 42 | If you already know [cookiecutter](https://github.com/cookiecutter/cookiecutter), just generate your project with: 43 | 44 | ```bash 45 | cookiecutter https://github.com/grok-ai/nn-template 46 | ``` 47 | 48 |
49 | Otherwise 50 | Cookiecutter manages the setup stages and delivers to you a personalized ready to run project. 51 | 52 | Install it with: 53 |
pip install cookiecutter
54 | 
55 |
56 | 57 | More details in the [documentation](https://grok-ai.github.io/nn-template/latest/getting-started/generation/). 58 | 59 | ## Strengths 60 | 61 | - **Actually works for [research](https://grok-ai.github.io/nn-template/latest/papers/)**! 62 | - Guided setup to customize project bootstrapping; 63 | - Fast prototyping of new ideas, no need to build a new code base from scratch; 64 | - Less boilerplate with no impact on the learning curve (as long as you know the integrated tools); 65 | - Ensure experiments reproducibility; 66 | - Automatize via GitHub actions: testing, stylish documentation deploy, PyPi upload; 67 | - Enforce Python [best practices](https://grok-ai.github.io/nn-template/latest/features/bestpractices/); 68 | - Many more in the [documentation](https://grok-ai.github.io/nn-template/latest/features/nncore/); 69 | 70 | ## Integrations 71 | 72 | Avoid writing boilerplate code to integrate: 73 | 74 | - [PyTorch Lightning](https://github.com/PyTorchLightning/pytorch-lightning), lightweight PyTorch wrapper for high-performance AI research. 75 | - [Hydra](https://github.com/facebookresearch/hydra), a framework for elegantly configuring complex applications. 76 | - [Hugging Face Datasets](https://huggingface.co/docs/datasets/index),a library for easily accessing and sharing datasets. 77 | - [Weights and Biases](https://wandb.ai/home), organize and analyze machine learning experiments. *(educational account available)* 78 | - [Streamlit](https://streamlit.io/), turns data scripts into shareable web apps in minutes. 79 | - [MkDocs](https://www.mkdocs.org/) and [Material for MkDocs](https://squidfunk.github.io/mkdocs-material/), a fast, simple and downright gorgeous static site generator. 80 | - [DVC](https://dvc.org/doc/start/data-versioning), track large files, directories, or ML models. Think "Git for data". 81 | - [GitHub Actions](https://github.com/features/actions), to run the tests, publish the documentation and to PyPI automatically. 82 | - Python best practices for developing and publishing research projects. 83 | 84 | ## Maintainers 85 | 86 | - Valentino Maiorca [@Flegyas](https://github.com/Flegyas) 87 | - Luca Moschella [@lucmos](https://github.com/lucmos) 88 | -------------------------------------------------------------------------------- /cookiecutter.json: -------------------------------------------------------------------------------- 1 | { 2 | "author": "Paul Erdős", 3 | "author_email": "paul@erdos.com", 4 | "github_user": "erdos", 5 | "project_name": "Awesome Project", 6 | "project_description": "A new awesome project.", 7 | "repository_name": "{{ cookiecutter.project_name.strip().lower().replace(' ', '-') }}", 8 | "package_name": "{{ cookiecutter.project_name.strip().lower().replace(' ', '-').replace('-', '_') }}", 9 | "repository_url": "https://github.com/{{ cookiecutter.github_user }}/{{ cookiecutter.project_name.strip().lower().replace(' ', '-') }}", 10 | "conda_env_name": "{{ cookiecutter.project_name.strip().lower().replace(' ', '-') }}", 11 | "python_version": "3.11", 12 | "__version": "0.4.0" 13 | } 14 | -------------------------------------------------------------------------------- /docs/changelog/index.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | See the changelog in the [releases](https://github.com/grok-ai/nn-template/releases) page. 4 | -------------------------------------------------------------------------------- /docs/changelog/upgrade.md: -------------------------------------------------------------------------------- 1 | The need for upgrading the template itself is lessened thanks to the `nn-template-core` library decoupling. 2 | 3 | !!! info 4 | 5 | Update the `nn-template-core` library changing the version constraint in the `setup.cfg`. 6 | 7 | --- 8 | 9 | However, you can use [cruft](https://github.com/cruft/cruft) to automate also the template updates! 10 | [![](https://camo.githubusercontent.com/01c5aa4ff2ddfc69282c75da05dc6bdadc8cc0a98119846a8630a981c86115f8/68747470733a2f2f7261772e6769746875622e636f6d2f63727566742f63727566742f6d61737465722f6172742f6c6f676f5f6c617267652e706e67)](https://github.com/cruft/cruft) 11 | -------------------------------------------------------------------------------- /docs/features/bestpractices.md: -------------------------------------------------------------------------------- 1 | # Tooling 2 | 3 | The template configures are the tooling necessary for a modern Python project. 4 | 5 | These include: 6 | 7 | - [**EditorConfig**](https://editorconfig.org/) maintain consistent coding styles for multiple developers. 8 | - [**Black**](https://black.readthedocs.io/en/stable/index.html) the uncompromising code formatter. 9 | - [**isort**](https://github.com/PyCQA/isort) sort imports alphabetically, and automatically separated into sections and by type. 10 | - [**flake8**](https://flake8.pycqa.org/en/latest/) check coding style (PEP8), programming errors and cyclomatic complexity. 11 | - [**pydocstyle**](http://www.pydocstyle.org/en/stable/) static analysis tool for checking compliance with Python docstring conventions. 12 | - [**MyPy**](http://mypy-lang.org/) static type checker for Python. 13 | - [**Coverage**](https://coverage.readthedocs.io/en/6.2/) measure code coverage of Python programs. 14 | - [**bandit**](https://github.com/PyCQA/bandit) security linter from PyCQA. 15 | - [**pre-commit**](https://pre-commit.com/) framework for managing and maintaining pre-commit hooks. 16 | 17 | ## Pre commits 18 | 19 | The pre-commits configuration is defined in `.pre-commit-config.yaml`, 20 | and includes the most important checks and auto-fix to perform. 21 | 22 | If one of the pre-commits fails, **the commit is aborted** avoiding distraction errors. 23 | 24 | !!! info 25 | 26 | The pre-commits are also run in the CI/CD as part of the Test Suite. This helps 27 | guaranteeing the all the contributors are respecting the code conventions in the 28 | pull requests. 29 | -------------------------------------------------------------------------------- /docs/features/cicd.md: -------------------------------------------------------------------------------- 1 | # CI/CD 2 | 3 | The generated project contains two [GiHub Actions](https://github.com/features/actions) workflow to run the Test Suite and to publish you project. 4 | 5 | !!! note 6 | You need to enable the GitHub Actions from the settings in your repository. 7 | 8 | !!! important 9 | All the workflow already implement the logic needed to **cache** the conda and pip environment 10 | between workflow runs. 11 | 12 | !!! warning 13 | The annotated tags in the git repository to manage releases should follow the [semantic versioning](https://semver.org/) 14 | conventions: `..` 15 | 16 | 17 | 18 | ## Test Suite 19 | 20 | The Test Suite runs automatically for each commit in a Pull Request. 21 | It is successful if: 22 | 23 | - The pre-commits do not raise any errors 24 | - All the tests pass 25 | 26 | After that, the PR are marked with ✔️ or ❌ depending on the test suite results. 27 | 28 | ## Publish docs 29 | 30 | The first time you should use `mike` to: create the `gh-pages` branch and 31 | specify the default docs version. 32 | 33 | ```bash 34 | mike deploy 0.1 latest --push 35 | mike set-default latest 36 | ``` 37 | 38 | !!! warning 39 | 40 | You do not need to execute these commands if you accepted the optional cookiecutter setup step. 41 | 42 | !!! info 43 | 44 | Remember to enable the GitHub Pages from the repository settings. 45 | 46 | 47 | After that, the docs are built and automatically published `on release` on GitHub Pages. 48 | This means that every time you publish a new release in your project an associated version of the documentation is published. 49 | 50 | !!! important 51 | 52 | The documentation version utilizes only the `.` version of the release tag, discarding the patch version. 53 | 54 | ## Publish PyPi 55 | 56 | To publish your package on PyPi it is enough to configure 57 | the PyPi token in the GitHub repository `secrets` and de-comment the following in the 58 | `publish.yaml` workflow: 59 | 60 | ```yaml 61 | - name: Build SDist and wheel 62 | run: pipx run build 63 | 64 | - name: Check metadata 65 | run: pipx run twine check dist/* 66 | 67 | - name: Publish distribution 📦 to PyPI 68 | uses: pypa/gh-action-pypi-publish@release/v1 69 | with: 70 | user: __token__ 71 | password: ${{ secrets.PYPI_API_TOKEN }} 72 | ``` 73 | 74 | In this way, on each GitHub release the package gets published on PyPi and the associated documentation 75 | is published on GitHub Pages. 76 | -------------------------------------------------------------------------------- /docs/features/conda.md: -------------------------------------------------------------------------------- 1 | # Python Environment 2 | 3 | The generated project is a Python Package, whose dependencies are defined in the `setup.cfg` 4 | The development setup comprises a `conda` environment which installs the package itself in edit mode. 5 | 6 | ## Dependencies 7 | 8 | All the project dependencies should be defined in the `setup.cfg` as `pip` dependencies. 9 | In rare cases, it is useful to specify conda dependencies --- they will not be resolved when installing the package 10 | from PyPi. 11 | 12 | This division is useful when installing particular or optimized packages such a `PyTorch` and PyTorch Geometric. 13 | 14 | !!! hint 15 | 16 | It is possible to manage the Python version to use in the conda `env.yaml`. 17 | 18 | !!! info 19 | 20 | This organization allows for `conda` and `pip` dependencies to co-exhist, which in practice happens a lot in 21 | research projects. 22 | 23 | ## Update 24 | 25 | In order to update the `pip` dependencies after changing the `setup.cfg` it is enough to run: 26 | 27 | ```bash 28 | pip install -e '.[dev]' 29 | ``` 30 | -------------------------------------------------------------------------------- /docs/features/determinism.md: -------------------------------------------------------------------------------- 1 | # Determinism 2 | 3 | The template always logs the seed utilized in order to guarantee **reproducibility**. 4 | 5 | The user specifies a `seed_index` value in the configuration `train/default.yaml`: 6 | 7 | ```yaml 8 | seed_index: 1 9 | deterministic: False 10 | ``` 11 | 12 | This value indexes an array of deterministic but randomly generated seeds, e.g.: 13 | 14 | ```bash 15 | Setting seed 1273642419 from seeds[1] 16 | ``` 17 | 18 | 19 | !!! hint 20 | 21 | This setup allows to easily run the same experiment with different seeds in a reproducible way. 22 | It is enough to run a Hydra multi-run over the `seed_index`. 23 | 24 | The following would run the same experiment with five different seeds, which can be analyzed 25 | in the logger dashboard: 26 | 27 | ```bash 28 | python src/project/run.py -m train.seed_index=1,2,3,4 29 | ``` 30 | 31 | !!! info 32 | 33 | The deterministic option `deterministic: False` controls the use of deterministic algorithms 34 | in PyTorch, it is forwarded to the Lightning Trainer. 35 | -------------------------------------------------------------------------------- /docs/features/docs.md: -------------------------------------------------------------------------------- 1 | # Documentation 2 | 3 | `MkDocs` and `Material for MkDocs` is already configured in the generated project. 4 | 5 | In order to create your docs it is enough to: 6 | 7 | 1. Modify the `nav` index in the `mkdocs.yaml`, which describes how to organize the pages. 8 | An example of the `nav` is the following: 9 | 10 | ```yaml 11 | nav: 12 | - Home: index.md 13 | - Getting started: 14 | - Generating your project: getting-started/generation.md 15 | - Strucure: getting-started/structure.md 16 | ``` 17 | 18 | 2. Create all the files referenced in the `nav` relative to the `docs/` folder. 19 | 20 | ```bash 21 | ❯ tree docs 22 | docs 23 | ├── getting-started 24 | │   ├── generation.md 25 | │   └── structure.md 26 | └── index.md 27 | ``` 28 | 29 | 3. To preview your documentation it is enough to run `mkdocs serve`. To manually deploy the documentation 30 | see [`mike`](https://github.com/jimporter/mike), or see the integrated GitHub Action to [publish the docs on release](https://grok-ai.github.io/nn-template/latest/features/cicd/#publish-docs). 31 | -------------------------------------------------------------------------------- /docs/features/envvars.md: -------------------------------------------------------------------------------- 1 | 2 | # Environment Variables 3 | 4 | System specific variables (e.g. absolute paths) should not be under version control, otherwise there will be conflicts between different users. 5 | 6 | The best way to handle system specific variables is through environment variables. 7 | 8 | You can define new environment variables in a `.env` file in the project root. A copy of this file (e.g. `.env.template`) can be under version control to ease new project configurations. 9 | 10 | To define a new variable write inside `.env`: 11 | 12 | ```bash 13 | export MY_VAR=/home/user/my_system_path 14 | ``` 15 | 16 | 17 | You can dynamically resolve the variable name from everywhere 18 | 19 | === "python" 20 | 21 | In Python code use: 22 | 23 | ```python 24 | get_env("MY_VAR") 25 | ``` 26 | 27 | === "yaml" 28 | 29 | In the Hydra `yaml` configurations: 30 | 31 | ```yaml 32 | ${oc.env:MY_VAR} 33 | ``` 34 | 35 | === "posix" 36 | 37 | In posix shells: 38 | 39 | ```bash 40 | . .env 41 | echo $MY_VAR 42 | ``` 43 | -------------------------------------------------------------------------------- /docs/features/fastdevrun.md: -------------------------------------------------------------------------------- 1 | # Fast Dev Run 2 | 3 | The template expands the Lightning `fast_dev_run` mode to be more debugging friendly. 4 | 5 | It will also: 6 | 7 | - Disable multiple workers in the dataloaders 8 | - Use the CPU and not the GPU 9 | 10 | !!! info 11 | 12 | It is possible to modify this behaviour by simply modifying the `run.py` file. 13 | -------------------------------------------------------------------------------- /docs/features/metadata.md: -------------------------------------------------------------------------------- 1 | # MetaData 2 | 3 | The *bridge* between the Lightning DataModule and the Lightning Module. 4 | 5 | It is responsible for collecting data information to be fed to the module. 6 | The Lightning Module will receive an instance of MetaData when instantiated, 7 | both in the train loop or when restored from a checkpoint. 8 | 9 | !!! warning 10 | 11 | MetaData exposes `save` and `load`. Those are two user-defined methods that specify how to serialize and de-serialize the information contained in its attributes. 12 | This is needed for the checkpointing restore to work properly and **must be 13 | always implemented**, where the metadata is needed. 14 | 15 | This decoupling allows the architecture to be parametric (e.g. in the number of classes) and 16 | DataModule/Trainer independent (useful in prediction scenarios). 17 | Examples are the class names in a classification task or the vocabulary in NLP tasks. 18 | 19 | -------------------------------------------------------------------------------- /docs/features/nncore.md: -------------------------------------------------------------------------------- 1 | # NN Template core 2 | 3 | Most of the logic is abstracted from the template into an accompanying library: [`nn-template-core`](https://pypi.org/project/nn-template-core/). 4 | 5 | This library contains the logic necessary for the restore, logging, and many other functionalities implemented in the template. 6 | 7 | !!! info 8 | 9 | This decoupling eases the updating of the template, reaching a desirable tradeoff: 10 | 11 | - `template`: easy to use and customize, hard to update 12 | - `library`: hard to customize, easy to update 13 | 14 | With our approach updating most of the functions is extremely easy, it is just a Python 15 | dependency, while maintaing the flexibility of a template. 16 | 17 | 18 | !!! warning 19 | 20 | It is important to **not** remove the `NNTemplateCore` callback from the instantiated callbacks 21 | in the template. It is used to inject personalized behaviour in the training loop. 22 | -------------------------------------------------------------------------------- /docs/features/restore.md: -------------------------------------------------------------------------------- 1 | # Restore 2 | 3 | The template offers a way to restore a previous run from the configuration. 4 | The relevant configuration block is in `conf/train/default.yml`: 5 | 6 | ```yaml 7 | restore: 8 | ckpt_or_run_path: null 9 | mode: null # null, finetune, hotstart, continue 10 | ``` 11 | 12 | ## ckpt_or_run_path 13 | 14 | The `ckpt_or_run_path` can be a path towards a Lightning Checkpoint or the run identifiers w.r.t. the logger. 15 | In case of W&B as a logger, they are called `run_path` and are in the form of `entity/project/run_id`. 16 | 17 | !!! warning 18 | 19 | If `ckpt_or_run_path` points to a checkpoint, that checkpoint must have been saved with 20 | this template, because additional information are attached to the checkpoint to guarantee 21 | a correct restore. These include the `run_path` itself and the whole configuration used. 22 | 23 | ## mode 24 | 25 | We support 4 different modes for restoring an experiment: 26 | 27 | === "null" 28 | 29 | ```yaml 30 | restore: 31 | mode: null 32 | ``` 33 | In this `mode` no restore happens, and `ckpt_or_run_path` is ignored. 34 | 35 | 36 | !!! example "Use Case" 37 | 38 | This is the default option and allows the user to train the model from 39 | scratch logging into a new run. 40 | 41 | 42 | === "finetune" 43 | 44 | ```yaml 45 | restore: 46 | mode: finetune 47 | ``` 48 | In this `mode` only the model weights are restored, both the `Trainer` state and the logger run 49 | are *not restored*. 50 | 51 | 52 | !!! example "Use Case" 53 | 54 | As the name suggest, one of the most common use case is when fine 55 | tuning a trained model logging into a new run with a novel training 56 | regimen. 57 | 58 | === "hotstart" 59 | 60 | ```yaml 61 | restore: 62 | mode: hotstart 63 | ``` 64 | In this `mode` the training continues from the checkpoint restoring the `Trainer` state **but** the logging does not. 65 | A new run is created on the logger dashboard. 66 | 67 | 68 | !!! example "Use Case" 69 | 70 | Perform different tests in separate logging runs branching from the same trained 71 | model. 72 | 73 | 74 | === "continue" 75 | 76 | ```yaml 77 | restore: 78 | mode: continue 79 | ``` 80 | In this `mode` the training continues from the checkpoint **and** the logging continues 81 | in the previous run. No new run is created on the logger dashboard. 82 | 83 | 84 | !!! example "Use Case" 85 | 86 | The training execution was interrupted and the user wants to continue it. 87 | 88 | 89 | !!! tldr "Restore summary" 90 | 91 | | | null | finetune | hotstart | continue | 92 | |---------------|------|--------------------|--------------------|--------------------| 93 | | **Model weights** | :x: | :white_check_mark: | :white_check_mark: | :white_check_mark: | 94 | | **Trainer state** | :x: | :x: | :white_check_mark: | :white_check_mark: | 95 | | **Logging run** | :x: | :x: | :x: | :white_check_mark: | 96 | -------------------------------------------------------------------------------- /docs/features/storage.md: -------------------------------------------------------------------------------- 1 | # Storage 2 | 3 | The checkpoints and other data produces by the experiment is stored in 4 | a logger agnostic folder defined in the configuration `core.storage_dir` 5 | 6 | This is the organization of the `storage_dir`: 7 | 8 | ```bash 9 | storage 10 | └── 11 | └── 12 | ├── checkpoints 13 | │ └── .ckpt.zip 14 | └── config.yaml 15 | ``` 16 | 17 | In the configuration it is possible to specify whether the run files 18 | stored inside the `storage_dir` should be uploaded to the cloud: 19 | 20 | ```yaml 21 | logging: 22 | upload: 23 | run_files: true 24 | source: true 25 | ``` 26 | -------------------------------------------------------------------------------- /docs/features/tags.md: -------------------------------------------------------------------------------- 1 | # Tags 2 | 3 | Each run should be `tagged` in order to easily filter them from the logged dashboard. 4 | Unfortunately, it is easy to forget to tag correctly each run. 5 | 6 | We ask interactively for a list of comma separated tags, if those are not already defined in the configuration: 7 | ``` 8 | WARNING No tags provided, asking for tags... 9 | Enter a list of comma separated tags (develop): 10 | ``` 11 | 12 | !!! info 13 | 14 | If the current experiment is a sweep comprised of multiple runs and there are not any tags defined, 15 | an error is raised instead: 16 | ``` 17 | ERROR You need to specify 'core.tags' in a multi-run setting! 18 | ``` 19 | -------------------------------------------------------------------------------- /docs/features/tests.md: -------------------------------------------------------------------------------- 1 | # Tests 2 | 3 | The generated project includes automated tests that use the **current configuration** defined in your project. 4 | 5 | You should write additional tests specific to each project, but running the tests should give an 6 | idea at least if the code and fundamental operations work as expected. 7 | 8 | 9 | !!! info 10 | 11 | You can execute the tests with: 12 | 13 | ```bash 14 | pytest -v 15 | ``` 16 | -------------------------------------------------------------------------------- /docs/getting-started/generation.md: -------------------------------------------------------------------------------- 1 | # Initial Setup 2 | 3 | ## Cookiecutter 4 | 5 | `nn-template` is, by definition, a template to generate projects. It's a robust **starting point** 6 | for your projects, something that lets you skip the initial boilerplate in configuring the environment, tests and such. 7 | Since it is a blueprint to build upon, it has no utility in being installed via pip or similar tools. 8 | 9 | Instead, we rely on [cookiecutter](https://cookiecutter.readthedocs.io) to manage the setup stages and deliver to you a 10 | ready-to-run project. It is a general-purpose tool that enables users to add their water of choice (variable 11 | configurations) to their particular Cup-a-Soup (the template to be setup). 12 | 13 | !!! hint "Installing cookiecutter" 14 | 15 | `cookiecutter` can be installed via pip in any Python-enabled environment (it won't be the same used by the project once 16 | instantiated). Our advice is to install `cookiecutter` as a system utility via [pipx](https://github.com/pypa/pipx): 17 | 18 | ```shell 19 | pipx install cookiecutter 20 | ``` 21 | 22 | Then, we need to tell cookiecutter which template to work on: 23 | 24 | ```shell 25 | cookiecutter https://github.com/grok-ai/nn-template.git 26 | ``` 27 | 28 | It will clone the nn-template repository in the background, call its interactive setup, and build your project's folder 29 | according to the given parametrization. 30 | 31 | The parametrized setup will take care of: 32 | 33 | - Set up the development of a Python package 34 | - Initializing a clean Git repository and add the GitHub remote of choice 35 | - Create a **new Conda environment** to execute your code in 36 | 37 | This extra step via cookiecutter is done to avoid a lot of manual parametrization, unavoidable when cloning a template 38 | repository from scratch. **Trust us, it is totally worth the bother!** 39 | 40 | ## Building Blocks 41 | The generated project already contains a minimal working example. You are **free to modify anything** you want except for 42 | a few essential and high-level things that keep everything working. **(again, this is not a framework!)**. 43 | In particular mantain: 44 | 45 | - Any `LightningLogger` you may want to use wrapped in a `NNLogger` 46 | - The `NNTemplateCore` Lightning callback 47 | 48 | !!! hint nn-template main components 49 | 50 | The template bootstraps the project with most of the needed boilerplate. 51 | The remaining components to implement for your project are the following: 52 | 53 | 1. Implement data pipeline 54 | 1. Dataset 55 | 2. Pytorch Lightning DataModule 56 | 2. Implement neural modules 57 | 1. Model 58 | 2. Pytorch Lightning Module 59 | 60 | ## FAQs 61 | 62 | ??? question "What is The Answer to the Ultimate Question of Life, the Universe, and Everything?" 63 | 64 | 42 65 | 66 | ??? question "Why are the logs badly formatted in PyCharm?" 67 | 68 | This is due to the fact that we are using [Rich](https://rich.readthedocs.io/en/stable/introduction.html) to handle 69 | the logging, and Rich is not compatible with customized terminals. As its documentation says: 70 | 71 | "*PyCharm users will need to enable “emulate terminal” in output console option in run/debug configuration to see styled output.*" 72 | 73 | ??? question "Why are file paths not interactive in the terminal's output?" 74 | 75 | [We would like to know, too.](https://youtrack.jetbrains.com/issue/PY-46305) 76 | 77 | ??? question "How can I exclude specific file paths from pre-commit checks (e.g. pydocstyle)?" 78 | 79 | While we encourage everyone to keep best-practices and standards enforced via the pre-commit utility, we also take 80 | into account situations where you just copy/paste code from the Internet and fixing it would be tedious. 81 | In those cases, the file `.pre-commit-config.yaml` has you covered. Each hook can receive an additional property, 82 | namely `exclude` where you can specify single files or patterns to be excluded when running that hook. 83 | 84 | For example, if you want to exclude a file named `ugly_but_working_code.py` from an annoying hook `annoying_hook` (most likely `pydocstyle`): 85 | ```yaml 86 | - repo: https://github.com/slow_coding/annoying_hook.git 87 | hooks: 88 | - id: annoying_hook 89 | exclude: ugly_but_working_code.py 90 | ``` 91 | 92 | ## Future Features 93 | 94 | - [ ] Optuna support 95 | - [ ] Support different loggers other than WandB 96 | -------------------------------------------------------------------------------- /docs/getting-started/index.md: -------------------------------------------------------------------------------- 1 | # Principles behind nn-template 2 | 3 | When developing neural models ourselves, we often struggled with: 4 | 5 | - **Reproducibility**. We strongly believe in the reproducibility requirement of scientific work. 6 | - **Framework Learning**. Even when you find (or code yourself) the best framework to fit your needs, you still end up 7 | in messy situations when collaborating since others have to learn to use it; 8 | - **Avoiding boilerplate**. We were bored to write the same code over and over in 9 | every project to handle the typical ML pipeline. 10 | 11 | Over the course of the years, we fine-tuned our toolbox to reach this local minimum with respect to our self-imposed 12 | requirements. After many epochs of training, the result is **nn-template**. 13 | 14 | 15 | !!! warning "nn-template is not a framework" 16 | 17 | - It does not aim to sidestep the need to write code. 18 | - It does not constrain your workflow more than PyTorch Lightning does. 19 | 20 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | --- 2 | hide: 3 | - navigation 4 | - toc 5 | --- 6 | 7 | # NN Template 8 | 9 |

10 | CI 11 | CI 12 | Release 13 | Code style: black 14 |

15 | 16 |

17 | 18 | "We demand rigidly defined areas of doubt and uncertainty." 19 | 20 |

21 | 22 | --- 23 | 24 | ```bash 25 | cookiecutter https://github.com/grok-ai/nn-template 26 | ``` 27 | 28 | --- 29 | 30 | [![asciicast](https://asciinema.org/a/475623.svg)](https://asciinema.org/a/475623) 31 | 32 | --- 33 | 34 | Generic cookiecutter template to bootstrap [PyTorch](https://pytorch.org/get-started/locally/) projects 35 | and to avoid writing boilerplate code to integrate: 36 | 37 | - [PyTorch Lightning](https://github.com/PyTorchLightning/pytorch-lightning), lightweight PyTorch wrapper for high-performance AI research. 38 | - [Hydra](https://github.com/facebookresearch/hydra), a framework for elegantly configuring complex applications. 39 | - [Hugging Face Datasets](https://huggingface.co/docs/datasets/index),a library for easily accessing and sharing datasets. 40 | - [Weights and Biases](https://wandb.ai/home), organize and analyze machine learning experiments. *(educational account available)* 41 | - [Streamlit](https://streamlit.io/), turns data scripts into shareable web apps in minutes. 42 | - [MkDocs](https://www.mkdocs.org/) and [Material for MkDocs](https://squidfunk.github.io/mkdocs-material/), a fast, simple and downright gorgeous static site generator. 43 | - [DVC](https://dvc.org/doc/start/data-versioning), track large files, directories, or ML models. Think "Git for data". 44 | - [GitHub Actions](https://github.com/features/actions), to run the tests, publish the documentation and to PyPI automatically. 45 | - Python best practices for developing and publishing research projects. 46 | 47 | 48 | !!! help "cookiecutter" 49 | 50 | This is a *parametrized* template that uses [cookiecutter](https://github.com/cookiecutter/cookiecutter). 51 | Install cookiecutter with: 52 | 53 | ```pip install cookiecutter``` 54 | -------------------------------------------------------------------------------- /docs/integrations/dvc.md: -------------------------------------------------------------------------------- 1 | 2 | # Data Version Control 3 | 4 | DVC runs alongside `git` and uses the current commit hash to version control the data. 5 | 6 | Initialize the `dvc` repository: 7 | 8 | ```bash 9 | $ dvc init 10 | ``` 11 | 12 | To start tracking a file or directory, use `dvc add`: 13 | 14 | ```bash 15 | $ dvc add data/ImageNet 16 | ``` 17 | 18 | DVC stores information about the added file (or a directory) in a special `.dvc` file named `data/ImageNet.dvc`, a small text file with a human-readable format. 19 | This file can be easily versioned like source code with Git, as a placeholder for the original data (which gets listed in `.gitignore`): 20 | 21 | ```bash 22 | git add data/ImageNet.dvc data/.gitignore 23 | git commit -m "Add raw data" 24 | ``` 25 | 26 | ## Making changes 27 | 28 | When you make a change to a file or directory, run `dvc add` again to track the latest version: 29 | 30 | ```bash 31 | $ dvc add data/ImageNet 32 | ``` 33 | 34 | ## Switching between versions 35 | 36 | The regular workflow is to use `git checkout` first to switch a branch, checkout a commit, or a revision of a `.dvc` file, and then run `dvc checkout` to sync data: 37 | 38 | ```bash 39 | $ git checkout <...> 40 | $ dvc checkout 41 | ``` 42 | 43 | !!! info 44 | 45 | Read more in the DVC [docs](https://dvc.org/doc/start/data-versioning)! 46 | -------------------------------------------------------------------------------- /docs/integrations/githubactions.md: -------------------------------------------------------------------------------- 1 | # GitHub Actions 2 | 3 | Automate, customize, and execute your software development workflows right in your repository with GitHub Actions. 4 | 5 | !!! info 6 | 7 | The template offers workflows to automatically run tests and pre-commits on pull requests, publish on PyPi and the docs on GitHub Pages on release. 8 | -------------------------------------------------------------------------------- /docs/integrations/hydra.md: -------------------------------------------------------------------------------- 1 | 2 | # Hydra 3 | 4 | Hydra is an open-source Python framework that simplifies the development of research and other complex applications. The key feature is the ability to dynamically create a hierarchical configuration by composition and override it through config files and the command line. The name Hydra comes from its ability to run multiple similar jobs - much like a Hydra with multiple heads. 5 | 6 | The basic functionalities are intuitive: it is enough to change the configuration files in `conf/*` accordingly to your preferences. Everything will be logged in `wandb` automatically. 7 | 8 | Consider creating new root configurations `conf/myawesomeexp.yaml` instead of always using the default `conf/default.yaml`. 9 | 10 | 11 | ## Multi-run 12 | 13 | You can easily perform hyperparameters [sweeps](https://hydra.cc/docs/advanced/override_grammar/extended), which override the configuration defined in `/conf/*`. 14 | 15 | The easiest one is the grid-search. It executes the code with every possible combinations of the specified hyperparameters: 16 | 17 | ```bash 18 | python src/run.py -m optim.optimizer.lr=0.02,0.002,0.0002 optim.lr_scheduler.T_mult=1,2 optim.optimizer.weight_decay=0,1e-5 19 | ``` 20 | 21 | You can explore aggregate statistics or compare and analyze each run in the W&B dashboard. 22 | 23 | 24 | !!! info 25 | 26 | We recommend to go through at least the [Basic Tutorial](https://hydra.cc/docs/tutorials/basic/your_first_app/simple_cli), keep in mind that Hydra builds on top of [OmegaConf](https://omegaconf.readthedocs.io/en/latest/index.html). 27 | -------------------------------------------------------------------------------- /docs/integrations/lightning.md: -------------------------------------------------------------------------------- 1 | 2 | # PyTorch Lightning 3 | 4 | Lightning makes coding complex networks simple. 5 | It is not a high level framework like `keras`, but forces a neat code organization and encapsulation. 6 | 7 | You should be somewhat familiar with PyTorch and [PyTorch Lightning](https://pytorch-lightning.readthedocs.io/en/stable/index.html) before using this template. 8 | 9 | ![PT to PL](https://github.com/PyTorchLightning/pytorch-lightning/blob/master/docs/source/_static/images/general/pl_quick_start_full_compressed.gif?raw=true) 10 | -------------------------------------------------------------------------------- /docs/integrations/mkdocs.md: -------------------------------------------------------------------------------- 1 | # MkDocs 2 | 3 | MkDocs is a fast, simple and downright gorgeous static site generator that's geared towards building project documentation. 4 | 5 | Documentation source files are written in Markdown, and configured with a single YAML configuration file. 6 | 7 | ## Material for MkDocs 8 | 9 | [Material for MkDocs](https://squidfunk.github.io/mkdocs-material/) is a theme for MkDocs, a static site generator geared towards (technical) project documentation. 10 | 11 | !!! hint 12 | 13 | The template comes with Material for MkDocs already configured, 14 | to create your documentation you only need to write markdown files and define the `nav`. 15 | 16 | See the [Documentation](https://grok-ai.github.io/nn-template/latest/features/docs/) page to get started! 17 | -------------------------------------------------------------------------------- /docs/integrations/streamlit.md: -------------------------------------------------------------------------------- 1 | 2 | # Streamlit 3 | [Streamlit](https://docs.streamlit.io/) is an open-source Python library that makes 4 | it easy to create and share beautiful, custom web apps for machine learning and data science. 5 | 6 | In just a few minutes, you can build and deploy powerful data apps to: 7 | 8 | - **Explore** your data 9 | - **Interact** with your model 10 | - **Analyze** your model behavior and input sensitivity 11 | - **Showcase** your prototype with [awesome web apps](https://streamlit.io/gallery) 12 | 13 | Moreover, Streamlit enables interactive development with automatic rerun on files changes. 14 | 15 | ![Example of live coding an app in Streamlit|635x380](https://github.com/streamlit/docs/raw/main/public/images/Streamlit_overview.gif) 16 | 17 | 18 | !!! info 19 | 20 | Launch a minimal app with `PYTHONPATH=. streamlit run src/ui/run.py`. There is a built-in function to restore a model checkpoint stored on W&B, with automatic download if the checkpoint is not present in the local machine: 21 | -------------------------------------------------------------------------------- /docs/integrations/wandb.md: -------------------------------------------------------------------------------- 1 | 2 | # Weights and Biases 3 | 4 | Weights & Biases helps you keep track of your machine learning projects. Use tools to log hyperparameters and output metrics from your runs, then visualize and compare results and quickly share findings with your colleagues. 5 | 6 | [This](https://wandb.ai/gladia/nn-template?workspace=user-lucmos) is an example of a simple dashboard. 7 | 8 | ## Quickstart 9 | 10 | Login to your `wandb` account, running once `wandb login`. 11 | Configure the logging in `conf/logging/*`. 12 | 13 | !!! info 14 | 15 | Read more in the [docs](https://docs.wandb.ai/). Particularly useful the [`log` method](https://docs.wandb.ai/library/log), accessible from inside a PyTorch Lightning module with `self.logger.experiment.log`. 16 | -------------------------------------------------------------------------------- /docs/overrides/main.html: -------------------------------------------------------------------------------- 1 | {% extends "base.html" %} 2 | 3 | {% block outdated %} 4 | You're not viewing the latest version. 5 | 6 | Click here to go to latest. 7 | 8 | {% endblock %} 9 | -------------------------------------------------------------------------------- /docs/papers.md: -------------------------------------------------------------------------------- 1 | --- 2 | hide: 3 | - navigation 4 | - toc 5 | --- 6 | 7 | # Scientific Papers based on nn-template 8 | 9 | The following papers acknowledge the adoption of NN Template: 10 | 11 | !!! abstract "arXiv 2022" 12 | 13 | **Metric Based Few-Shot Graph Classification** 14 | 15 | *Donato Crisostomi, Simone Antonelli, Valentino Maiorca, Luca Moschella, Riccardo Marin, Emanuele Rodolà* 16 | 17 | [![](https://shields.io/badge/-Repository-emerald?style=for-the-badge&logo=github&labelColor=gray)](https://github.com/crisostomi/metric-few-shot-graph) 18 | 19 | !!! abstract "Computer Graphics Forum: CGF 2022" 20 | 21 | **Learning Spectral Unions of Partial Deformable 3D Shapes** 22 | 23 | *Luca Moschella, Simone Melzi, Luca Cosmo, Filippo Maggioli, Or Litany, Maks Ovsjanikov, Leonidas Guibas, Emanuele Rodolà* 24 | 25 | [![](https://shields.io/badge/-Repository-emerald?style=for-the-badge&logo=github&labelColor=gray)](https://github.com/lucmos/spectral-unions) 26 | 27 | !!! abstract "Findings of the Association for Computational Linguistics: EMNLP 2021" 28 | 29 | **WikiNEuRal: Combined Neural and Knowledge-based Silver Data Creation for Multilingual NER** 30 | 31 | *Simone Tedeschi, Valentino Maiorca, Niccolò Campolungo, Francesco Cecconi, and Roberto Navigli* 32 | 33 | [![](https://shields.io/badge/-Repository-emerald?style=for-the-badge&logo=github&labelColor=gray)](https://github.com/Babelscape/wikineural) 34 | 35 | !!! abstract "Findings of the Association for Computational Linguistics: EMNLP 2021" 36 | 37 | **Named Entity Recognition for Entity Linking: What Works and What's Next.** 38 | 39 | *Simone Tedeschi, Simone Conia, Francesco Cecconi, and Roberto Navigli* 40 | 41 | [![](https://shields.io/badge/-Repository-emerald?style=for-the-badge&logo=github&labelColor=gray)](https://github.com/Babelscape/ner4el) 42 | 43 | --- 44 | 45 | !!! tip "Please let us know if your paper also does and we'll add it to the list!" -------------------------------------------------------------------------------- /docs/project-structure/conf.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/grok-ai/nn-template/8ba02bba8f015e1eb7efb0d2ab8c9d433bd1c431/docs/project-structure/conf.md -------------------------------------------------------------------------------- /docs/project-structure/index.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/grok-ai/nn-template/8ba02bba8f015e1eb7efb0d2ab8c9d433bd1c431/docs/project-structure/index.md -------------------------------------------------------------------------------- /docs/project-structure/structure.md: -------------------------------------------------------------------------------- 1 | 2 | # Structure 3 | 4 | ```bash 5 | . 6 | ├── conf 7 | │   ├── default.yaml 8 | │   ├── hydra 9 | │   │   └── default.yaml 10 | │   ├── nn 11 | │   │   └── default.yaml 12 | │   └── train 13 | │   └── default.yaml 14 | ├── data 15 | │   └── .gitignore 16 | ├── docs 17 | │   ├── index.md 18 | │   └── overrides 19 | │   └── main.html 20 | ├── .editorconfig 21 | ├── .env 22 | ├── .env.template 23 | ├── env.yaml 24 | ├── .flake8 25 | ├── .github 26 | │   └── workflows 27 | │   ├── publish.yml 28 | │   └── test_suite.yml 29 | ├── .gitignore 30 | ├── LICENSE 31 | ├── mkdocs.yml 32 | ├── .pre-commit-config.yaml 33 | ├── pyproject.toml 34 | ├── README.md 35 | ├── setup.cfg 36 | ├── setup.py 37 | ├── src 38 | │   └── awesome_project 39 | │   ├── data 40 | │   │   ├── datamodule.py 41 | │   │   ├── dataset.py 42 | │   │   └── __init__.py 43 | │   ├── __init__.py 44 | │   ├── modules 45 | │   │   ├── __init__.py 46 | │   │   └── module.py 47 | │   ├── pl_modules 48 | │   │   ├── __init__.py 49 | │   │   └── pl_module.py 50 | │   ├── run.py 51 | │   └── ui 52 | │   ├── __init__.py 53 | │   └── run.py 54 | └── tests 55 | ├── conftest.py 56 | ├── __init__.py 57 | ├── test_checkpoint.py 58 | ├── test_configuration.py 59 | ├── test_nn_core_integration.py 60 | ├── test_resume.py 61 | ├── test_storage.py 62 | └── test_training.py 63 | ``` -------------------------------------------------------------------------------- /hooks/post_gen_project.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | import subprocess 3 | import sys 4 | import textwrap 5 | from dataclasses import dataclass, field 6 | from distutils.util import strtobool 7 | from typing import Dict, List, Optional 8 | 9 | 10 | def initialize_env_variables( 11 | env_file: str = ".env", env_file_template: str = ".env.template" 12 | ) -> None: 13 | """Initialize the .env file""" 14 | shutil.copy(src=env_file_template, dst=env_file) 15 | 16 | 17 | def bool_query(question: str, default: Optional[bool] = None) -> bool: 18 | """Ask a yes/no question via input() and return their boolean answer. 19 | 20 | Args: 21 | question: is a string that is presented to the user. 22 | default: is the presumed answer if the user just hits . 23 | 24 | Returns: 25 | the boolean representation of the user answer, or the default value if present. 26 | """ 27 | if default is None: 28 | prompt = " [y/n] " 29 | elif default: 30 | prompt = " [Y/n] " 31 | else: 32 | prompt = " [y/N] " 33 | 34 | while True: 35 | sys.stdout.write(question + prompt) 36 | choice = input().lower() 37 | 38 | if default is not None and not choice: 39 | return default 40 | 41 | try: 42 | return strtobool(choice) 43 | except ValueError: 44 | sys.stdout.write("Please respond with 'yes' or 'no' " "(or 'y' or 'n').\n") 45 | 46 | 47 | @dataclass 48 | class Dependency: 49 | expected: bool 50 | id: str 51 | 52 | 53 | @dataclass 54 | class Query: 55 | id: str 56 | interactive: bool 57 | default: bool 58 | prompt: str 59 | command: str 60 | autorun: bool 61 | dependencies: List[Dependency] = field(default_factory=list) 62 | 63 | 64 | SETUP_COMMANDS: List[Query] = [ 65 | Query( 66 | id="git_init", 67 | interactive=True, 68 | default=True, 69 | prompt="Initializing git repository...", 70 | command="git init\n" 71 | "git add --all\n" 72 | 'git commit -m "Initialize project from nn-template={{ cookiecutter.__version }}"', 73 | autorun=True, 74 | ), 75 | Query( 76 | id="git_remote", 77 | interactive=True, 78 | default=True, 79 | prompt="Adding an existing git remote...\n(You should create the remote from the web UI before proceeding!)", 80 | command="git remote add origin git@github.com:{{ cookiecutter.github_user }}/{{ cookiecutter.repository_name }}.git", 81 | autorun=True, 82 | dependencies=[ 83 | Dependency(id="git_init", expected=True), 84 | ], 85 | ), 86 | Query( 87 | id="git_push_main", 88 | interactive=True, 89 | default=True, 90 | prompt="Pushing default branch to existing remote...", 91 | command="git push -u origin HEAD", 92 | autorun=True, 93 | dependencies=[ 94 | Dependency(id="git_remote", expected=True), 95 | ], 96 | ), 97 | Query( 98 | id="conda_env", 99 | interactive=True, 100 | default=True, 101 | prompt="Creating conda environment...", 102 | command="conda env create -f env.yaml", 103 | autorun=True, 104 | ), 105 | Query( 106 | id="precommit_install", 107 | interactive=True, 108 | default=True, 109 | prompt="Installing pre-commits...", 110 | command="conda run -n {{ cookiecutter.conda_env_name }} pre-commit install", 111 | autorun=True, 112 | dependencies=[ 113 | Dependency(id="git_init", expected=True), 114 | Dependency(id="conda_env", expected=True), 115 | ], 116 | ), 117 | Query( 118 | id="mike_init", 119 | interactive=True, 120 | default=True, 121 | prompt="Initializing gh-pages branch for GitHub Pages...", 122 | command="conda run -n {{ cookiecutter.conda_env_name }} mike deploy --update-aliases 0.0 latest\n" 123 | "conda run -n {{ cookiecutter.conda_env_name }} mike set-default latest", 124 | autorun=True, 125 | dependencies=[ 126 | Dependency(id="conda_env", expected=True), 127 | Dependency(id="git_init", expected=True), 128 | ], 129 | ), 130 | Query( 131 | id="mike_push", 132 | interactive=True, 133 | default=True, 134 | prompt="Pushing 'gh-pages' branch to existing remote...", 135 | command="git push origin gh-pages", 136 | autorun=True, 137 | dependencies=[ 138 | Dependency(id="mike_init", expected=True), 139 | Dependency(id="git_remote", expected=True), 140 | ], 141 | ), 142 | Query( 143 | id="conda_activate", 144 | interactive=False, 145 | default=True, 146 | prompt="Activate your conda environment with:", 147 | command="cd {{ cookiecutter.repository_name }}\n" 148 | "conda activate {{ cookiecutter.conda_env_name }}\n" 149 | "pytest -v", 150 | autorun=False, 151 | dependencies=[ 152 | Dependency(id="conda_env", expected=True), 153 | ], 154 | ), 155 | ] 156 | 157 | 158 | def should_execute_query(query: Query, answers: Dict[str, bool]) -> bool: 159 | if not query.dependencies: 160 | return True 161 | return all( 162 | dependency.expected == answers.get(dependency.id, False) 163 | for dependency in query.dependencies 164 | ) 165 | 166 | 167 | def setup(setup_commands) -> None: 168 | answers: Dict[str, bool] = {} 169 | 170 | for query in setup_commands: 171 | assert query.id not in answers 172 | 173 | if should_execute_query(query=query, answers=answers): 174 | if query.interactive: 175 | answers[query.id] = bool_query( 176 | question=f"\n" 177 | f"{query.prompt}\n" 178 | f"\n" 179 | f'{textwrap.indent(query.command, prefix=" ")}\n' 180 | f"\n" 181 | f"Execute?", 182 | default=query.default, 183 | ) 184 | else: 185 | print( 186 | f"\n" 187 | f"{query.prompt}\n" 188 | f"\n" 189 | f'{textwrap.indent(query.command, prefix=" ")}\n' 190 | ) 191 | answers[query.id] = True 192 | 193 | if answers[query.id] and (query.interactive or query.autorun): 194 | try: 195 | subprocess.run( 196 | query.command, 197 | check=True, 198 | text=True, 199 | shell=True, 200 | ) 201 | except subprocess.CalledProcessError: 202 | answers[query.id] = False 203 | print() 204 | 205 | 206 | initialize_env_variables() 207 | setup(setup_commands=SETUP_COMMANDS) 208 | 209 | print( 210 | "\nYou are all set!\n\n" 211 | "Remember that if you use PyCharm, you must:\n" 212 | ' - Mark the "src" directory as "Sources Root".\n' 213 | ' - Enable "Emulate terminal in output console" in the run configuration.\n' 214 | ) 215 | print( 216 | "Remember to:\n" 217 | " - Ensure the GitHub Actions in the repository are enabled.\n" 218 | " - Ensure the Github Pages are enabled to auto-publish the docs on each release." 219 | ) 220 | 221 | print("Have fun! :]") 222 | -------------------------------------------------------------------------------- /hooks/pre_gen_project.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/grok-ai/nn-template/8ba02bba8f015e1eb7efb0d2ab8c9d433bd1c431/hooks/pre_gen_project.py -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: NN Template 2 | site_description: Generic template to bootstrap your PyTorch project with PyTorch Lightning, Hydra, W&B, DVC, and Streamlit. 3 | repo_url: https://github.com/grok-ai/nn-template 4 | copyright: Copyright © 2021 - 2022 Valentino Maiorca | Luca Moschella 5 | nav: 6 | - Home: index.md 7 | - Getting started: 8 | - getting-started/index.md 9 | - Generating your project: getting-started/generation.md 10 | - Features: 11 | - Core: features/nncore.md 12 | - Restore: features/restore.md 13 | - Metadata: features/metadata.md 14 | - Tags: features/tags.md 15 | - Docs: features/docs.md 16 | - Tests: features/tests.md 17 | - Storage: features/storage.md 18 | - Determinism: features/determinism.md 19 | - Fast dev run: features/fastdevrun.md 20 | - Environment variables: features/envvars.md 21 | - CI/CD: features/cicd.md 22 | - Best practices: 23 | - Python Environment: features/conda.md 24 | - Tooling: features/bestpractices.md 25 | # - Project Structure: 26 | # - project-structure/index.md 27 | # - Structure: project-structure/structure.md 28 | # - Conf: project-structure/conf.md 29 | # - NN: project-structure/conf.md 30 | # - Train: project-structure/conf.md 31 | - Integrations: 32 | - PyTorch Lightning: integrations/lightning.md 33 | - Hydra: integrations/hydra.md 34 | - Weigth & Biases: integrations/wandb.md 35 | - Streamlit: integrations/streamlit.md 36 | - MkDocs: integrations/mkdocs.md 37 | - DVC: integrations/dvc.md 38 | - GitHub Actions: integrations/githubactions.md 39 | - Publications: papers.md 40 | - Changelog: 41 | - changelog/index.md 42 | - Upgrade: changelog/upgrade.md 43 | 44 | theme: 45 | name: material 46 | custom_dir: docs/overrides 47 | icon: 48 | repo: fontawesome/brands/github 49 | 50 | features: 51 | - content.code.annotate 52 | - navigation.indexes 53 | - navigation.instant 54 | - navigation.sections 55 | - navigation.tabs 56 | - navigation.tabs.sticky 57 | - navigation.top 58 | - navigation.tracking 59 | - search.highlight 60 | - search.share 61 | - search.suggest 62 | 63 | palette: 64 | - scheme: default 65 | primary: light green 66 | accent: green 67 | toggle: 68 | icon: material/weather-night 69 | name: Switch to dark mode 70 | - scheme: slate 71 | primary: green 72 | accent: green 73 | toggle: 74 | icon: material/weather-sunny 75 | name: Switch to light mode 76 | 77 | # Extensions 78 | markdown_extensions: 79 | - abbr 80 | - admonition 81 | - attr_list 82 | - def_list 83 | - footnotes 84 | - meta 85 | - md_in_html 86 | - toc: 87 | permalink: true 88 | - tables 89 | - pymdownx.arithmatex: 90 | generic: true 91 | - pymdownx.betterem: 92 | smart_enable: all 93 | - pymdownx.caret 94 | - pymdownx.mark 95 | - pymdownx.tilde 96 | - pymdownx.critic 97 | - pymdownx.details 98 | - pymdownx.highlight: 99 | anchor_linenums: true 100 | - pymdownx.superfences: 101 | custom_fences: 102 | - name: mermaid 103 | class: mermaid 104 | format: !!python/name:pymdownx.superfences.fence_code_format 105 | - pymdownx.inlinehilite 106 | - pymdownx.keys 107 | - pymdownx.smartsymbols 108 | - pymdownx.snippets 109 | - pymdownx.tabbed: 110 | alternate_style: true 111 | - pymdownx.tasklist: 112 | custom_checkbox: true 113 | - pymdownx.emoji: 114 | emoji_index: !!python/name:materialx.emoji.twemoji 115 | emoji_generator: !!python/name:materialx.emoji.to_svg 116 | 117 | extra_javascript: 118 | - javascripts/mathjax.js 119 | - https://polyfill.io/v3/polyfill.min.js?features=es6 120 | - https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js 121 | 122 | plugins: 123 | - search 124 | 125 | extra: 126 | generator: true 127 | version: 128 | provider: mike 129 | -------------------------------------------------------------------------------- /{{ cookiecutter.repository_name }}/.editorconfig: -------------------------------------------------------------------------------- 1 | # EditorConfig is awesome: https://EditorConfig.org 2 | 3 | # top-most EditorConfig file 4 | root = true 5 | 6 | [*] 7 | end_of_line = lf 8 | insert_final_newline = true 9 | max_line_length = 120 10 | trim_trailing_whitespace = true 11 | 12 | # 4 space indentation 13 | [*.py] 14 | indent_style = space 15 | indent_size = 4 16 | charset = utf-8 17 | 18 | [*.{yaml, yml}] 19 | indent_size = 2 20 | -------------------------------------------------------------------------------- /{{ cookiecutter.repository_name }}/.env.template: -------------------------------------------------------------------------------- 1 | # .env.template is a template for .env file that can be versioned. 2 | 3 | # Set to 1 to show full stack trace on error, 0 to hide it 4 | HYDRA_FULL_ERROR=1 5 | 6 | # Configure where huggingface_hub will locally store data. 7 | HF_HOME="~/.cache/huggingface" 8 | 9 | # Configure the User Access Token to authenticate to the Hub 10 | # HUGGING_FACE_HUB_TOKEN= 11 | -------------------------------------------------------------------------------- /{{ cookiecutter.repository_name }}/.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | # https://github.com/pytorch/pytorch/blob/master/.flake8 3 | max-line-length = 120 4 | select = B,C,E,F,P,T4,W,B9 5 | extend-ignore = E203, E501 6 | -------------------------------------------------------------------------------- /{{ cookiecutter.repository_name }}/.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish 2 | 3 | on: 4 | release: 5 | types: 6 | - created 7 | 8 | env: 9 | CACHE_NUMBER: 0 # increase to reset cache manually 10 | CONDA_ENV_FILE: './env.yaml' 11 | CONDA_ENV_NAME: '{{ cookiecutter.conda_env_name }}' 12 | {% raw %} 13 | HUGGING_FACE_HUB_TOKEN: ${{secrets.HUGGING_FACE_HUB_TOKEN}} 14 | 15 | jobs: 16 | build: 17 | strategy: 18 | fail-fast: false 19 | matrix: 20 | python-version: ['3.9'] 21 | include: 22 | - os: ubuntu-20.04 23 | label: linux-64 24 | prefix: /usr/share/miniconda3/envs/ 25 | 26 | name: ${{ matrix.label }}-py${{ matrix.python-version }} 27 | runs-on: ${{ matrix.os }} 28 | 29 | steps: 30 | - name: Parametrize conda env name 31 | run: echo "PY_CONDA_ENV_NAME=${{ env.CONDA_ENV_NAME }}-${{ matrix.python-version }}" >> $GITHUB_ENV 32 | - name: echo conda env name 33 | run: echo ${{ env.PY_CONDA_ENV_NAME }} 34 | 35 | - name: Parametrize conda prefix 36 | run: echo "PY_PREFIX=${{ matrix.prefix }}${{ env.PY_CONDA_ENV_NAME }}" >> $GITHUB_ENV 37 | - name: echo conda prefix 38 | run: echo ${{ env.PY_PREFIX }} 39 | 40 | # extract the first two digits from the release note 41 | - name: Set release notes tag 42 | run: | 43 | export RELEASE_TAG_VERSION=${{ github.event.release.tag_name }} 44 | echo "RELEASE_TAG_VERSION=${RELEASE_TAG_VERSION%.*}">> $GITHUB_ENV 45 | 46 | - name: Echo release notes tag 47 | run: | 48 | echo "${RELEASE_TAG_VERSION}" 49 | 50 | - uses: actions/checkout@v2 51 | with: 52 | fetch-depth: 0 53 | 54 | # Remove the python version pin from the env.yml which could be inconsistent 55 | - name: Remove explicit python version from the environment 56 | shell: bash -l {0} 57 | run: | 58 | sed -Ei '/^\s*-?\s*python\s*([#=].*)?$/d' ${{ env.CONDA_ENV_FILE }} 59 | cat ${{ env.CONDA_ENV_FILE }} 60 | 61 | - name: Setup Mambaforge 62 | uses: conda-incubator/setup-miniconda@v2 63 | with: 64 | miniforge-variant: Mambaforge 65 | miniforge-version: latest 66 | activate-environment: ${{ env.PY_CONDA_ENV_NAME }} 67 | python-version: ${{ matrix.python-version }} 68 | use-mamba: true 69 | 70 | - uses: actions/cache@v2 71 | name: Conda cache 72 | with: 73 | path: ${{ env.PY_PREFIX }} 74 | key: ${{ matrix.label }}-conda-${{ matrix.python-version }}-${{ env.CACHE_NUMBER }}-${{ env.PY_CONDA_ENV_NAME }}-${{ hashFiles(env.CONDA_ENV_FILE) }}-${{hashFiles('./setup.cfg') }} 75 | id: conda_cache 76 | 77 | - uses: actions/cache@v2 78 | name: Pip cache 79 | with: 80 | path: ~/.cache/pip 81 | key: ${{ matrix.label }}-pip-${{ matrix.python-version }}-${{ env.CACHE_NUMBER }}-${{ env.PY_CONDA_ENV_NAME }}-${{ hashFiles(env.CONDA_ENV_FILE) }}-${{hashFiles('./setup.cfg') }} 82 | 83 | - uses: actions/cache@v2 84 | name: Pre-commit cache 85 | with: 86 | path: ~/.cache/pre-commit 87 | key: ${{ matrix.label }}-pre-commit-${{ hashFiles('.pre-commit-config.yaml') }}-${{ matrix.python-version }}-${{ env.CACHE_NUMBER }}-${{ env.PY_CONDA_ENV_NAME }}-${{ hashFiles(env.CONDA_ENV_FILE) }}-${{hashFiles('./setup.cfg') }} 88 | 89 | # Ensure the hack for the python version worked 90 | - name: Ensure we have the right Python 91 | shell: bash -l {0} 92 | run: | 93 | echo "Installed Python: $(python --version)" 94 | echo "Expected: ${{ matrix.python-version }}" 95 | python --version | grep "Python ${{ matrix.python-version }}" 96 | 97 | # https://stackoverflow.com/questions/70520120/attributeerror-module-setuptools-distutils-has-no-attribute-version 98 | # https://github.com/pytorch/pytorch/pull/69904 99 | - name: Downgrade setuptools due to a but in PyTorch 1.10.1 100 | shell: bash -l {0} 101 | run: | 102 | pip install setuptools==59.5.0 --upgrade 103 | 104 | - name: Update conda environment 105 | run: mamba env update -n ${{ env.PY_CONDA_ENV_NAME }} -f ${{ env.CONDA_ENV_FILE }} 106 | if: steps.conda_cache.outputs.cache-hit != 'true' 107 | 108 | # Update pip env whether or not there was a conda cache hit 109 | - name: Update pip environment 110 | shell: bash -l {0} 111 | run: pip install -e ".[dev]" 112 | if: steps.conda_cache.outputs.cache-hit == 'true' 113 | 114 | - run: pip3 list 115 | shell: bash -l {0} 116 | - run: mamba info 117 | - run: mamba list 118 | 119 | # Ensure the hack for the python version worked 120 | - name: Ensure we have the right Python 121 | shell: bash -l {0} 122 | run: | 123 | echo "Installed Python: $(python --version)" 124 | echo "Expected: ${{ matrix.python-version }}" 125 | python --version | grep "Python ${{ matrix.python-version }}" 126 | 127 | - name: Run pre-commits 128 | shell: bash -l {0} 129 | run: | 130 | pre-commit install 131 | pre-commit run -v --all-files --show-diff-on-failure 132 | 133 | - name: Test with pytest 134 | shell: bash -l {0} 135 | run: | 136 | pytest -v 137 | 138 | - name: Build docs website 139 | shell: bash -l {0} 140 | run: | 141 | git config user.name ci-bot 142 | git config user.email ci-bot@ci.com 143 | mike deploy --rebase --push --update-aliases ${RELEASE_TAG_VERSION} latest 144 | 145 | # Uncomment to publish on PyPI on release 146 | # - name: Build SDist and wheel 147 | # run: pipx run build 148 | # 149 | # - name: Check metadata 150 | # run: pipx run twine check dist/* 151 | # 152 | # - name: Publish distribution 📦 to PyPI 153 | # uses: pypa/gh-action-pypi-publish@release/v1 154 | # with: 155 | # user: __token__ 156 | # password: ${{ secrets.PYPI_API_TOKEN }} 157 | # 158 | #{% endraw %} 159 | -------------------------------------------------------------------------------- /{{ cookiecutter.repository_name }}/.github/workflows/test_suite.yml: -------------------------------------------------------------------------------- 1 | name: Test Suite 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | - develop 8 | 9 | pull_request: 10 | types: 11 | - opened 12 | - reopened 13 | - synchronize 14 | 15 | env: 16 | CACHE_NUMBER: 1 # increase to reset cache manually 17 | CONDA_ENV_FILE: './env.yaml' 18 | CONDA_ENV_NAME: '{{ cookiecutter.conda_env_name }}' 19 | {% raw %} 20 | HUGGING_FACE_HUB_TOKEN: ${{secrets.HUGGING_FACE_HUB_TOKEN}} 21 | 22 | jobs: 23 | build: 24 | 25 | strategy: 26 | fail-fast: false 27 | matrix: 28 | python-version: ['3.11'] 29 | include: 30 | - os: ubuntu-20.04 31 | label: linux-64 32 | prefix: /usr/share/miniconda3/envs/ 33 | 34 | # - os: macos-latest 35 | # label: osx-64 36 | # prefix: /Users/runner/miniconda3/envs/$CONDA_ENV_NAME 37 | 38 | # - os: windows-latest 39 | # label: win-64 40 | # prefix: C:\Miniconda3\envs\$CONDA_ENV_NAME 41 | 42 | name: ${{ matrix.label }}-py${{ matrix.python-version }} 43 | runs-on: ${{ matrix.os }} 44 | 45 | steps: 46 | - name: Parametrize conda env name 47 | run: echo "PY_CONDA_ENV_NAME=${{ env.CONDA_ENV_NAME }}-${{ matrix.python-version }}" >> $GITHUB_ENV 48 | - name: echo conda env name 49 | run: echo ${{ env.PY_CONDA_ENV_NAME }} 50 | 51 | - name: Parametrize conda prefix 52 | run: echo "PY_PREFIX=${{ matrix.prefix }}${{ env.PY_CONDA_ENV_NAME }}" >> $GITHUB_ENV 53 | - name: echo conda prefix 54 | run: echo ${{ env.PY_PREFIX }} 55 | 56 | - uses: actions/checkout@v2 57 | 58 | # Remove the python version pin from the env.yml which could be inconsistent 59 | - name: Remove explicit python version from the environment 60 | shell: bash -l {0} 61 | run: | 62 | sed -Ei '/^\s*-?\s*python\s*([#=].*)?$/d' ${{ env.CONDA_ENV_FILE }} 63 | cat ${{ env.CONDA_ENV_FILE }} 64 | 65 | # Install torch cpu-only 66 | - name: Install torch cpu only 67 | shell: bash -l {0} 68 | run: | 69 | sed -i '/nvidia\|cuda/d' ${{ env.CONDA_ENV_FILE }} 70 | cat ${{ env.CONDA_ENV_FILE }} 71 | 72 | - name: Setup Mambaforge 73 | uses: conda-incubator/setup-miniconda@v2 74 | with: 75 | miniforge-variant: Mambaforge 76 | miniforge-version: latest 77 | activate-environment: ${{ env.PY_CONDA_ENV_NAME }} 78 | python-version: ${{ matrix.python-version }} 79 | use-mamba: true 80 | 81 | - uses: actions/cache@v2 82 | name: Conda cache 83 | with: 84 | path: ${{ env.PY_PREFIX }} 85 | key: ${{ matrix.label }}-conda-${{ matrix.python-version }}-${{ env.CACHE_NUMBER }}-${{ env.PY_CONDA_ENV_NAME }}-${{ hashFiles(env.CONDA_ENV_FILE) }}-${{hashFiles('./setup.cfg') }} 86 | id: conda_cache 87 | 88 | - uses: actions/cache@v2 89 | name: Pip cache 90 | with: 91 | path: ~/.cache/pip 92 | key: ${{ matrix.label }}-pip-${{ matrix.python-version }}-${{ env.CACHE_NUMBER }}-${{ env.PY_CONDA_ENV_NAME }}-${{ hashFiles(env.CONDA_ENV_FILE) }}-${{hashFiles('./setup.cfg') }} 93 | 94 | - uses: actions/cache@v2 95 | name: Pre-commit cache 96 | with: 97 | path: ~/.cache/pre-commit 98 | key: ${{ matrix.label }}-pre-commit-${{ hashFiles('.pre-commit-config.yaml') }}-${{ matrix.python-version }}-${{ env.CACHE_NUMBER }}-${{ env.PY_CONDA_ENV_NAME }}-${{ hashFiles(env.CONDA_ENV_FILE) }}-${{hashFiles('./setup.cfg') }} 99 | 100 | # Ensure the hack for the python version worked 101 | - name: Ensure we have the right Python 102 | shell: bash -l {0} 103 | run: | 104 | echo "Installed Python: $(python --version)" 105 | echo "Expected: ${{ matrix.python-version }}" 106 | python --version | grep "Python ${{ matrix.python-version }}" 107 | 108 | 109 | # https://stackoverflow.com/questions/70520120/attributeerror-module-setuptools-distutils-has-no-attribute-version 110 | # https://github.com/pytorch/pytorch/pull/69904 111 | - name: Downgrade setuptools due to a but in PyTorch 1.10.1 112 | shell: bash -l {0} 113 | run: | 114 | pip install setuptools==59.5.0 --upgrade 115 | 116 | - name: Update conda environment 117 | run: mamba env update -n ${{ env.PY_CONDA_ENV_NAME }} -f ${{ env.CONDA_ENV_FILE }} 118 | if: steps.conda_cache.outputs.cache-hit != 'true' 119 | 120 | # Update pip env whether or not there was a conda cache hit 121 | - name: Update pip environment 122 | shell: bash -l {0} 123 | run: pip install -e ".[dev]" 124 | if: steps.conda_cache.outputs.cache-hit == 'true' 125 | 126 | - run: pip3 list 127 | shell: bash -l {0} 128 | - run: mamba info 129 | - run: mamba list 130 | 131 | # Ensure the hack for the python version worked 132 | - name: Ensure we have the right Python 133 | shell: bash -l {0} 134 | run: | 135 | echo "Installed Python: $(python --version)" 136 | echo "Expected: ${{ matrix.python-version }}" 137 | python --version | grep "Python ${{ matrix.python-version }}" 138 | 139 | - name: Run pre-commits 140 | shell: bash -l {0} 141 | run: | 142 | pre-commit install 143 | pre-commit run -v --all-files --show-diff-on-failure 144 | 145 | - name: Test with pytest 146 | shell: bash -l {0} 147 | run: | 148 | pytest -v 149 | # 150 | #{% endraw %} 151 | -------------------------------------------------------------------------------- /{{ cookiecutter.repository_name }}/.gitignore: -------------------------------------------------------------------------------- 1 | wandb 2 | multirun.yaml 3 | storage 4 | 5 | # ignore the _version.py file 6 | _version.py 7 | 8 | # .gitignore defaults for python and pycharm 9 | .idea 10 | 11 | # Byte-compiled / optimized / DLL files 12 | __pycache__/ 13 | *.py[cod] 14 | *$py.class 15 | 16 | # C extensions 17 | *.so 18 | 19 | # Distribution / packaging 20 | .Python 21 | build/ 22 | develop-eggs/ 23 | dist/ 24 | downloads/ 25 | eggs/ 26 | .eggs/ 27 | lib/ 28 | lib64/ 29 | parts/ 30 | sdist/ 31 | var/ 32 | wheels/ 33 | share/python-wheels/ 34 | *.egg-info/ 35 | .installed.cfg 36 | *.egg 37 | MANIFEST 38 | 39 | # PyInstaller 40 | # Usually these files are written by a python script from a template 41 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 42 | *.manifest 43 | *.spec 44 | 45 | # Installer logs 46 | pip-log.txt 47 | pip-delete-this-directory.txt 48 | 49 | # Unit test / coverage reports 50 | htmlcov/ 51 | .tox/ 52 | .nox/ 53 | .coverage 54 | .coverage.* 55 | .cache 56 | nosetests.xml 57 | coverage.xml 58 | *.cover 59 | *.py,cover 60 | .hypothesis/ 61 | .pytest_cache/ 62 | cover/ 63 | 64 | # Translations 65 | *.mo 66 | *.pot 67 | 68 | # Django stuff: 69 | *.log 70 | local_settings.py 71 | db.sqlite3 72 | db.sqlite3-journal 73 | 74 | # Flask stuff: 75 | instance/ 76 | .webassets-cache 77 | 78 | # Scrapy stuff: 79 | .scrapy 80 | 81 | # Sphinx documentation 82 | docs/_build/ 83 | 84 | # PyBuilder 85 | .pybuilder/ 86 | target/ 87 | 88 | # Jupyter Notebook 89 | .ipynb_checkpoints 90 | 91 | # IPython 92 | profile_default/ 93 | ipython_config.py 94 | 95 | # pyenv 96 | # For a library or package, you might want to ignore these files since the code is 97 | # intended to run in multiple environments; otherwise, check them in: 98 | # .python-version 99 | 100 | # pipenv 101 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 102 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 103 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 104 | # install all needed dependencies. 105 | #Pipfile.lock 106 | 107 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 108 | __pypackages__/ 109 | 110 | # Celery stuff 111 | celerybeat-schedule 112 | celerybeat.pid 113 | 114 | # SageMath parsed files 115 | *.sage.py 116 | 117 | # Environments 118 | .env 119 | .venv 120 | env/ 121 | venv/ 122 | ENV/ 123 | env.bak/ 124 | venv.bak/ 125 | 126 | # Spyder project settings 127 | .spyderproject 128 | .spyproject 129 | 130 | # Rope project settings 131 | .ropeproject 132 | 133 | # mkdocs documentation 134 | /site 135 | 136 | # mypy 137 | .mypy_cache/ 138 | .dmypy.json 139 | dmypy.json 140 | 141 | # Pyre type checker 142 | .pyre/ 143 | 144 | # pytype static type analyzer 145 | .pytype/ 146 | 147 | # Cython debug symbols 148 | cython_debug/ 149 | 150 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider 151 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 152 | 153 | # User-specific stuff 154 | .idea/**/workspace.xml 155 | .idea/**/tasks.xml 156 | .idea/**/usage.statistics.xml 157 | .idea/**/dictionaries 158 | .idea/**/shelf 159 | 160 | # Generated files 161 | .idea/**/contentModel.xml 162 | 163 | # Sensitive or high-churn files 164 | .idea/**/dataSources/ 165 | .idea/**/dataSources.ids 166 | .idea/**/dataSources.local.xml 167 | .idea/**/sqlDataSources.xml 168 | .idea/**/dynamic.xml 169 | .idea/**/uiDesigner.xml 170 | .idea/**/dbnavigator.xml 171 | 172 | # Gradle 173 | .idea/**/gradle.xml 174 | .idea/**/libraries 175 | 176 | # Gradle and Maven with auto-import 177 | # When using Gradle or Maven with auto-import, you should exclude module files, 178 | # since they will be recreated, and may cause churn. Uncomment if using 179 | # auto-import. 180 | # .idea/artifacts 181 | # .idea/compiler.xml 182 | # .idea/jarRepositories.xml 183 | # .idea/modules.xml 184 | # .idea/*.iml 185 | # .idea/modules 186 | # *.iml 187 | # *.ipr 188 | 189 | # CMake 190 | cmake-build-*/ 191 | 192 | # Mongo Explorer plugin 193 | .idea/**/mongoSettings.xml 194 | 195 | # File-based project format 196 | *.iws 197 | 198 | # IntelliJ 199 | out/ 200 | 201 | # mpeltonen/sbt-idea plugin 202 | .idea_modules/ 203 | 204 | # JIRA plugin 205 | atlassian-ide-plugin.xml 206 | 207 | # Cursive Clojure plugin 208 | .idea/replstate.xml 209 | 210 | # Crashlytics plugin (for Android Studio and IntelliJ) 211 | com_crashlytics_export_strings.xml 212 | crashlytics.properties 213 | crashlytics-build.properties 214 | fabric.properties 215 | 216 | # Editor-based Rest Client 217 | .idea/httpRequests 218 | 219 | # Android studio 3.1+ serialized cache file 220 | .idea/caches/build_file_checksums.ser 221 | -------------------------------------------------------------------------------- /{{ cookiecutter.repository_name }}/.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # See https://pre-commit.com for more information 2 | # See https://pre-commit.com/hooks.html for more hooks 3 | repos: 4 | - repo: https://github.com/pre-commit/pre-commit-hooks 5 | rev: v4.1.0 6 | hooks: 7 | - id: check-added-large-files # prevents giant files from being committed. 8 | args: ['--maxkb=4096'] 9 | - id: check-ast # simply checks whether the files parse as valid python. 10 | # - id: check-byte-order-marker # forbids files which have a utf-8 byte-order marker. 11 | # - id: check-builtin-literals # requires literal syntax when initializing empty or zero python builtin types. 12 | # - id: check-case-conflict # checks for files that would conflict in case-insensitive filesystems. 13 | - id: check-docstring-first # checks a common error of defining a docstring after code. 14 | - id: check-executables-have-shebangs # ensures that (non-binary) executables have a shebang. 15 | # - id: check-json # checks json files for parseable syntax. 16 | # - id: check-shebang-scripts-are-executable # ensures that (non-binary) files with a shebang are executable. 17 | # - id: pretty-format-json # sets a standard for formatting json files. 18 | - id: check-merge-conflict # checks for files that contain merge conflict strings. 19 | - id: check-symlinks # checks for symlinks which do not point to anything. 20 | - id: check-toml # checks toml files for parseable syntax. 21 | # - id: check-vcs-permalinks # ensures that links to vcs websites are permalinks. 22 | # - id: check-xml # checks xml files for parseable syntax. 23 | - id: check-yaml # checks yaml files for parseable syntax. 24 | - id: debug-statements # checks for debugger imports and py37+ `breakpoint()` calls in python source. 25 | - id: destroyed-symlinks # detects symlinks which are changed to regular files with a content of a path which that symlink was pointing to. 26 | # - id: detect-aws-credentials # detects *your* aws credentials from the aws cli credentials file. 27 | - id: detect-private-key # detects the presence of private keys. 28 | # - id: double-quote-string-fixer # replaces double quoted strings with single quoted strings. 29 | - id: end-of-file-fixer # ensures that a file is either empty, or ends with one newline. 30 | # - id: file-contents-sorter # sorts the lines in specified files (defaults to alphabetical). you must provide list of target files as input in your .pre-commit-config.yaml file. 31 | # - id: fix-byte-order-marker # removes utf-8 byte order marker. 32 | # - id: fix-encoding-pragma # adds # -*- coding: utf-8 -*- to the top of python files. 33 | # - id: forbid-new-submodules # prevents addition of new git submodules. 34 | - id: mixed-line-ending # replaces or checks mixed line ending. 35 | args: ['--fix=no'] 36 | # - id: name-tests-test # this verifies that test files are named correctly. 37 | # - id: no-commit-to-branch # don't commit to branch 38 | # - id: requirements-txt-fixer # sorts entries in requirements.txt. 39 | # - id: sort-simple-yaml # sorts simple yaml files which consist only of top-level keys, preserving comments and blocks. 40 | - id: trailing-whitespace # trims trailing whitespace. 41 | 42 | - repo: https://github.com/PyCQA/isort.git 43 | rev: '5.12.0' 44 | hooks: 45 | - id: isort 46 | 47 | - repo: https://github.com/psf/black.git 48 | rev: '23.7.0' 49 | hooks: 50 | - id: black 51 | - id: black-jupyter 52 | 53 | - repo: https://github.com/asottile/blacken-docs.git 54 | rev: '1.16.0' 55 | hooks: 56 | - id: blacken-docs 57 | 58 | - repo: https://github.com/PyCQA/flake8.git 59 | rev: '6.1.0' 60 | hooks: 61 | - id: flake8 62 | additional_dependencies: 63 | - flake8-docstrings==1.7.0 64 | 65 | - repo: https://github.com/pycqa/pydocstyle.git 66 | rev: '6.3.0' 67 | hooks: 68 | - id: pydocstyle 69 | additional_dependencies: 70 | - toml 71 | 72 | - repo: https://github.com/kynan/nbstripout.git 73 | rev: '0.6.1' 74 | hooks: 75 | - id: nbstripout 76 | 77 | - repo: https://github.com/PyCQA/bandit 78 | rev: '1.7.5' 79 | hooks: 80 | - id: bandit 81 | args: ['-c', 'pyproject.toml', '--recursive', 'src'] 82 | additional_dependencies: 83 | - toml 84 | -------------------------------------------------------------------------------- /{{ cookiecutter.repository_name }}/LICENSE: -------------------------------------------------------------------------------- 1 | ../LICENSE -------------------------------------------------------------------------------- /{{ cookiecutter.repository_name }}/README.md: -------------------------------------------------------------------------------- 1 | # {{ cookiecutter.project_name }} 2 | 3 |

4 | CI 5 | Docs 6 | NN Template 7 | Python 8 | Code style: black 9 |

10 | 11 | {{ cookiecutter.project_description }} 12 | 13 | 14 | ## Installation 15 | 16 | ```bash 17 | pip install git+ssh://git@github.com/{{ cookiecutter.github_user }}/{{ cookiecutter.repository_name }}.git 18 | ``` 19 | 20 | 21 | ## Quickstart 22 | 23 | [comment]: <> (> Fill me!) 24 | 25 | 26 | ## Development installation 27 | 28 | Setup the development environment: 29 | 30 | ```bash 31 | git clone git@github.com:{{ cookiecutter.github_user }}/{{ cookiecutter.repository_name }}.git 32 | cd {{ cookiecutter.repository_name }} 33 | conda env create -f env.yaml 34 | conda activate {{ cookiecutter.conda_env_name }} 35 | pre-commit install 36 | ``` 37 | 38 | Run the tests: 39 | 40 | ```bash 41 | pre-commit run --all-files 42 | pytest -v 43 | ``` 44 | 45 | 46 | ### Update the dependencies 47 | 48 | Re-install the project in edit mode: 49 | 50 | ```bash 51 | pip install -e .[dev] 52 | ``` 53 | -------------------------------------------------------------------------------- /{{ cookiecutter.repository_name }}/conf/default.yaml: -------------------------------------------------------------------------------- 1 | # metadata specialised for each experiment 2 | core: 3 | project_name: {{ cookiecutter.repository_name }} 4 | storage_dir: ${oc.env:PROJECT_ROOT}/storage 5 | version: 0.0.1 6 | tags: null 7 | 8 | conventions: 9 | x_key: 'x' 10 | y_key: 'y' 11 | 12 | defaults: 13 | - hydra: default 14 | - nn: default 15 | - train: default 16 | - _self_ # as last argument to allow the override of parameters via this main config 17 | # Decomment this parameter to get parallel job running 18 | # - hydra/launcher: joblib 19 | -------------------------------------------------------------------------------- /{{ cookiecutter.repository_name }}/conf/hydra/default.yaml: -------------------------------------------------------------------------------- 1 | ## https://github.com/facebookresearch/hydra/issues/910 2 | # Not changing the working directory 3 | run: 4 | dir: . 5 | sweep: 6 | dir: . 7 | subdir: . 8 | 9 | # Not saving the .hydra directory 10 | output_subdir: null 11 | 12 | job: 13 | env_set: 14 | WANDB_START_METHOD: thread 15 | WANDB_DIR: ${oc.env:PROJECT_ROOT} 16 | 17 | defaults: 18 | # Disable hydra logging configuration, otherwise the basicConfig does not have any effect 19 | - override job_logging: none 20 | - override hydra_logging: none 21 | -------------------------------------------------------------------------------- /{{ cookiecutter.repository_name }}/conf/nn/data/dataset/vision/mnist.yaml: -------------------------------------------------------------------------------- 1 | # This class defines which dataset to use, 2 | # and also how to split in train/[val]/test. 3 | _target_: {{ cookiecutter.package_name }}.utils.hf_io.load_hf_dataset 4 | name: "mnist" 5 | ref: "mnist" 6 | train_split: train 7 | # val_split: val 8 | val_percentage: 0.1 9 | test_split: test 10 | label_key: label 11 | data_key: image 12 | num_classes: 10 13 | input_shape: [1, 28, 28] 14 | standard_x_key: ${conventions.x_key} 15 | standard_y_key: ${conventions.y_key} 16 | transforms: 17 | _target_: {{ cookiecutter.package_name }}.utils.hf_io.HFTransform 18 | key: ${conventions.x_key} 19 | transform: 20 | _target_: torchvision.transforms.Compose 21 | transforms: 22 | - _target_: torchvision.transforms.ToTensor 23 | -------------------------------------------------------------------------------- /{{ cookiecutter.repository_name }}/conf/nn/data/default.yaml: -------------------------------------------------------------------------------- 1 | _target_: {{ cookiecutter.package_name }}.data.datamodule.MyDataModule 2 | 3 | val_images_fixed_idxs: [7371, 3963, 2861, 1701, 3172, 4 | 1749, 7023, 1606, 6481, 1377, 5 | 6003, 3593, 3410, 3399, 7277, 6 | 5337, 968, 8206, 288, 1968, 7 | 5677, 9156, 8139, 7660, 7089, 8 | 1893, 3845, 2084, 1944, 3375, 9 | 4848, 8704, 6038, 2183, 7422, 10 | 2682, 6878, 6127, 2941, 5823, 11 | 9129, 1798, 6477, 9264, 476, 12 | 3007, 4992, 1428, 9901, 5388] 13 | 14 | accelerator: ${train.trainer.accelerator} 15 | 16 | num_workers: 17 | train: 4 18 | val: 2 19 | test: 0 20 | 21 | batch_size: 22 | train: 512 23 | val: 128 24 | test: 16 25 | 26 | defaults: 27 | - _self_ 28 | - dataset: vision/mnist # pick one of the yamls in nn/data/ 29 | -------------------------------------------------------------------------------- /{{ cookiecutter.repository_name }}/conf/nn/default.yaml: -------------------------------------------------------------------------------- 1 | data: ??? 2 | 3 | module: 4 | optimizer: 5 | _target_: torch.optim.Adam 6 | lr: 1e-3 7 | betas: [ 0.9, 0.999 ] 8 | eps: 1e-08 9 | weight_decay: 0 10 | 11 | # lr_scheduler: 12 | # _target_: torch.optim.lr_scheduler.CosineAnnealingWarmRestarts 13 | # T_0: 20 14 | # T_mult: 1 15 | # eta_min: 0 16 | # last_epoch: -1 17 | # verbose: False 18 | 19 | 20 | defaults: 21 | - _self_ 22 | - data: default 23 | - module: default 24 | -------------------------------------------------------------------------------- /{{ cookiecutter.repository_name }}/conf/nn/module/default.yaml: -------------------------------------------------------------------------------- 1 | _target_: {{ cookiecutter.package_name }}.pl_modules.pl_module.MyLightningModule 2 | x_key: ${conventions.x_key} 3 | y_key: ${conventions.y_key} 4 | 5 | defaults: 6 | - _self_ 7 | - model: cnn 8 | -------------------------------------------------------------------------------- /{{ cookiecutter.repository_name }}/conf/nn/module/model/cnn.yaml: -------------------------------------------------------------------------------- 1 | _target_: {{ cookiecutter.package_name }}.modules.module.CNN 2 | input_shape: ${nn.data.dataset.input_shape} 3 | -------------------------------------------------------------------------------- /{{ cookiecutter.repository_name }}/conf/train/default.yaml: -------------------------------------------------------------------------------- 1 | # reproducibility 2 | seed_index: 0 3 | deterministic: False 4 | 5 | # PyTorch Lightning Trainer https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html 6 | trainer: 7 | fast_dev_run: False # Enable this for debug purposes 8 | accelerator: 'gpu' 9 | devices: 1 10 | precision: 32 11 | max_epochs: 3 12 | max_steps: 10000 13 | num_sanity_val_steps: 2 14 | gradient_clip_val: 10.0 15 | val_check_interval: 1.0 16 | deterministic: ${train.deterministic} 17 | 18 | restore: 19 | ckpt_or_run_path: null 20 | mode: null # null, finetune, hotstart, continue 21 | 22 | monitor: 23 | metric: 'loss/val' 24 | mode: 'min' 25 | 26 | callbacks: 27 | - _target_: lightning.pytorch.callbacks.EarlyStopping 28 | patience: 42 29 | verbose: False 30 | monitor: ${train.monitor.metric} 31 | mode: ${train.monitor.mode} 32 | 33 | - _target_: lightning.pytorch.callbacks.ModelCheckpoint 34 | save_top_k: 1 35 | verbose: False 36 | monitor: ${train.monitor.metric} 37 | mode: ${train.monitor.mode} 38 | 39 | - _target_: lightning.pytorch.callbacks.LearningRateMonitor 40 | logging_interval: "step" 41 | log_momentum: False 42 | 43 | - _target_: lightning.pytorch.callbacks.progress.tqdm_progress.TQDMProgressBar 44 | refresh_rate: 20 45 | 46 | logging: 47 | upload: 48 | run_files: true 49 | source: true 50 | 51 | logger: 52 | _target_: lightning.pytorch.loggers.WandbLogger 53 | 54 | project: ${core.project_name} 55 | entity: null 56 | log_model: ${..upload.run_files} 57 | mode: 'online' 58 | tags: ${core.tags} 59 | 60 | wandb_watch: 61 | log: 'all' 62 | log_freq: 100 63 | -------------------------------------------------------------------------------- /{{ cookiecutter.repository_name }}/data/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | 3 | !.gitignore 4 | -------------------------------------------------------------------------------- /{{ cookiecutter.repository_name }}/docs/index.md: -------------------------------------------------------------------------------- 1 | ../README.md -------------------------------------------------------------------------------- /{{ cookiecutter.repository_name }}/docs/overrides/main.html: -------------------------------------------------------------------------------- 1 | {% raw %}{% extends "base.html" %} 2 | 3 | {% block outdated %} 4 | You're not viewing the latest version. 5 | 6 | Click here to go to latest. 7 | 8 | {% endblock %}{% endraw %} 9 | -------------------------------------------------------------------------------- /{{ cookiecutter.repository_name }}/env.yaml: -------------------------------------------------------------------------------- 1 | name: {{ cookiecutter.conda_env_name }} 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - defaults 6 | 7 | dependencies: 8 | - python={{ cookiecutter.python_version }} 9 | - pytorch=2.0.* 10 | - torchvision 11 | - pytorch-cuda=11.8 12 | - pip 13 | - pip: 14 | - -e .[dev] 15 | -------------------------------------------------------------------------------- /{{ cookiecutter.repository_name }}/mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: {{ cookiecutter.repository_name }} 2 | site_description: {{ cookiecutter.project_description }} 3 | repo_url: {{ cookiecutter.repository_url }} 4 | 5 | nav: 6 | - Home: index.md 7 | 8 | theme: 9 | name: material 10 | custom_dir: docs/overrides 11 | icon: 12 | repo: fontawesome/brands/github 13 | 14 | features: 15 | - content.code.annotate 16 | - navigation.indexes 17 | - navigation.instant 18 | - navigation.sections 19 | - navigation.tabs 20 | - navigation.tabs.sticky 21 | - navigation.top 22 | - navigation.tracking 23 | - search.highlight 24 | - search.share 25 | - search.suggest 26 | 27 | palette: 28 | - scheme: default 29 | primary: light green 30 | accent: green 31 | toggle: 32 | icon: material/weather-night 33 | name: Switch to dark mode 34 | - scheme: slate 35 | primary: green 36 | accent: green 37 | toggle: 38 | icon: material/weather-sunny 39 | name: Switch to light mode 40 | 41 | # Extensions 42 | markdown_extensions: 43 | - abbr 44 | - admonition 45 | - attr_list 46 | - def_list 47 | - footnotes 48 | - meta 49 | - md_in_html 50 | - toc: 51 | permalink: true 52 | - tables 53 | - pymdownx.arithmatex: 54 | generic: true 55 | - pymdownx.betterem: 56 | smart_enable: all 57 | - pymdownx.caret 58 | - pymdownx.mark 59 | - pymdownx.tilde 60 | - pymdownx.critic 61 | - pymdownx.details 62 | - pymdownx.highlight: 63 | anchor_linenums: true 64 | - pymdownx.superfences 65 | - pymdownx.inlinehilite 66 | - pymdownx.keys 67 | - pymdownx.smartsymbols 68 | - pymdownx.snippets 69 | - pymdownx.tabbed: 70 | alternate_style: true 71 | - pymdownx.tasklist: 72 | custom_checkbox: true 73 | 74 | extra_javascript: 75 | - javascripts/mathjax.js 76 | - https://polyfill.io/v3/polyfill.min.js?features=es6 77 | - https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js 78 | 79 | plugins: 80 | - search 81 | 82 | extra: 83 | version: 84 | provider: mike 85 | -------------------------------------------------------------------------------- /{{ cookiecutter.repository_name }}/pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.pytest.ini_options] 2 | minversion = "6.2" 3 | addopts = "-ra" 4 | testpaths = ["tests"] 5 | 6 | [tool.coverage.report] 7 | exclude_lines = [ 8 | "raise NotImplementedError", 9 | "raise NotImplementedError()", 10 | "pragma: nocover", 11 | "if __name__ == .__main__.:", 12 | ] 13 | 14 | [tool.black] 15 | line-length = 120 16 | include = '\.pyi?$' 17 | 18 | [tool.mypy] 19 | files= ["src/**/*.py", "test/*.py"] 20 | ignore_missing_imports = true 21 | 22 | [tool.isort] 23 | profile = 'black' 24 | line_length = 120 25 | known_third_party = ["numpy", "pytest", "wandb", "torch"] 26 | known_first_party = ["nn_core"] 27 | known_local_folder = "{{ cookiecutter.package_name }}" 28 | 29 | [tool.pydocstyle] 30 | convention = 'google' 31 | # ignore all missing docs errors 32 | add-ignore = ['D100', 'D101', 'D102', 'D103', 'D104', 'D105', 'D106', 'D107'] 33 | 34 | [tool.bandit] 35 | skips = ["B101"] 36 | 37 | [tool.setuptools_scm] 38 | write_to = "src/{{ cookiecutter.package_name }}/_version.py" 39 | write_to_template = '__version__ = "{version}"' 40 | 41 | [build-system] 42 | requires = ["setuptools==59.5", "wheel", "setuptools_scm[toml]>=6.3.1"] 43 | build-backend = "setuptools.build_meta" 44 | -------------------------------------------------------------------------------- /{{ cookiecutter.repository_name }}/setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | name = {{ cookiecutter.repository_name }} 3 | description = {{ cookiecutter.project_description }} 4 | url = {{ cookiecutter.repository_url }} 5 | long_description = file: README.md 6 | author = {{ cookiecutter.author }} 7 | author_email = {{ cookiecutter.author_email }} 8 | keywords = python 9 | license = MIT Licence 10 | 11 | [options] 12 | zip_safe = False 13 | include_package_data = True 14 | package_dir= 15 | =src 16 | packages=find: 17 | install_requires = 18 | nn-template-core==0.4.* 19 | anypy==0.0.* 20 | 21 | # Add project specific dependencies 22 | # Stuff easy to break with updates 23 | lightning==2.0.* 24 | torchmetrics==1.0.* 25 | hydra-core==1.3.* 26 | wandb 27 | streamlit 28 | # hydra-joblib-launcher 29 | 30 | # Stable stuff usually backward compatible 31 | rich 32 | dvc 33 | python-dotenv 34 | matplotlib 35 | stqdm 36 | 37 | [options.packages.find] 38 | where=src 39 | 40 | [options.package_data] 41 | * = *.txt, *.md 42 | 43 | [options.extras_require] 44 | docs = 45 | mkdocs 46 | mkdocs-material 47 | mike 48 | 49 | test = 50 | pytest 51 | pytest-cov 52 | 53 | dev = 54 | black 55 | flake8 56 | isort 57 | pre-commit 58 | bandit 59 | %(test)s 60 | %(docs)s 61 | -------------------------------------------------------------------------------- /{{ cookiecutter.repository_name }}/setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import setuptools 4 | 5 | if __name__ == "__main__": 6 | setuptools.setup() 7 | -------------------------------------------------------------------------------- /{{ cookiecutter.repository_name }}/src/{{ cookiecutter.package_name }}/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from nn_core.console_logging import NNRichHandler 4 | 5 | # Required workaround because PyTorch Lightning configures the logging on import, 6 | # thus the logging configuration defined in the __init__.py must be called before 7 | # the lightning import otherwise it has no effect. 8 | # See https://github.com/PyTorchLightning/pytorch-lightning/issues/1503 9 | lightning_logger = logging.getLogger("lightning.pytorch") 10 | # Remove all handlers associated with the lightning logger. 11 | for handler in lightning_logger.handlers[:]: 12 | lightning_logger.removeHandler(handler) 13 | lightning_logger.propagate = True 14 | 15 | FORMAT = "%(message)s" 16 | logging.basicConfig( 17 | format=FORMAT, 18 | level=logging.INFO, 19 | datefmt="%Y-%m-%d %H:%M:%S", 20 | handlers=[ 21 | NNRichHandler( 22 | rich_tracebacks=True, 23 | show_level=True, 24 | show_path=True, 25 | show_time=True, 26 | omit_repeated_times=True, 27 | ) 28 | ], 29 | ) 30 | 31 | try: 32 | from ._version import __version__ as __version__ 33 | except ImportError: 34 | import sys 35 | 36 | print( 37 | "Project not installed in the current env, activate the correct env or install it with:\n\tpip install -e .", 38 | file=sys.stderr, 39 | ) 40 | __version__ = "unknown" 41 | -------------------------------------------------------------------------------- /{{ cookiecutter.repository_name }}/src/{{ cookiecutter.package_name }}/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/grok-ai/nn-template/8ba02bba8f015e1eb7efb0d2ab8c9d433bd1c431/{{ cookiecutter.repository_name }}/src/{{ cookiecutter.package_name }}/data/__init__.py -------------------------------------------------------------------------------- /{{ cookiecutter.repository_name }}/src/{{ cookiecutter.package_name }}/data/datamodule.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from functools import cached_property, partial 3 | from pathlib import Path 4 | from typing import List, Mapping, Optional 5 | 6 | import hydra 7 | import lightning.pytorch as pl 8 | import omegaconf 9 | from omegaconf import DictConfig 10 | from torch.utils.data import DataLoader, Dataset 11 | from torch.utils.data.dataloader import default_collate 12 | from tqdm import tqdm 13 | 14 | from nn_core.common import PROJECT_ROOT 15 | from nn_core.nn_types import Split 16 | 17 | pylogger = logging.getLogger(__name__) 18 | 19 | 20 | class MetaData: 21 | def __init__(self, class_vocab: Mapping[str, int]): 22 | """The data information the Lightning Module will be provided with. 23 | 24 | This is a "bridge" between the Lightning DataModule and the Lightning Module. 25 | There is no constraint on the class name nor in the stored information, as long as it exposes the 26 | `save` and `load` methods. 27 | 28 | The Lightning Module will receive an instance of MetaData when instantiated, 29 | both in the train loop or when restored from a checkpoint. 30 | 31 | This decoupling allows the architecture to be parametric (e.g. in the number of classes) and 32 | DataModule/Trainer independent (useful in prediction scenarios). 33 | MetaData should contain all the information needed at test time, derived from its train dataset. 34 | 35 | Examples are the class names in a classification task or the vocabulary in NLP tasks. 36 | MetaData exposes `save` and `load`. Those are two user-defined methods that specify 37 | how to serialize and de-serialize the information contained in its attributes. 38 | This is needed for the checkpointing restore to work properly. 39 | 40 | Args: 41 | class_vocab: association between class names and their indices 42 | """ 43 | # example 44 | self.class_vocab: Mapping[str, int] = class_vocab 45 | 46 | def save(self, dst_path: Path) -> None: 47 | """Serialize the MetaData attributes into the zipped checkpoint in dst_path. 48 | 49 | Args: 50 | dst_path: the root folder of the metadata inside the zipped checkpoint 51 | """ 52 | pylogger.debug(f"Saving MetaData to '{dst_path}'") 53 | 54 | # example 55 | (dst_path / "class_vocab.tsv").write_text( 56 | "\n".join(f"{key}\t{value}" for key, value in self.class_vocab.items()) 57 | ) 58 | 59 | @staticmethod 60 | def load(src_path: Path) -> "MetaData": 61 | """Deserialize the MetaData from the information contained inside the zipped checkpoint in src_path. 62 | 63 | Args: 64 | src_path: the root folder of the metadata inside the zipped checkpoint 65 | 66 | Returns: 67 | an instance of MetaData containing the information in the checkpoint 68 | """ 69 | pylogger.debug(f"Loading MetaData from '{src_path}'") 70 | 71 | # example 72 | lines = (src_path / "class_vocab.tsv").read_text(encoding="utf-8").splitlines() 73 | 74 | class_vocab = {} 75 | for line in lines: 76 | key, value = line.strip().split("\t") 77 | class_vocab[key] = value 78 | 79 | return MetaData( 80 | class_vocab=class_vocab, 81 | ) 82 | 83 | def __repr__(self) -> str: 84 | attributes = ",\n ".join([f"{key}={value}" for key, value in self.__dict__.items()]) 85 | return f"{self.__class__.__name__}(\n {attributes}\n)" 86 | 87 | 88 | def collate_fn(samples: List, split: Split, metadata: MetaData): 89 | """Custom collate function for dataloaders with access to split and metadata. 90 | 91 | Args: 92 | samples: A list of samples coming from the Dataset to be merged into a batch 93 | split: The data split (e.g. train/val/test) 94 | metadata: The MetaData instance coming from the DataModule or the restored checkpoint 95 | 96 | Returns: 97 | A batch generated from the given samples 98 | """ 99 | return default_collate(samples) 100 | 101 | 102 | class MyDataModule(pl.LightningDataModule): 103 | def __init__( 104 | self, 105 | dataset: DictConfig, 106 | num_workers: DictConfig, 107 | batch_size: DictConfig, 108 | accelerator: str, 109 | # example 110 | val_images_fixed_idxs: List[int], 111 | ): 112 | super().__init__() 113 | self.dataset = dataset 114 | self.num_workers = num_workers 115 | self.batch_size = batch_size 116 | # https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html#gpus 117 | self.pin_memory: bool = accelerator is not None and str(accelerator) == "gpu" 118 | 119 | self.train_dataset: Optional[Dataset] = None 120 | self.val_dataset: Optional[Dataset] = None 121 | self.test_dataset: Optional[Dataset] = None 122 | 123 | # example 124 | self.val_images_fixed_idxs: List[int] = val_images_fixed_idxs 125 | 126 | @cached_property 127 | def metadata(self) -> MetaData: 128 | """Data information to be fed to the Lightning Module as parameter. 129 | 130 | Examples are vocabularies, number of classes... 131 | 132 | Returns: 133 | metadata: everything the model should know about the data, wrapped in a MetaData object. 134 | """ 135 | # Since MetaData depends on the training data, we need to ensure the setup method has been called. 136 | if self.train_dataset is None: 137 | self.setup(stage="fit") 138 | 139 | return MetaData(class_vocab={i: name for i, name in enumerate(self.train_dataset.features["y"].names)}) 140 | 141 | def prepare_data(self) -> None: 142 | # download only 143 | pass 144 | 145 | def setup(self, stage: Optional[str] = None): 146 | self.transform = hydra.utils.instantiate(self.dataset.transforms) 147 | 148 | self.hf_datasets = hydra.utils.instantiate(self.dataset) 149 | self.hf_datasets.set_transform(self.transform) 150 | 151 | # Here you should instantiate your dataset, you may also split the train into train and validation if needed. 152 | if (stage is None or stage == "fit") and (self.train_dataset is None and self.val_dataset is None): 153 | self.train_dataset = self.hf_datasets["train"] 154 | self.val_dataset = self.hf_datasets["val"] 155 | 156 | if stage is None or stage == "test": 157 | self.test_dataset = self.hf_datasets["test"] 158 | 159 | def train_dataloader(self) -> DataLoader: 160 | return DataLoader( 161 | self.train_dataset, 162 | shuffle=True, 163 | batch_size=self.batch_size.train, 164 | num_workers=self.num_workers.train, 165 | pin_memory=self.pin_memory, 166 | collate_fn=partial(collate_fn, split="train", metadata=self.metadata), 167 | ) 168 | 169 | def val_dataloader(self) -> DataLoader: 170 | return DataLoader( 171 | self.val_dataset, 172 | shuffle=False, 173 | batch_size=self.batch_size.val, 174 | num_workers=self.num_workers.val, 175 | pin_memory=self.pin_memory, 176 | collate_fn=partial(collate_fn, split="val", metadata=self.metadata), 177 | ) 178 | 179 | def test_dataloader(self) -> DataLoader: 180 | return DataLoader( 181 | self.test_dataset, 182 | shuffle=False, 183 | batch_size=self.batch_size.test, 184 | num_workers=self.num_workers.test, 185 | pin_memory=self.pin_memory, 186 | collate_fn=partial(collate_fn, split="test", metadata=self.metadata), 187 | ) 188 | 189 | def __repr__(self) -> str: 190 | return f"{self.__class__.__name__}(" f"{self.dataset=}, " f"{self.num_workers=}, " f"{self.batch_size=})" 191 | 192 | 193 | @hydra.main(config_path=str(PROJECT_ROOT / "conf"), config_name="default") 194 | def main(cfg: omegaconf.DictConfig) -> None: 195 | """Debug main to quickly develop the DataModule. 196 | 197 | Args: 198 | cfg: the hydra configuration 199 | """ 200 | m: pl.LightningDataModule = hydra.utils.instantiate(cfg.nn.data, _recursive_=False) 201 | m.metadata 202 | m.setup() 203 | 204 | for _ in tqdm(m.train_dataloader()): 205 | pass 206 | 207 | 208 | if __name__ == "__main__": 209 | main() 210 | -------------------------------------------------------------------------------- /{{ cookiecutter.repository_name }}/src/{{ cookiecutter.package_name }}/data/dataset.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | import omegaconf 3 | from torch.utils.data import Dataset 4 | 5 | from nn_core.common import PROJECT_ROOT 6 | 7 | 8 | @hydra.main(config_path=str(PROJECT_ROOT / "conf"), config_name="default") 9 | def main(cfg: omegaconf.DictConfig) -> None: 10 | """Debug main to quickly develop the Dataset. 11 | 12 | Args: 13 | cfg: the hydra configuration 14 | """ 15 | _: Dataset = hydra.utils.instantiate(cfg.nn.data.dataset, _recursive_=False) 16 | 17 | 18 | if __name__ == "__main__": 19 | main() 20 | -------------------------------------------------------------------------------- /{{ cookiecutter.repository_name }}/src/{{ cookiecutter.package_name }}/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/grok-ai/nn-template/8ba02bba8f015e1eb7efb0d2ab8c9d433bd1c431/{{ cookiecutter.repository_name }}/src/{{ cookiecutter.package_name }}/modules/__init__.py -------------------------------------------------------------------------------- /{{ cookiecutter.repository_name }}/src/{{ cookiecutter.package_name }}/modules/module.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | from torch import nn 4 | 5 | 6 | # https://medium.com/@nutanbhogendrasharma/pytorch-convolutional-neural-network-with-mnist-dataset-4e8a4265e118 7 | class CNN(nn.Module): 8 | def __init__(self, input_shape: Tuple[int], num_classes: int): 9 | super(CNN, self).__init__() 10 | self.model = nn.Sequential( 11 | nn.Conv2d( 12 | in_channels=input_shape[0], 13 | out_channels=16, 14 | kernel_size=5, 15 | stride=1, 16 | padding=2, 17 | ), 18 | nn.SiLU(), 19 | nn.MaxPool2d(kernel_size=2), 20 | nn.Conv2d(16, 32, 5, 1, 2), 21 | nn.SiLU(), 22 | nn.MaxPool2d(2), 23 | ) 24 | self.conv2 = nn.Sequential() 25 | self.out = nn.Linear(32 * 7 * 7, num_classes) 26 | 27 | def forward(self, x): 28 | x = self.model(x) 29 | # [batch_size, 32 * 7 * 7] 30 | x = x.view(x.size(0), -1) 31 | output = self.out(x) 32 | return output 33 | -------------------------------------------------------------------------------- /{{ cookiecutter.repository_name }}/src/{{ cookiecutter.package_name }}/pl_modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/grok-ai/nn-template/8ba02bba8f015e1eb7efb0d2ab8c9d433bd1c431/{{ cookiecutter.repository_name }}/src/{{ cookiecutter.package_name }}/pl_modules/__init__.py -------------------------------------------------------------------------------- /{{ cookiecutter.repository_name }}/src/{{ cookiecutter.package_name }}/pl_modules/pl_module.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Any, Dict, Mapping, Optional, Sequence, Tuple, Union 3 | 4 | import hydra 5 | import lightning.pytorch as pl 6 | import omegaconf 7 | import torch 8 | import torch.nn.functional as F 9 | import torchmetrics 10 | from torch.optim import Optimizer 11 | 12 | from nn_core.common import PROJECT_ROOT 13 | from nn_core.model_logging import NNLogger 14 | 15 | from {{ cookiecutter.package_name }}.data.datamodule import MetaData 16 | 17 | pylogger = logging.getLogger(__name__) 18 | 19 | 20 | class MyLightningModule(pl.LightningModule): 21 | logger: NNLogger 22 | 23 | def __init__(self, model, metadata: Optional[MetaData] = None, *args, **kwargs) -> None: 24 | super().__init__() 25 | 26 | # Populate self.hparams with args and kwargs automagically! 27 | # We want to skip metadata since it is saved separately by the NNCheckpointIO object. 28 | # Be careful when modifying this instruction. If in doubt, don't do it :] 29 | self.save_hyperparameters(logger=False, ignore=("metadata",)) 30 | 31 | self.metadata = metadata 32 | 33 | # example 34 | metric = torchmetrics.Accuracy( 35 | task="multiclass", 36 | num_classes=len(metadata.class_vocab) if metadata is not None else None, 37 | ) 38 | self.train_acc = metric.clone() 39 | self.val_acc = metric.clone() 40 | self.test_acc = metric.clone() 41 | 42 | self.model = hydra.utils.instantiate(model, num_classes=len(metadata.class_vocab)) 43 | 44 | def forward(self, x: torch.Tensor) -> torch.Tensor: 45 | """Method for the forward pass. 46 | 47 | 'training_step', 'validation_step' and 'test_step' should call 48 | this method in order to compute the output predictions and the loss. 49 | 50 | Returns: 51 | output_dict: forward output containing the predictions (output logits ecc...) and the loss if any. 52 | """ 53 | # example 54 | return self.model(x) 55 | 56 | def _step(self, batch: Dict[str, torch.Tensor], split: str) -> Mapping[str, Any]: 57 | x = batch[self.hparams.x_key] 58 | gt_y = batch[self.hparams.y_key] 59 | 60 | # example 61 | logits = self(x) 62 | loss = F.cross_entropy(logits, gt_y) 63 | preds = torch.softmax(logits, dim=-1) 64 | 65 | metrics = getattr(self, f"{split}_acc") 66 | metrics.update(preds, gt_y) 67 | 68 | self.log_dict( 69 | { 70 | f"acc/{split}": metrics, 71 | f"loss/{split}": loss, 72 | }, 73 | on_epoch=True, 74 | ) 75 | 76 | return {"logits": logits.detach(), "loss": loss} 77 | 78 | def training_step(self, batch: Any, batch_idx: int) -> Mapping[str, Any]: 79 | return self._step(batch=batch, split="train") 80 | 81 | def validation_step(self, batch: Any, batch_idx: int) -> Mapping[str, Any]: 82 | return self._step(batch=batch, split="val") 83 | 84 | def test_step(self, batch: Any, batch_idx: int) -> Mapping[str, Any]: 85 | return self._step(batch=batch, split="test") 86 | 87 | def configure_optimizers( 88 | self, 89 | ) -> Union[Optimizer, Tuple[Sequence[Optimizer], Sequence[Any]]]: 90 | """Choose what optimizers and learning-rate schedulers to use in your optimization. 91 | 92 | Normally you'd need one. But in the case of GANs or similar you might have multiple. 93 | 94 | Return: 95 | Any of these 6 options. 96 | - Single optimizer. 97 | - List or Tuple - List of optimizers. 98 | - Two lists - The first list has multiple optimizers, the second a list of LR schedulers (or lr_dict). 99 | - Dictionary, with an 'optimizer' key, and (optionally) a 'lr_scheduler' 100 | key whose value is a single LR scheduler or lr_dict. 101 | - Tuple of dictionaries as described, with an optional 'frequency' key. 102 | - None - Fit will run without any optimizer. 103 | """ 104 | opt = hydra.utils.instantiate(self.hparams.optimizer, params=self.parameters(), _convert_="partial") 105 | if "lr_scheduler" not in self.hparams: 106 | return [opt] 107 | scheduler = hydra.utils.instantiate(self.hparams.lr_scheduler, optimizer=opt) 108 | return [opt], [scheduler] 109 | 110 | 111 | @hydra.main(config_path=str(PROJECT_ROOT / "conf"), config_name="default") 112 | def main(cfg: omegaconf.DictConfig) -> None: 113 | """Debug main to quickly develop the Lightning Module. 114 | 115 | Args: 116 | cfg: the hydra configuration 117 | """ 118 | m: pl.LightningDataModule = hydra.utils.instantiate(cfg.nn.data, _recursive_=False) 119 | _: pl.LightningModule = hydra.utils.instantiate(cfg.nn.module, _recursive_=False, metadata=m.metadata) 120 | 121 | 122 | if __name__ == "__main__": 123 | main() 124 | -------------------------------------------------------------------------------- /{{ cookiecutter.repository_name }}/src/{{ cookiecutter.package_name }}/run.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import List, Optional 3 | 4 | import hydra 5 | import lightning.pytorch as pl 6 | import omegaconf 7 | import torch 8 | from lightning.pytorch import Callback 9 | from omegaconf import DictConfig, ListConfig 10 | 11 | from nn_core.callbacks import NNTemplateCore 12 | from nn_core.common import PROJECT_ROOT 13 | from nn_core.common.utils import enforce_tags, seed_index_everything 14 | from nn_core.model_logging import NNLogger 15 | from nn_core.serialization import NNCheckpointIO 16 | 17 | # Force the execution of __init__.py if this file is executed directly. 18 | import {{ cookiecutter.package_name }} # noqa 19 | from {{ cookiecutter.package_name }}.data.datamodule import MetaData 20 | 21 | pylogger = logging.getLogger(__name__) 22 | 23 | torch.set_float32_matmul_precision("high") 24 | 25 | 26 | def build_callbacks(cfg: ListConfig, *args: Callback) -> List[Callback]: 27 | """Instantiate the callbacks given their configuration. 28 | 29 | Args: 30 | cfg: a list of callbacks instantiable configuration 31 | *args: a list of extra callbacks already instantiated 32 | 33 | Returns: 34 | the complete list of callbacks to use 35 | """ 36 | callbacks: List[Callback] = list(args) 37 | 38 | for callback in cfg: 39 | pylogger.info(f"Adding callback <{callback['_target_'].split('.')[-1]}>") 40 | callbacks.append(hydra.utils.instantiate(callback, _recursive_=False)) 41 | 42 | return callbacks 43 | 44 | 45 | def run(cfg: DictConfig) -> str: 46 | """Generic train loop. 47 | 48 | Args: 49 | cfg: run configuration, defined by Hydra in /conf 50 | 51 | Returns: 52 | the run directory inside the storage_dir used by the current experiment 53 | """ 54 | seed_index_everything(cfg.train) 55 | 56 | fast_dev_run: bool = cfg.train.trainer.fast_dev_run 57 | if fast_dev_run: 58 | pylogger.info(f"Debug mode <{cfg.train.trainer.fast_dev_run=}>. Forcing debugger friendly configuration!") 59 | # Debuggers don't like GPUs nor multiprocessing 60 | cfg.train.trainer.accelerator = "cpu" 61 | cfg.nn.data.num_workers.train = 0 62 | cfg.nn.data.num_workers.val = 0 63 | cfg.nn.data.num_workers.test = 0 64 | 65 | cfg.core.tags = enforce_tags(cfg.core.get("tags", None)) 66 | 67 | # Instantiate datamodule 68 | pylogger.info(f"Instantiating <{cfg.nn.data['_target_']}>") 69 | datamodule: pl.LightningDataModule = hydra.utils.instantiate(cfg.nn.data, _recursive_=False) 70 | datamodule.setup(stage=None) 71 | 72 | metadata: Optional[MetaData] = getattr(datamodule, "metadata", None) 73 | if metadata is None: 74 | pylogger.warning(f"No 'metadata' attribute found in datamodule <{datamodule.__class__.__name__}>") 75 | 76 | # Instantiate model 77 | pylogger.info(f"Instantiating <{cfg.nn.module['_target_']}>") 78 | model: pl.LightningModule = hydra.utils.instantiate(cfg.nn.module, _recursive_=False, metadata=metadata) 79 | 80 | # Instantiate the callbacks 81 | template_core: NNTemplateCore = NNTemplateCore( 82 | restore_cfg=cfg.train.get("restore", None), 83 | ) 84 | callbacks: List[Callback] = build_callbacks(cfg.train.callbacks, template_core) 85 | 86 | storage_dir: str = cfg.core.storage_dir 87 | 88 | logger: NNLogger = NNLogger(logging_cfg=cfg.train.logging, cfg=cfg, resume_id=template_core.resume_id) 89 | 90 | pylogger.info("Instantiating the ") 91 | trainer = pl.Trainer( 92 | default_root_dir=storage_dir, 93 | plugins=[NNCheckpointIO(jailing_dir=logger.run_dir)], 94 | logger=logger, 95 | callbacks=callbacks, 96 | **cfg.train.trainer, 97 | ) 98 | 99 | pylogger.info("Starting training!") 100 | trainer.fit(model=model, datamodule=datamodule, ckpt_path=template_core.trainer_ckpt_path) 101 | 102 | if fast_dev_run: 103 | pylogger.info("Skipping testing in 'fast_dev_run' mode!") 104 | else: 105 | if datamodule.test_dataset is not None and trainer.checkpoint_callback.best_model_path is not None: 106 | pylogger.info("Starting testing!") 107 | trainer.test(datamodule=datamodule) 108 | 109 | # Logger closing to release resources/avoid multi-run conflicts 110 | if logger is not None: 111 | logger.experiment.finish() 112 | 113 | return logger.run_dir 114 | 115 | 116 | @hydra.main(config_path=str(PROJECT_ROOT / "conf"), config_name="default") 117 | def main(cfg: omegaconf.DictConfig): 118 | run(cfg) 119 | 120 | 121 | if __name__ == "__main__": 122 | main() 123 | -------------------------------------------------------------------------------- /{{ cookiecutter.repository_name }}/src/{{ cookiecutter.package_name }}/ui/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/grok-ai/nn-template/8ba02bba8f015e1eb7efb0d2ab8c9d433bd1c431/{{ cookiecutter.repository_name }}/src/{{ cookiecutter.package_name }}/ui/__init__.py -------------------------------------------------------------------------------- /{{ cookiecutter.repository_name }}/src/{{ cookiecutter.package_name }}/ui/run.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import streamlit as st 4 | import wandb 5 | 6 | from nn_core.serialization import load_model 7 | from nn_core.ui import select_checkpoint 8 | 9 | from {{ cookiecutter.package_name }}.pl_modules.pl_module import MyLightningModule 10 | 11 | 12 | @st.cache(allow_output_mutation=True) 13 | def get_model(checkpoint_path: Path): 14 | return load_model(module_class=MyLightningModule, checkpoint_path=checkpoint_path) 15 | 16 | 17 | if wandb.api.api_key is None: 18 | st.error("You are not logged in on `Weights and Biases`: https://docs.wandb.ai/ref/cli/wandb-login") 19 | st.stop() 20 | 21 | st.sidebar.subheader(f"Logged in W&B as: {wandb.api.viewer()['entity']}") 22 | 23 | checkpoint_path = select_checkpoint() 24 | model: MyLightningModule = get_model(checkpoint_path=checkpoint_path) 25 | model 26 | -------------------------------------------------------------------------------- /{{ cookiecutter.repository_name }}/src/{{ cookiecutter.package_name }}/utils/hf_io.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | from pathlib import Path 3 | from typing import Any, Callable, Dict, Sequence 4 | 5 | import torch 6 | from anypy.data.metadata_dataset_dict import MetadataDatasetDict 7 | from datasets import Dataset, DatasetDict, load_dataset, load_from_disk 8 | from omegaconf import DictConfig 9 | 10 | from nn_core.common import PROJECT_ROOT 11 | 12 | DatasetParams = namedtuple("DatasetParams", ["name", "fine_grained", "train_split", "test_split", "hf_key"]) 13 | 14 | 15 | class HFTransform: 16 | def __init__( 17 | self, 18 | key: str, 19 | transform: Callable[[torch.Tensor], torch.Tensor], 20 | ): 21 | """Apply a row-wise transform to a dataset column. 22 | 23 | Args: 24 | key (str): The key of the column to transform. 25 | transform (Callable[[torch.Tensor], torch.Tensor]): The transform to apply. 26 | """ 27 | self.transform = transform 28 | self.key = key 29 | 30 | def __call__(self, samples: Dict[str, Sequence[Any]]) -> Dict[str, Sequence[Any]]: 31 | """Apply the transform to the samples. 32 | 33 | Args: 34 | samples (Dict[str, Sequence[Any]]): The samples to transform. 35 | 36 | Returns: 37 | Dict[str, Sequence[Any]]: The transformed samples. 38 | """ 39 | samples[self.key] = [self.transform(data) for data in samples[self.key]] 40 | return samples 41 | 42 | def __repr__(self) -> str: 43 | return repr(self.transform) 44 | 45 | 46 | def preprocess_dataset( 47 | dataset: Dataset, 48 | cfg: Dict, 49 | ) -> Dataset: 50 | """Preprocess a dataset. 51 | 52 | This function applies the following preprocessing steps: 53 | - Rename the label column to the standard key. 54 | - Rename the data column to the standard key. 55 | 56 | Do not apply transforms here, as the preprocessed dataset will be saved to disk once 57 | and then resued; thus updates on the transforms will not be reflected in the dataset. 58 | 59 | Args: 60 | dataset (Dataset): The dataset to preprocess. 61 | cfg (Dict): The configuration. 62 | 63 | Returns: 64 | Dataset: The preprocessed dataset. 65 | """ 66 | dataset = dataset.rename_column(cfg["label_key"], cfg["standard_y_key"]) 67 | dataset = dataset.rename_column(cfg["data_key"], cfg["standard_x_key"]) 68 | return dataset 69 | 70 | 71 | def save_dataset_to_disk(dataset: MetadataDatasetDict, output_path: Path) -> None: 72 | """Save a dataset to disk. 73 | 74 | Args: 75 | dataset (MetadataDatasetDict): The dataset to save. 76 | output_path (Path): The path to save the dataset to. 77 | """ 78 | if not isinstance(output_path, Path): 79 | output_path = Path(output_path) 80 | 81 | output_path.mkdir(parents=True, exist_ok=True) 82 | 83 | dataset.save_to_disk(output_path) 84 | 85 | 86 | def load_hf_dataset(**cfg: DictConfig) -> MetadataDatasetDict: 87 | """Load a dataset from the HuggingFace datasets library. 88 | 89 | The returned dataset is a MetadataDatasetDict, which is a wrapper around a DatasetDict. 90 | It will contain the following splits: 91 | - train 92 | - val 93 | - test 94 | If `val_split` is not specified in the config, it will be created from the train split 95 | according to the `val_percentage` specified in the config. 96 | 97 | The returned dataset will be preprocessed and saved to disk, 98 | if it does not exist yet, and loaded from disk otherwise. 99 | 100 | Args: 101 | cfg: The configuration. 102 | 103 | Returns: 104 | Dataset: The loaded dataset. 105 | """ 106 | dataset_params: DatasetParams = DatasetParams( 107 | cfg["ref"], 108 | None, 109 | cfg["train_split"], 110 | cfg["test_split"], 111 | (cfg["ref"],), 112 | ) 113 | DATASET_KEY = "_".join( 114 | map( 115 | str, 116 | [v for k, v in dataset_params._asdict().items() if k != "hf_key" and v is not None], 117 | ) 118 | ) 119 | DATASET_DIR: Path = PROJECT_ROOT / "data" / "datasets" / DATASET_KEY 120 | 121 | if not DATASET_DIR.exists(): 122 | train_dataset = load_dataset( 123 | dataset_params.name, 124 | split=dataset_params.train_split, 125 | token=True, 126 | ) 127 | if "val_percentage" in cfg: 128 | train_val_dataset = train_dataset.train_test_split(test_size=cfg["val_percentage"], shuffle=True) 129 | train_dataset = train_val_dataset["train"] 130 | val_dataset = train_val_dataset["test"] 131 | elif "val_split" in cfg: 132 | val_dataset = load_dataset( 133 | dataset_params.name, 134 | split=cfg["val_split"], 135 | token=True, 136 | ) 137 | else: 138 | raise RuntimeError("Either val_percentage or val_split must be specified in the config.") 139 | 140 | test_dataset = load_dataset( 141 | dataset_params.name, 142 | split=dataset_params.test_split, 143 | token=True, 144 | ) 145 | 146 | dataset: DatasetDict = MetadataDatasetDict( 147 | train=train_dataset, 148 | val=val_dataset, 149 | test=test_dataset, 150 | ) 151 | 152 | dataset = preprocess_dataset(dataset, cfg) 153 | 154 | save_dataset_to_disk(dataset, DATASET_DIR) 155 | else: 156 | dataset: Dataset = load_from_disk(dataset_path=str(DATASET_DIR)) 157 | 158 | return dataset 159 | -------------------------------------------------------------------------------- /{{ cookiecutter.repository_name }}/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/grok-ai/nn-template/8ba02bba8f015e1eb7efb0d2ab8c9d433bd1c431/{{ cookiecutter.repository_name }}/tests/__init__.py -------------------------------------------------------------------------------- /{{ cookiecutter.repository_name }}/tests/conftest.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import shutil 4 | from pathlib import Path 5 | from typing import Dict, Union 6 | 7 | import pytest 8 | from hydra import compose, initialize 9 | from hydra.core.hydra_config import HydraConfig 10 | from lightning.pytorch import seed_everything 11 | from omegaconf import DictConfig, OmegaConf, open_dict 12 | from pytest import FixtureRequest, TempPathFactory 13 | 14 | from nn_core.serialization import NNCheckpointIO 15 | 16 | from {{ cookiecutter.package_name }}.run import run 17 | 18 | logging.basicConfig(force=True, level=logging.DEBUG) 19 | 20 | seed_everything(42) 21 | 22 | TRAIN_MAX_NSTEPS = 1 23 | 24 | 25 | # 26 | # Base configurations 27 | # 28 | @pytest.fixture(scope="package") 29 | def cfg(tmp_path_factory: TempPathFactory) -> DictConfig: 30 | test_cfg_tmpdir = tmp_path_factory.mktemp("test_train_tmpdir") 31 | 32 | with initialize(config_path="../conf"): 33 | cfg = compose(config_name="default", return_hydra_config=True) 34 | HydraConfig().set_config(cfg) 35 | 36 | # Force the wandb dir to be in the temp folder 37 | os.environ["WANDB_DIR"] = str(test_cfg_tmpdir) 38 | 39 | # Force the storage dir to be in the temp folder 40 | cfg.core.storage_dir = str(test_cfg_tmpdir) 41 | 42 | yield cfg 43 | 44 | shutil.rmtree(test_cfg_tmpdir) 45 | 46 | 47 | # 48 | # Training configurations 49 | # 50 | @pytest.fixture(scope="package") 51 | def cfg_simple_train(cfg: DictConfig) -> DictConfig: 52 | cfg = OmegaConf.create(cfg) 53 | 54 | # Add test tag 55 | cfg.core.tags = ["testing"] 56 | 57 | # Disable gpus 58 | cfg.train.trainer.accelerator = "cpu" 59 | 60 | # Disable logger 61 | cfg.train.logging.logger.mode = "disabled" 62 | 63 | # Disable files upload because wandb in offline modes uses always /tmp 64 | # as run.dir, which causes conflicts between multiple trainings 65 | cfg.train.logging.upload.run_files = False 66 | 67 | # Disable multiple workers in test training 68 | cfg.nn.data.num_workers.train = 0 69 | cfg.nn.data.num_workers.val = 0 70 | cfg.nn.data.num_workers.test = 0 71 | 72 | # Minimize the amount of work in test training 73 | cfg.train.trainer.max_steps = TRAIN_MAX_NSTEPS 74 | cfg.train.trainer.val_check_interval = TRAIN_MAX_NSTEPS 75 | 76 | # Ensure the resuming is disabled 77 | with open_dict(config=cfg): 78 | cfg.train.restore = {} 79 | cfg.train.restore.ckpt_or_run_path = None 80 | cfg.train.restore.mode = None 81 | 82 | return cfg 83 | 84 | 85 | @pytest.fixture(scope="package") 86 | def cfg_fast_dev_run(cfg_simple_train: DictConfig) -> DictConfig: 87 | cfg_simple_train = OmegaConf.create(cfg_simple_train) 88 | 89 | # Enable the fast_dev_run flag 90 | cfg_simple_train.train.trainer.fast_dev_run = True 91 | return cfg_simple_train 92 | 93 | 94 | # 95 | # Training configurations aggregations 96 | # 97 | @pytest.fixture( 98 | scope="package", 99 | params=[ 100 | "cfg_simple_train", 101 | ], 102 | ) 103 | def cfg_all_not_dry(request: FixtureRequest): 104 | return request.getfixturevalue(request.param) 105 | 106 | 107 | @pytest.fixture( 108 | scope="package", 109 | params=[ 110 | "cfg_simple_train", 111 | "cfg_fast_dev_run", 112 | ], 113 | ) 114 | def cfg_all(request: FixtureRequest): 115 | return request.getfixturevalue(request.param) 116 | 117 | 118 | # 119 | # Training fixtures 120 | # 121 | @pytest.fixture( 122 | scope="package", 123 | ) 124 | def run_trainings_not_dry(cfg_all_not_dry: DictConfig) -> str: 125 | yield run(cfg=cfg_all_not_dry) 126 | 127 | 128 | @pytest.fixture( 129 | scope="package", 130 | ) 131 | def run_trainings(cfg_all: DictConfig) -> str: 132 | yield run(cfg=cfg_all) 133 | 134 | 135 | # 136 | # Utility functions 137 | # 138 | def get_checkpoint_path(storagedir: Union[str, Path]) -> Path: 139 | ckpts_path = Path(storagedir) / "checkpoints" 140 | checkpoint_path = next(ckpts_path.glob("*")) 141 | assert checkpoint_path 142 | return checkpoint_path 143 | 144 | 145 | def load_checkpoint(storagedir: Union[str, Path]) -> Dict: 146 | checkpoint = NNCheckpointIO.load(path=get_checkpoint_path(storagedir)) 147 | assert checkpoint 148 | return checkpoint 149 | -------------------------------------------------------------------------------- /{{ cookiecutter.repository_name }}/tests/test_checkpoint.py: -------------------------------------------------------------------------------- 1 | from importlib import import_module 2 | from pathlib import Path 3 | from typing import Any, Dict 4 | 5 | from lightning.pytorch import LightningModule 6 | from lightning.pytorch.core.saving import _load_state 7 | from omegaconf import DictConfig, OmegaConf 8 | 9 | from nn_core.serialization import NNCheckpointIO 10 | from tests.conftest import load_checkpoint 11 | 12 | from {{ cookiecutter.package_name }}.pl_modules.pl_module import MyLightningModule 13 | from {{ cookiecutter.package_name }}.run import run 14 | 15 | 16 | def test_load_checkpoint(run_trainings_not_dry: str, cfg_all_not_dry: DictConfig) -> None: 17 | ckpts_path = Path(run_trainings_not_dry) / "checkpoints" 18 | checkpoint_path = next(ckpts_path.glob("*")) 19 | assert checkpoint_path 20 | 21 | reference: str = cfg_all_not_dry.nn.module._target_ 22 | module_ref, class_ref = reference.rsplit(".", maxsplit=1) 23 | module_class: LightningModule = getattr(import_module(module_ref), class_ref) 24 | assert module_class is not None 25 | 26 | checkpoint = NNCheckpointIO.load(path=checkpoint_path) 27 | 28 | module = _load_state(cls=module_class, checkpoint=checkpoint, metadata=checkpoint["metadata"], strict=True) 29 | assert module is not None 30 | assert sum(p.numel() for p in module.parameters()) 31 | 32 | 33 | def _check_cfg_in_checkpoint(checkpoint: Dict, _cfg: DictConfig) -> Dict: 34 | assert "cfg" in checkpoint 35 | assert checkpoint["cfg"] == _cfg 36 | 37 | 38 | def _check_run_path_in_checkpoint(checkpoint: Dict) -> Dict: 39 | assert "run_path" in checkpoint 40 | assert checkpoint["run_path"] 41 | checkpoint["run_path"]: str 42 | assert checkpoint["run_path"].startswith("//") 43 | 44 | 45 | def test_cfg_in_checkpoint(run_trainings_not_dry: str, cfg_all_not_dry: DictConfig) -> None: 46 | checkpoint = load_checkpoint(run_trainings_not_dry) 47 | 48 | _check_cfg_in_checkpoint(checkpoint, cfg_all_not_dry) 49 | _check_run_path_in_checkpoint(checkpoint) 50 | 51 | 52 | class ModuleWithCustomCheckpoint(MyLightningModule): 53 | def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: 54 | checkpoint["test_key"] = "test_value" 55 | 56 | 57 | def test_on_save_checkpoint_hook(cfg_all_not_dry: DictConfig) -> None: 58 | cfg = OmegaConf.create(cfg_all_not_dry) 59 | cfg.nn.module._target_ = "tests.test_checkpoint.ModuleWithCustomCheckpoint" 60 | output_path = Path(run(cfg)) 61 | 62 | checkpoint = load_checkpoint(output_path) 63 | 64 | _check_cfg_in_checkpoint(checkpoint, cfg) 65 | _check_run_path_in_checkpoint(checkpoint) 66 | 67 | assert "test_key" in checkpoint 68 | assert checkpoint["test_key"] == "test_value" 69 | -------------------------------------------------------------------------------- /{{ cookiecutter.repository_name }}/tests/test_configuration.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | from omegaconf import DictConfig 3 | 4 | from {{ cookiecutter.package_name }}.run import build_callbacks 5 | 6 | 7 | def test_configuration_parsing(cfg: DictConfig) -> None: 8 | assert cfg is not None 9 | 10 | 11 | def test_callbacks_instantiation(cfg: DictConfig) -> None: 12 | build_callbacks(cfg.train.callbacks) 13 | 14 | 15 | def test_model_instantiation(cfg: DictConfig) -> None: 16 | datamodule = hydra.utils.instantiate(cfg.nn.data, _recursive_=False) 17 | hydra.utils.instantiate(cfg.nn.module, metadata=datamodule.metadata, _recursive_=False) 18 | 19 | 20 | def test_cfg_parametrization(cfg_all: DictConfig): 21 | assert cfg_all 22 | -------------------------------------------------------------------------------- /{{ cookiecutter.repository_name }}/tests/test_nn_core_integration.py: -------------------------------------------------------------------------------- 1 | from nn_core.common import PROJECT_ROOT 2 | 3 | 4 | def test_project_root() -> None: 5 | assert PROJECT_ROOT 6 | assert (PROJECT_ROOT / "conf").exists() 7 | -------------------------------------------------------------------------------- /{{ cookiecutter.repository_name }}/tests/test_resume.py: -------------------------------------------------------------------------------- 1 | from omegaconf import DictConfig, OmegaConf 2 | from pytest import TempPathFactory 3 | 4 | from nn_core.serialization import NNCheckpointIO 5 | from tests.conftest import TRAIN_MAX_NSTEPS, get_checkpoint_path, load_checkpoint 6 | 7 | from {{ cookiecutter.package_name }}.run import run 8 | 9 | 10 | def test_resume(run_trainings_not_dry: str, cfg_all_not_dry: DictConfig, tmp_path_factory: TempPathFactory) -> None: 11 | old_checkpoint_path = get_checkpoint_path(run_trainings_not_dry) 12 | 13 | new_cfg = OmegaConf.create(cfg_all_not_dry) 14 | new_storage_dir = tmp_path_factory.mktemp("resumed_training") 15 | 16 | new_cfg.core.storage_dir = str(new_storage_dir) 17 | new_cfg.train.trainer.max_steps = 2 * TRAIN_MAX_NSTEPS 18 | 19 | new_cfg.train.restore.ckpt_or_run_path = str(old_checkpoint_path) 20 | new_cfg.train.restore.mode = "hotstart" 21 | 22 | new_training_dir = run(new_cfg) 23 | 24 | old_checkpoint = NNCheckpointIO.load(path=old_checkpoint_path) 25 | new_checkpoint = load_checkpoint(new_training_dir) 26 | 27 | assert old_checkpoint["run_path"] != new_checkpoint["run_path"] 28 | assert old_checkpoint["global_step"] * 2 == new_checkpoint["global_step"] 29 | -------------------------------------------------------------------------------- /{{ cookiecutter.repository_name }}/tests/test_seeding.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from omegaconf import DictConfig, OmegaConf 3 | 4 | from nn_core.common.utils import seed_index_everything 5 | 6 | 7 | @pytest.mark.parametrize( 8 | "seed_index, expected_seed", 9 | [ 10 | (0, 1608637542), 11 | (30, 787716372), 12 | ], 13 | ) 14 | def test_seed_index_determinism(cfg_all: DictConfig, seed_index: int, expected_seed: int): 15 | cfg_all = OmegaConf.create(cfg_all) 16 | 17 | cfg_all.train.seed_index = seed_index 18 | current_seed = seed_index_everything(train_cfg=cfg_all.train, sampling_seed=42) 19 | assert current_seed == expected_seed 20 | -------------------------------------------------------------------------------- /{{ cookiecutter.repository_name }}/tests/test_storage.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import yaml 4 | from omegaconf import DictConfig 5 | 6 | 7 | def test_storage_config(run_trainings_not_dry: str, cfg_all_not_dry: DictConfig) -> None: 8 | cfg_path = Path(run_trainings_not_dry) / "config.yaml" 9 | 10 | assert cfg_path.exists() 11 | 12 | with cfg_path.open() as f: 13 | loaded_cfg = yaml.safe_load(f) 14 | assert loaded_cfg == cfg_all_not_dry 15 | 16 | 17 | def test_storage_checkpoint(run_trainings_not_dry: str, cfg_all_not_dry: DictConfig) -> None: 18 | cktps_path = Path(run_trainings_not_dry) / "checkpoints" 19 | 20 | assert cktps_path.exists() 21 | 22 | checkpoints = list(cktps_path.glob("*")) 23 | assert len(checkpoints) == 1 24 | -------------------------------------------------------------------------------- /{{ cookiecutter.repository_name }}/tests/test_training.py: -------------------------------------------------------------------------------- 1 | def test_train_loop(run_trainings: str) -> None: 2 | assert run_trainings 3 | --------------------------------------------------------------------------------