├── .github
└── workflows
│ ├── publish_docs.yml
│ └── test_suite.yml
├── .gitignore
├── LICENSE
├── README.md
├── cookiecutter.json
├── docs
├── changelog
│ ├── index.md
│ └── upgrade.md
├── features
│ ├── bestpractices.md
│ ├── cicd.md
│ ├── conda.md
│ ├── determinism.md
│ ├── docs.md
│ ├── envvars.md
│ ├── fastdevrun.md
│ ├── metadata.md
│ ├── nncore.md
│ ├── restore.md
│ ├── storage.md
│ ├── tags.md
│ └── tests.md
├── getting-started
│ ├── generation.md
│ └── index.md
├── index.md
├── integrations
│ ├── dvc.md
│ ├── githubactions.md
│ ├── hydra.md
│ ├── lightning.md
│ ├── mkdocs.md
│ ├── streamlit.md
│ └── wandb.md
├── overrides
│ └── main.html
├── papers.md
└── project-structure
│ ├── conf.md
│ ├── index.md
│ └── structure.md
├── hooks
├── post_gen_project.py
└── pre_gen_project.py
├── mkdocs.yml
└── {{ cookiecutter.repository_name }}
├── .editorconfig
├── .env.template
├── .flake8
├── .github
└── workflows
│ ├── publish.yml
│ └── test_suite.yml
├── .gitignore
├── .pre-commit-config.yaml
├── LICENSE
├── README.md
├── conf
├── default.yaml
├── hydra
│ └── default.yaml
├── nn
│ ├── data
│ │ ├── dataset
│ │ │ └── vision
│ │ │ │ └── mnist.yaml
│ │ └── default.yaml
│ ├── default.yaml
│ └── module
│ │ ├── default.yaml
│ │ └── model
│ │ └── cnn.yaml
└── train
│ └── default.yaml
├── data
└── .gitignore
├── docs
├── index.md
└── overrides
│ └── main.html
├── env.yaml
├── mkdocs.yml
├── pyproject.toml
├── setup.cfg
├── setup.py
├── src
└── {{ cookiecutter.package_name }}
│ ├── __init__.py
│ ├── data
│ ├── __init__.py
│ ├── datamodule.py
│ └── dataset.py
│ ├── modules
│ ├── __init__.py
│ └── module.py
│ ├── pl_modules
│ ├── __init__.py
│ └── pl_module.py
│ ├── run.py
│ ├── ui
│ ├── __init__.py
│ └── run.py
│ └── utils
│ └── hf_io.py
└── tests
├── __init__.py
├── conftest.py
├── test_checkpoint.py
├── test_configuration.py
├── test_nn_core_integration.py
├── test_resume.py
├── test_seeding.py
├── test_storage.py
└── test_training.py
/.github/workflows/publish_docs.yml:
--------------------------------------------------------------------------------
1 | name: Publish docs
2 |
3 | on:
4 | release:
5 | types:
6 | - created
7 |
8 | jobs:
9 | build:
10 | strategy:
11 | fail-fast: false
12 | matrix:
13 | python-version: ['3.9']
14 | include:
15 | - os: ubuntu-20.04
16 | label: linux-64
17 | prefix: /usr/share/miniconda3/envs/
18 |
19 | name: ${{ matrix.label }}-py${{ matrix.python-version }}
20 | runs-on: ${{ matrix.os }}
21 |
22 | steps:
23 | - uses: actions/checkout@v2
24 | with:
25 | fetch-depth: 0
26 |
27 | - name: Set up Python
28 | uses: actions/setup-python@v1
29 |
30 | - name: Install Dependencies
31 | shell: bash -l {0}
32 | run: |
33 | pip install cookiecutter mkdocs mkdocs-material mike
34 |
35 | # extract the first two digits from the release note
36 | - name: Set release notes tag
37 | run: |
38 | export RELEASE_TAG_VERSION=${{ github.event.release.tag_name }}
39 | echo "RELEASE_TAG_VERSION=${RELEASE_TAG_VERSION%.*}">> $GITHUB_ENV
40 |
41 | - name: Echo release notes tag
42 | run: |
43 | echo "${RELEASE_TAG_VERSION}"
44 |
45 | - name: Build docs website
46 | shell: bash -l {0}
47 | run: |
48 | git config user.name ci-bot
49 | git config user.email ci-bot@ci.com
50 | mike deploy --push --rebase --update-aliases ${RELEASE_TAG_VERSION} latest
51 |
--------------------------------------------------------------------------------
/.github/workflows/test_suite.yml:
--------------------------------------------------------------------------------
1 | name: Test Suite
2 |
3 | on:
4 | push:
5 | branches:
6 | - main
7 | - develop
8 |
9 | pull_request:
10 | types:
11 | - opened
12 | - reopened
13 | - synchronize
14 |
15 | env:
16 | CACHE_NUMBER: 4 # increase to reset cache manually
17 | CONDA_ENV_FILE: 'env.yaml'
18 | CONDA_ENV_NAME: 'project-test'
19 | COOKIECUTTER_PROJECT_NAME: 'project-test'
20 | HUGGING_FACE_HUB_TOKEN: ${{secrets.HUGGING_FACE_HUB_TOKEN}}
21 |
22 | jobs:
23 | build:
24 |
25 | strategy:
26 | fail-fast: false
27 | matrix:
28 | python-version: ['3.11']
29 | include:
30 | - os: ubuntu-20.04
31 | label: linux-64
32 | prefix: /usr/share/miniconda3/envs/
33 |
34 | # - os: macos-latest
35 | # label: osx-64
36 | # prefix: /Users/runner/miniconda3/envs/$CONDA_ENV_NAME
37 |
38 | # - os: windows-latest
39 | # label: win-64
40 | # prefix: C:\Miniconda3\envs\$CONDA_ENV_NAME
41 |
42 | name: ${{ matrix.label }}-py${{ matrix.python-version }}
43 | runs-on: ${{ matrix.os }}
44 |
45 | steps:
46 | - name: Parametrize conda env name
47 | run: echo "PY_CONDA_ENV_NAME=${{ env.CONDA_ENV_NAME }}-${{ matrix.python-version }}" >> $GITHUB_ENV
48 | - name: echo conda env name
49 | run: echo ${{ env.PY_CONDA_ENV_NAME }}
50 |
51 | - name: Parametrize conda prefix
52 | run: echo "PY_PREFIX=${{ matrix.prefix }}${{ env.PY_CONDA_ENV_NAME }}" >> $GITHUB_ENV
53 | - name: echo conda prefix
54 | run: echo ${{ env.PY_PREFIX }}
55 |
56 | - name: Define generated project files paths
57 | run: |
58 | echo "PROJECT_SETUPCFG_FILE=${{ env.COOKIECUTTER_PROJECT_NAME }}/setup.cfg" >> $GITHUB_ENV
59 | echo "PROJECT_CONDAENV_FILE=${{ env.COOKIECUTTER_PROJECT_NAME }}/${{ env.CONDA_ENV_FILE }}" >> $GITHUB_ENV
60 | echo "PROJECT_PRECOMMIT_FILE=${{ env.COOKIECUTTER_PROJECT_NAME }}/.pre-commit-config.yaml" >> $GITHUB_ENV
61 |
62 | - uses: actions/checkout@v2
63 |
64 | # COOKIECUTTER GENERATION
65 | - name: Set up Python
66 | uses: actions/setup-python@v1
67 |
68 | - name: Install Dependencies
69 | shell: bash -l {0}
70 | run: |
71 | pip install cookiecutter
72 |
73 | - name: Generate Repo
74 | shell: bash -l {0}
75 | run: |
76 | echo -e 'n\nn\nn\n' | cookiecutter . --no-input project_name=${{ env.COOKIECUTTER_PROJECT_NAME }}
77 |
78 | - name: Init git into generated repo
79 | shell: bash -l {0}
80 | run: |
81 | git config --global user.name ci-bot
82 | git config --global user.email ci-bot@ci.com
83 | git init
84 | git add --all
85 | git commit -m "Initial commit"
86 | working-directory: ${{ env.COOKIECUTTER_PROJECT_NAME }}
87 |
88 | - name: Define cache key postfix
89 | run: |
90 | echo "CACHE_KEY_POSTFIX=${{ matrix.label }}-${{ matrix.python-version }}-${{ env.CACHE_NUMBER }}-${{ env.PY_CONDA_ENV_NAME }}-${{ hashFiles(env.PROJECT_CONDAENV_FILE) }}-${{ hashFiles(env.PROJECT_SETUPCFG_FILE) }}" >> $GITHUB_ENV
91 |
92 | - name: Echo cache keys
93 | run: |
94 | echo ${{ env.PROJECT_SETUPCFG_FILE }}
95 | echo ${{ env.PROJECT_CONDAENV_FILE }}
96 | echo ${{ env.PROJECT_PRECOMMIT_FILE }}
97 | echo ${{ env.CACHE_KEY_POSTFIX }}
98 |
99 | # GENERATED PROJECT CI/CD
100 |
101 | # Remove the python version pin from the env.yml which could be inconsistent
102 | - name: Remove explicit python version from the environment
103 | shell: bash -l {0}
104 | run: |
105 | sed -Ei '/^\s*-?\s*python\s*([#=].*)?$/d' ${{ env.CONDA_ENV_FILE }}
106 | cat ${{ env.CONDA_ENV_FILE }}
107 | working-directory: ${{ env.COOKIECUTTER_PROJECT_NAME }}
108 |
109 | # Install torch cpu-only
110 | - name: Install torch cpu only
111 | shell: bash -l {0}
112 | run: |
113 | sed -i '/nvidia\|cuda/d' ${{ env.CONDA_ENV_FILE }}
114 | cat ${{ env.CONDA_ENV_FILE }}
115 | working-directory: ${{ env.COOKIECUTTER_PROJECT_NAME }}
116 |
117 | - name: Setup Mambaforge
118 | uses: conda-incubator/setup-miniconda@v2
119 | with:
120 | miniforge-variant: Mambaforge
121 | miniforge-version: latest
122 | activate-environment: ${{ env.PY_CONDA_ENV_NAME }}
123 | python-version: ${{ matrix.python-version }}
124 | use-mamba: true
125 |
126 | - uses: actions/cache@v2
127 | name: Conda cache
128 | with:
129 | path: ${{ env.PY_PREFIX }}
130 | key: conda-${{ env.CACHE_KEY_POSTFIX }}
131 | id: conda_cache
132 |
133 | - uses: actions/cache@v2
134 | name: Pip cache
135 | with:
136 | path: ~/.cache/pip
137 | key: pip-${{ env.CACHE_KEY_POSTFIX }}
138 |
139 | - uses: actions/cache@v2
140 | name: Pre-commit cache
141 | with:
142 | path: ~/.cache/pre-commit
143 | key: pre-commit-${{ hashFiles(env.PROJECT_PRECOMMIT_FILE) }}-conda-${{ env.CACHE_KEY_POSTFIX }}
144 |
145 | # Ensure the hack for the python version worked
146 | - name: Ensure we have the right Python
147 | shell: bash -l {0}
148 | run: |
149 | echo "Installed Python: $(python --version)"
150 | echo "Expected: ${{ matrix.python-version }}"
151 | python --version | grep "Python ${{ matrix.python-version }}"
152 | working-directory: ${{ env.COOKIECUTTER_PROJECT_NAME }}
153 |
154 | # https://stackoverflow.com/questions/70520120/attributeerror-module-setuptools-distutils-has-no-attribute-version
155 | # https://github.com/pytorch/pytorch/pull/69904
156 | - name: Downgrade setuptools due to a but in PyTorch 1.10.1
157 | shell: bash -l {0}
158 | run: |
159 | pip install setuptools==59.5.0 --upgrade
160 | working-directory: ${{ env.COOKIECUTTER_PROJECT_NAME }}
161 |
162 | - name: Update conda environment
163 | run: mamba env update -n ${{ env.PY_CONDA_ENV_NAME }} -f ${{ env.CONDA_ENV_FILE }}
164 | if: steps.conda_cache.outputs.cache-hit != 'true'
165 | working-directory: ${{ env.COOKIECUTTER_PROJECT_NAME }}
166 |
167 | # Update pip env whether or not there was a conda cache hit
168 | - name: Update pip environment
169 | shell: bash -l {0}
170 | run: pip install -e ".[dev]"
171 | if: steps.conda_cache.outputs.cache-hit == 'true'
172 | working-directory: ${{ env.COOKIECUTTER_PROJECT_NAME }}
173 |
174 | - run: pip3 list
175 | shell: bash -l {0}
176 | working-directory: ${{ env.COOKIECUTTER_PROJECT_NAME }}
177 |
178 | - run: mamba info
179 | working-directory: ${{ env.COOKIECUTTER_PROJECT_NAME }}
180 |
181 | - run: mamba list
182 | working-directory: ${{ env.COOKIECUTTER_PROJECT_NAME }}
183 |
184 | # Ensure the hack for the python version worked
185 | - name: Ensure we have the right Python
186 | shell: bash -l {0}
187 | run: |
188 | echo "Installed Python: $(python --version)"
189 | echo "Expected: ${{ matrix.python-version }}"
190 | python --version | grep "Python ${{ matrix.python-version }}"
191 | working-directory: ${{ env.COOKIECUTTER_PROJECT_NAME }}
192 |
193 | - name: Run pre-commits
194 | shell: bash -l {0}
195 | run: |
196 | pre-commit install
197 | pre-commit run -v --all-files --show-diff-on-failure
198 | working-directory: ${{ env.COOKIECUTTER_PROJECT_NAME }}
199 |
200 | - name: Test with pytest
201 | shell: bash -l {0}
202 | run: |
203 | pytest -v
204 | working-directory: ${{ env.COOKIECUTTER_PROJECT_NAME }}
205 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | wandb
2 | multirun.yaml
3 | storage
4 |
5 | # ignore the _version.py file
6 | _version.py
7 |
8 | # .gitignore defaults for python and pycharm
9 | .idea
10 |
11 | # Byte-compiled / optimized / DLL files
12 | __pycache__/
13 | *.py[cod]
14 | *$py.class
15 |
16 | # C extensions
17 | *.so
18 |
19 | # Distribution / packaging
20 | .Python
21 | build/
22 | develop-eggs/
23 | dist/
24 | downloads/
25 | eggs/
26 | .eggs/
27 | lib/
28 | lib64/
29 | parts/
30 | sdist/
31 | var/
32 | wheels/
33 | share/python-wheels/
34 | *.egg-info/
35 | .installed.cfg
36 | *.egg
37 | MANIFEST
38 |
39 | # PyInstaller
40 | # Usually these files are written by a python script from a template
41 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
42 | *.manifest
43 | *.spec
44 |
45 | # Installer logs
46 | pip-log.txt
47 | pip-delete-this-directory.txt
48 |
49 | # Unit test / coverage reports
50 | htmlcov/
51 | .tox/
52 | .nox/
53 | .coverage
54 | .coverage.*
55 | .cache
56 | nosetests.xml
57 | coverage.xml
58 | *.cover
59 | *.py,cover
60 | .hypothesis/
61 | .pytest_cache/
62 | cover/
63 |
64 | # Translations
65 | *.mo
66 | *.pot
67 |
68 | # Django stuff:
69 | *.log
70 | local_settings.py
71 | db.sqlite3
72 | db.sqlite3-journal
73 |
74 | # Flask stuff:
75 | instance/
76 | .webassets-cache
77 |
78 | # Scrapy stuff:
79 | .scrapy
80 |
81 | # Sphinx documentation
82 | docs/_build/
83 |
84 | # PyBuilder
85 | .pybuilder/
86 | target/
87 |
88 | # Jupyter Notebook
89 | .ipynb_checkpoints
90 |
91 | # IPython
92 | profile_default/
93 | ipython_config.py
94 |
95 | # pyenv
96 | # For a library or package, you might want to ignore these files since the code is
97 | # intended to run in multiple environments; otherwise, check them in:
98 | # .python-version
99 |
100 | # pipenv
101 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
102 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
103 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
104 | # install all needed dependencies.
105 | #Pipfile.lock
106 |
107 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
108 | __pypackages__/
109 |
110 | # Celery stuff
111 | celerybeat-schedule
112 | celerybeat.pid
113 |
114 | # SageMath parsed files
115 | *.sage.py
116 |
117 | # Environments
118 | .env
119 | .venv
120 | env/
121 | venv/
122 | ENV/
123 | env.bak/
124 | venv.bak/
125 |
126 | # Spyder project settings
127 | .spyderproject
128 | .spyproject
129 |
130 | # Rope project settings
131 | .ropeproject
132 |
133 | # mkdocs documentation
134 | /site
135 |
136 | # mypy
137 | .mypy_cache/
138 | .dmypy.json
139 | dmypy.json
140 |
141 | # Pyre type checker
142 | .pyre/
143 |
144 | # pytype static type analyzer
145 | .pytype/
146 |
147 | # Cython debug symbols
148 | cython_debug/
149 |
150 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider
151 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
152 |
153 | # User-specific stuff
154 | .idea/**/workspace.xml
155 | .idea/**/tasks.xml
156 | .idea/**/usage.statistics.xml
157 | .idea/**/dictionaries
158 | .idea/**/shelf
159 |
160 | # Generated files
161 | .idea/**/contentModel.xml
162 |
163 | # Sensitive or high-churn files
164 | .idea/**/dataSources/
165 | .idea/**/dataSources.ids
166 | .idea/**/dataSources.local.xml
167 | .idea/**/sqlDataSources.xml
168 | .idea/**/dynamic.xml
169 | .idea/**/uiDesigner.xml
170 | .idea/**/dbnavigator.xml
171 |
172 | # Gradle
173 | .idea/**/gradle.xml
174 | .idea/**/libraries
175 |
176 | # Gradle and Maven with auto-import
177 | # When using Gradle or Maven with auto-import, you should exclude module files,
178 | # since they will be recreated, and may cause churn. Uncomment if using
179 | # auto-import.
180 | # .idea/artifacts
181 | # .idea/compiler.xml
182 | # .idea/jarRepositories.xml
183 | # .idea/modules.xml
184 | # .idea/*.iml
185 | # .idea/modules
186 | # *.iml
187 | # *.ipr
188 |
189 | # CMake
190 | cmake-build-*/
191 |
192 | # Mongo Explorer plugin
193 | .idea/**/mongoSettings.xml
194 |
195 | # File-based project format
196 | *.iws
197 |
198 | # IntelliJ
199 | out/
200 |
201 | # mpeltonen/sbt-idea plugin
202 | .idea_modules/
203 |
204 | # JIRA plugin
205 | atlassian-ide-plugin.xml
206 |
207 | # Cursive Clojure plugin
208 | .idea/replstate.xml
209 |
210 | # Crashlytics plugin (for Android Studio and IntelliJ)
211 | com_crashlytics_export_strings.xml
212 | crashlytics.properties
213 | crashlytics-build.properties
214 | fabric.properties
215 |
216 | # Editor-based Rest Client
217 | .idea/httpRequests
218 |
219 | # Android studio 3.1+ serialized cache file
220 | .idea/caches/build_file_checksums.ser
221 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Valentino Maiorca, Luca Moschella
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # NN Template
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 | [comment]: <> ()
12 |
13 | [comment]: <> (
)
14 |
15 | [comment]: <> (
)
16 |
17 | [comment]: <> (
)
18 |
19 | [comment]: <> (
)
20 |
21 | [comment]: <> (
)
22 |
23 | [comment]: <> (
)
24 |
25 | [comment]: <> (
)
26 |
27 |
28 |
29 | "We demand rigidly defined areas of doubt and uncertainty."
30 |
31 |
32 |
33 |
34 | Generic template to bootstrap your [PyTorch](https://pytorch.org/get-started/locally/) project,
35 | read more in the [documentation](https://grok-ai.github.io/nn-template).
36 |
37 |
38 | [](https://asciinema.org/a/475623)
39 |
40 | ## Get started
41 |
42 | If you already know [cookiecutter](https://github.com/cookiecutter/cookiecutter), just generate your project with:
43 |
44 | ```bash
45 | cookiecutter https://github.com/grok-ai/nn-template
46 | ```
47 |
48 |
49 | Otherwise
50 | Cookiecutter manages the setup stages and delivers to you a personalized ready to run project.
51 |
52 | Install it with:
53 | pip install cookiecutter
54 |
55 |
56 |
57 | More details in the [documentation](https://grok-ai.github.io/nn-template/latest/getting-started/generation/).
58 |
59 | ## Strengths
60 |
61 | - **Actually works for [research](https://grok-ai.github.io/nn-template/latest/papers/)**!
62 | - Guided setup to customize project bootstrapping;
63 | - Fast prototyping of new ideas, no need to build a new code base from scratch;
64 | - Less boilerplate with no impact on the learning curve (as long as you know the integrated tools);
65 | - Ensure experiments reproducibility;
66 | - Automatize via GitHub actions: testing, stylish documentation deploy, PyPi upload;
67 | - Enforce Python [best practices](https://grok-ai.github.io/nn-template/latest/features/bestpractices/);
68 | - Many more in the [documentation](https://grok-ai.github.io/nn-template/latest/features/nncore/);
69 |
70 | ## Integrations
71 |
72 | Avoid writing boilerplate code to integrate:
73 |
74 | - [PyTorch Lightning](https://github.com/PyTorchLightning/pytorch-lightning), lightweight PyTorch wrapper for high-performance AI research.
75 | - [Hydra](https://github.com/facebookresearch/hydra), a framework for elegantly configuring complex applications.
76 | - [Hugging Face Datasets](https://huggingface.co/docs/datasets/index),a library for easily accessing and sharing datasets.
77 | - [Weights and Biases](https://wandb.ai/home), organize and analyze machine learning experiments. *(educational account available)*
78 | - [Streamlit](https://streamlit.io/), turns data scripts into shareable web apps in minutes.
79 | - [MkDocs](https://www.mkdocs.org/) and [Material for MkDocs](https://squidfunk.github.io/mkdocs-material/), a fast, simple and downright gorgeous static site generator.
80 | - [DVC](https://dvc.org/doc/start/data-versioning), track large files, directories, or ML models. Think "Git for data".
81 | - [GitHub Actions](https://github.com/features/actions), to run the tests, publish the documentation and to PyPI automatically.
82 | - Python best practices for developing and publishing research projects.
83 |
84 | ## Maintainers
85 |
86 | - Valentino Maiorca [@Flegyas](https://github.com/Flegyas)
87 | - Luca Moschella [@lucmos](https://github.com/lucmos)
88 |
--------------------------------------------------------------------------------
/cookiecutter.json:
--------------------------------------------------------------------------------
1 | {
2 | "author": "Paul Erdős",
3 | "author_email": "paul@erdos.com",
4 | "github_user": "erdos",
5 | "project_name": "Awesome Project",
6 | "project_description": "A new awesome project.",
7 | "repository_name": "{{ cookiecutter.project_name.strip().lower().replace(' ', '-') }}",
8 | "package_name": "{{ cookiecutter.project_name.strip().lower().replace(' ', '-').replace('-', '_') }}",
9 | "repository_url": "https://github.com/{{ cookiecutter.github_user }}/{{ cookiecutter.project_name.strip().lower().replace(' ', '-') }}",
10 | "conda_env_name": "{{ cookiecutter.project_name.strip().lower().replace(' ', '-') }}",
11 | "python_version": "3.11",
12 | "__version": "0.4.0"
13 | }
14 |
--------------------------------------------------------------------------------
/docs/changelog/index.md:
--------------------------------------------------------------------------------
1 | # Changelog
2 |
3 | See the changelog in the [releases](https://github.com/grok-ai/nn-template/releases) page.
4 |
--------------------------------------------------------------------------------
/docs/changelog/upgrade.md:
--------------------------------------------------------------------------------
1 | The need for upgrading the template itself is lessened thanks to the `nn-template-core` library decoupling.
2 |
3 | !!! info
4 |
5 | Update the `nn-template-core` library changing the version constraint in the `setup.cfg`.
6 |
7 | ---
8 |
9 | However, you can use [cruft](https://github.com/cruft/cruft) to automate also the template updates!
10 | [](https://github.com/cruft/cruft)
11 |
--------------------------------------------------------------------------------
/docs/features/bestpractices.md:
--------------------------------------------------------------------------------
1 | # Tooling
2 |
3 | The template configures are the tooling necessary for a modern Python project.
4 |
5 | These include:
6 |
7 | - [**EditorConfig**](https://editorconfig.org/) maintain consistent coding styles for multiple developers.
8 | - [**Black**](https://black.readthedocs.io/en/stable/index.html) the uncompromising code formatter.
9 | - [**isort**](https://github.com/PyCQA/isort) sort imports alphabetically, and automatically separated into sections and by type.
10 | - [**flake8**](https://flake8.pycqa.org/en/latest/) check coding style (PEP8), programming errors and cyclomatic complexity.
11 | - [**pydocstyle**](http://www.pydocstyle.org/en/stable/) static analysis tool for checking compliance with Python docstring conventions.
12 | - [**MyPy**](http://mypy-lang.org/) static type checker for Python.
13 | - [**Coverage**](https://coverage.readthedocs.io/en/6.2/) measure code coverage of Python programs.
14 | - [**bandit**](https://github.com/PyCQA/bandit) security linter from PyCQA.
15 | - [**pre-commit**](https://pre-commit.com/) framework for managing and maintaining pre-commit hooks.
16 |
17 | ## Pre commits
18 |
19 | The pre-commits configuration is defined in `.pre-commit-config.yaml`,
20 | and includes the most important checks and auto-fix to perform.
21 |
22 | If one of the pre-commits fails, **the commit is aborted** avoiding distraction errors.
23 |
24 | !!! info
25 |
26 | The pre-commits are also run in the CI/CD as part of the Test Suite. This helps
27 | guaranteeing the all the contributors are respecting the code conventions in the
28 | pull requests.
29 |
--------------------------------------------------------------------------------
/docs/features/cicd.md:
--------------------------------------------------------------------------------
1 | # CI/CD
2 |
3 | The generated project contains two [GiHub Actions](https://github.com/features/actions) workflow to run the Test Suite and to publish you project.
4 |
5 | !!! note
6 | You need to enable the GitHub Actions from the settings in your repository.
7 |
8 | !!! important
9 | All the workflow already implement the logic needed to **cache** the conda and pip environment
10 | between workflow runs.
11 |
12 | !!! warning
13 | The annotated tags in the git repository to manage releases should follow the [semantic versioning](https://semver.org/)
14 | conventions: `..`
15 |
16 |
17 |
18 | ## Test Suite
19 |
20 | The Test Suite runs automatically for each commit in a Pull Request.
21 | It is successful if:
22 |
23 | - The pre-commits do not raise any errors
24 | - All the tests pass
25 |
26 | After that, the PR are marked with ✔️ or ❌ depending on the test suite results.
27 |
28 | ## Publish docs
29 |
30 | The first time you should use `mike` to: create the `gh-pages` branch and
31 | specify the default docs version.
32 |
33 | ```bash
34 | mike deploy 0.1 latest --push
35 | mike set-default latest
36 | ```
37 |
38 | !!! warning
39 |
40 | You do not need to execute these commands if you accepted the optional cookiecutter setup step.
41 |
42 | !!! info
43 |
44 | Remember to enable the GitHub Pages from the repository settings.
45 |
46 |
47 | After that, the docs are built and automatically published `on release` on GitHub Pages.
48 | This means that every time you publish a new release in your project an associated version of the documentation is published.
49 |
50 | !!! important
51 |
52 | The documentation version utilizes only the `.` version of the release tag, discarding the patch version.
53 |
54 | ## Publish PyPi
55 |
56 | To publish your package on PyPi it is enough to configure
57 | the PyPi token in the GitHub repository `secrets` and de-comment the following in the
58 | `publish.yaml` workflow:
59 |
60 | ```yaml
61 | - name: Build SDist and wheel
62 | run: pipx run build
63 |
64 | - name: Check metadata
65 | run: pipx run twine check dist/*
66 |
67 | - name: Publish distribution 📦 to PyPI
68 | uses: pypa/gh-action-pypi-publish@release/v1
69 | with:
70 | user: __token__
71 | password: ${{ secrets.PYPI_API_TOKEN }}
72 | ```
73 |
74 | In this way, on each GitHub release the package gets published on PyPi and the associated documentation
75 | is published on GitHub Pages.
76 |
--------------------------------------------------------------------------------
/docs/features/conda.md:
--------------------------------------------------------------------------------
1 | # Python Environment
2 |
3 | The generated project is a Python Package, whose dependencies are defined in the `setup.cfg`
4 | The development setup comprises a `conda` environment which installs the package itself in edit mode.
5 |
6 | ## Dependencies
7 |
8 | All the project dependencies should be defined in the `setup.cfg` as `pip` dependencies.
9 | In rare cases, it is useful to specify conda dependencies --- they will not be resolved when installing the package
10 | from PyPi.
11 |
12 | This division is useful when installing particular or optimized packages such a `PyTorch` and PyTorch Geometric.
13 |
14 | !!! hint
15 |
16 | It is possible to manage the Python version to use in the conda `env.yaml`.
17 |
18 | !!! info
19 |
20 | This organization allows for `conda` and `pip` dependencies to co-exhist, which in practice happens a lot in
21 | research projects.
22 |
23 | ## Update
24 |
25 | In order to update the `pip` dependencies after changing the `setup.cfg` it is enough to run:
26 |
27 | ```bash
28 | pip install -e '.[dev]'
29 | ```
30 |
--------------------------------------------------------------------------------
/docs/features/determinism.md:
--------------------------------------------------------------------------------
1 | # Determinism
2 |
3 | The template always logs the seed utilized in order to guarantee **reproducibility**.
4 |
5 | The user specifies a `seed_index` value in the configuration `train/default.yaml`:
6 |
7 | ```yaml
8 | seed_index: 1
9 | deterministic: False
10 | ```
11 |
12 | This value indexes an array of deterministic but randomly generated seeds, e.g.:
13 |
14 | ```bash
15 | Setting seed 1273642419 from seeds[1]
16 | ```
17 |
18 |
19 | !!! hint
20 |
21 | This setup allows to easily run the same experiment with different seeds in a reproducible way.
22 | It is enough to run a Hydra multi-run over the `seed_index`.
23 |
24 | The following would run the same experiment with five different seeds, which can be analyzed
25 | in the logger dashboard:
26 |
27 | ```bash
28 | python src/project/run.py -m train.seed_index=1,2,3,4
29 | ```
30 |
31 | !!! info
32 |
33 | The deterministic option `deterministic: False` controls the use of deterministic algorithms
34 | in PyTorch, it is forwarded to the Lightning Trainer.
35 |
--------------------------------------------------------------------------------
/docs/features/docs.md:
--------------------------------------------------------------------------------
1 | # Documentation
2 |
3 | `MkDocs` and `Material for MkDocs` is already configured in the generated project.
4 |
5 | In order to create your docs it is enough to:
6 |
7 | 1. Modify the `nav` index in the `mkdocs.yaml`, which describes how to organize the pages.
8 | An example of the `nav` is the following:
9 |
10 | ```yaml
11 | nav:
12 | - Home: index.md
13 | - Getting started:
14 | - Generating your project: getting-started/generation.md
15 | - Strucure: getting-started/structure.md
16 | ```
17 |
18 | 2. Create all the files referenced in the `nav` relative to the `docs/` folder.
19 |
20 | ```bash
21 | ❯ tree docs
22 | docs
23 | ├── getting-started
24 | │ ├── generation.md
25 | │ └── structure.md
26 | └── index.md
27 | ```
28 |
29 | 3. To preview your documentation it is enough to run `mkdocs serve`. To manually deploy the documentation
30 | see [`mike`](https://github.com/jimporter/mike), or see the integrated GitHub Action to [publish the docs on release](https://grok-ai.github.io/nn-template/latest/features/cicd/#publish-docs).
31 |
--------------------------------------------------------------------------------
/docs/features/envvars.md:
--------------------------------------------------------------------------------
1 |
2 | # Environment Variables
3 |
4 | System specific variables (e.g. absolute paths) should not be under version control, otherwise there will be conflicts between different users.
5 |
6 | The best way to handle system specific variables is through environment variables.
7 |
8 | You can define new environment variables in a `.env` file in the project root. A copy of this file (e.g. `.env.template`) can be under version control to ease new project configurations.
9 |
10 | To define a new variable write inside `.env`:
11 |
12 | ```bash
13 | export MY_VAR=/home/user/my_system_path
14 | ```
15 |
16 |
17 | You can dynamically resolve the variable name from everywhere
18 |
19 | === "python"
20 |
21 | In Python code use:
22 |
23 | ```python
24 | get_env("MY_VAR")
25 | ```
26 |
27 | === "yaml"
28 |
29 | In the Hydra `yaml` configurations:
30 |
31 | ```yaml
32 | ${oc.env:MY_VAR}
33 | ```
34 |
35 | === "posix"
36 |
37 | In posix shells:
38 |
39 | ```bash
40 | . .env
41 | echo $MY_VAR
42 | ```
43 |
--------------------------------------------------------------------------------
/docs/features/fastdevrun.md:
--------------------------------------------------------------------------------
1 | # Fast Dev Run
2 |
3 | The template expands the Lightning `fast_dev_run` mode to be more debugging friendly.
4 |
5 | It will also:
6 |
7 | - Disable multiple workers in the dataloaders
8 | - Use the CPU and not the GPU
9 |
10 | !!! info
11 |
12 | It is possible to modify this behaviour by simply modifying the `run.py` file.
13 |
--------------------------------------------------------------------------------
/docs/features/metadata.md:
--------------------------------------------------------------------------------
1 | # MetaData
2 |
3 | The *bridge* between the Lightning DataModule and the Lightning Module.
4 |
5 | It is responsible for collecting data information to be fed to the module.
6 | The Lightning Module will receive an instance of MetaData when instantiated,
7 | both in the train loop or when restored from a checkpoint.
8 |
9 | !!! warning
10 |
11 | MetaData exposes `save` and `load`. Those are two user-defined methods that specify how to serialize and de-serialize the information contained in its attributes.
12 | This is needed for the checkpointing restore to work properly and **must be
13 | always implemented**, where the metadata is needed.
14 |
15 | This decoupling allows the architecture to be parametric (e.g. in the number of classes) and
16 | DataModule/Trainer independent (useful in prediction scenarios).
17 | Examples are the class names in a classification task or the vocabulary in NLP tasks.
18 |
19 |
--------------------------------------------------------------------------------
/docs/features/nncore.md:
--------------------------------------------------------------------------------
1 | # NN Template core
2 |
3 | Most of the logic is abstracted from the template into an accompanying library: [`nn-template-core`](https://pypi.org/project/nn-template-core/).
4 |
5 | This library contains the logic necessary for the restore, logging, and many other functionalities implemented in the template.
6 |
7 | !!! info
8 |
9 | This decoupling eases the updating of the template, reaching a desirable tradeoff:
10 |
11 | - `template`: easy to use and customize, hard to update
12 | - `library`: hard to customize, easy to update
13 |
14 | With our approach updating most of the functions is extremely easy, it is just a Python
15 | dependency, while maintaing the flexibility of a template.
16 |
17 |
18 | !!! warning
19 |
20 | It is important to **not** remove the `NNTemplateCore` callback from the instantiated callbacks
21 | in the template. It is used to inject personalized behaviour in the training loop.
22 |
--------------------------------------------------------------------------------
/docs/features/restore.md:
--------------------------------------------------------------------------------
1 | # Restore
2 |
3 | The template offers a way to restore a previous run from the configuration.
4 | The relevant configuration block is in `conf/train/default.yml`:
5 |
6 | ```yaml
7 | restore:
8 | ckpt_or_run_path: null
9 | mode: null # null, finetune, hotstart, continue
10 | ```
11 |
12 | ## ckpt_or_run_path
13 |
14 | The `ckpt_or_run_path` can be a path towards a Lightning Checkpoint or the run identifiers w.r.t. the logger.
15 | In case of W&B as a logger, they are called `run_path` and are in the form of `entity/project/run_id`.
16 |
17 | !!! warning
18 |
19 | If `ckpt_or_run_path` points to a checkpoint, that checkpoint must have been saved with
20 | this template, because additional information are attached to the checkpoint to guarantee
21 | a correct restore. These include the `run_path` itself and the whole configuration used.
22 |
23 | ## mode
24 |
25 | We support 4 different modes for restoring an experiment:
26 |
27 | === "null"
28 |
29 | ```yaml
30 | restore:
31 | mode: null
32 | ```
33 | In this `mode` no restore happens, and `ckpt_or_run_path` is ignored.
34 |
35 |
36 | !!! example "Use Case"
37 |
38 | This is the default option and allows the user to train the model from
39 | scratch logging into a new run.
40 |
41 |
42 | === "finetune"
43 |
44 | ```yaml
45 | restore:
46 | mode: finetune
47 | ```
48 | In this `mode` only the model weights are restored, both the `Trainer` state and the logger run
49 | are *not restored*.
50 |
51 |
52 | !!! example "Use Case"
53 |
54 | As the name suggest, one of the most common use case is when fine
55 | tuning a trained model logging into a new run with a novel training
56 | regimen.
57 |
58 | === "hotstart"
59 |
60 | ```yaml
61 | restore:
62 | mode: hotstart
63 | ```
64 | In this `mode` the training continues from the checkpoint restoring the `Trainer` state **but** the logging does not.
65 | A new run is created on the logger dashboard.
66 |
67 |
68 | !!! example "Use Case"
69 |
70 | Perform different tests in separate logging runs branching from the same trained
71 | model.
72 |
73 |
74 | === "continue"
75 |
76 | ```yaml
77 | restore:
78 | mode: continue
79 | ```
80 | In this `mode` the training continues from the checkpoint **and** the logging continues
81 | in the previous run. No new run is created on the logger dashboard.
82 |
83 |
84 | !!! example "Use Case"
85 |
86 | The training execution was interrupted and the user wants to continue it.
87 |
88 |
89 | !!! tldr "Restore summary"
90 |
91 | | | null | finetune | hotstart | continue |
92 | |---------------|------|--------------------|--------------------|--------------------|
93 | | **Model weights** | :x: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
94 | | **Trainer state** | :x: | :x: | :white_check_mark: | :white_check_mark: |
95 | | **Logging run** | :x: | :x: | :x: | :white_check_mark: |
96 |
--------------------------------------------------------------------------------
/docs/features/storage.md:
--------------------------------------------------------------------------------
1 | # Storage
2 |
3 | The checkpoints and other data produces by the experiment is stored in
4 | a logger agnostic folder defined in the configuration `core.storage_dir`
5 |
6 | This is the organization of the `storage_dir`:
7 |
8 | ```bash
9 | storage
10 | └──
11 | └──
12 | ├── checkpoints
13 | │ └── .ckpt.zip
14 | └── config.yaml
15 | ```
16 |
17 | In the configuration it is possible to specify whether the run files
18 | stored inside the `storage_dir` should be uploaded to the cloud:
19 |
20 | ```yaml
21 | logging:
22 | upload:
23 | run_files: true
24 | source: true
25 | ```
26 |
--------------------------------------------------------------------------------
/docs/features/tags.md:
--------------------------------------------------------------------------------
1 | # Tags
2 |
3 | Each run should be `tagged` in order to easily filter them from the logged dashboard.
4 | Unfortunately, it is easy to forget to tag correctly each run.
5 |
6 | We ask interactively for a list of comma separated tags, if those are not already defined in the configuration:
7 | ```
8 | WARNING No tags provided, asking for tags...
9 | Enter a list of comma separated tags (develop):
10 | ```
11 |
12 | !!! info
13 |
14 | If the current experiment is a sweep comprised of multiple runs and there are not any tags defined,
15 | an error is raised instead:
16 | ```
17 | ERROR You need to specify 'core.tags' in a multi-run setting!
18 | ```
19 |
--------------------------------------------------------------------------------
/docs/features/tests.md:
--------------------------------------------------------------------------------
1 | # Tests
2 |
3 | The generated project includes automated tests that use the **current configuration** defined in your project.
4 |
5 | You should write additional tests specific to each project, but running the tests should give an
6 | idea at least if the code and fundamental operations work as expected.
7 |
8 |
9 | !!! info
10 |
11 | You can execute the tests with:
12 |
13 | ```bash
14 | pytest -v
15 | ```
16 |
--------------------------------------------------------------------------------
/docs/getting-started/generation.md:
--------------------------------------------------------------------------------
1 | # Initial Setup
2 |
3 | ## Cookiecutter
4 |
5 | `nn-template` is, by definition, a template to generate projects. It's a robust **starting point**
6 | for your projects, something that lets you skip the initial boilerplate in configuring the environment, tests and such.
7 | Since it is a blueprint to build upon, it has no utility in being installed via pip or similar tools.
8 |
9 | Instead, we rely on [cookiecutter](https://cookiecutter.readthedocs.io) to manage the setup stages and deliver to you a
10 | ready-to-run project. It is a general-purpose tool that enables users to add their water of choice (variable
11 | configurations) to their particular Cup-a-Soup (the template to be setup).
12 |
13 | !!! hint "Installing cookiecutter"
14 |
15 | `cookiecutter` can be installed via pip in any Python-enabled environment (it won't be the same used by the project once
16 | instantiated). Our advice is to install `cookiecutter` as a system utility via [pipx](https://github.com/pypa/pipx):
17 |
18 | ```shell
19 | pipx install cookiecutter
20 | ```
21 |
22 | Then, we need to tell cookiecutter which template to work on:
23 |
24 | ```shell
25 | cookiecutter https://github.com/grok-ai/nn-template.git
26 | ```
27 |
28 | It will clone the nn-template repository in the background, call its interactive setup, and build your project's folder
29 | according to the given parametrization.
30 |
31 | The parametrized setup will take care of:
32 |
33 | - Set up the development of a Python package
34 | - Initializing a clean Git repository and add the GitHub remote of choice
35 | - Create a **new Conda environment** to execute your code in
36 |
37 | This extra step via cookiecutter is done to avoid a lot of manual parametrization, unavoidable when cloning a template
38 | repository from scratch. **Trust us, it is totally worth the bother!**
39 |
40 | ## Building Blocks
41 | The generated project already contains a minimal working example. You are **free to modify anything** you want except for
42 | a few essential and high-level things that keep everything working. **(again, this is not a framework!)**.
43 | In particular mantain:
44 |
45 | - Any `LightningLogger` you may want to use wrapped in a `NNLogger`
46 | - The `NNTemplateCore` Lightning callback
47 |
48 | !!! hint nn-template main components
49 |
50 | The template bootstraps the project with most of the needed boilerplate.
51 | The remaining components to implement for your project are the following:
52 |
53 | 1. Implement data pipeline
54 | 1. Dataset
55 | 2. Pytorch Lightning DataModule
56 | 2. Implement neural modules
57 | 1. Model
58 | 2. Pytorch Lightning Module
59 |
60 | ## FAQs
61 |
62 | ??? question "What is The Answer to the Ultimate Question of Life, the Universe, and Everything?"
63 |
64 | 42
65 |
66 | ??? question "Why are the logs badly formatted in PyCharm?"
67 |
68 | This is due to the fact that we are using [Rich](https://rich.readthedocs.io/en/stable/introduction.html) to handle
69 | the logging, and Rich is not compatible with customized terminals. As its documentation says:
70 |
71 | "*PyCharm users will need to enable “emulate terminal” in output console option in run/debug configuration to see styled output.*"
72 |
73 | ??? question "Why are file paths not interactive in the terminal's output?"
74 |
75 | [We would like to know, too.](https://youtrack.jetbrains.com/issue/PY-46305)
76 |
77 | ??? question "How can I exclude specific file paths from pre-commit checks (e.g. pydocstyle)?"
78 |
79 | While we encourage everyone to keep best-practices and standards enforced via the pre-commit utility, we also take
80 | into account situations where you just copy/paste code from the Internet and fixing it would be tedious.
81 | In those cases, the file `.pre-commit-config.yaml` has you covered. Each hook can receive an additional property,
82 | namely `exclude` where you can specify single files or patterns to be excluded when running that hook.
83 |
84 | For example, if you want to exclude a file named `ugly_but_working_code.py` from an annoying hook `annoying_hook` (most likely `pydocstyle`):
85 | ```yaml
86 | - repo: https://github.com/slow_coding/annoying_hook.git
87 | hooks:
88 | - id: annoying_hook
89 | exclude: ugly_but_working_code.py
90 | ```
91 |
92 | ## Future Features
93 |
94 | - [ ] Optuna support
95 | - [ ] Support different loggers other than WandB
96 |
--------------------------------------------------------------------------------
/docs/getting-started/index.md:
--------------------------------------------------------------------------------
1 | # Principles behind nn-template
2 |
3 | When developing neural models ourselves, we often struggled with:
4 |
5 | - **Reproducibility**. We strongly believe in the reproducibility requirement of scientific work.
6 | - **Framework Learning**. Even when you find (or code yourself) the best framework to fit your needs, you still end up
7 | in messy situations when collaborating since others have to learn to use it;
8 | - **Avoiding boilerplate**. We were bored to write the same code over and over in
9 | every project to handle the typical ML pipeline.
10 |
11 | Over the course of the years, we fine-tuned our toolbox to reach this local minimum with respect to our self-imposed
12 | requirements. After many epochs of training, the result is **nn-template**.
13 |
14 |
15 | !!! warning "nn-template is not a framework"
16 |
17 | - It does not aim to sidestep the need to write code.
18 | - It does not constrain your workflow more than PyTorch Lightning does.
19 |
20 |
--------------------------------------------------------------------------------
/docs/index.md:
--------------------------------------------------------------------------------
1 | ---
2 | hide:
3 | - navigation
4 | - toc
5 | ---
6 |
7 | # NN Template
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 | "We demand rigidly defined areas of doubt and uncertainty."
19 |
20 |
21 |
22 | ---
23 |
24 | ```bash
25 | cookiecutter https://github.com/grok-ai/nn-template
26 | ```
27 |
28 | ---
29 |
30 | [](https://asciinema.org/a/475623)
31 |
32 | ---
33 |
34 | Generic cookiecutter template to bootstrap [PyTorch](https://pytorch.org/get-started/locally/) projects
35 | and to avoid writing boilerplate code to integrate:
36 |
37 | - [PyTorch Lightning](https://github.com/PyTorchLightning/pytorch-lightning), lightweight PyTorch wrapper for high-performance AI research.
38 | - [Hydra](https://github.com/facebookresearch/hydra), a framework for elegantly configuring complex applications.
39 | - [Hugging Face Datasets](https://huggingface.co/docs/datasets/index),a library for easily accessing and sharing datasets.
40 | - [Weights and Biases](https://wandb.ai/home), organize and analyze machine learning experiments. *(educational account available)*
41 | - [Streamlit](https://streamlit.io/), turns data scripts into shareable web apps in minutes.
42 | - [MkDocs](https://www.mkdocs.org/) and [Material for MkDocs](https://squidfunk.github.io/mkdocs-material/), a fast, simple and downright gorgeous static site generator.
43 | - [DVC](https://dvc.org/doc/start/data-versioning), track large files, directories, or ML models. Think "Git for data".
44 | - [GitHub Actions](https://github.com/features/actions), to run the tests, publish the documentation and to PyPI automatically.
45 | - Python best practices for developing and publishing research projects.
46 |
47 |
48 | !!! help "cookiecutter"
49 |
50 | This is a *parametrized* template that uses [cookiecutter](https://github.com/cookiecutter/cookiecutter).
51 | Install cookiecutter with:
52 |
53 | ```pip install cookiecutter```
54 |
--------------------------------------------------------------------------------
/docs/integrations/dvc.md:
--------------------------------------------------------------------------------
1 |
2 | # Data Version Control
3 |
4 | DVC runs alongside `git` and uses the current commit hash to version control the data.
5 |
6 | Initialize the `dvc` repository:
7 |
8 | ```bash
9 | $ dvc init
10 | ```
11 |
12 | To start tracking a file or directory, use `dvc add`:
13 |
14 | ```bash
15 | $ dvc add data/ImageNet
16 | ```
17 |
18 | DVC stores information about the added file (or a directory) in a special `.dvc` file named `data/ImageNet.dvc`, a small text file with a human-readable format.
19 | This file can be easily versioned like source code with Git, as a placeholder for the original data (which gets listed in `.gitignore`):
20 |
21 | ```bash
22 | git add data/ImageNet.dvc data/.gitignore
23 | git commit -m "Add raw data"
24 | ```
25 |
26 | ## Making changes
27 |
28 | When you make a change to a file or directory, run `dvc add` again to track the latest version:
29 |
30 | ```bash
31 | $ dvc add data/ImageNet
32 | ```
33 |
34 | ## Switching between versions
35 |
36 | The regular workflow is to use `git checkout` first to switch a branch, checkout a commit, or a revision of a `.dvc` file, and then run `dvc checkout` to sync data:
37 |
38 | ```bash
39 | $ git checkout <...>
40 | $ dvc checkout
41 | ```
42 |
43 | !!! info
44 |
45 | Read more in the DVC [docs](https://dvc.org/doc/start/data-versioning)!
46 |
--------------------------------------------------------------------------------
/docs/integrations/githubactions.md:
--------------------------------------------------------------------------------
1 | # GitHub Actions
2 |
3 | Automate, customize, and execute your software development workflows right in your repository with GitHub Actions.
4 |
5 | !!! info
6 |
7 | The template offers workflows to automatically run tests and pre-commits on pull requests, publish on PyPi and the docs on GitHub Pages on release.
8 |
--------------------------------------------------------------------------------
/docs/integrations/hydra.md:
--------------------------------------------------------------------------------
1 |
2 | # Hydra
3 |
4 | Hydra is an open-source Python framework that simplifies the development of research and other complex applications. The key feature is the ability to dynamically create a hierarchical configuration by composition and override it through config files and the command line. The name Hydra comes from its ability to run multiple similar jobs - much like a Hydra with multiple heads.
5 |
6 | The basic functionalities are intuitive: it is enough to change the configuration files in `conf/*` accordingly to your preferences. Everything will be logged in `wandb` automatically.
7 |
8 | Consider creating new root configurations `conf/myawesomeexp.yaml` instead of always using the default `conf/default.yaml`.
9 |
10 |
11 | ## Multi-run
12 |
13 | You can easily perform hyperparameters [sweeps](https://hydra.cc/docs/advanced/override_grammar/extended), which override the configuration defined in `/conf/*`.
14 |
15 | The easiest one is the grid-search. It executes the code with every possible combinations of the specified hyperparameters:
16 |
17 | ```bash
18 | python src/run.py -m optim.optimizer.lr=0.02,0.002,0.0002 optim.lr_scheduler.T_mult=1,2 optim.optimizer.weight_decay=0,1e-5
19 | ```
20 |
21 | You can explore aggregate statistics or compare and analyze each run in the W&B dashboard.
22 |
23 |
24 | !!! info
25 |
26 | We recommend to go through at least the [Basic Tutorial](https://hydra.cc/docs/tutorials/basic/your_first_app/simple_cli), keep in mind that Hydra builds on top of [OmegaConf](https://omegaconf.readthedocs.io/en/latest/index.html).
27 |
--------------------------------------------------------------------------------
/docs/integrations/lightning.md:
--------------------------------------------------------------------------------
1 |
2 | # PyTorch Lightning
3 |
4 | Lightning makes coding complex networks simple.
5 | It is not a high level framework like `keras`, but forces a neat code organization and encapsulation.
6 |
7 | You should be somewhat familiar with PyTorch and [PyTorch Lightning](https://pytorch-lightning.readthedocs.io/en/stable/index.html) before using this template.
8 |
9 | 
10 |
--------------------------------------------------------------------------------
/docs/integrations/mkdocs.md:
--------------------------------------------------------------------------------
1 | # MkDocs
2 |
3 | MkDocs is a fast, simple and downright gorgeous static site generator that's geared towards building project documentation.
4 |
5 | Documentation source files are written in Markdown, and configured with a single YAML configuration file.
6 |
7 | ## Material for MkDocs
8 |
9 | [Material for MkDocs](https://squidfunk.github.io/mkdocs-material/) is a theme for MkDocs, a static site generator geared towards (technical) project documentation.
10 |
11 | !!! hint
12 |
13 | The template comes with Material for MkDocs already configured,
14 | to create your documentation you only need to write markdown files and define the `nav`.
15 |
16 | See the [Documentation](https://grok-ai.github.io/nn-template/latest/features/docs/) page to get started!
17 |
--------------------------------------------------------------------------------
/docs/integrations/streamlit.md:
--------------------------------------------------------------------------------
1 |
2 | # Streamlit
3 | [Streamlit](https://docs.streamlit.io/) is an open-source Python library that makes
4 | it easy to create and share beautiful, custom web apps for machine learning and data science.
5 |
6 | In just a few minutes, you can build and deploy powerful data apps to:
7 |
8 | - **Explore** your data
9 | - **Interact** with your model
10 | - **Analyze** your model behavior and input sensitivity
11 | - **Showcase** your prototype with [awesome web apps](https://streamlit.io/gallery)
12 |
13 | Moreover, Streamlit enables interactive development with automatic rerun on files changes.
14 |
15 | 
16 |
17 |
18 | !!! info
19 |
20 | Launch a minimal app with `PYTHONPATH=. streamlit run src/ui/run.py`. There is a built-in function to restore a model checkpoint stored on W&B, with automatic download if the checkpoint is not present in the local machine:
21 |
--------------------------------------------------------------------------------
/docs/integrations/wandb.md:
--------------------------------------------------------------------------------
1 |
2 | # Weights and Biases
3 |
4 | Weights & Biases helps you keep track of your machine learning projects. Use tools to log hyperparameters and output metrics from your runs, then visualize and compare results and quickly share findings with your colleagues.
5 |
6 | [This](https://wandb.ai/gladia/nn-template?workspace=user-lucmos) is an example of a simple dashboard.
7 |
8 | ## Quickstart
9 |
10 | Login to your `wandb` account, running once `wandb login`.
11 | Configure the logging in `conf/logging/*`.
12 |
13 | !!! info
14 |
15 | Read more in the [docs](https://docs.wandb.ai/). Particularly useful the [`log` method](https://docs.wandb.ai/library/log), accessible from inside a PyTorch Lightning module with `self.logger.experiment.log`.
16 |
--------------------------------------------------------------------------------
/docs/overrides/main.html:
--------------------------------------------------------------------------------
1 | {% extends "base.html" %}
2 |
3 | {% block outdated %}
4 | You're not viewing the latest version.
5 |
6 | Click here to go to latest.
7 |
8 | {% endblock %}
9 |
--------------------------------------------------------------------------------
/docs/papers.md:
--------------------------------------------------------------------------------
1 | ---
2 | hide:
3 | - navigation
4 | - toc
5 | ---
6 |
7 | # Scientific Papers based on nn-template
8 |
9 | The following papers acknowledge the adoption of NN Template:
10 |
11 | !!! abstract "arXiv 2022"
12 |
13 | **Metric Based Few-Shot Graph Classification**
14 |
15 | *Donato Crisostomi, Simone Antonelli, Valentino Maiorca, Luca Moschella, Riccardo Marin, Emanuele Rodolà*
16 |
17 | [](https://github.com/crisostomi/metric-few-shot-graph)
18 |
19 | !!! abstract "Computer Graphics Forum: CGF 2022"
20 |
21 | **Learning Spectral Unions of Partial Deformable 3D Shapes**
22 |
23 | *Luca Moschella, Simone Melzi, Luca Cosmo, Filippo Maggioli, Or Litany, Maks Ovsjanikov, Leonidas Guibas, Emanuele Rodolà*
24 |
25 | [](https://github.com/lucmos/spectral-unions)
26 |
27 | !!! abstract "Findings of the Association for Computational Linguistics: EMNLP 2021"
28 |
29 | **WikiNEuRal: Combined Neural and Knowledge-based Silver Data Creation for Multilingual NER**
30 |
31 | *Simone Tedeschi, Valentino Maiorca, Niccolò Campolungo, Francesco Cecconi, and Roberto Navigli*
32 |
33 | [](https://github.com/Babelscape/wikineural)
34 |
35 | !!! abstract "Findings of the Association for Computational Linguistics: EMNLP 2021"
36 |
37 | **Named Entity Recognition for Entity Linking: What Works and What's Next.**
38 |
39 | *Simone Tedeschi, Simone Conia, Francesco Cecconi, and Roberto Navigli*
40 |
41 | [](https://github.com/Babelscape/ner4el)
42 |
43 | ---
44 |
45 | !!! tip "Please let us know if your paper also does and we'll add it to the list!"
--------------------------------------------------------------------------------
/docs/project-structure/conf.md:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/grok-ai/nn-template/8ba02bba8f015e1eb7efb0d2ab8c9d433bd1c431/docs/project-structure/conf.md
--------------------------------------------------------------------------------
/docs/project-structure/index.md:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/grok-ai/nn-template/8ba02bba8f015e1eb7efb0d2ab8c9d433bd1c431/docs/project-structure/index.md
--------------------------------------------------------------------------------
/docs/project-structure/structure.md:
--------------------------------------------------------------------------------
1 |
2 | # Structure
3 |
4 | ```bash
5 | .
6 | ├── conf
7 | │ ├── default.yaml
8 | │ ├── hydra
9 | │ │ └── default.yaml
10 | │ ├── nn
11 | │ │ └── default.yaml
12 | │ └── train
13 | │ └── default.yaml
14 | ├── data
15 | │ └── .gitignore
16 | ├── docs
17 | │ ├── index.md
18 | │ └── overrides
19 | │ └── main.html
20 | ├── .editorconfig
21 | ├── .env
22 | ├── .env.template
23 | ├── env.yaml
24 | ├── .flake8
25 | ├── .github
26 | │ └── workflows
27 | │ ├── publish.yml
28 | │ └── test_suite.yml
29 | ├── .gitignore
30 | ├── LICENSE
31 | ├── mkdocs.yml
32 | ├── .pre-commit-config.yaml
33 | ├── pyproject.toml
34 | ├── README.md
35 | ├── setup.cfg
36 | ├── setup.py
37 | ├── src
38 | │ └── awesome_project
39 | │ ├── data
40 | │ │ ├── datamodule.py
41 | │ │ ├── dataset.py
42 | │ │ └── __init__.py
43 | │ ├── __init__.py
44 | │ ├── modules
45 | │ │ ├── __init__.py
46 | │ │ └── module.py
47 | │ ├── pl_modules
48 | │ │ ├── __init__.py
49 | │ │ └── pl_module.py
50 | │ ├── run.py
51 | │ └── ui
52 | │ ├── __init__.py
53 | │ └── run.py
54 | └── tests
55 | ├── conftest.py
56 | ├── __init__.py
57 | ├── test_checkpoint.py
58 | ├── test_configuration.py
59 | ├── test_nn_core_integration.py
60 | ├── test_resume.py
61 | ├── test_storage.py
62 | └── test_training.py
63 | ```
--------------------------------------------------------------------------------
/hooks/post_gen_project.py:
--------------------------------------------------------------------------------
1 | import shutil
2 | import subprocess
3 | import sys
4 | import textwrap
5 | from dataclasses import dataclass, field
6 | from distutils.util import strtobool
7 | from typing import Dict, List, Optional
8 |
9 |
10 | def initialize_env_variables(
11 | env_file: str = ".env", env_file_template: str = ".env.template"
12 | ) -> None:
13 | """Initialize the .env file"""
14 | shutil.copy(src=env_file_template, dst=env_file)
15 |
16 |
17 | def bool_query(question: str, default: Optional[bool] = None) -> bool:
18 | """Ask a yes/no question via input() and return their boolean answer.
19 |
20 | Args:
21 | question: is a string that is presented to the user.
22 | default: is the presumed answer if the user just hits .
23 |
24 | Returns:
25 | the boolean representation of the user answer, or the default value if present.
26 | """
27 | if default is None:
28 | prompt = " [y/n] "
29 | elif default:
30 | prompt = " [Y/n] "
31 | else:
32 | prompt = " [y/N] "
33 |
34 | while True:
35 | sys.stdout.write(question + prompt)
36 | choice = input().lower()
37 |
38 | if default is not None and not choice:
39 | return default
40 |
41 | try:
42 | return strtobool(choice)
43 | except ValueError:
44 | sys.stdout.write("Please respond with 'yes' or 'no' " "(or 'y' or 'n').\n")
45 |
46 |
47 | @dataclass
48 | class Dependency:
49 | expected: bool
50 | id: str
51 |
52 |
53 | @dataclass
54 | class Query:
55 | id: str
56 | interactive: bool
57 | default: bool
58 | prompt: str
59 | command: str
60 | autorun: bool
61 | dependencies: List[Dependency] = field(default_factory=list)
62 |
63 |
64 | SETUP_COMMANDS: List[Query] = [
65 | Query(
66 | id="git_init",
67 | interactive=True,
68 | default=True,
69 | prompt="Initializing git repository...",
70 | command="git init\n"
71 | "git add --all\n"
72 | 'git commit -m "Initialize project from nn-template={{ cookiecutter.__version }}"',
73 | autorun=True,
74 | ),
75 | Query(
76 | id="git_remote",
77 | interactive=True,
78 | default=True,
79 | prompt="Adding an existing git remote...\n(You should create the remote from the web UI before proceeding!)",
80 | command="git remote add origin git@github.com:{{ cookiecutter.github_user }}/{{ cookiecutter.repository_name }}.git",
81 | autorun=True,
82 | dependencies=[
83 | Dependency(id="git_init", expected=True),
84 | ],
85 | ),
86 | Query(
87 | id="git_push_main",
88 | interactive=True,
89 | default=True,
90 | prompt="Pushing default branch to existing remote...",
91 | command="git push -u origin HEAD",
92 | autorun=True,
93 | dependencies=[
94 | Dependency(id="git_remote", expected=True),
95 | ],
96 | ),
97 | Query(
98 | id="conda_env",
99 | interactive=True,
100 | default=True,
101 | prompt="Creating conda environment...",
102 | command="conda env create -f env.yaml",
103 | autorun=True,
104 | ),
105 | Query(
106 | id="precommit_install",
107 | interactive=True,
108 | default=True,
109 | prompt="Installing pre-commits...",
110 | command="conda run -n {{ cookiecutter.conda_env_name }} pre-commit install",
111 | autorun=True,
112 | dependencies=[
113 | Dependency(id="git_init", expected=True),
114 | Dependency(id="conda_env", expected=True),
115 | ],
116 | ),
117 | Query(
118 | id="mike_init",
119 | interactive=True,
120 | default=True,
121 | prompt="Initializing gh-pages branch for GitHub Pages...",
122 | command="conda run -n {{ cookiecutter.conda_env_name }} mike deploy --update-aliases 0.0 latest\n"
123 | "conda run -n {{ cookiecutter.conda_env_name }} mike set-default latest",
124 | autorun=True,
125 | dependencies=[
126 | Dependency(id="conda_env", expected=True),
127 | Dependency(id="git_init", expected=True),
128 | ],
129 | ),
130 | Query(
131 | id="mike_push",
132 | interactive=True,
133 | default=True,
134 | prompt="Pushing 'gh-pages' branch to existing remote...",
135 | command="git push origin gh-pages",
136 | autorun=True,
137 | dependencies=[
138 | Dependency(id="mike_init", expected=True),
139 | Dependency(id="git_remote", expected=True),
140 | ],
141 | ),
142 | Query(
143 | id="conda_activate",
144 | interactive=False,
145 | default=True,
146 | prompt="Activate your conda environment with:",
147 | command="cd {{ cookiecutter.repository_name }}\n"
148 | "conda activate {{ cookiecutter.conda_env_name }}\n"
149 | "pytest -v",
150 | autorun=False,
151 | dependencies=[
152 | Dependency(id="conda_env", expected=True),
153 | ],
154 | ),
155 | ]
156 |
157 |
158 | def should_execute_query(query: Query, answers: Dict[str, bool]) -> bool:
159 | if not query.dependencies:
160 | return True
161 | return all(
162 | dependency.expected == answers.get(dependency.id, False)
163 | for dependency in query.dependencies
164 | )
165 |
166 |
167 | def setup(setup_commands) -> None:
168 | answers: Dict[str, bool] = {}
169 |
170 | for query in setup_commands:
171 | assert query.id not in answers
172 |
173 | if should_execute_query(query=query, answers=answers):
174 | if query.interactive:
175 | answers[query.id] = bool_query(
176 | question=f"\n"
177 | f"{query.prompt}\n"
178 | f"\n"
179 | f'{textwrap.indent(query.command, prefix=" ")}\n'
180 | f"\n"
181 | f"Execute?",
182 | default=query.default,
183 | )
184 | else:
185 | print(
186 | f"\n"
187 | f"{query.prompt}\n"
188 | f"\n"
189 | f'{textwrap.indent(query.command, prefix=" ")}\n'
190 | )
191 | answers[query.id] = True
192 |
193 | if answers[query.id] and (query.interactive or query.autorun):
194 | try:
195 | subprocess.run(
196 | query.command,
197 | check=True,
198 | text=True,
199 | shell=True,
200 | )
201 | except subprocess.CalledProcessError:
202 | answers[query.id] = False
203 | print()
204 |
205 |
206 | initialize_env_variables()
207 | setup(setup_commands=SETUP_COMMANDS)
208 |
209 | print(
210 | "\nYou are all set!\n\n"
211 | "Remember that if you use PyCharm, you must:\n"
212 | ' - Mark the "src" directory as "Sources Root".\n'
213 | ' - Enable "Emulate terminal in output console" in the run configuration.\n'
214 | )
215 | print(
216 | "Remember to:\n"
217 | " - Ensure the GitHub Actions in the repository are enabled.\n"
218 | " - Ensure the Github Pages are enabled to auto-publish the docs on each release."
219 | )
220 |
221 | print("Have fun! :]")
222 |
--------------------------------------------------------------------------------
/hooks/pre_gen_project.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/grok-ai/nn-template/8ba02bba8f015e1eb7efb0d2ab8c9d433bd1c431/hooks/pre_gen_project.py
--------------------------------------------------------------------------------
/mkdocs.yml:
--------------------------------------------------------------------------------
1 | site_name: NN Template
2 | site_description: Generic template to bootstrap your PyTorch project with PyTorch Lightning, Hydra, W&B, DVC, and Streamlit.
3 | repo_url: https://github.com/grok-ai/nn-template
4 | copyright: Copyright © 2021 - 2022 Valentino Maiorca | Luca Moschella
5 | nav:
6 | - Home: index.md
7 | - Getting started:
8 | - getting-started/index.md
9 | - Generating your project: getting-started/generation.md
10 | - Features:
11 | - Core: features/nncore.md
12 | - Restore: features/restore.md
13 | - Metadata: features/metadata.md
14 | - Tags: features/tags.md
15 | - Docs: features/docs.md
16 | - Tests: features/tests.md
17 | - Storage: features/storage.md
18 | - Determinism: features/determinism.md
19 | - Fast dev run: features/fastdevrun.md
20 | - Environment variables: features/envvars.md
21 | - CI/CD: features/cicd.md
22 | - Best practices:
23 | - Python Environment: features/conda.md
24 | - Tooling: features/bestpractices.md
25 | # - Project Structure:
26 | # - project-structure/index.md
27 | # - Structure: project-structure/structure.md
28 | # - Conf: project-structure/conf.md
29 | # - NN: project-structure/conf.md
30 | # - Train: project-structure/conf.md
31 | - Integrations:
32 | - PyTorch Lightning: integrations/lightning.md
33 | - Hydra: integrations/hydra.md
34 | - Weigth & Biases: integrations/wandb.md
35 | - Streamlit: integrations/streamlit.md
36 | - MkDocs: integrations/mkdocs.md
37 | - DVC: integrations/dvc.md
38 | - GitHub Actions: integrations/githubactions.md
39 | - Publications: papers.md
40 | - Changelog:
41 | - changelog/index.md
42 | - Upgrade: changelog/upgrade.md
43 |
44 | theme:
45 | name: material
46 | custom_dir: docs/overrides
47 | icon:
48 | repo: fontawesome/brands/github
49 |
50 | features:
51 | - content.code.annotate
52 | - navigation.indexes
53 | - navigation.instant
54 | - navigation.sections
55 | - navigation.tabs
56 | - navigation.tabs.sticky
57 | - navigation.top
58 | - navigation.tracking
59 | - search.highlight
60 | - search.share
61 | - search.suggest
62 |
63 | palette:
64 | - scheme: default
65 | primary: light green
66 | accent: green
67 | toggle:
68 | icon: material/weather-night
69 | name: Switch to dark mode
70 | - scheme: slate
71 | primary: green
72 | accent: green
73 | toggle:
74 | icon: material/weather-sunny
75 | name: Switch to light mode
76 |
77 | # Extensions
78 | markdown_extensions:
79 | - abbr
80 | - admonition
81 | - attr_list
82 | - def_list
83 | - footnotes
84 | - meta
85 | - md_in_html
86 | - toc:
87 | permalink: true
88 | - tables
89 | - pymdownx.arithmatex:
90 | generic: true
91 | - pymdownx.betterem:
92 | smart_enable: all
93 | - pymdownx.caret
94 | - pymdownx.mark
95 | - pymdownx.tilde
96 | - pymdownx.critic
97 | - pymdownx.details
98 | - pymdownx.highlight:
99 | anchor_linenums: true
100 | - pymdownx.superfences:
101 | custom_fences:
102 | - name: mermaid
103 | class: mermaid
104 | format: !!python/name:pymdownx.superfences.fence_code_format
105 | - pymdownx.inlinehilite
106 | - pymdownx.keys
107 | - pymdownx.smartsymbols
108 | - pymdownx.snippets
109 | - pymdownx.tabbed:
110 | alternate_style: true
111 | - pymdownx.tasklist:
112 | custom_checkbox: true
113 | - pymdownx.emoji:
114 | emoji_index: !!python/name:materialx.emoji.twemoji
115 | emoji_generator: !!python/name:materialx.emoji.to_svg
116 |
117 | extra_javascript:
118 | - javascripts/mathjax.js
119 | - https://polyfill.io/v3/polyfill.min.js?features=es6
120 | - https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js
121 |
122 | plugins:
123 | - search
124 |
125 | extra:
126 | generator: true
127 | version:
128 | provider: mike
129 |
--------------------------------------------------------------------------------
/{{ cookiecutter.repository_name }}/.editorconfig:
--------------------------------------------------------------------------------
1 | # EditorConfig is awesome: https://EditorConfig.org
2 |
3 | # top-most EditorConfig file
4 | root = true
5 |
6 | [*]
7 | end_of_line = lf
8 | insert_final_newline = true
9 | max_line_length = 120
10 | trim_trailing_whitespace = true
11 |
12 | # 4 space indentation
13 | [*.py]
14 | indent_style = space
15 | indent_size = 4
16 | charset = utf-8
17 |
18 | [*.{yaml, yml}]
19 | indent_size = 2
20 |
--------------------------------------------------------------------------------
/{{ cookiecutter.repository_name }}/.env.template:
--------------------------------------------------------------------------------
1 | # .env.template is a template for .env file that can be versioned.
2 |
3 | # Set to 1 to show full stack trace on error, 0 to hide it
4 | HYDRA_FULL_ERROR=1
5 |
6 | # Configure where huggingface_hub will locally store data.
7 | HF_HOME="~/.cache/huggingface"
8 |
9 | # Configure the User Access Token to authenticate to the Hub
10 | # HUGGING_FACE_HUB_TOKEN=
11 |
--------------------------------------------------------------------------------
/{{ cookiecutter.repository_name }}/.flake8:
--------------------------------------------------------------------------------
1 | [flake8]
2 | # https://github.com/pytorch/pytorch/blob/master/.flake8
3 | max-line-length = 120
4 | select = B,C,E,F,P,T4,W,B9
5 | extend-ignore = E203, E501
6 |
--------------------------------------------------------------------------------
/{{ cookiecutter.repository_name }}/.github/workflows/publish.yml:
--------------------------------------------------------------------------------
1 | name: Publish
2 |
3 | on:
4 | release:
5 | types:
6 | - created
7 |
8 | env:
9 | CACHE_NUMBER: 0 # increase to reset cache manually
10 | CONDA_ENV_FILE: './env.yaml'
11 | CONDA_ENV_NAME: '{{ cookiecutter.conda_env_name }}'
12 | {% raw %}
13 | HUGGING_FACE_HUB_TOKEN: ${{secrets.HUGGING_FACE_HUB_TOKEN}}
14 |
15 | jobs:
16 | build:
17 | strategy:
18 | fail-fast: false
19 | matrix:
20 | python-version: ['3.9']
21 | include:
22 | - os: ubuntu-20.04
23 | label: linux-64
24 | prefix: /usr/share/miniconda3/envs/
25 |
26 | name: ${{ matrix.label }}-py${{ matrix.python-version }}
27 | runs-on: ${{ matrix.os }}
28 |
29 | steps:
30 | - name: Parametrize conda env name
31 | run: echo "PY_CONDA_ENV_NAME=${{ env.CONDA_ENV_NAME }}-${{ matrix.python-version }}" >> $GITHUB_ENV
32 | - name: echo conda env name
33 | run: echo ${{ env.PY_CONDA_ENV_NAME }}
34 |
35 | - name: Parametrize conda prefix
36 | run: echo "PY_PREFIX=${{ matrix.prefix }}${{ env.PY_CONDA_ENV_NAME }}" >> $GITHUB_ENV
37 | - name: echo conda prefix
38 | run: echo ${{ env.PY_PREFIX }}
39 |
40 | # extract the first two digits from the release note
41 | - name: Set release notes tag
42 | run: |
43 | export RELEASE_TAG_VERSION=${{ github.event.release.tag_name }}
44 | echo "RELEASE_TAG_VERSION=${RELEASE_TAG_VERSION%.*}">> $GITHUB_ENV
45 |
46 | - name: Echo release notes tag
47 | run: |
48 | echo "${RELEASE_TAG_VERSION}"
49 |
50 | - uses: actions/checkout@v2
51 | with:
52 | fetch-depth: 0
53 |
54 | # Remove the python version pin from the env.yml which could be inconsistent
55 | - name: Remove explicit python version from the environment
56 | shell: bash -l {0}
57 | run: |
58 | sed -Ei '/^\s*-?\s*python\s*([#=].*)?$/d' ${{ env.CONDA_ENV_FILE }}
59 | cat ${{ env.CONDA_ENV_FILE }}
60 |
61 | - name: Setup Mambaforge
62 | uses: conda-incubator/setup-miniconda@v2
63 | with:
64 | miniforge-variant: Mambaforge
65 | miniforge-version: latest
66 | activate-environment: ${{ env.PY_CONDA_ENV_NAME }}
67 | python-version: ${{ matrix.python-version }}
68 | use-mamba: true
69 |
70 | - uses: actions/cache@v2
71 | name: Conda cache
72 | with:
73 | path: ${{ env.PY_PREFIX }}
74 | key: ${{ matrix.label }}-conda-${{ matrix.python-version }}-${{ env.CACHE_NUMBER }}-${{ env.PY_CONDA_ENV_NAME }}-${{ hashFiles(env.CONDA_ENV_FILE) }}-${{hashFiles('./setup.cfg') }}
75 | id: conda_cache
76 |
77 | - uses: actions/cache@v2
78 | name: Pip cache
79 | with:
80 | path: ~/.cache/pip
81 | key: ${{ matrix.label }}-pip-${{ matrix.python-version }}-${{ env.CACHE_NUMBER }}-${{ env.PY_CONDA_ENV_NAME }}-${{ hashFiles(env.CONDA_ENV_FILE) }}-${{hashFiles('./setup.cfg') }}
82 |
83 | - uses: actions/cache@v2
84 | name: Pre-commit cache
85 | with:
86 | path: ~/.cache/pre-commit
87 | key: ${{ matrix.label }}-pre-commit-${{ hashFiles('.pre-commit-config.yaml') }}-${{ matrix.python-version }}-${{ env.CACHE_NUMBER }}-${{ env.PY_CONDA_ENV_NAME }}-${{ hashFiles(env.CONDA_ENV_FILE) }}-${{hashFiles('./setup.cfg') }}
88 |
89 | # Ensure the hack for the python version worked
90 | - name: Ensure we have the right Python
91 | shell: bash -l {0}
92 | run: |
93 | echo "Installed Python: $(python --version)"
94 | echo "Expected: ${{ matrix.python-version }}"
95 | python --version | grep "Python ${{ matrix.python-version }}"
96 |
97 | # https://stackoverflow.com/questions/70520120/attributeerror-module-setuptools-distutils-has-no-attribute-version
98 | # https://github.com/pytorch/pytorch/pull/69904
99 | - name: Downgrade setuptools due to a but in PyTorch 1.10.1
100 | shell: bash -l {0}
101 | run: |
102 | pip install setuptools==59.5.0 --upgrade
103 |
104 | - name: Update conda environment
105 | run: mamba env update -n ${{ env.PY_CONDA_ENV_NAME }} -f ${{ env.CONDA_ENV_FILE }}
106 | if: steps.conda_cache.outputs.cache-hit != 'true'
107 |
108 | # Update pip env whether or not there was a conda cache hit
109 | - name: Update pip environment
110 | shell: bash -l {0}
111 | run: pip install -e ".[dev]"
112 | if: steps.conda_cache.outputs.cache-hit == 'true'
113 |
114 | - run: pip3 list
115 | shell: bash -l {0}
116 | - run: mamba info
117 | - run: mamba list
118 |
119 | # Ensure the hack for the python version worked
120 | - name: Ensure we have the right Python
121 | shell: bash -l {0}
122 | run: |
123 | echo "Installed Python: $(python --version)"
124 | echo "Expected: ${{ matrix.python-version }}"
125 | python --version | grep "Python ${{ matrix.python-version }}"
126 |
127 | - name: Run pre-commits
128 | shell: bash -l {0}
129 | run: |
130 | pre-commit install
131 | pre-commit run -v --all-files --show-diff-on-failure
132 |
133 | - name: Test with pytest
134 | shell: bash -l {0}
135 | run: |
136 | pytest -v
137 |
138 | - name: Build docs website
139 | shell: bash -l {0}
140 | run: |
141 | git config user.name ci-bot
142 | git config user.email ci-bot@ci.com
143 | mike deploy --rebase --push --update-aliases ${RELEASE_TAG_VERSION} latest
144 |
145 | # Uncomment to publish on PyPI on release
146 | # - name: Build SDist and wheel
147 | # run: pipx run build
148 | #
149 | # - name: Check metadata
150 | # run: pipx run twine check dist/*
151 | #
152 | # - name: Publish distribution 📦 to PyPI
153 | # uses: pypa/gh-action-pypi-publish@release/v1
154 | # with:
155 | # user: __token__
156 | # password: ${{ secrets.PYPI_API_TOKEN }}
157 | #
158 | #{% endraw %}
159 |
--------------------------------------------------------------------------------
/{{ cookiecutter.repository_name }}/.github/workflows/test_suite.yml:
--------------------------------------------------------------------------------
1 | name: Test Suite
2 |
3 | on:
4 | push:
5 | branches:
6 | - main
7 | - develop
8 |
9 | pull_request:
10 | types:
11 | - opened
12 | - reopened
13 | - synchronize
14 |
15 | env:
16 | CACHE_NUMBER: 1 # increase to reset cache manually
17 | CONDA_ENV_FILE: './env.yaml'
18 | CONDA_ENV_NAME: '{{ cookiecutter.conda_env_name }}'
19 | {% raw %}
20 | HUGGING_FACE_HUB_TOKEN: ${{secrets.HUGGING_FACE_HUB_TOKEN}}
21 |
22 | jobs:
23 | build:
24 |
25 | strategy:
26 | fail-fast: false
27 | matrix:
28 | python-version: ['3.11']
29 | include:
30 | - os: ubuntu-20.04
31 | label: linux-64
32 | prefix: /usr/share/miniconda3/envs/
33 |
34 | # - os: macos-latest
35 | # label: osx-64
36 | # prefix: /Users/runner/miniconda3/envs/$CONDA_ENV_NAME
37 |
38 | # - os: windows-latest
39 | # label: win-64
40 | # prefix: C:\Miniconda3\envs\$CONDA_ENV_NAME
41 |
42 | name: ${{ matrix.label }}-py${{ matrix.python-version }}
43 | runs-on: ${{ matrix.os }}
44 |
45 | steps:
46 | - name: Parametrize conda env name
47 | run: echo "PY_CONDA_ENV_NAME=${{ env.CONDA_ENV_NAME }}-${{ matrix.python-version }}" >> $GITHUB_ENV
48 | - name: echo conda env name
49 | run: echo ${{ env.PY_CONDA_ENV_NAME }}
50 |
51 | - name: Parametrize conda prefix
52 | run: echo "PY_PREFIX=${{ matrix.prefix }}${{ env.PY_CONDA_ENV_NAME }}" >> $GITHUB_ENV
53 | - name: echo conda prefix
54 | run: echo ${{ env.PY_PREFIX }}
55 |
56 | - uses: actions/checkout@v2
57 |
58 | # Remove the python version pin from the env.yml which could be inconsistent
59 | - name: Remove explicit python version from the environment
60 | shell: bash -l {0}
61 | run: |
62 | sed -Ei '/^\s*-?\s*python\s*([#=].*)?$/d' ${{ env.CONDA_ENV_FILE }}
63 | cat ${{ env.CONDA_ENV_FILE }}
64 |
65 | # Install torch cpu-only
66 | - name: Install torch cpu only
67 | shell: bash -l {0}
68 | run: |
69 | sed -i '/nvidia\|cuda/d' ${{ env.CONDA_ENV_FILE }}
70 | cat ${{ env.CONDA_ENV_FILE }}
71 |
72 | - name: Setup Mambaforge
73 | uses: conda-incubator/setup-miniconda@v2
74 | with:
75 | miniforge-variant: Mambaforge
76 | miniforge-version: latest
77 | activate-environment: ${{ env.PY_CONDA_ENV_NAME }}
78 | python-version: ${{ matrix.python-version }}
79 | use-mamba: true
80 |
81 | - uses: actions/cache@v2
82 | name: Conda cache
83 | with:
84 | path: ${{ env.PY_PREFIX }}
85 | key: ${{ matrix.label }}-conda-${{ matrix.python-version }}-${{ env.CACHE_NUMBER }}-${{ env.PY_CONDA_ENV_NAME }}-${{ hashFiles(env.CONDA_ENV_FILE) }}-${{hashFiles('./setup.cfg') }}
86 | id: conda_cache
87 |
88 | - uses: actions/cache@v2
89 | name: Pip cache
90 | with:
91 | path: ~/.cache/pip
92 | key: ${{ matrix.label }}-pip-${{ matrix.python-version }}-${{ env.CACHE_NUMBER }}-${{ env.PY_CONDA_ENV_NAME }}-${{ hashFiles(env.CONDA_ENV_FILE) }}-${{hashFiles('./setup.cfg') }}
93 |
94 | - uses: actions/cache@v2
95 | name: Pre-commit cache
96 | with:
97 | path: ~/.cache/pre-commit
98 | key: ${{ matrix.label }}-pre-commit-${{ hashFiles('.pre-commit-config.yaml') }}-${{ matrix.python-version }}-${{ env.CACHE_NUMBER }}-${{ env.PY_CONDA_ENV_NAME }}-${{ hashFiles(env.CONDA_ENV_FILE) }}-${{hashFiles('./setup.cfg') }}
99 |
100 | # Ensure the hack for the python version worked
101 | - name: Ensure we have the right Python
102 | shell: bash -l {0}
103 | run: |
104 | echo "Installed Python: $(python --version)"
105 | echo "Expected: ${{ matrix.python-version }}"
106 | python --version | grep "Python ${{ matrix.python-version }}"
107 |
108 |
109 | # https://stackoverflow.com/questions/70520120/attributeerror-module-setuptools-distutils-has-no-attribute-version
110 | # https://github.com/pytorch/pytorch/pull/69904
111 | - name: Downgrade setuptools due to a but in PyTorch 1.10.1
112 | shell: bash -l {0}
113 | run: |
114 | pip install setuptools==59.5.0 --upgrade
115 |
116 | - name: Update conda environment
117 | run: mamba env update -n ${{ env.PY_CONDA_ENV_NAME }} -f ${{ env.CONDA_ENV_FILE }}
118 | if: steps.conda_cache.outputs.cache-hit != 'true'
119 |
120 | # Update pip env whether or not there was a conda cache hit
121 | - name: Update pip environment
122 | shell: bash -l {0}
123 | run: pip install -e ".[dev]"
124 | if: steps.conda_cache.outputs.cache-hit == 'true'
125 |
126 | - run: pip3 list
127 | shell: bash -l {0}
128 | - run: mamba info
129 | - run: mamba list
130 |
131 | # Ensure the hack for the python version worked
132 | - name: Ensure we have the right Python
133 | shell: bash -l {0}
134 | run: |
135 | echo "Installed Python: $(python --version)"
136 | echo "Expected: ${{ matrix.python-version }}"
137 | python --version | grep "Python ${{ matrix.python-version }}"
138 |
139 | - name: Run pre-commits
140 | shell: bash -l {0}
141 | run: |
142 | pre-commit install
143 | pre-commit run -v --all-files --show-diff-on-failure
144 |
145 | - name: Test with pytest
146 | shell: bash -l {0}
147 | run: |
148 | pytest -v
149 | #
150 | #{% endraw %}
151 |
--------------------------------------------------------------------------------
/{{ cookiecutter.repository_name }}/.gitignore:
--------------------------------------------------------------------------------
1 | wandb
2 | multirun.yaml
3 | storage
4 |
5 | # ignore the _version.py file
6 | _version.py
7 |
8 | # .gitignore defaults for python and pycharm
9 | .idea
10 |
11 | # Byte-compiled / optimized / DLL files
12 | __pycache__/
13 | *.py[cod]
14 | *$py.class
15 |
16 | # C extensions
17 | *.so
18 |
19 | # Distribution / packaging
20 | .Python
21 | build/
22 | develop-eggs/
23 | dist/
24 | downloads/
25 | eggs/
26 | .eggs/
27 | lib/
28 | lib64/
29 | parts/
30 | sdist/
31 | var/
32 | wheels/
33 | share/python-wheels/
34 | *.egg-info/
35 | .installed.cfg
36 | *.egg
37 | MANIFEST
38 |
39 | # PyInstaller
40 | # Usually these files are written by a python script from a template
41 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
42 | *.manifest
43 | *.spec
44 |
45 | # Installer logs
46 | pip-log.txt
47 | pip-delete-this-directory.txt
48 |
49 | # Unit test / coverage reports
50 | htmlcov/
51 | .tox/
52 | .nox/
53 | .coverage
54 | .coverage.*
55 | .cache
56 | nosetests.xml
57 | coverage.xml
58 | *.cover
59 | *.py,cover
60 | .hypothesis/
61 | .pytest_cache/
62 | cover/
63 |
64 | # Translations
65 | *.mo
66 | *.pot
67 |
68 | # Django stuff:
69 | *.log
70 | local_settings.py
71 | db.sqlite3
72 | db.sqlite3-journal
73 |
74 | # Flask stuff:
75 | instance/
76 | .webassets-cache
77 |
78 | # Scrapy stuff:
79 | .scrapy
80 |
81 | # Sphinx documentation
82 | docs/_build/
83 |
84 | # PyBuilder
85 | .pybuilder/
86 | target/
87 |
88 | # Jupyter Notebook
89 | .ipynb_checkpoints
90 |
91 | # IPython
92 | profile_default/
93 | ipython_config.py
94 |
95 | # pyenv
96 | # For a library or package, you might want to ignore these files since the code is
97 | # intended to run in multiple environments; otherwise, check them in:
98 | # .python-version
99 |
100 | # pipenv
101 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
102 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
103 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
104 | # install all needed dependencies.
105 | #Pipfile.lock
106 |
107 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
108 | __pypackages__/
109 |
110 | # Celery stuff
111 | celerybeat-schedule
112 | celerybeat.pid
113 |
114 | # SageMath parsed files
115 | *.sage.py
116 |
117 | # Environments
118 | .env
119 | .venv
120 | env/
121 | venv/
122 | ENV/
123 | env.bak/
124 | venv.bak/
125 |
126 | # Spyder project settings
127 | .spyderproject
128 | .spyproject
129 |
130 | # Rope project settings
131 | .ropeproject
132 |
133 | # mkdocs documentation
134 | /site
135 |
136 | # mypy
137 | .mypy_cache/
138 | .dmypy.json
139 | dmypy.json
140 |
141 | # Pyre type checker
142 | .pyre/
143 |
144 | # pytype static type analyzer
145 | .pytype/
146 |
147 | # Cython debug symbols
148 | cython_debug/
149 |
150 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider
151 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
152 |
153 | # User-specific stuff
154 | .idea/**/workspace.xml
155 | .idea/**/tasks.xml
156 | .idea/**/usage.statistics.xml
157 | .idea/**/dictionaries
158 | .idea/**/shelf
159 |
160 | # Generated files
161 | .idea/**/contentModel.xml
162 |
163 | # Sensitive or high-churn files
164 | .idea/**/dataSources/
165 | .idea/**/dataSources.ids
166 | .idea/**/dataSources.local.xml
167 | .idea/**/sqlDataSources.xml
168 | .idea/**/dynamic.xml
169 | .idea/**/uiDesigner.xml
170 | .idea/**/dbnavigator.xml
171 |
172 | # Gradle
173 | .idea/**/gradle.xml
174 | .idea/**/libraries
175 |
176 | # Gradle and Maven with auto-import
177 | # When using Gradle or Maven with auto-import, you should exclude module files,
178 | # since they will be recreated, and may cause churn. Uncomment if using
179 | # auto-import.
180 | # .idea/artifacts
181 | # .idea/compiler.xml
182 | # .idea/jarRepositories.xml
183 | # .idea/modules.xml
184 | # .idea/*.iml
185 | # .idea/modules
186 | # *.iml
187 | # *.ipr
188 |
189 | # CMake
190 | cmake-build-*/
191 |
192 | # Mongo Explorer plugin
193 | .idea/**/mongoSettings.xml
194 |
195 | # File-based project format
196 | *.iws
197 |
198 | # IntelliJ
199 | out/
200 |
201 | # mpeltonen/sbt-idea plugin
202 | .idea_modules/
203 |
204 | # JIRA plugin
205 | atlassian-ide-plugin.xml
206 |
207 | # Cursive Clojure plugin
208 | .idea/replstate.xml
209 |
210 | # Crashlytics plugin (for Android Studio and IntelliJ)
211 | com_crashlytics_export_strings.xml
212 | crashlytics.properties
213 | crashlytics-build.properties
214 | fabric.properties
215 |
216 | # Editor-based Rest Client
217 | .idea/httpRequests
218 |
219 | # Android studio 3.1+ serialized cache file
220 | .idea/caches/build_file_checksums.ser
221 |
--------------------------------------------------------------------------------
/{{ cookiecutter.repository_name }}/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | # See https://pre-commit.com for more information
2 | # See https://pre-commit.com/hooks.html for more hooks
3 | repos:
4 | - repo: https://github.com/pre-commit/pre-commit-hooks
5 | rev: v4.1.0
6 | hooks:
7 | - id: check-added-large-files # prevents giant files from being committed.
8 | args: ['--maxkb=4096']
9 | - id: check-ast # simply checks whether the files parse as valid python.
10 | # - id: check-byte-order-marker # forbids files which have a utf-8 byte-order marker.
11 | # - id: check-builtin-literals # requires literal syntax when initializing empty or zero python builtin types.
12 | # - id: check-case-conflict # checks for files that would conflict in case-insensitive filesystems.
13 | - id: check-docstring-first # checks a common error of defining a docstring after code.
14 | - id: check-executables-have-shebangs # ensures that (non-binary) executables have a shebang.
15 | # - id: check-json # checks json files for parseable syntax.
16 | # - id: check-shebang-scripts-are-executable # ensures that (non-binary) files with a shebang are executable.
17 | # - id: pretty-format-json # sets a standard for formatting json files.
18 | - id: check-merge-conflict # checks for files that contain merge conflict strings.
19 | - id: check-symlinks # checks for symlinks which do not point to anything.
20 | - id: check-toml # checks toml files for parseable syntax.
21 | # - id: check-vcs-permalinks # ensures that links to vcs websites are permalinks.
22 | # - id: check-xml # checks xml files for parseable syntax.
23 | - id: check-yaml # checks yaml files for parseable syntax.
24 | - id: debug-statements # checks for debugger imports and py37+ `breakpoint()` calls in python source.
25 | - id: destroyed-symlinks # detects symlinks which are changed to regular files with a content of a path which that symlink was pointing to.
26 | # - id: detect-aws-credentials # detects *your* aws credentials from the aws cli credentials file.
27 | - id: detect-private-key # detects the presence of private keys.
28 | # - id: double-quote-string-fixer # replaces double quoted strings with single quoted strings.
29 | - id: end-of-file-fixer # ensures that a file is either empty, or ends with one newline.
30 | # - id: file-contents-sorter # sorts the lines in specified files (defaults to alphabetical). you must provide list of target files as input in your .pre-commit-config.yaml file.
31 | # - id: fix-byte-order-marker # removes utf-8 byte order marker.
32 | # - id: fix-encoding-pragma # adds # -*- coding: utf-8 -*- to the top of python files.
33 | # - id: forbid-new-submodules # prevents addition of new git submodules.
34 | - id: mixed-line-ending # replaces or checks mixed line ending.
35 | args: ['--fix=no']
36 | # - id: name-tests-test # this verifies that test files are named correctly.
37 | # - id: no-commit-to-branch # don't commit to branch
38 | # - id: requirements-txt-fixer # sorts entries in requirements.txt.
39 | # - id: sort-simple-yaml # sorts simple yaml files which consist only of top-level keys, preserving comments and blocks.
40 | - id: trailing-whitespace # trims trailing whitespace.
41 |
42 | - repo: https://github.com/PyCQA/isort.git
43 | rev: '5.12.0'
44 | hooks:
45 | - id: isort
46 |
47 | - repo: https://github.com/psf/black.git
48 | rev: '23.7.0'
49 | hooks:
50 | - id: black
51 | - id: black-jupyter
52 |
53 | - repo: https://github.com/asottile/blacken-docs.git
54 | rev: '1.16.0'
55 | hooks:
56 | - id: blacken-docs
57 |
58 | - repo: https://github.com/PyCQA/flake8.git
59 | rev: '6.1.0'
60 | hooks:
61 | - id: flake8
62 | additional_dependencies:
63 | - flake8-docstrings==1.7.0
64 |
65 | - repo: https://github.com/pycqa/pydocstyle.git
66 | rev: '6.3.0'
67 | hooks:
68 | - id: pydocstyle
69 | additional_dependencies:
70 | - toml
71 |
72 | - repo: https://github.com/kynan/nbstripout.git
73 | rev: '0.6.1'
74 | hooks:
75 | - id: nbstripout
76 |
77 | - repo: https://github.com/PyCQA/bandit
78 | rev: '1.7.5'
79 | hooks:
80 | - id: bandit
81 | args: ['-c', 'pyproject.toml', '--recursive', 'src']
82 | additional_dependencies:
83 | - toml
84 |
--------------------------------------------------------------------------------
/{{ cookiecutter.repository_name }}/LICENSE:
--------------------------------------------------------------------------------
1 | ../LICENSE
--------------------------------------------------------------------------------
/{{ cookiecutter.repository_name }}/README.md:
--------------------------------------------------------------------------------
1 | # {{ cookiecutter.project_name }}
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 | {{ cookiecutter.project_description }}
12 |
13 |
14 | ## Installation
15 |
16 | ```bash
17 | pip install git+ssh://git@github.com/{{ cookiecutter.github_user }}/{{ cookiecutter.repository_name }}.git
18 | ```
19 |
20 |
21 | ## Quickstart
22 |
23 | [comment]: <> (> Fill me!)
24 |
25 |
26 | ## Development installation
27 |
28 | Setup the development environment:
29 |
30 | ```bash
31 | git clone git@github.com:{{ cookiecutter.github_user }}/{{ cookiecutter.repository_name }}.git
32 | cd {{ cookiecutter.repository_name }}
33 | conda env create -f env.yaml
34 | conda activate {{ cookiecutter.conda_env_name }}
35 | pre-commit install
36 | ```
37 |
38 | Run the tests:
39 |
40 | ```bash
41 | pre-commit run --all-files
42 | pytest -v
43 | ```
44 |
45 |
46 | ### Update the dependencies
47 |
48 | Re-install the project in edit mode:
49 |
50 | ```bash
51 | pip install -e .[dev]
52 | ```
53 |
--------------------------------------------------------------------------------
/{{ cookiecutter.repository_name }}/conf/default.yaml:
--------------------------------------------------------------------------------
1 | # metadata specialised for each experiment
2 | core:
3 | project_name: {{ cookiecutter.repository_name }}
4 | storage_dir: ${oc.env:PROJECT_ROOT}/storage
5 | version: 0.0.1
6 | tags: null
7 |
8 | conventions:
9 | x_key: 'x'
10 | y_key: 'y'
11 |
12 | defaults:
13 | - hydra: default
14 | - nn: default
15 | - train: default
16 | - _self_ # as last argument to allow the override of parameters via this main config
17 | # Decomment this parameter to get parallel job running
18 | # - hydra/launcher: joblib
19 |
--------------------------------------------------------------------------------
/{{ cookiecutter.repository_name }}/conf/hydra/default.yaml:
--------------------------------------------------------------------------------
1 | ## https://github.com/facebookresearch/hydra/issues/910
2 | # Not changing the working directory
3 | run:
4 | dir: .
5 | sweep:
6 | dir: .
7 | subdir: .
8 |
9 | # Not saving the .hydra directory
10 | output_subdir: null
11 |
12 | job:
13 | env_set:
14 | WANDB_START_METHOD: thread
15 | WANDB_DIR: ${oc.env:PROJECT_ROOT}
16 |
17 | defaults:
18 | # Disable hydra logging configuration, otherwise the basicConfig does not have any effect
19 | - override job_logging: none
20 | - override hydra_logging: none
21 |
--------------------------------------------------------------------------------
/{{ cookiecutter.repository_name }}/conf/nn/data/dataset/vision/mnist.yaml:
--------------------------------------------------------------------------------
1 | # This class defines which dataset to use,
2 | # and also how to split in train/[val]/test.
3 | _target_: {{ cookiecutter.package_name }}.utils.hf_io.load_hf_dataset
4 | name: "mnist"
5 | ref: "mnist"
6 | train_split: train
7 | # val_split: val
8 | val_percentage: 0.1
9 | test_split: test
10 | label_key: label
11 | data_key: image
12 | num_classes: 10
13 | input_shape: [1, 28, 28]
14 | standard_x_key: ${conventions.x_key}
15 | standard_y_key: ${conventions.y_key}
16 | transforms:
17 | _target_: {{ cookiecutter.package_name }}.utils.hf_io.HFTransform
18 | key: ${conventions.x_key}
19 | transform:
20 | _target_: torchvision.transforms.Compose
21 | transforms:
22 | - _target_: torchvision.transforms.ToTensor
23 |
--------------------------------------------------------------------------------
/{{ cookiecutter.repository_name }}/conf/nn/data/default.yaml:
--------------------------------------------------------------------------------
1 | _target_: {{ cookiecutter.package_name }}.data.datamodule.MyDataModule
2 |
3 | val_images_fixed_idxs: [7371, 3963, 2861, 1701, 3172,
4 | 1749, 7023, 1606, 6481, 1377,
5 | 6003, 3593, 3410, 3399, 7277,
6 | 5337, 968, 8206, 288, 1968,
7 | 5677, 9156, 8139, 7660, 7089,
8 | 1893, 3845, 2084, 1944, 3375,
9 | 4848, 8704, 6038, 2183, 7422,
10 | 2682, 6878, 6127, 2941, 5823,
11 | 9129, 1798, 6477, 9264, 476,
12 | 3007, 4992, 1428, 9901, 5388]
13 |
14 | accelerator: ${train.trainer.accelerator}
15 |
16 | num_workers:
17 | train: 4
18 | val: 2
19 | test: 0
20 |
21 | batch_size:
22 | train: 512
23 | val: 128
24 | test: 16
25 |
26 | defaults:
27 | - _self_
28 | - dataset: vision/mnist # pick one of the yamls in nn/data/
29 |
--------------------------------------------------------------------------------
/{{ cookiecutter.repository_name }}/conf/nn/default.yaml:
--------------------------------------------------------------------------------
1 | data: ???
2 |
3 | module:
4 | optimizer:
5 | _target_: torch.optim.Adam
6 | lr: 1e-3
7 | betas: [ 0.9, 0.999 ]
8 | eps: 1e-08
9 | weight_decay: 0
10 |
11 | # lr_scheduler:
12 | # _target_: torch.optim.lr_scheduler.CosineAnnealingWarmRestarts
13 | # T_0: 20
14 | # T_mult: 1
15 | # eta_min: 0
16 | # last_epoch: -1
17 | # verbose: False
18 |
19 |
20 | defaults:
21 | - _self_
22 | - data: default
23 | - module: default
24 |
--------------------------------------------------------------------------------
/{{ cookiecutter.repository_name }}/conf/nn/module/default.yaml:
--------------------------------------------------------------------------------
1 | _target_: {{ cookiecutter.package_name }}.pl_modules.pl_module.MyLightningModule
2 | x_key: ${conventions.x_key}
3 | y_key: ${conventions.y_key}
4 |
5 | defaults:
6 | - _self_
7 | - model: cnn
8 |
--------------------------------------------------------------------------------
/{{ cookiecutter.repository_name }}/conf/nn/module/model/cnn.yaml:
--------------------------------------------------------------------------------
1 | _target_: {{ cookiecutter.package_name }}.modules.module.CNN
2 | input_shape: ${nn.data.dataset.input_shape}
3 |
--------------------------------------------------------------------------------
/{{ cookiecutter.repository_name }}/conf/train/default.yaml:
--------------------------------------------------------------------------------
1 | # reproducibility
2 | seed_index: 0
3 | deterministic: False
4 |
5 | # PyTorch Lightning Trainer https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html
6 | trainer:
7 | fast_dev_run: False # Enable this for debug purposes
8 | accelerator: 'gpu'
9 | devices: 1
10 | precision: 32
11 | max_epochs: 3
12 | max_steps: 10000
13 | num_sanity_val_steps: 2
14 | gradient_clip_val: 10.0
15 | val_check_interval: 1.0
16 | deterministic: ${train.deterministic}
17 |
18 | restore:
19 | ckpt_or_run_path: null
20 | mode: null # null, finetune, hotstart, continue
21 |
22 | monitor:
23 | metric: 'loss/val'
24 | mode: 'min'
25 |
26 | callbacks:
27 | - _target_: lightning.pytorch.callbacks.EarlyStopping
28 | patience: 42
29 | verbose: False
30 | monitor: ${train.monitor.metric}
31 | mode: ${train.monitor.mode}
32 |
33 | - _target_: lightning.pytorch.callbacks.ModelCheckpoint
34 | save_top_k: 1
35 | verbose: False
36 | monitor: ${train.monitor.metric}
37 | mode: ${train.monitor.mode}
38 |
39 | - _target_: lightning.pytorch.callbacks.LearningRateMonitor
40 | logging_interval: "step"
41 | log_momentum: False
42 |
43 | - _target_: lightning.pytorch.callbacks.progress.tqdm_progress.TQDMProgressBar
44 | refresh_rate: 20
45 |
46 | logging:
47 | upload:
48 | run_files: true
49 | source: true
50 |
51 | logger:
52 | _target_: lightning.pytorch.loggers.WandbLogger
53 |
54 | project: ${core.project_name}
55 | entity: null
56 | log_model: ${..upload.run_files}
57 | mode: 'online'
58 | tags: ${core.tags}
59 |
60 | wandb_watch:
61 | log: 'all'
62 | log_freq: 100
63 |
--------------------------------------------------------------------------------
/{{ cookiecutter.repository_name }}/data/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 |
3 | !.gitignore
4 |
--------------------------------------------------------------------------------
/{{ cookiecutter.repository_name }}/docs/index.md:
--------------------------------------------------------------------------------
1 | ../README.md
--------------------------------------------------------------------------------
/{{ cookiecutter.repository_name }}/docs/overrides/main.html:
--------------------------------------------------------------------------------
1 | {% raw %}{% extends "base.html" %}
2 |
3 | {% block outdated %}
4 | You're not viewing the latest version.
5 |
6 | Click here to go to latest.
7 |
8 | {% endblock %}{% endraw %}
9 |
--------------------------------------------------------------------------------
/{{ cookiecutter.repository_name }}/env.yaml:
--------------------------------------------------------------------------------
1 | name: {{ cookiecutter.conda_env_name }}
2 | channels:
3 | - pytorch
4 | - nvidia
5 | - defaults
6 |
7 | dependencies:
8 | - python={{ cookiecutter.python_version }}
9 | - pytorch=2.0.*
10 | - torchvision
11 | - pytorch-cuda=11.8
12 | - pip
13 | - pip:
14 | - -e .[dev]
15 |
--------------------------------------------------------------------------------
/{{ cookiecutter.repository_name }}/mkdocs.yml:
--------------------------------------------------------------------------------
1 | site_name: {{ cookiecutter.repository_name }}
2 | site_description: {{ cookiecutter.project_description }}
3 | repo_url: {{ cookiecutter.repository_url }}
4 |
5 | nav:
6 | - Home: index.md
7 |
8 | theme:
9 | name: material
10 | custom_dir: docs/overrides
11 | icon:
12 | repo: fontawesome/brands/github
13 |
14 | features:
15 | - content.code.annotate
16 | - navigation.indexes
17 | - navigation.instant
18 | - navigation.sections
19 | - navigation.tabs
20 | - navigation.tabs.sticky
21 | - navigation.top
22 | - navigation.tracking
23 | - search.highlight
24 | - search.share
25 | - search.suggest
26 |
27 | palette:
28 | - scheme: default
29 | primary: light green
30 | accent: green
31 | toggle:
32 | icon: material/weather-night
33 | name: Switch to dark mode
34 | - scheme: slate
35 | primary: green
36 | accent: green
37 | toggle:
38 | icon: material/weather-sunny
39 | name: Switch to light mode
40 |
41 | # Extensions
42 | markdown_extensions:
43 | - abbr
44 | - admonition
45 | - attr_list
46 | - def_list
47 | - footnotes
48 | - meta
49 | - md_in_html
50 | - toc:
51 | permalink: true
52 | - tables
53 | - pymdownx.arithmatex:
54 | generic: true
55 | - pymdownx.betterem:
56 | smart_enable: all
57 | - pymdownx.caret
58 | - pymdownx.mark
59 | - pymdownx.tilde
60 | - pymdownx.critic
61 | - pymdownx.details
62 | - pymdownx.highlight:
63 | anchor_linenums: true
64 | - pymdownx.superfences
65 | - pymdownx.inlinehilite
66 | - pymdownx.keys
67 | - pymdownx.smartsymbols
68 | - pymdownx.snippets
69 | - pymdownx.tabbed:
70 | alternate_style: true
71 | - pymdownx.tasklist:
72 | custom_checkbox: true
73 |
74 | extra_javascript:
75 | - javascripts/mathjax.js
76 | - https://polyfill.io/v3/polyfill.min.js?features=es6
77 | - https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js
78 |
79 | plugins:
80 | - search
81 |
82 | extra:
83 | version:
84 | provider: mike
85 |
--------------------------------------------------------------------------------
/{{ cookiecutter.repository_name }}/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.pytest.ini_options]
2 | minversion = "6.2"
3 | addopts = "-ra"
4 | testpaths = ["tests"]
5 |
6 | [tool.coverage.report]
7 | exclude_lines = [
8 | "raise NotImplementedError",
9 | "raise NotImplementedError()",
10 | "pragma: nocover",
11 | "if __name__ == .__main__.:",
12 | ]
13 |
14 | [tool.black]
15 | line-length = 120
16 | include = '\.pyi?$'
17 |
18 | [tool.mypy]
19 | files= ["src/**/*.py", "test/*.py"]
20 | ignore_missing_imports = true
21 |
22 | [tool.isort]
23 | profile = 'black'
24 | line_length = 120
25 | known_third_party = ["numpy", "pytest", "wandb", "torch"]
26 | known_first_party = ["nn_core"]
27 | known_local_folder = "{{ cookiecutter.package_name }}"
28 |
29 | [tool.pydocstyle]
30 | convention = 'google'
31 | # ignore all missing docs errors
32 | add-ignore = ['D100', 'D101', 'D102', 'D103', 'D104', 'D105', 'D106', 'D107']
33 |
34 | [tool.bandit]
35 | skips = ["B101"]
36 |
37 | [tool.setuptools_scm]
38 | write_to = "src/{{ cookiecutter.package_name }}/_version.py"
39 | write_to_template = '__version__ = "{version}"'
40 |
41 | [build-system]
42 | requires = ["setuptools==59.5", "wheel", "setuptools_scm[toml]>=6.3.1"]
43 | build-backend = "setuptools.build_meta"
44 |
--------------------------------------------------------------------------------
/{{ cookiecutter.repository_name }}/setup.cfg:
--------------------------------------------------------------------------------
1 | [metadata]
2 | name = {{ cookiecutter.repository_name }}
3 | description = {{ cookiecutter.project_description }}
4 | url = {{ cookiecutter.repository_url }}
5 | long_description = file: README.md
6 | author = {{ cookiecutter.author }}
7 | author_email = {{ cookiecutter.author_email }}
8 | keywords = python
9 | license = MIT Licence
10 |
11 | [options]
12 | zip_safe = False
13 | include_package_data = True
14 | package_dir=
15 | =src
16 | packages=find:
17 | install_requires =
18 | nn-template-core==0.4.*
19 | anypy==0.0.*
20 |
21 | # Add project specific dependencies
22 | # Stuff easy to break with updates
23 | lightning==2.0.*
24 | torchmetrics==1.0.*
25 | hydra-core==1.3.*
26 | wandb
27 | streamlit
28 | # hydra-joblib-launcher
29 |
30 | # Stable stuff usually backward compatible
31 | rich
32 | dvc
33 | python-dotenv
34 | matplotlib
35 | stqdm
36 |
37 | [options.packages.find]
38 | where=src
39 |
40 | [options.package_data]
41 | * = *.txt, *.md
42 |
43 | [options.extras_require]
44 | docs =
45 | mkdocs
46 | mkdocs-material
47 | mike
48 |
49 | test =
50 | pytest
51 | pytest-cov
52 |
53 | dev =
54 | black
55 | flake8
56 | isort
57 | pre-commit
58 | bandit
59 | %(test)s
60 | %(docs)s
61 |
--------------------------------------------------------------------------------
/{{ cookiecutter.repository_name }}/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | import setuptools
4 |
5 | if __name__ == "__main__":
6 | setuptools.setup()
7 |
--------------------------------------------------------------------------------
/{{ cookiecutter.repository_name }}/src/{{ cookiecutter.package_name }}/__init__.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | from nn_core.console_logging import NNRichHandler
4 |
5 | # Required workaround because PyTorch Lightning configures the logging on import,
6 | # thus the logging configuration defined in the __init__.py must be called before
7 | # the lightning import otherwise it has no effect.
8 | # See https://github.com/PyTorchLightning/pytorch-lightning/issues/1503
9 | lightning_logger = logging.getLogger("lightning.pytorch")
10 | # Remove all handlers associated with the lightning logger.
11 | for handler in lightning_logger.handlers[:]:
12 | lightning_logger.removeHandler(handler)
13 | lightning_logger.propagate = True
14 |
15 | FORMAT = "%(message)s"
16 | logging.basicConfig(
17 | format=FORMAT,
18 | level=logging.INFO,
19 | datefmt="%Y-%m-%d %H:%M:%S",
20 | handlers=[
21 | NNRichHandler(
22 | rich_tracebacks=True,
23 | show_level=True,
24 | show_path=True,
25 | show_time=True,
26 | omit_repeated_times=True,
27 | )
28 | ],
29 | )
30 |
31 | try:
32 | from ._version import __version__ as __version__
33 | except ImportError:
34 | import sys
35 |
36 | print(
37 | "Project not installed in the current env, activate the correct env or install it with:\n\tpip install -e .",
38 | file=sys.stderr,
39 | )
40 | __version__ = "unknown"
41 |
--------------------------------------------------------------------------------
/{{ cookiecutter.repository_name }}/src/{{ cookiecutter.package_name }}/data/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/grok-ai/nn-template/8ba02bba8f015e1eb7efb0d2ab8c9d433bd1c431/{{ cookiecutter.repository_name }}/src/{{ cookiecutter.package_name }}/data/__init__.py
--------------------------------------------------------------------------------
/{{ cookiecutter.repository_name }}/src/{{ cookiecutter.package_name }}/data/datamodule.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from functools import cached_property, partial
3 | from pathlib import Path
4 | from typing import List, Mapping, Optional
5 |
6 | import hydra
7 | import lightning.pytorch as pl
8 | import omegaconf
9 | from omegaconf import DictConfig
10 | from torch.utils.data import DataLoader, Dataset
11 | from torch.utils.data.dataloader import default_collate
12 | from tqdm import tqdm
13 |
14 | from nn_core.common import PROJECT_ROOT
15 | from nn_core.nn_types import Split
16 |
17 | pylogger = logging.getLogger(__name__)
18 |
19 |
20 | class MetaData:
21 | def __init__(self, class_vocab: Mapping[str, int]):
22 | """The data information the Lightning Module will be provided with.
23 |
24 | This is a "bridge" between the Lightning DataModule and the Lightning Module.
25 | There is no constraint on the class name nor in the stored information, as long as it exposes the
26 | `save` and `load` methods.
27 |
28 | The Lightning Module will receive an instance of MetaData when instantiated,
29 | both in the train loop or when restored from a checkpoint.
30 |
31 | This decoupling allows the architecture to be parametric (e.g. in the number of classes) and
32 | DataModule/Trainer independent (useful in prediction scenarios).
33 | MetaData should contain all the information needed at test time, derived from its train dataset.
34 |
35 | Examples are the class names in a classification task or the vocabulary in NLP tasks.
36 | MetaData exposes `save` and `load`. Those are two user-defined methods that specify
37 | how to serialize and de-serialize the information contained in its attributes.
38 | This is needed for the checkpointing restore to work properly.
39 |
40 | Args:
41 | class_vocab: association between class names and their indices
42 | """
43 | # example
44 | self.class_vocab: Mapping[str, int] = class_vocab
45 |
46 | def save(self, dst_path: Path) -> None:
47 | """Serialize the MetaData attributes into the zipped checkpoint in dst_path.
48 |
49 | Args:
50 | dst_path: the root folder of the metadata inside the zipped checkpoint
51 | """
52 | pylogger.debug(f"Saving MetaData to '{dst_path}'")
53 |
54 | # example
55 | (dst_path / "class_vocab.tsv").write_text(
56 | "\n".join(f"{key}\t{value}" for key, value in self.class_vocab.items())
57 | )
58 |
59 | @staticmethod
60 | def load(src_path: Path) -> "MetaData":
61 | """Deserialize the MetaData from the information contained inside the zipped checkpoint in src_path.
62 |
63 | Args:
64 | src_path: the root folder of the metadata inside the zipped checkpoint
65 |
66 | Returns:
67 | an instance of MetaData containing the information in the checkpoint
68 | """
69 | pylogger.debug(f"Loading MetaData from '{src_path}'")
70 |
71 | # example
72 | lines = (src_path / "class_vocab.tsv").read_text(encoding="utf-8").splitlines()
73 |
74 | class_vocab = {}
75 | for line in lines:
76 | key, value = line.strip().split("\t")
77 | class_vocab[key] = value
78 |
79 | return MetaData(
80 | class_vocab=class_vocab,
81 | )
82 |
83 | def __repr__(self) -> str:
84 | attributes = ",\n ".join([f"{key}={value}" for key, value in self.__dict__.items()])
85 | return f"{self.__class__.__name__}(\n {attributes}\n)"
86 |
87 |
88 | def collate_fn(samples: List, split: Split, metadata: MetaData):
89 | """Custom collate function for dataloaders with access to split and metadata.
90 |
91 | Args:
92 | samples: A list of samples coming from the Dataset to be merged into a batch
93 | split: The data split (e.g. train/val/test)
94 | metadata: The MetaData instance coming from the DataModule or the restored checkpoint
95 |
96 | Returns:
97 | A batch generated from the given samples
98 | """
99 | return default_collate(samples)
100 |
101 |
102 | class MyDataModule(pl.LightningDataModule):
103 | def __init__(
104 | self,
105 | dataset: DictConfig,
106 | num_workers: DictConfig,
107 | batch_size: DictConfig,
108 | accelerator: str,
109 | # example
110 | val_images_fixed_idxs: List[int],
111 | ):
112 | super().__init__()
113 | self.dataset = dataset
114 | self.num_workers = num_workers
115 | self.batch_size = batch_size
116 | # https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html#gpus
117 | self.pin_memory: bool = accelerator is not None and str(accelerator) == "gpu"
118 |
119 | self.train_dataset: Optional[Dataset] = None
120 | self.val_dataset: Optional[Dataset] = None
121 | self.test_dataset: Optional[Dataset] = None
122 |
123 | # example
124 | self.val_images_fixed_idxs: List[int] = val_images_fixed_idxs
125 |
126 | @cached_property
127 | def metadata(self) -> MetaData:
128 | """Data information to be fed to the Lightning Module as parameter.
129 |
130 | Examples are vocabularies, number of classes...
131 |
132 | Returns:
133 | metadata: everything the model should know about the data, wrapped in a MetaData object.
134 | """
135 | # Since MetaData depends on the training data, we need to ensure the setup method has been called.
136 | if self.train_dataset is None:
137 | self.setup(stage="fit")
138 |
139 | return MetaData(class_vocab={i: name for i, name in enumerate(self.train_dataset.features["y"].names)})
140 |
141 | def prepare_data(self) -> None:
142 | # download only
143 | pass
144 |
145 | def setup(self, stage: Optional[str] = None):
146 | self.transform = hydra.utils.instantiate(self.dataset.transforms)
147 |
148 | self.hf_datasets = hydra.utils.instantiate(self.dataset)
149 | self.hf_datasets.set_transform(self.transform)
150 |
151 | # Here you should instantiate your dataset, you may also split the train into train and validation if needed.
152 | if (stage is None or stage == "fit") and (self.train_dataset is None and self.val_dataset is None):
153 | self.train_dataset = self.hf_datasets["train"]
154 | self.val_dataset = self.hf_datasets["val"]
155 |
156 | if stage is None or stage == "test":
157 | self.test_dataset = self.hf_datasets["test"]
158 |
159 | def train_dataloader(self) -> DataLoader:
160 | return DataLoader(
161 | self.train_dataset,
162 | shuffle=True,
163 | batch_size=self.batch_size.train,
164 | num_workers=self.num_workers.train,
165 | pin_memory=self.pin_memory,
166 | collate_fn=partial(collate_fn, split="train", metadata=self.metadata),
167 | )
168 |
169 | def val_dataloader(self) -> DataLoader:
170 | return DataLoader(
171 | self.val_dataset,
172 | shuffle=False,
173 | batch_size=self.batch_size.val,
174 | num_workers=self.num_workers.val,
175 | pin_memory=self.pin_memory,
176 | collate_fn=partial(collate_fn, split="val", metadata=self.metadata),
177 | )
178 |
179 | def test_dataloader(self) -> DataLoader:
180 | return DataLoader(
181 | self.test_dataset,
182 | shuffle=False,
183 | batch_size=self.batch_size.test,
184 | num_workers=self.num_workers.test,
185 | pin_memory=self.pin_memory,
186 | collate_fn=partial(collate_fn, split="test", metadata=self.metadata),
187 | )
188 |
189 | def __repr__(self) -> str:
190 | return f"{self.__class__.__name__}(" f"{self.dataset=}, " f"{self.num_workers=}, " f"{self.batch_size=})"
191 |
192 |
193 | @hydra.main(config_path=str(PROJECT_ROOT / "conf"), config_name="default")
194 | def main(cfg: omegaconf.DictConfig) -> None:
195 | """Debug main to quickly develop the DataModule.
196 |
197 | Args:
198 | cfg: the hydra configuration
199 | """
200 | m: pl.LightningDataModule = hydra.utils.instantiate(cfg.nn.data, _recursive_=False)
201 | m.metadata
202 | m.setup()
203 |
204 | for _ in tqdm(m.train_dataloader()):
205 | pass
206 |
207 |
208 | if __name__ == "__main__":
209 | main()
210 |
--------------------------------------------------------------------------------
/{{ cookiecutter.repository_name }}/src/{{ cookiecutter.package_name }}/data/dataset.py:
--------------------------------------------------------------------------------
1 | import hydra
2 | import omegaconf
3 | from torch.utils.data import Dataset
4 |
5 | from nn_core.common import PROJECT_ROOT
6 |
7 |
8 | @hydra.main(config_path=str(PROJECT_ROOT / "conf"), config_name="default")
9 | def main(cfg: omegaconf.DictConfig) -> None:
10 | """Debug main to quickly develop the Dataset.
11 |
12 | Args:
13 | cfg: the hydra configuration
14 | """
15 | _: Dataset = hydra.utils.instantiate(cfg.nn.data.dataset, _recursive_=False)
16 |
17 |
18 | if __name__ == "__main__":
19 | main()
20 |
--------------------------------------------------------------------------------
/{{ cookiecutter.repository_name }}/src/{{ cookiecutter.package_name }}/modules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/grok-ai/nn-template/8ba02bba8f015e1eb7efb0d2ab8c9d433bd1c431/{{ cookiecutter.repository_name }}/src/{{ cookiecutter.package_name }}/modules/__init__.py
--------------------------------------------------------------------------------
/{{ cookiecutter.repository_name }}/src/{{ cookiecutter.package_name }}/modules/module.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 |
3 | from torch import nn
4 |
5 |
6 | # https://medium.com/@nutanbhogendrasharma/pytorch-convolutional-neural-network-with-mnist-dataset-4e8a4265e118
7 | class CNN(nn.Module):
8 | def __init__(self, input_shape: Tuple[int], num_classes: int):
9 | super(CNN, self).__init__()
10 | self.model = nn.Sequential(
11 | nn.Conv2d(
12 | in_channels=input_shape[0],
13 | out_channels=16,
14 | kernel_size=5,
15 | stride=1,
16 | padding=2,
17 | ),
18 | nn.SiLU(),
19 | nn.MaxPool2d(kernel_size=2),
20 | nn.Conv2d(16, 32, 5, 1, 2),
21 | nn.SiLU(),
22 | nn.MaxPool2d(2),
23 | )
24 | self.conv2 = nn.Sequential()
25 | self.out = nn.Linear(32 * 7 * 7, num_classes)
26 |
27 | def forward(self, x):
28 | x = self.model(x)
29 | # [batch_size, 32 * 7 * 7]
30 | x = x.view(x.size(0), -1)
31 | output = self.out(x)
32 | return output
33 |
--------------------------------------------------------------------------------
/{{ cookiecutter.repository_name }}/src/{{ cookiecutter.package_name }}/pl_modules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/grok-ai/nn-template/8ba02bba8f015e1eb7efb0d2ab8c9d433bd1c431/{{ cookiecutter.repository_name }}/src/{{ cookiecutter.package_name }}/pl_modules/__init__.py
--------------------------------------------------------------------------------
/{{ cookiecutter.repository_name }}/src/{{ cookiecutter.package_name }}/pl_modules/pl_module.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from typing import Any, Dict, Mapping, Optional, Sequence, Tuple, Union
3 |
4 | import hydra
5 | import lightning.pytorch as pl
6 | import omegaconf
7 | import torch
8 | import torch.nn.functional as F
9 | import torchmetrics
10 | from torch.optim import Optimizer
11 |
12 | from nn_core.common import PROJECT_ROOT
13 | from nn_core.model_logging import NNLogger
14 |
15 | from {{ cookiecutter.package_name }}.data.datamodule import MetaData
16 |
17 | pylogger = logging.getLogger(__name__)
18 |
19 |
20 | class MyLightningModule(pl.LightningModule):
21 | logger: NNLogger
22 |
23 | def __init__(self, model, metadata: Optional[MetaData] = None, *args, **kwargs) -> None:
24 | super().__init__()
25 |
26 | # Populate self.hparams with args and kwargs automagically!
27 | # We want to skip metadata since it is saved separately by the NNCheckpointIO object.
28 | # Be careful when modifying this instruction. If in doubt, don't do it :]
29 | self.save_hyperparameters(logger=False, ignore=("metadata",))
30 |
31 | self.metadata = metadata
32 |
33 | # example
34 | metric = torchmetrics.Accuracy(
35 | task="multiclass",
36 | num_classes=len(metadata.class_vocab) if metadata is not None else None,
37 | )
38 | self.train_acc = metric.clone()
39 | self.val_acc = metric.clone()
40 | self.test_acc = metric.clone()
41 |
42 | self.model = hydra.utils.instantiate(model, num_classes=len(metadata.class_vocab))
43 |
44 | def forward(self, x: torch.Tensor) -> torch.Tensor:
45 | """Method for the forward pass.
46 |
47 | 'training_step', 'validation_step' and 'test_step' should call
48 | this method in order to compute the output predictions and the loss.
49 |
50 | Returns:
51 | output_dict: forward output containing the predictions (output logits ecc...) and the loss if any.
52 | """
53 | # example
54 | return self.model(x)
55 |
56 | def _step(self, batch: Dict[str, torch.Tensor], split: str) -> Mapping[str, Any]:
57 | x = batch[self.hparams.x_key]
58 | gt_y = batch[self.hparams.y_key]
59 |
60 | # example
61 | logits = self(x)
62 | loss = F.cross_entropy(logits, gt_y)
63 | preds = torch.softmax(logits, dim=-1)
64 |
65 | metrics = getattr(self, f"{split}_acc")
66 | metrics.update(preds, gt_y)
67 |
68 | self.log_dict(
69 | {
70 | f"acc/{split}": metrics,
71 | f"loss/{split}": loss,
72 | },
73 | on_epoch=True,
74 | )
75 |
76 | return {"logits": logits.detach(), "loss": loss}
77 |
78 | def training_step(self, batch: Any, batch_idx: int) -> Mapping[str, Any]:
79 | return self._step(batch=batch, split="train")
80 |
81 | def validation_step(self, batch: Any, batch_idx: int) -> Mapping[str, Any]:
82 | return self._step(batch=batch, split="val")
83 |
84 | def test_step(self, batch: Any, batch_idx: int) -> Mapping[str, Any]:
85 | return self._step(batch=batch, split="test")
86 |
87 | def configure_optimizers(
88 | self,
89 | ) -> Union[Optimizer, Tuple[Sequence[Optimizer], Sequence[Any]]]:
90 | """Choose what optimizers and learning-rate schedulers to use in your optimization.
91 |
92 | Normally you'd need one. But in the case of GANs or similar you might have multiple.
93 |
94 | Return:
95 | Any of these 6 options.
96 | - Single optimizer.
97 | - List or Tuple - List of optimizers.
98 | - Two lists - The first list has multiple optimizers, the second a list of LR schedulers (or lr_dict).
99 | - Dictionary, with an 'optimizer' key, and (optionally) a 'lr_scheduler'
100 | key whose value is a single LR scheduler or lr_dict.
101 | - Tuple of dictionaries as described, with an optional 'frequency' key.
102 | - None - Fit will run without any optimizer.
103 | """
104 | opt = hydra.utils.instantiate(self.hparams.optimizer, params=self.parameters(), _convert_="partial")
105 | if "lr_scheduler" not in self.hparams:
106 | return [opt]
107 | scheduler = hydra.utils.instantiate(self.hparams.lr_scheduler, optimizer=opt)
108 | return [opt], [scheduler]
109 |
110 |
111 | @hydra.main(config_path=str(PROJECT_ROOT / "conf"), config_name="default")
112 | def main(cfg: omegaconf.DictConfig) -> None:
113 | """Debug main to quickly develop the Lightning Module.
114 |
115 | Args:
116 | cfg: the hydra configuration
117 | """
118 | m: pl.LightningDataModule = hydra.utils.instantiate(cfg.nn.data, _recursive_=False)
119 | _: pl.LightningModule = hydra.utils.instantiate(cfg.nn.module, _recursive_=False, metadata=m.metadata)
120 |
121 |
122 | if __name__ == "__main__":
123 | main()
124 |
--------------------------------------------------------------------------------
/{{ cookiecutter.repository_name }}/src/{{ cookiecutter.package_name }}/run.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from typing import List, Optional
3 |
4 | import hydra
5 | import lightning.pytorch as pl
6 | import omegaconf
7 | import torch
8 | from lightning.pytorch import Callback
9 | from omegaconf import DictConfig, ListConfig
10 |
11 | from nn_core.callbacks import NNTemplateCore
12 | from nn_core.common import PROJECT_ROOT
13 | from nn_core.common.utils import enforce_tags, seed_index_everything
14 | from nn_core.model_logging import NNLogger
15 | from nn_core.serialization import NNCheckpointIO
16 |
17 | # Force the execution of __init__.py if this file is executed directly.
18 | import {{ cookiecutter.package_name }} # noqa
19 | from {{ cookiecutter.package_name }}.data.datamodule import MetaData
20 |
21 | pylogger = logging.getLogger(__name__)
22 |
23 | torch.set_float32_matmul_precision("high")
24 |
25 |
26 | def build_callbacks(cfg: ListConfig, *args: Callback) -> List[Callback]:
27 | """Instantiate the callbacks given their configuration.
28 |
29 | Args:
30 | cfg: a list of callbacks instantiable configuration
31 | *args: a list of extra callbacks already instantiated
32 |
33 | Returns:
34 | the complete list of callbacks to use
35 | """
36 | callbacks: List[Callback] = list(args)
37 |
38 | for callback in cfg:
39 | pylogger.info(f"Adding callback <{callback['_target_'].split('.')[-1]}>")
40 | callbacks.append(hydra.utils.instantiate(callback, _recursive_=False))
41 |
42 | return callbacks
43 |
44 |
45 | def run(cfg: DictConfig) -> str:
46 | """Generic train loop.
47 |
48 | Args:
49 | cfg: run configuration, defined by Hydra in /conf
50 |
51 | Returns:
52 | the run directory inside the storage_dir used by the current experiment
53 | """
54 | seed_index_everything(cfg.train)
55 |
56 | fast_dev_run: bool = cfg.train.trainer.fast_dev_run
57 | if fast_dev_run:
58 | pylogger.info(f"Debug mode <{cfg.train.trainer.fast_dev_run=}>. Forcing debugger friendly configuration!")
59 | # Debuggers don't like GPUs nor multiprocessing
60 | cfg.train.trainer.accelerator = "cpu"
61 | cfg.nn.data.num_workers.train = 0
62 | cfg.nn.data.num_workers.val = 0
63 | cfg.nn.data.num_workers.test = 0
64 |
65 | cfg.core.tags = enforce_tags(cfg.core.get("tags", None))
66 |
67 | # Instantiate datamodule
68 | pylogger.info(f"Instantiating <{cfg.nn.data['_target_']}>")
69 | datamodule: pl.LightningDataModule = hydra.utils.instantiate(cfg.nn.data, _recursive_=False)
70 | datamodule.setup(stage=None)
71 |
72 | metadata: Optional[MetaData] = getattr(datamodule, "metadata", None)
73 | if metadata is None:
74 | pylogger.warning(f"No 'metadata' attribute found in datamodule <{datamodule.__class__.__name__}>")
75 |
76 | # Instantiate model
77 | pylogger.info(f"Instantiating <{cfg.nn.module['_target_']}>")
78 | model: pl.LightningModule = hydra.utils.instantiate(cfg.nn.module, _recursive_=False, metadata=metadata)
79 |
80 | # Instantiate the callbacks
81 | template_core: NNTemplateCore = NNTemplateCore(
82 | restore_cfg=cfg.train.get("restore", None),
83 | )
84 | callbacks: List[Callback] = build_callbacks(cfg.train.callbacks, template_core)
85 |
86 | storage_dir: str = cfg.core.storage_dir
87 |
88 | logger: NNLogger = NNLogger(logging_cfg=cfg.train.logging, cfg=cfg, resume_id=template_core.resume_id)
89 |
90 | pylogger.info("Instantiating the ")
91 | trainer = pl.Trainer(
92 | default_root_dir=storage_dir,
93 | plugins=[NNCheckpointIO(jailing_dir=logger.run_dir)],
94 | logger=logger,
95 | callbacks=callbacks,
96 | **cfg.train.trainer,
97 | )
98 |
99 | pylogger.info("Starting training!")
100 | trainer.fit(model=model, datamodule=datamodule, ckpt_path=template_core.trainer_ckpt_path)
101 |
102 | if fast_dev_run:
103 | pylogger.info("Skipping testing in 'fast_dev_run' mode!")
104 | else:
105 | if datamodule.test_dataset is not None and trainer.checkpoint_callback.best_model_path is not None:
106 | pylogger.info("Starting testing!")
107 | trainer.test(datamodule=datamodule)
108 |
109 | # Logger closing to release resources/avoid multi-run conflicts
110 | if logger is not None:
111 | logger.experiment.finish()
112 |
113 | return logger.run_dir
114 |
115 |
116 | @hydra.main(config_path=str(PROJECT_ROOT / "conf"), config_name="default")
117 | def main(cfg: omegaconf.DictConfig):
118 | run(cfg)
119 |
120 |
121 | if __name__ == "__main__":
122 | main()
123 |
--------------------------------------------------------------------------------
/{{ cookiecutter.repository_name }}/src/{{ cookiecutter.package_name }}/ui/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/grok-ai/nn-template/8ba02bba8f015e1eb7efb0d2ab8c9d433bd1c431/{{ cookiecutter.repository_name }}/src/{{ cookiecutter.package_name }}/ui/__init__.py
--------------------------------------------------------------------------------
/{{ cookiecutter.repository_name }}/src/{{ cookiecutter.package_name }}/ui/run.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | import streamlit as st
4 | import wandb
5 |
6 | from nn_core.serialization import load_model
7 | from nn_core.ui import select_checkpoint
8 |
9 | from {{ cookiecutter.package_name }}.pl_modules.pl_module import MyLightningModule
10 |
11 |
12 | @st.cache(allow_output_mutation=True)
13 | def get_model(checkpoint_path: Path):
14 | return load_model(module_class=MyLightningModule, checkpoint_path=checkpoint_path)
15 |
16 |
17 | if wandb.api.api_key is None:
18 | st.error("You are not logged in on `Weights and Biases`: https://docs.wandb.ai/ref/cli/wandb-login")
19 | st.stop()
20 |
21 | st.sidebar.subheader(f"Logged in W&B as: {wandb.api.viewer()['entity']}")
22 |
23 | checkpoint_path = select_checkpoint()
24 | model: MyLightningModule = get_model(checkpoint_path=checkpoint_path)
25 | model
26 |
--------------------------------------------------------------------------------
/{{ cookiecutter.repository_name }}/src/{{ cookiecutter.package_name }}/utils/hf_io.py:
--------------------------------------------------------------------------------
1 | from collections import namedtuple
2 | from pathlib import Path
3 | from typing import Any, Callable, Dict, Sequence
4 |
5 | import torch
6 | from anypy.data.metadata_dataset_dict import MetadataDatasetDict
7 | from datasets import Dataset, DatasetDict, load_dataset, load_from_disk
8 | from omegaconf import DictConfig
9 |
10 | from nn_core.common import PROJECT_ROOT
11 |
12 | DatasetParams = namedtuple("DatasetParams", ["name", "fine_grained", "train_split", "test_split", "hf_key"])
13 |
14 |
15 | class HFTransform:
16 | def __init__(
17 | self,
18 | key: str,
19 | transform: Callable[[torch.Tensor], torch.Tensor],
20 | ):
21 | """Apply a row-wise transform to a dataset column.
22 |
23 | Args:
24 | key (str): The key of the column to transform.
25 | transform (Callable[[torch.Tensor], torch.Tensor]): The transform to apply.
26 | """
27 | self.transform = transform
28 | self.key = key
29 |
30 | def __call__(self, samples: Dict[str, Sequence[Any]]) -> Dict[str, Sequence[Any]]:
31 | """Apply the transform to the samples.
32 |
33 | Args:
34 | samples (Dict[str, Sequence[Any]]): The samples to transform.
35 |
36 | Returns:
37 | Dict[str, Sequence[Any]]: The transformed samples.
38 | """
39 | samples[self.key] = [self.transform(data) for data in samples[self.key]]
40 | return samples
41 |
42 | def __repr__(self) -> str:
43 | return repr(self.transform)
44 |
45 |
46 | def preprocess_dataset(
47 | dataset: Dataset,
48 | cfg: Dict,
49 | ) -> Dataset:
50 | """Preprocess a dataset.
51 |
52 | This function applies the following preprocessing steps:
53 | - Rename the label column to the standard key.
54 | - Rename the data column to the standard key.
55 |
56 | Do not apply transforms here, as the preprocessed dataset will be saved to disk once
57 | and then resued; thus updates on the transforms will not be reflected in the dataset.
58 |
59 | Args:
60 | dataset (Dataset): The dataset to preprocess.
61 | cfg (Dict): The configuration.
62 |
63 | Returns:
64 | Dataset: The preprocessed dataset.
65 | """
66 | dataset = dataset.rename_column(cfg["label_key"], cfg["standard_y_key"])
67 | dataset = dataset.rename_column(cfg["data_key"], cfg["standard_x_key"])
68 | return dataset
69 |
70 |
71 | def save_dataset_to_disk(dataset: MetadataDatasetDict, output_path: Path) -> None:
72 | """Save a dataset to disk.
73 |
74 | Args:
75 | dataset (MetadataDatasetDict): The dataset to save.
76 | output_path (Path): The path to save the dataset to.
77 | """
78 | if not isinstance(output_path, Path):
79 | output_path = Path(output_path)
80 |
81 | output_path.mkdir(parents=True, exist_ok=True)
82 |
83 | dataset.save_to_disk(output_path)
84 |
85 |
86 | def load_hf_dataset(**cfg: DictConfig) -> MetadataDatasetDict:
87 | """Load a dataset from the HuggingFace datasets library.
88 |
89 | The returned dataset is a MetadataDatasetDict, which is a wrapper around a DatasetDict.
90 | It will contain the following splits:
91 | - train
92 | - val
93 | - test
94 | If `val_split` is not specified in the config, it will be created from the train split
95 | according to the `val_percentage` specified in the config.
96 |
97 | The returned dataset will be preprocessed and saved to disk,
98 | if it does not exist yet, and loaded from disk otherwise.
99 |
100 | Args:
101 | cfg: The configuration.
102 |
103 | Returns:
104 | Dataset: The loaded dataset.
105 | """
106 | dataset_params: DatasetParams = DatasetParams(
107 | cfg["ref"],
108 | None,
109 | cfg["train_split"],
110 | cfg["test_split"],
111 | (cfg["ref"],),
112 | )
113 | DATASET_KEY = "_".join(
114 | map(
115 | str,
116 | [v for k, v in dataset_params._asdict().items() if k != "hf_key" and v is not None],
117 | )
118 | )
119 | DATASET_DIR: Path = PROJECT_ROOT / "data" / "datasets" / DATASET_KEY
120 |
121 | if not DATASET_DIR.exists():
122 | train_dataset = load_dataset(
123 | dataset_params.name,
124 | split=dataset_params.train_split,
125 | token=True,
126 | )
127 | if "val_percentage" in cfg:
128 | train_val_dataset = train_dataset.train_test_split(test_size=cfg["val_percentage"], shuffle=True)
129 | train_dataset = train_val_dataset["train"]
130 | val_dataset = train_val_dataset["test"]
131 | elif "val_split" in cfg:
132 | val_dataset = load_dataset(
133 | dataset_params.name,
134 | split=cfg["val_split"],
135 | token=True,
136 | )
137 | else:
138 | raise RuntimeError("Either val_percentage or val_split must be specified in the config.")
139 |
140 | test_dataset = load_dataset(
141 | dataset_params.name,
142 | split=dataset_params.test_split,
143 | token=True,
144 | )
145 |
146 | dataset: DatasetDict = MetadataDatasetDict(
147 | train=train_dataset,
148 | val=val_dataset,
149 | test=test_dataset,
150 | )
151 |
152 | dataset = preprocess_dataset(dataset, cfg)
153 |
154 | save_dataset_to_disk(dataset, DATASET_DIR)
155 | else:
156 | dataset: Dataset = load_from_disk(dataset_path=str(DATASET_DIR))
157 |
158 | return dataset
159 |
--------------------------------------------------------------------------------
/{{ cookiecutter.repository_name }}/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/grok-ai/nn-template/8ba02bba8f015e1eb7efb0d2ab8c9d433bd1c431/{{ cookiecutter.repository_name }}/tests/__init__.py
--------------------------------------------------------------------------------
/{{ cookiecutter.repository_name }}/tests/conftest.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import shutil
4 | from pathlib import Path
5 | from typing import Dict, Union
6 |
7 | import pytest
8 | from hydra import compose, initialize
9 | from hydra.core.hydra_config import HydraConfig
10 | from lightning.pytorch import seed_everything
11 | from omegaconf import DictConfig, OmegaConf, open_dict
12 | from pytest import FixtureRequest, TempPathFactory
13 |
14 | from nn_core.serialization import NNCheckpointIO
15 |
16 | from {{ cookiecutter.package_name }}.run import run
17 |
18 | logging.basicConfig(force=True, level=logging.DEBUG)
19 |
20 | seed_everything(42)
21 |
22 | TRAIN_MAX_NSTEPS = 1
23 |
24 |
25 | #
26 | # Base configurations
27 | #
28 | @pytest.fixture(scope="package")
29 | def cfg(tmp_path_factory: TempPathFactory) -> DictConfig:
30 | test_cfg_tmpdir = tmp_path_factory.mktemp("test_train_tmpdir")
31 |
32 | with initialize(config_path="../conf"):
33 | cfg = compose(config_name="default", return_hydra_config=True)
34 | HydraConfig().set_config(cfg)
35 |
36 | # Force the wandb dir to be in the temp folder
37 | os.environ["WANDB_DIR"] = str(test_cfg_tmpdir)
38 |
39 | # Force the storage dir to be in the temp folder
40 | cfg.core.storage_dir = str(test_cfg_tmpdir)
41 |
42 | yield cfg
43 |
44 | shutil.rmtree(test_cfg_tmpdir)
45 |
46 |
47 | #
48 | # Training configurations
49 | #
50 | @pytest.fixture(scope="package")
51 | def cfg_simple_train(cfg: DictConfig) -> DictConfig:
52 | cfg = OmegaConf.create(cfg)
53 |
54 | # Add test tag
55 | cfg.core.tags = ["testing"]
56 |
57 | # Disable gpus
58 | cfg.train.trainer.accelerator = "cpu"
59 |
60 | # Disable logger
61 | cfg.train.logging.logger.mode = "disabled"
62 |
63 | # Disable files upload because wandb in offline modes uses always /tmp
64 | # as run.dir, which causes conflicts between multiple trainings
65 | cfg.train.logging.upload.run_files = False
66 |
67 | # Disable multiple workers in test training
68 | cfg.nn.data.num_workers.train = 0
69 | cfg.nn.data.num_workers.val = 0
70 | cfg.nn.data.num_workers.test = 0
71 |
72 | # Minimize the amount of work in test training
73 | cfg.train.trainer.max_steps = TRAIN_MAX_NSTEPS
74 | cfg.train.trainer.val_check_interval = TRAIN_MAX_NSTEPS
75 |
76 | # Ensure the resuming is disabled
77 | with open_dict(config=cfg):
78 | cfg.train.restore = {}
79 | cfg.train.restore.ckpt_or_run_path = None
80 | cfg.train.restore.mode = None
81 |
82 | return cfg
83 |
84 |
85 | @pytest.fixture(scope="package")
86 | def cfg_fast_dev_run(cfg_simple_train: DictConfig) -> DictConfig:
87 | cfg_simple_train = OmegaConf.create(cfg_simple_train)
88 |
89 | # Enable the fast_dev_run flag
90 | cfg_simple_train.train.trainer.fast_dev_run = True
91 | return cfg_simple_train
92 |
93 |
94 | #
95 | # Training configurations aggregations
96 | #
97 | @pytest.fixture(
98 | scope="package",
99 | params=[
100 | "cfg_simple_train",
101 | ],
102 | )
103 | def cfg_all_not_dry(request: FixtureRequest):
104 | return request.getfixturevalue(request.param)
105 |
106 |
107 | @pytest.fixture(
108 | scope="package",
109 | params=[
110 | "cfg_simple_train",
111 | "cfg_fast_dev_run",
112 | ],
113 | )
114 | def cfg_all(request: FixtureRequest):
115 | return request.getfixturevalue(request.param)
116 |
117 |
118 | #
119 | # Training fixtures
120 | #
121 | @pytest.fixture(
122 | scope="package",
123 | )
124 | def run_trainings_not_dry(cfg_all_not_dry: DictConfig) -> str:
125 | yield run(cfg=cfg_all_not_dry)
126 |
127 |
128 | @pytest.fixture(
129 | scope="package",
130 | )
131 | def run_trainings(cfg_all: DictConfig) -> str:
132 | yield run(cfg=cfg_all)
133 |
134 |
135 | #
136 | # Utility functions
137 | #
138 | def get_checkpoint_path(storagedir: Union[str, Path]) -> Path:
139 | ckpts_path = Path(storagedir) / "checkpoints"
140 | checkpoint_path = next(ckpts_path.glob("*"))
141 | assert checkpoint_path
142 | return checkpoint_path
143 |
144 |
145 | def load_checkpoint(storagedir: Union[str, Path]) -> Dict:
146 | checkpoint = NNCheckpointIO.load(path=get_checkpoint_path(storagedir))
147 | assert checkpoint
148 | return checkpoint
149 |
--------------------------------------------------------------------------------
/{{ cookiecutter.repository_name }}/tests/test_checkpoint.py:
--------------------------------------------------------------------------------
1 | from importlib import import_module
2 | from pathlib import Path
3 | from typing import Any, Dict
4 |
5 | from lightning.pytorch import LightningModule
6 | from lightning.pytorch.core.saving import _load_state
7 | from omegaconf import DictConfig, OmegaConf
8 |
9 | from nn_core.serialization import NNCheckpointIO
10 | from tests.conftest import load_checkpoint
11 |
12 | from {{ cookiecutter.package_name }}.pl_modules.pl_module import MyLightningModule
13 | from {{ cookiecutter.package_name }}.run import run
14 |
15 |
16 | def test_load_checkpoint(run_trainings_not_dry: str, cfg_all_not_dry: DictConfig) -> None:
17 | ckpts_path = Path(run_trainings_not_dry) / "checkpoints"
18 | checkpoint_path = next(ckpts_path.glob("*"))
19 | assert checkpoint_path
20 |
21 | reference: str = cfg_all_not_dry.nn.module._target_
22 | module_ref, class_ref = reference.rsplit(".", maxsplit=1)
23 | module_class: LightningModule = getattr(import_module(module_ref), class_ref)
24 | assert module_class is not None
25 |
26 | checkpoint = NNCheckpointIO.load(path=checkpoint_path)
27 |
28 | module = _load_state(cls=module_class, checkpoint=checkpoint, metadata=checkpoint["metadata"], strict=True)
29 | assert module is not None
30 | assert sum(p.numel() for p in module.parameters())
31 |
32 |
33 | def _check_cfg_in_checkpoint(checkpoint: Dict, _cfg: DictConfig) -> Dict:
34 | assert "cfg" in checkpoint
35 | assert checkpoint["cfg"] == _cfg
36 |
37 |
38 | def _check_run_path_in_checkpoint(checkpoint: Dict) -> Dict:
39 | assert "run_path" in checkpoint
40 | assert checkpoint["run_path"]
41 | checkpoint["run_path"]: str
42 | assert checkpoint["run_path"].startswith("//")
43 |
44 |
45 | def test_cfg_in_checkpoint(run_trainings_not_dry: str, cfg_all_not_dry: DictConfig) -> None:
46 | checkpoint = load_checkpoint(run_trainings_not_dry)
47 |
48 | _check_cfg_in_checkpoint(checkpoint, cfg_all_not_dry)
49 | _check_run_path_in_checkpoint(checkpoint)
50 |
51 |
52 | class ModuleWithCustomCheckpoint(MyLightningModule):
53 | def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
54 | checkpoint["test_key"] = "test_value"
55 |
56 |
57 | def test_on_save_checkpoint_hook(cfg_all_not_dry: DictConfig) -> None:
58 | cfg = OmegaConf.create(cfg_all_not_dry)
59 | cfg.nn.module._target_ = "tests.test_checkpoint.ModuleWithCustomCheckpoint"
60 | output_path = Path(run(cfg))
61 |
62 | checkpoint = load_checkpoint(output_path)
63 |
64 | _check_cfg_in_checkpoint(checkpoint, cfg)
65 | _check_run_path_in_checkpoint(checkpoint)
66 |
67 | assert "test_key" in checkpoint
68 | assert checkpoint["test_key"] == "test_value"
69 |
--------------------------------------------------------------------------------
/{{ cookiecutter.repository_name }}/tests/test_configuration.py:
--------------------------------------------------------------------------------
1 | import hydra
2 | from omegaconf import DictConfig
3 |
4 | from {{ cookiecutter.package_name }}.run import build_callbacks
5 |
6 |
7 | def test_configuration_parsing(cfg: DictConfig) -> None:
8 | assert cfg is not None
9 |
10 |
11 | def test_callbacks_instantiation(cfg: DictConfig) -> None:
12 | build_callbacks(cfg.train.callbacks)
13 |
14 |
15 | def test_model_instantiation(cfg: DictConfig) -> None:
16 | datamodule = hydra.utils.instantiate(cfg.nn.data, _recursive_=False)
17 | hydra.utils.instantiate(cfg.nn.module, metadata=datamodule.metadata, _recursive_=False)
18 |
19 |
20 | def test_cfg_parametrization(cfg_all: DictConfig):
21 | assert cfg_all
22 |
--------------------------------------------------------------------------------
/{{ cookiecutter.repository_name }}/tests/test_nn_core_integration.py:
--------------------------------------------------------------------------------
1 | from nn_core.common import PROJECT_ROOT
2 |
3 |
4 | def test_project_root() -> None:
5 | assert PROJECT_ROOT
6 | assert (PROJECT_ROOT / "conf").exists()
7 |
--------------------------------------------------------------------------------
/{{ cookiecutter.repository_name }}/tests/test_resume.py:
--------------------------------------------------------------------------------
1 | from omegaconf import DictConfig, OmegaConf
2 | from pytest import TempPathFactory
3 |
4 | from nn_core.serialization import NNCheckpointIO
5 | from tests.conftest import TRAIN_MAX_NSTEPS, get_checkpoint_path, load_checkpoint
6 |
7 | from {{ cookiecutter.package_name }}.run import run
8 |
9 |
10 | def test_resume(run_trainings_not_dry: str, cfg_all_not_dry: DictConfig, tmp_path_factory: TempPathFactory) -> None:
11 | old_checkpoint_path = get_checkpoint_path(run_trainings_not_dry)
12 |
13 | new_cfg = OmegaConf.create(cfg_all_not_dry)
14 | new_storage_dir = tmp_path_factory.mktemp("resumed_training")
15 |
16 | new_cfg.core.storage_dir = str(new_storage_dir)
17 | new_cfg.train.trainer.max_steps = 2 * TRAIN_MAX_NSTEPS
18 |
19 | new_cfg.train.restore.ckpt_or_run_path = str(old_checkpoint_path)
20 | new_cfg.train.restore.mode = "hotstart"
21 |
22 | new_training_dir = run(new_cfg)
23 |
24 | old_checkpoint = NNCheckpointIO.load(path=old_checkpoint_path)
25 | new_checkpoint = load_checkpoint(new_training_dir)
26 |
27 | assert old_checkpoint["run_path"] != new_checkpoint["run_path"]
28 | assert old_checkpoint["global_step"] * 2 == new_checkpoint["global_step"]
29 |
--------------------------------------------------------------------------------
/{{ cookiecutter.repository_name }}/tests/test_seeding.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from omegaconf import DictConfig, OmegaConf
3 |
4 | from nn_core.common.utils import seed_index_everything
5 |
6 |
7 | @pytest.mark.parametrize(
8 | "seed_index, expected_seed",
9 | [
10 | (0, 1608637542),
11 | (30, 787716372),
12 | ],
13 | )
14 | def test_seed_index_determinism(cfg_all: DictConfig, seed_index: int, expected_seed: int):
15 | cfg_all = OmegaConf.create(cfg_all)
16 |
17 | cfg_all.train.seed_index = seed_index
18 | current_seed = seed_index_everything(train_cfg=cfg_all.train, sampling_seed=42)
19 | assert current_seed == expected_seed
20 |
--------------------------------------------------------------------------------
/{{ cookiecutter.repository_name }}/tests/test_storage.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | import yaml
4 | from omegaconf import DictConfig
5 |
6 |
7 | def test_storage_config(run_trainings_not_dry: str, cfg_all_not_dry: DictConfig) -> None:
8 | cfg_path = Path(run_trainings_not_dry) / "config.yaml"
9 |
10 | assert cfg_path.exists()
11 |
12 | with cfg_path.open() as f:
13 | loaded_cfg = yaml.safe_load(f)
14 | assert loaded_cfg == cfg_all_not_dry
15 |
16 |
17 | def test_storage_checkpoint(run_trainings_not_dry: str, cfg_all_not_dry: DictConfig) -> None:
18 | cktps_path = Path(run_trainings_not_dry) / "checkpoints"
19 |
20 | assert cktps_path.exists()
21 |
22 | checkpoints = list(cktps_path.glob("*"))
23 | assert len(checkpoints) == 1
24 |
--------------------------------------------------------------------------------
/{{ cookiecutter.repository_name }}/tests/test_training.py:
--------------------------------------------------------------------------------
1 | def test_train_loop(run_trainings: str) -> None:
2 | assert run_trainings
3 |
--------------------------------------------------------------------------------