├── .github └── workflows │ ├── codeql-analysis.yml │ ├── coverage.yml │ └── docs.yml ├── .gitignore ├── .pre-commit-config.yaml ├── CITATION.cff ├── LICENSE ├── MANIFEST.ini ├── README.md ├── app.py ├── checkpoints ├── production │ └── README.md └── trials │ └── README.md ├── dash ├── history │ ├── app.py │ ├── assets │ │ └── styles.css │ ├── components.py │ └── utilities.py ├── model │ ├── app.py │ ├── assets │ │ └── styles.css │ ├── components.py │ └── utilities.py └── training │ ├── app.py │ ├── assets │ └── styles.css │ ├── components.py │ └── utilities.py ├── data ├── README.md └── predictions │ └── predictions.pt ├── docs ├── .authors.yml ├── .meta.yml ├── index.md └── reference │ ├── datamodule.md │ └── module.md ├── helpers.txt ├── mkdocs.yml ├── pyproject.toml ├── requirements.txt ├── requirements ├── base.txt ├── cli.txt ├── dev.txt ├── docs.txt ├── frontends.txt └── packaging.txt ├── setup.cfg ├── setup.py ├── src └── visionlab │ ├── __init__.py │ ├── cli.py │ ├── config.py │ ├── datamodule.py │ └── module.py └── tests ├── __init__.py ├── test_datamodule.py └── test_module.py /.github/workflows/codeql-analysis.yml: -------------------------------------------------------------------------------- 1 | # For most projects, this workflow file will not need changing; you simply need 2 | # to commit it to your repository. 3 | # 4 | # You may wish to alter this file to override the set of languages analyzed, 5 | # or to provide custom queries or build logic. 6 | # 7 | # ******** NOTE ******** 8 | # We have attempted to detect the languages in your repository. Please check 9 | # the `language` matrix defined below to confirm you have the correct set of 10 | # supported CodeQL languages. 11 | # 12 | name: "CodeQL" 13 | 14 | on: 15 | push: 16 | branches: [ main ] 17 | pull_request: 18 | # The branches below must be a subset of the branches above 19 | branches: [ main ] 20 | schedule: 21 | - cron: '24 7 * * 1' 22 | 23 | jobs: 24 | analyze: 25 | name: Analyze 26 | runs-on: ubuntu-latest 27 | permissions: 28 | actions: read 29 | contents: read 30 | security-events: write 31 | 32 | strategy: 33 | fail-fast: false 34 | matrix: 35 | language: [ 'python' ] 36 | # CodeQL supports [ 'cpp', 'csharp', 'go', 'java', 'javascript', 'python', 'ruby' ] 37 | # Learn more about CodeQL language support at https://git.io/codeql-language-support 38 | 39 | steps: 40 | - name: Checkout repository 41 | uses: actions/checkout@v3 42 | 43 | # Initializes the CodeQL tools for scanning. 44 | - name: Initialize CodeQL 45 | uses: github/codeql-action/init@v2 46 | with: 47 | languages: ${{ matrix.language }} 48 | # If you wish to specify custom queries, you can do so here or in a config file. 49 | # By default, queries listed here will override any specified in a config file. 50 | # Prefix the list here with "+" to use these queries and those in the config file. 51 | # queries: ./path/to/local/query, your-org/your-repo/queries@main 52 | 53 | # Autobuild attempts to build any compiled languages (C/C++, C#, or Java). 54 | # If this step fails, then you should remove it and run the build manually (see below) 55 | - name: Autobuild 56 | uses: github/codeql-action/autobuild@v2 57 | 58 | # ℹ️ Command-line programs to run using the OS shell. 59 | # 📚 https://git.io/JvXDl 60 | 61 | # ✏️ If the Autobuild fails above, remove it and uncomment the following three lines 62 | # and modify them (or add more) to build your code if your project 63 | # uses a compiled language 64 | 65 | #- run: | 66 | # make bootstrap 67 | # make release 68 | 69 | - name: Perform CodeQL Analysis 70 | uses: github/codeql-action/analyze@v2 71 | -------------------------------------------------------------------------------- /.github/workflows/coverage.yml: -------------------------------------------------------------------------------- 1 | name: Codecov 2 | on: [push, pull_request] 3 | jobs: 4 | run: 5 | runs-on: ubuntu-latest 6 | steps: 7 | - name: Checkout 8 | uses: actions/checkout@v4 9 | - name: Set up Python 3.11 10 | uses: actions/setup-python@v4 11 | with: 12 | python-version: 3.11 13 | - name: Install dependencies 14 | run: pip install '.[dev]' 15 | - name: Run tests and collect coverage 16 | run: | 17 | coverage run -m pytest 18 | coverage xml 19 | - name: Upload coverage to Codecov 20 | uses: codecov/codecov-action@v3 21 | 22 | -------------------------------------------------------------------------------- /.github/workflows/docs.yml: -------------------------------------------------------------------------------- 1 | name: MkDocs 2 | on: 3 | push: 4 | branches: 5 | - master 6 | - main 7 | permissions: 8 | contents: write 9 | jobs: 10 | deploy: 11 | runs-on: ubuntu-latest 12 | steps: 13 | - uses: actions/checkout@v4 14 | - uses: actions/setup-python@v4 15 | with: 16 | python-version: 3.x 17 | - run: echo "cache_id=$(date --utc '+%V')" >> $GITHUB_ENV 18 | - uses: actions/cache@v3 19 | with: 20 | key: mkdocs-material-${{ env.cache_id }} 21 | path: .cache 22 | restore-keys: | 23 | mkdocs-material- 24 | - run: pip install ".[docs]" 25 | - run: mkdocs build 26 | - run: cp README.md docs/index.md 27 | - run: mkdocs gh-deploy --force -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/lightning* 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | .yarn* 30 | .pnp* 31 | node_modules 32 | .storage 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .nox/ 48 | .coverage 49 | .coverage.* 50 | .cache 51 | nosetests.xml 52 | coverage.xml 53 | *.cover 54 | *.py,cover 55 | .hypothesis/ 56 | .pytest_cache/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | db.sqlite3 66 | db.sqlite3-journal 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | .python-version 90 | 91 | # pipenv 92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 95 | # install all needed dependencies. 96 | #Pipfile.lock 97 | 98 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 99 | __pypackages__/ 100 | 101 | # Celery stuff 102 | celerybeat-schedule 103 | celerybeat.pid 104 | 105 | # SageMath parsed files 106 | *.sage.py 107 | 108 | # Environments 109 | .env 110 | .venv 111 | env/ 112 | venv/ 113 | ENV/ 114 | env.bak/ 115 | venv.bak/ 116 | .env*.local 117 | 118 | # Spyder project settings 119 | .spyderproject 120 | .spyproject 121 | 122 | # Rope project settings 123 | .ropeproject 124 | 125 | # mkdocs documentation 126 | /site 127 | 128 | # mypy 129 | .mypy_cache/ 130 | .dmypy.json 131 | dmypy.json 132 | 133 | # Pyre type checker 134 | .pyre/ 135 | 136 | # DS store 137 | .DS_Store 138 | 139 | # VS Code 140 | *.code-workspace 141 | .vscode 142 | 143 | # PyCharm 144 | .idea 145 | 146 | # TypeScript and JavaScript 147 | # dependencies 148 | /node_modules 149 | /.pnp 150 | .pnp.js 151 | 152 | # testing 153 | /coverage 154 | 155 | # next.js 156 | .next/ 157 | out/ 158 | 159 | # production 160 | build 161 | 162 | # misc 163 | .DS_Store 164 | *.pem 165 | 166 | # debug 167 | npm-debug.log* 168 | yarn-debug.log* 169 | yarn-error.log* 170 | .pnpm-debug.log* 171 | 172 | # vercel 173 | .vercel 174 | 175 | # typescript 176 | *.tsbuildinfo 177 | next-env.d.ts 178 | 179 | # turbo 180 | .turbo 181 | 182 | # ruff 183 | .ruff_cache/ 184 | 185 | # PROJECT 186 | # profiler, logger logs, model checkpoints, production models, data 187 | models/checkpoints/ 188 | models/onnx/ 189 | logs/csv/ 190 | data/cache/ 191 | data/training_split/ 192 | logs/csv/tensorboard -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | default_language_version: 2 | python: python3 3 | 4 | repos: 5 | - repo: https://github.com/pre-commit/pre-commit-hooks 6 | rev: v5.0.0 7 | hooks: 8 | - id: trailing-whitespace 9 | - id: end-of-file-fixer 10 | - id: check-yaml 11 | - id: check-added-large-files 12 | 13 | - repo: https://github.com/PyCQA/isort 14 | rev: 5.13.2 15 | hooks: 16 | - id: isort 17 | name: Format imports 18 | 19 | - repo: https://github.com/psf/black 20 | rev: 22.3.0 21 | hooks: 22 | - id: black 23 | name: Format code 24 | 25 | - repo: https://github.com/adamchainz/blacken-docs 26 | rev: v1.12.1 27 | hooks: 28 | - id: blacken-docs 29 | args: [--line-length=120] 30 | additional_dependencies: [black==21.12b0] 31 | 32 | - repo: https://github.com/PyCQA/flake8 33 | rev: 7.1.0 34 | hooks: 35 | - id: flake8 36 | args: ["--max-line-length=120", "--ignore=F401, W503"] 37 | name: Check PEP8 38 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | message: "If you use this software, please cite it as below." 3 | authors: 4 | - family-names: Goheen 5 | given-names: Justin 6 | - name: "Justin R. Goheen" 7 | title: "Lightning Lab Example" 8 | version: 0.0.5 9 | date-released: 2022-08-06 10 | license: "Apache-2.0" 11 | repository-code: "https://github.com/JustinGoheen/lightning-vision-example" 12 | keywords: 13 | - machine learning 14 | - deep learning 15 | - artificial intelligence 16 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2022 Justin R. Goheen 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /MANIFEST.ini: -------------------------------------------------------------------------------- 1 | include *.cff 2 | exclude *.toml 3 | exclude requirements.txt 4 | exclude __pycache__ 5 | exclude *.lightningai 6 | exclude *.devcontainer 7 | exclude *.github 8 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Vision Lab 2 | 3 | 16 | 17 | ## Overview 18 | 19 | Vision Lab is a public template for computer vision deep learning research projects using [TorchVision](https://pytorch.org/vision/stable/index.html) and Lightning AI's [PyTorch Lightning](https://lightning.ai/docs/pytorch/latest/). 20 | 21 | Use Vision Lab to train or finetune the default torchvision Vision Transformer or make it your own by implementing a new model and dataset after cloning the repo. 22 | 23 | You can fork Vision Lab with the [use this template](https://github.com/new?template_name=vision-lab&template_owner=JustinGoheen) button. 24 | 25 | > [!NOTE] 26 | > Vision Lab was featured by Weights and Biases in [this community spotlight](https://wandb.ai/wandb_fc/repo-spotlight/reports/Community-Spotlight-Lightning-Pod-Series--Vmlldzo0MDI2OTc0) 27 | 28 | ## Source Module 29 | 30 | `visionlab.core` contains code for the Lightning Module and Trainer. 31 | 32 | `visionlab.components` contains experiment utilities grouped by purpose for cohesion. 33 | 34 | `visionlab.pipeline` contains code for data acquistion and preprocessing, and building a TorchDataset and LightningDataModule. 35 | 36 | `visionlab.serve` contains code for model serving APIs built with [FastAPI](https://fastapi.tiangolo.com/project-generation/#machine-learning-models-with-spacy-and-fastapi). 37 | 38 | `visionlab.cli` contains code for the command line interface built with [Typer](https://typer.tiangolo.com/) and [Rich](https://rich.readthedocs.io/en/stable/). 39 | 40 | `visionlab.pages` contains code for data apps built with [Streamlit](https://streamlit.io/). 41 | 42 | `visionlab.config` assists with project, trainer, and sweep configurations. 43 | 44 | ## Base Requirements and Extras 45 | 46 | Vision Lab installs minimal requirements out of the box, and provides extras to make creating robust virtual environments easier. To view the requirements, in [setup.cfg](setup.cfg), see `install_requires` for the base requirements and `options.extras_require` for the available extras. 47 | 48 | The recommended install is as follows: 49 | 50 | ```sh 51 | python3 -m venv .venv 52 | source .venv/bin/activate 53 | pip install -e ".[all]" 54 | ``` 55 | 56 | ## Using Vision Lab 57 | 58 | Vision Lab also enables use of a CLI named `lab` that is built with [Typer](https://typer.tiangolo.com). This CLI is available in the terminal after install. `lab`'s features can be viewed with: 59 | 60 | ```sh 61 | lab --help 62 | ``` 63 | 64 | A [fast dev run](https://lightning.ai/docs/pytorch/latest/common/trainer.html#fast-dev-run) cab be ran with: 65 | 66 | ```sh 67 | lab run dev 68 | ``` 69 | 70 | A longer demo run can be inititated with: 71 | 72 | ```sh 73 | lab run demo 74 | ``` 75 | 76 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | # Copyright Justin R. Goheen. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | 17 | from lightning import LightningApp 18 | 19 | from visionlab import config 20 | from visionlab.components import TrainerFlow 21 | 22 | os.environ["WANDB_CONFIG_DIR"] = config.ExperimentManager.WANDB_CONFIG_DIR 23 | 24 | # TODO give a really verbose example of payload 25 | sweep_payload = dict( 26 | project_name="visionlab", # the wandb project name 27 | trial_count=2, # low trial count for proof of concept (POC) 28 | machine="default", # "gpu-rtx" if is_cloud_run else "default" 29 | idle_timeout=60, # time in seconds; wandb needs time to finish logging sweep 30 | interruptible=False, # set to True for spot instances. False because not supported yet 31 | trainer_init_flags=config.Sweep.fast_trainer_flags, # sets low max epochs for POC 32 | wandb_save_dir=config.Paths.wandb_logs, # where wandb will push logs to locally 33 | model_kwargs=config.Module.model_kwargs, # args required by ViT 34 | ) 35 | 36 | # TODO give a really verbose example of payload 37 | trainer_payload = dict( 38 | tune=True, # let trainer know to expect a tuned config payload 39 | machine="default", # "gpu-rtx" if is_cloud_run else "default" 40 | idle_timeout=30, # time in seconds; give wandb time to finish 41 | interruptible=False, # set to True for spot instances. False because not supported yet 42 | trainer_flags=config.Trainer.fast_flags, # sets low max epochs for POC 43 | model_kwargs=config.Module.model_kwargs, # args required by ViT 44 | ) 45 | 46 | 47 | app = LightningApp(TrainerFlow(sweep_payload=sweep_payload, trainer_payload=trainer_payload)) 48 | -------------------------------------------------------------------------------- /checkpoints/production/README.md: -------------------------------------------------------------------------------- 1 | directory for fully trained checkpoints -------------------------------------------------------------------------------- /checkpoints/trials/README.md: -------------------------------------------------------------------------------- 1 | directory for experiment checkpoints -------------------------------------------------------------------------------- /dash/history/app.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | import dash 5 | import dash_bootstrap_components as dbc 6 | from dash import html 7 | from dash.dependencies import Input, Output 8 | from components import NavBar, Body 9 | 10 | from utilities import create_figure 11 | 12 | 13 | this_file = Path(__file__) 14 | this_studio_idx = [i for i, j in enumerate(this_file.parents) if j.name.endswith("this_studio")][0] 15 | this_studio = this_file.parents[this_studio_idx] 16 | csvlogs = os.path.join(this_studio, "vision-lab", "logs", "csv") 17 | 18 | app = dash.Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP]) 19 | app.layout = html.Div([NavBar, html.Br(), Body]) 20 | 21 | 22 | @app.callback( 23 | Output("metric-graph", "figure"), 24 | [Input("dropdown", "value")], 25 | ) 26 | def update_figure(label_value): 27 | print(label_value) 28 | fig = create_figure(os.path.join(csvlogs, label_value, "metrics.csv")) 29 | return fig 30 | 31 | 32 | app.run_server(port=8000) 33 | -------------------------------------------------------------------------------- /dash/history/assets/styles.css: -------------------------------------------------------------------------------- 1 | /* # Copyright Justin R. Goheen. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. */ 14 | 15 | .pretty-container { 16 | border-radius: 8px; 17 | background-color: #f9f9f9; 18 | margin: 5px; 19 | padding: 10px; 20 | position: relative; 21 | box-shadow: 2px 2px 2px lightgrey; 22 | } 23 | 24 | .metric-container { 25 | border-radius: 8px; 26 | background-color: #f9f9f9; 27 | margin: 10px; 28 | margin-top: 0%; 29 | padding: 10px; 30 | position: relative; 31 | box-shadow: 2px 2px 2px lightgrey; 32 | } 33 | 34 | .metric-card-text { 35 | margin: 0px; 36 | padding: 0px; 37 | font-family: FreightSans, Helvetica Neue, Helvetica, Arial, sans-serif; 38 | color: darkslategray 39 | } 40 | 41 | .model-card-container { 42 | border-radius: 8px; 43 | background-color: #f9f9f9; 44 | margin: 0px; 45 | padding: 0px; 46 | position: relative; 47 | box-shadow: 2px 2px 2px lightgrey; 48 | } 49 | 50 | .model-card-text { 51 | margin: 0px; 52 | padding: 1px; 53 | font-family: FreightSans, Helvetica Neue, Helvetica, Arial, sans-serif; 54 | color: darkslategray; 55 | } 56 | 57 | .card-title { 58 | margin: 0px; 59 | padding: 0px; 60 | font-family: Ucityweb, sans-serif; 61 | font-weight: normal; 62 | } 63 | 64 | .app-title { 65 | font-family: Montserrat, sans-serif 66 | } 67 | 68 | #left-fig .modebar { 69 | margin-top: 10px; 70 | } 71 | 72 | #right-fig .modebar { 73 | margin-top: 10px; 74 | } -------------------------------------------------------------------------------- /dash/history/components.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | import dash_bootstrap_components as dbc 5 | from dash import dcc, html 6 | 7 | from utilities import create_figure 8 | 9 | this_file = Path(__file__) 10 | this_studio_idx = [i for i, j in enumerate(this_file.parents) if j.name.endswith("this_studio")][0] 11 | this_studio = this_file.parents[this_studio_idx] 12 | csvlogs = os.path.join(this_studio, "vision-lab", "logs", "csv") 13 | 14 | RUNS = os.listdir(csvlogs) 15 | 16 | NavBar = dbc.NavbarSimple( 17 | brand="VisionTransformer Base 32 Run Metrics", 18 | color="#792ee5", 19 | dark=True, 20 | fluid=True, 21 | className="app-title", 22 | ) 23 | 24 | Control = dbc.Card( 25 | dbc.CardBody( 26 | [ 27 | html.H1("Run Version", className="card-title"), 28 | dcc.Dropdown( 29 | options=RUNS, 30 | value=RUNS[0], 31 | multi=False, 32 | id="dropdown", 33 | searchable=True, 34 | ), 35 | ] 36 | ), 37 | className="model-card-container", 38 | ) 39 | 40 | SideBar = dbc.Col([Control], width=3) 41 | 42 | Graph = dbc.Col( 43 | dcc.Loading( 44 | [ 45 | dcc.Graph( 46 | id="metric-graph", 47 | figure=create_figure(os.path.join(csvlogs, RUNS[0], "metrics.csv")), 48 | config={ 49 | "responsive": True, 50 | "displayModeBar": True, 51 | "displaylogo": False, 52 | }, 53 | ), 54 | dcc.Interval(id="interval-component", interval=1 * 1000, n_intervals=0), # in milliseconds 55 | ] 56 | ) 57 | ) 58 | 59 | Body = dbc.Container(dbc.Row([SideBar, Graph]), fluid=True) 60 | -------------------------------------------------------------------------------- /dash/history/utilities.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Union 3 | 4 | import pandas as pd 5 | import plotly.graph_objects as go 6 | 7 | 8 | def create_figure(path: Union[str, Path]): 9 | if isinstance(path, str): 10 | run_name = path.split("/")[-2] 11 | else: 12 | run_name = path.parent.name 13 | run_name = " ".join(run_name.split("_")).title() 14 | 15 | data = pd.read_csv(path).drop("step", axis=1) 16 | fig = go.Figure() 17 | fig.add_trace(go.Scatter(x=data.index, y=data["training-loss"])) 18 | fig.update_layout( 19 | title=dict( 20 | text=f"Run Metrics: {run_name}", 21 | font_family="Ucityweb, sans-serif", 22 | font=dict(size=24), 23 | y=0.90, 24 | yanchor="bottom", 25 | x=0.5, 26 | ) 27 | ) 28 | return fig 29 | -------------------------------------------------------------------------------- /dash/model/app.py: -------------------------------------------------------------------------------- 1 | # Copyright Justin R. Goheen. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import dash 16 | import dash_bootstrap_components as dbc 17 | import torch 18 | from dash import html 19 | from dash.dependencies import Input, Output 20 | from components import Body, create_figure, find_index, NavBar 21 | 22 | from visionlab import config 23 | 24 | PREDICTIONS = torch.load(config.Paths.predictions) 25 | DATASET = torch.load(config.Paths.test_split) 26 | LABELNAMES = DATASET.classes 27 | 28 | 29 | app = dash.Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP]) 30 | app.layout = html.Div( 31 | [ 32 | NavBar, 33 | html.Br(), 34 | Body, 35 | ] 36 | ) 37 | 38 | 39 | @app.callback( 40 | [Output("gt-fig", "figure"), Output("pred-card", "children")], 41 | [Input("dropdown", "value")], 42 | ) 43 | def update_figure(label_value): 44 | xidx = 0 45 | labelidx = 1 46 | idx = find_index(DATASET, label=LABELNAMES.index(label_value), label_idx=labelidx) 47 | gt = DATASET[idx][xidx] 48 | pred = LABELNAMES[torch.argmax(PREDICTIONS[idx][labelidx])] 49 | fig = create_figure(gt, "Ground Truth") 50 | return fig, pred 51 | 52 | 53 | app.run_server(port=8000) 54 | -------------------------------------------------------------------------------- /dash/model/assets/styles.css: -------------------------------------------------------------------------------- 1 | /* # Copyright Justin R. Goheen. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. */ 14 | 15 | .pretty-container { 16 | border-radius: 8px; 17 | background-color: #f9f9f9; 18 | margin: 5px; 19 | padding: 10px; 20 | position: relative; 21 | box-shadow: 2px 2px 2px lightgrey; 22 | } 23 | 24 | .metric-container { 25 | border-radius: 8px; 26 | background-color: #f9f9f9; 27 | margin: 10px; 28 | margin-top: 0%; 29 | padding: 10px; 30 | position: relative; 31 | box-shadow: 2px 2px 2px lightgrey; 32 | } 33 | 34 | .metric-card-text { 35 | margin: 0px; 36 | padding: 0px; 37 | font-family: FreightSans, Helvetica Neue, Helvetica, Arial, sans-serif; 38 | color: darkslategray 39 | } 40 | 41 | .model-card-container { 42 | border-radius: 8px; 43 | background-color: #f9f9f9; 44 | margin: 0px; 45 | padding: 0px; 46 | position: relative; 47 | box-shadow: 2px 2px 2px lightgrey; 48 | } 49 | 50 | .model-card-text { 51 | margin: 0px; 52 | padding: 1px; 53 | font-family: FreightSans, Helvetica Neue, Helvetica, Arial, sans-serif; 54 | color: darkslategray; 55 | } 56 | 57 | .card-title { 58 | margin: 0px; 59 | padding: 0px; 60 | font-family: Ucityweb, sans-serif; 61 | font-weight: normal; 62 | } 63 | 64 | .app-title { 65 | font-family: Montserrat, sans-serif 66 | } 67 | 68 | #left-fig .modebar { 69 | margin-top: 10px; 70 | } 71 | 72 | #right-fig .modebar { 73 | margin-top: 10px; 74 | } -------------------------------------------------------------------------------- /dash/model/components.py: -------------------------------------------------------------------------------- 1 | # Copyright Justin R. Goheen. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import dash_bootstrap_components as dbc 17 | import torch 18 | from dash import dcc, html 19 | from utilities import create_figure, find_index, make_metrics_summary, make_model_summary 20 | 21 | from visionlab import config 22 | 23 | PREDICTIONS = torch.load(config.Paths.predictions) 24 | DATASET = torch.load(config.Paths.test_split) 25 | LABELNAMES = DATASET.classes 26 | LABELS = list(range(len(LABELNAMES))) 27 | LABELIDX = find_index(DATASET, label=LABELS[0], label_idx=1) 28 | 29 | 30 | # MODEL SUMMARY 31 | model_summary = make_model_summary() 32 | metrics = make_metrics_summary() 33 | metrics_names = list(metrics.keys()) 34 | metrics_values = [round(i, 4) for i in list(metrics.values())] 35 | 36 | # APP LAYOUT 37 | NavBar = dbc.NavbarSimple( 38 | brand="VisionTransformer Base 32", 39 | color="#792ee5", 40 | dark=True, 41 | fluid=True, 42 | className="app-title", 43 | ) 44 | 45 | Control = dbc.Card( 46 | dbc.CardBody( 47 | [ 48 | html.H1("Label", className="card-title"), 49 | dcc.Dropdown( 50 | options=LABELNAMES, 51 | value=LABELNAMES[0], 52 | multi=False, 53 | id="dropdown", 54 | searchable=True, 55 | ), 56 | ] 57 | ), 58 | className="model-card-container", 59 | ) 60 | 61 | ModelCard = dbc.Card( 62 | [ 63 | dbc.CardBody( 64 | [ 65 | html.H1("Model Card", id="model_card", className="card-title"), 66 | html.Br(), 67 | html.H3("Layers", className="card-title"), 68 | model_summary["layers"], 69 | html.Br(), 70 | html.H3("Parameters", className="card-title"), 71 | html.P( 72 | f"{model_summary['params'][0]}", 73 | id="model_info_1", 74 | className="model-card-text", 75 | ), 76 | html.P( 77 | f"{model_summary['params'][1]}", 78 | id="model_info_2", 79 | className="model-card-text", 80 | ), 81 | html.P( 82 | f"{model_summary['params'][2]}", 83 | id="model_info_3", 84 | className="model-card-text", 85 | ), 86 | html.P( 87 | f"{model_summary['params'][3]}", 88 | id="model_info_4", 89 | className="model-card-text", 90 | ), 91 | ] 92 | ), 93 | ], 94 | className="model-card-container", 95 | ) 96 | 97 | SideBar = dbc.Col( 98 | [ 99 | Control, 100 | html.Br(), 101 | ModelCard, 102 | ], 103 | width=3, 104 | ) 105 | 106 | GroundTruth = dcc.Loading( 107 | [ 108 | dcc.Graph( 109 | id="gt-fig", 110 | figure=create_figure(DATASET[LABELIDX][0], "Ground Truth"), 111 | config={ 112 | "responsive": True, 113 | "displayModeBar": True, 114 | "displaylogo": False, 115 | }, 116 | ) 117 | ] 118 | ) 119 | 120 | Prediction = dbc.Card( 121 | [ 122 | dbc.CardHeader("Predicted Class"), 123 | dbc.CardBody( 124 | [ 125 | html.H3( 126 | LABELNAMES[torch.argmax(PREDICTIONS[LABELIDX][0])], 127 | id="pred-card", 128 | ), 129 | ] 130 | ), 131 | ], 132 | ) 133 | 134 | Metrics = dbc.Row( 135 | [ 136 | dbc.Col( 137 | [ 138 | dbc.Card( 139 | [ 140 | html.H4(metrics_names[0], className="card-title"), 141 | html.P(metrics_values[0], id="metric_1_text", className="metric-card-text"), 142 | ], 143 | id="metric_1_card", 144 | className="metric-container", 145 | ) 146 | ], 147 | width=3, 148 | ), 149 | dbc.Col( 150 | [ 151 | dbc.Card( 152 | [ 153 | html.H4(metrics_names[1], className="card-title"), 154 | html.P(metrics_values[1], id="metric_2_text", className="metric-card-text"), 155 | ], 156 | id="metric_2_card", 157 | className="metric-container", 158 | ) 159 | ], 160 | width=3, 161 | ), 162 | dbc.Col( 163 | [ 164 | dbc.Card( 165 | [ 166 | html.H4(metrics_names[2], className="card-title"), 167 | html.P(metrics_values[2], id="metric_3_text", className="metric-card-text"), 168 | ], 169 | id="metric_3_card", 170 | className="metric-container", 171 | ) 172 | ], 173 | width=3, 174 | ), 175 | # dbc.Col( 176 | # [ 177 | # dbc.Card( 178 | # [ 179 | # html.H4(metrics_names[3], className="card-title"), 180 | # html.P(metrics_values[3], id="metric_4_text", className="metric-card-text"), 181 | # ], 182 | # id="metric_4_card", 183 | # className="metric-container", 184 | # ) 185 | # ], 186 | # width=3, 187 | # ), 188 | ], 189 | id="metrics", 190 | justify="center", 191 | ) 192 | 193 | Graphs = dbc.Row( 194 | [ 195 | dbc.Col([GroundTruth], className="pretty-container", width=4), 196 | dbc.Col(width=1), 197 | dbc.Col([Prediction], width=4), 198 | ], 199 | justify="center", 200 | align="middle", 201 | className="graph-row", 202 | ) 203 | 204 | MainArea = dbc.Col([Metrics, html.Br(), Graphs]) 205 | 206 | Body = dbc.Container([dbc.Row([SideBar, MainArea])], fluid=True) 207 | -------------------------------------------------------------------------------- /dash/model/utilities.py: -------------------------------------------------------------------------------- 1 | # Copyright Justin R. Goheen. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | 17 | import numpy as np 18 | import pandas as pd 19 | import plotly.express as px 20 | from dash import dash_table 21 | from lightning.pytorch.utilities.model_summary import ModelSummary 22 | 23 | from visionlab import config, VisionTransformer 24 | 25 | 26 | def make_metrics_summary(): 27 | logsdir = os.path.join("logs", "csv") 28 | logs = os.listdir(logsdir) 29 | most_recent = os.path.join(logsdir, f"version_{len(logs)-1}", "metrics.csv") 30 | summary = pd.read_csv(most_recent) 31 | if not pd.isna(summary["val-loss"].iloc[-1]): 32 | index = -1 33 | elif pd.isna(summary["val-loss"].iloc[-1]): 34 | index = -2 35 | collection = { 36 | "Training Loss": summary["training-loss"].iloc[index - 1], 37 | "Val Loss": summary["val-loss"].iloc[index], 38 | "Val Acc": summary["val-acc"].iloc[index], 39 | } 40 | return collection 41 | 42 | 43 | def create_figure(image, title_text): 44 | image = np.transpose(image.numpy(), (1, 2, 0)) 45 | fig = px.imshow(image) 46 | fig.update_layout( 47 | title=dict( 48 | text=title_text, 49 | font_family="Ucityweb, sans-serif", 50 | font=dict(size=24), 51 | y=0.05, 52 | yanchor="bottom", 53 | x=0.5, 54 | ), 55 | height=300, 56 | ) 57 | return fig 58 | 59 | 60 | def make_model_layer_table(model_summary: list): 61 | model_layers = model_summary[:-4] 62 | model_layers = [i for i in model_layers if not all(j == "-" for j in i)] 63 | model_layers = [i.split("|") for i in model_layers] 64 | model_layers = [[j.strip() for j in i] for i in model_layers] 65 | model_layers[0][0] = "Layer" 66 | header = model_layers[0] 67 | body = model_layers[1:] 68 | table = pd.DataFrame(body, columns=header) 69 | table = dash_table.DataTable( 70 | data=table.to_dict("records"), 71 | columns=[{"name": i, "id": i} for i in table.columns], 72 | style_cell={ 73 | "textAlign": "left", 74 | "font-family": "FreightSans, Helvetica Neue, Helvetica, Arial, sans-serif", 75 | }, 76 | style_as_list_view=True, 77 | style_table={ 78 | "overflow-x": "auto", 79 | }, 80 | style_header={"border": "0px solid black"}, 81 | ) 82 | return table 83 | 84 | 85 | def make_model_param_text(model_summary: list): 86 | model_params = model_summary[-4:] 87 | model_params = [i.split(" ") for i in model_params] 88 | model_params = [[i[0]] + [i[-1]] for i in model_params] 89 | model_params = [[j.strip() for j in i] for i in model_params] 90 | model_params = [i[::-1] for i in model_params] 91 | model_params[-1][0] = "Est. params size (MB)" 92 | model_params = ["".join([i[0], ": ", i[-1]]) for i in model_params] 93 | return model_params 94 | 95 | 96 | def make_model_summary(): 97 | available_checkpoints = os.listdir(config.Paths.ckpts) 98 | available_checkpoints.remove("README.md") 99 | latest_checkpoint = available_checkpoints[0] 100 | chkpt_filename = os.path.join(config.Paths.ckpts, latest_checkpoint) 101 | model = VisionTransformer.load_from_checkpoint(chkpt_filename) 102 | model_summary = ModelSummary(model) 103 | model_summary = model_summary.__str__().split("\n") 104 | model_layers = make_model_layer_table(model_summary) 105 | model_params = make_model_param_text(model_summary) 106 | return {"layers": model_layers, "params": model_params} 107 | 108 | 109 | def find_index(dataset, label, label_idx): 110 | for i in range(len(dataset)): 111 | if dataset[i][label_idx] == label: 112 | return i 113 | -------------------------------------------------------------------------------- /dash/training/app.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | import dash 5 | import dash_bootstrap_components as dbc 6 | from dash import html 7 | from dash.dependencies import Input, Output 8 | from components import NavBar, Body 9 | 10 | from utilities import create_figure 11 | 12 | 13 | this_file = Path(__file__) 14 | this_studio_idx = [i for i, j in enumerate(this_file.parents) if j.name.endswith("this_studio")][0] 15 | this_studio = this_file.parents[this_studio_idx] 16 | csvlogs = os.path.join(this_studio, "vision-lab", "logs", "csv") 17 | 18 | runs = os.listdir(csvlogs) 19 | numruns = len(runs) 20 | tgtrun = numruns - 1 21 | 22 | app = dash.Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP]) 23 | app.layout = html.Div([NavBar, html.Br(), Body]) 24 | 25 | 26 | @app.callback( 27 | Output("metric-graph", "figure"), 28 | [Input("interval-component", "n_intervals")], 29 | ) 30 | def update_figure(n_intervals): 31 | fig = create_figure(os.path.join(csvlogs, runs[tgtrun], "metrics.csv")) 32 | return fig 33 | 34 | 35 | app.run_server(port=8000) 36 | -------------------------------------------------------------------------------- /dash/training/assets/styles.css: -------------------------------------------------------------------------------- 1 | /* # Copyright Justin R. Goheen. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. */ 14 | 15 | .pretty-container { 16 | border-radius: 8px; 17 | background-color: #f9f9f9; 18 | margin: 5px; 19 | padding: 10px; 20 | position: relative; 21 | box-shadow: 2px 2px 2px lightgrey; 22 | } 23 | 24 | .metric-container { 25 | border-radius: 8px; 26 | background-color: #f9f9f9; 27 | margin: 10px; 28 | margin-top: 0%; 29 | padding: 10px; 30 | position: relative; 31 | box-shadow: 2px 2px 2px lightgrey; 32 | } 33 | 34 | .metric-card-text { 35 | margin: 0px; 36 | padding: 0px; 37 | font-family: FreightSans, Helvetica Neue, Helvetica, Arial, sans-serif; 38 | color: darkslategray 39 | } 40 | 41 | .model-card-container { 42 | border-radius: 8px; 43 | background-color: #f9f9f9; 44 | margin: 0px; 45 | padding: 0px; 46 | position: relative; 47 | box-shadow: 2px 2px 2px lightgrey; 48 | } 49 | 50 | .model-card-text { 51 | margin: 0px; 52 | padding: 1px; 53 | font-family: FreightSans, Helvetica Neue, Helvetica, Arial, sans-serif; 54 | color: darkslategray; 55 | } 56 | 57 | .card-title { 58 | margin: 0px; 59 | padding: 0px; 60 | font-family: Ucityweb, sans-serif; 61 | font-weight: normal; 62 | } 63 | 64 | .app-title { 65 | font-family: Montserrat, sans-serif 66 | } 67 | 68 | #left-fig .modebar { 69 | margin-top: 10px; 70 | } 71 | 72 | #right-fig .modebar { 73 | margin-top: 10px; 74 | } -------------------------------------------------------------------------------- /dash/training/components.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | from time import sleep 4 | 5 | import dash_bootstrap_components as dbc 6 | from dash import dcc, html 7 | 8 | from utilities import create_figure 9 | 10 | this_file = Path(__file__) 11 | this_studio_idx = [i for i, j in enumerate(this_file.parents) if j.name.endswith("this_studio")][0] 12 | this_studio = this_file.parents[this_studio_idx] 13 | csvlogs = os.path.join(this_studio, "vision-lab", "logs", "csv") 14 | 15 | runs = os.listdir(csvlogs) 16 | numruns = len(runs) 17 | tgtrun = numruns - 1 18 | 19 | 20 | NavBar = dbc.NavbarSimple( 21 | brand="VisionTransformer Base 32 Run Metrics", 22 | color="#792ee5", 23 | dark=True, 24 | fluid=True, 25 | className="app-title", 26 | ) 27 | 28 | 29 | Graph = dbc.Col( 30 | [ 31 | dcc.Graph( 32 | id="metric-graph", 33 | figure=create_figure(os.path.join(csvlogs, runs[tgtrun], "metrics.csv")), 34 | config={ 35 | "responsive": True, 36 | "displayModeBar": True, 37 | "displaylogo": False, 38 | }, 39 | ), 40 | dcc.Interval(id="interval-component", interval=1 * 1000, n_intervals=0), # in milliseconds 41 | ] 42 | ) 43 | 44 | Body = dbc.Container(dbc.Row([Graph]), fluid=True) 45 | -------------------------------------------------------------------------------- /dash/training/utilities.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | from typing import Union 4 | 5 | import pandas as pd 6 | import plotly.graph_objects as go 7 | 8 | 9 | def create_figure(path: Union[str, Path]): 10 | if isinstance(path, str): 11 | run_name = path.split("/")[-2] 12 | else: 13 | run_name = path.parent.name 14 | run_name = " ".join(run_name.split("_")).title() 15 | 16 | if not os.path.exists(path): 17 | return go.Figure() 18 | 19 | data = pd.read_csv(path).drop("step", axis=1) 20 | fig = go.Figure() 21 | fig.add_trace(go.Scatter(x=data.index, y=data["training-loss"])) 22 | fig.update_layout( 23 | title=dict( 24 | text=f"Run Metrics: {run_name}", 25 | font_family="Ucityweb, sans-serif", 26 | font=dict(size=24), 27 | y=0.90, 28 | yanchor="bottom", 29 | x=0.5, 30 | ) 31 | ) 32 | return fig 33 | -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | The data directory for the data cache 2 | -------------------------------------------------------------------------------- /data/predictions/predictions.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jxtngx/vision-lab/289395593ffe2a6b8128547badabb0bf884198ec/data/predictions/predictions.pt -------------------------------------------------------------------------------- /docs/.authors.yml: -------------------------------------------------------------------------------- 1 | authors: 2 | justingoheen: 3 | name: Justin Goheen # Author name 4 | description: AI Engineer and Advocate # Author description 5 | avatar: https://avatars.githubusercontent.com/u/26209687 # Author avatar 6 | -------------------------------------------------------------------------------- /docs/.meta.yml: -------------------------------------------------------------------------------- 1 | comments: true 2 | hide: 3 | - feedback 4 | tags: 5 | - Lightning Labs 6 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | # Vision Lab 2 | 3 | ## Overview 4 | 5 | Vision lab is a public template for computer vision deep learning research projects using Lightning AI's [PyTorch Lightning](https://lightning.ai/docs/pytorch/latest/). 6 | 7 | Use Vision Lab to train or finetune the default torchvision Vision Transformer or make it your own by implementing a new model and dataset. 8 | 9 | The recommended way for Vision lab users to create new repos is with the [use this template](https://docs.github.com/en/repositories/creating-and-managing-repositories/creating-a-repository-from-a-template) button. 10 | 11 | ## Source Module 12 | 13 | `visionlab.core` should contain code for the Lightning Module and Trainer. 14 | 15 | `visionlab.components` should contain experiment utilities grouped by purpose for cohesion. 16 | 17 | `visionlab.pipeline` should contain code for data acquistion and preprocessing, and building a TorchDataset and LightningDataModule. 18 | 19 | `visionlab.api` should contain code for model serving APIs built with [FastAPI](https://fastapi.tiangolo.com/project-generation/#machine-learning-models-with-spacy-and-fastapi). 20 | 21 | `visionlab.cli` should contain code for the command line interface built with [Typer](https://typer.tiangolo.com/)and [Rich](https://rich.readthedocs.io/en/stable/). 22 | 23 | `visionlab.pages` should contain code for data apps built with streamlit. 24 | 25 | `visionlab.config` can assist with project, trainer, and sweep configurations. 26 | 27 | ## Base Requirements and Extras 28 | 29 | Vision lab installs minimal requirements out of the box, and provides extras to make creating robust virtual environments easier. To view the requirements, in [setup.cfg](setup.cfg), see `install_requires` for the base requirements and `options.extras_require` for the available extras. 30 | 31 | The recommended install is as follows: 32 | 33 | ```sh 34 | python3 -m venv .venv 35 | source .venv/bin/activate 36 | pip install -e ".[all]" 37 | ``` 38 | -------------------------------------------------------------------------------- /docs/reference/datamodule.md: -------------------------------------------------------------------------------- 1 | ::: visionlab.CifarDataModule 2 | options: 3 | heading_level: 2 4 | show_root_heading: true -------------------------------------------------------------------------------- /docs/reference/module.md: -------------------------------------------------------------------------------- 1 | ::: visionlab.VisionTransformer 2 | options: 3 | heading_level: 2 4 | show_root_heading: true -------------------------------------------------------------------------------- /helpers.txt: -------------------------------------------------------------------------------- 1 | # setting mps fallback 2 | export PYTORCH_ENABLE_MPS_FALLBACK=1 3 | -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: Vision Lab 2 | site_author: Justin Goheen 3 | repo_url: https://github.com/JustinGoheen/vision-lab 4 | repo_name: Justin Goheen 5 | 6 | theme: 7 | name: material 8 | icon: 9 | repo: fontawesome/brands/github 10 | font: 11 | text: Roboto 12 | code: Roboto Mono 13 | features: 14 | - navigation.sections 15 | - navigation.path 16 | - navigation.top 17 | - content.code.copy 18 | palette: 19 | # Palette toggle for automatic mode 20 | - scheme: default 21 | primary: black 22 | toggle: 23 | icon: material/brightness-7 24 | name: Switch to dark mode 25 | - scheme: slate 26 | primary: black 27 | toggle: 28 | icon: material/brightness-4 29 | name: Switch to light mode 30 | 31 | plugins: 32 | - search 33 | - mkdocstrings: 34 | default_handler: python 35 | handlers: 36 | python: 37 | paths: [src] 38 | options: 39 | docstring_style: google 40 | show_source: false 41 | line_length: 120 42 | 43 | extra: 44 | social: 45 | - icon: fontawesome/brands/github 46 | link: https://github.com/JustinGoheen 47 | - icon: fontawesome/brands/twitter 48 | link: https://twitter.com/Justin_Goheen 49 | - icon: fontawesome/brands/linkedin 50 | link: https://www.linkedin.com/in/justingoheen/ 51 | - icon: fontawesome/brands/discord 52 | link: https://discord.gg/XncpTy7DSt 53 | 54 | markdown_extensions: 55 | - abbr 56 | - admonition 57 | - attr_list 58 | - def_list 59 | - footnotes 60 | - md_in_html 61 | - pymdownx.superfences: 62 | custom_fences: 63 | - name: mermaid 64 | class: mermaid 65 | format: !!python/name:pymdownx.superfences.fence_code_format 66 | 67 | - pymdownx.highlight: 68 | anchor_linenums: true 69 | line_spans: __span 70 | pygments_lang_class: true 71 | - pymdownx.inlinehilite 72 | - pymdownx.snippets 73 | - pymdownx.superfences 74 | 75 | extra_javascript: 76 | - javascripts/mathjax.js 77 | - https://polyfill.io/v3/polyfill.min.js?features=es6 78 | - https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js 79 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | # Copyright Justin R. Goheen. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | [build-system] 16 | requires = [ 17 | "setuptools", 18 | "wheel", 19 | ] 20 | 21 | [tool.black] 22 | line-length = 120 23 | 24 | [tool.isort] 25 | known_first_party = [ 26 | "visionlab", 27 | ] 28 | profile = "black" 29 | line_length = 120 30 | force_sort_within_sections = "False" 31 | order_by_type = "False" 32 | 33 | [tool.ruff] 34 | line-length = 120 35 | # Enable Pyflakes `E` and `F` codes by default. 36 | select = [ 37 | "E", "W", # see: https://pypi.org/project/pycodestyle 38 | "F", # see: https://pypi.org/project/pyflakes 39 | ] 40 | ignore = [ 41 | "E731", # Do not assign a lambda expression, use a def 42 | ] 43 | # Exclude a variety of commonly ignored directories. 44 | exclude = [ 45 | ".git", 46 | "docs" 47 | ] 48 | ignore-init-module-imports = true 49 | 50 | [tool.ruff.mccabe] 51 | # Unlike Flake8, default to a complexity level of 10. 52 | max-complexity = 10 53 | 54 | [tool.mypy] 55 | files = ["visionlab"] 56 | install_types = true 57 | non_interactive = true 58 | disallow_untyped_defs = true 59 | ignore_missing_imports = true 60 | show_error_codes = true 61 | warn_redundant_casts = true 62 | warn_unused_configs = true 63 | warn_unused_ignores = true 64 | allow_redefinition = true 65 | # disable this rule as the Trainer attributes are defined in the connectors, not in its __init__ 66 | disable_error_code = "attr-defined" 67 | # style choices 68 | warn_no_return = "False" 69 | 70 | # do not add type hints to lightnig_pod/cli/seed/ or core.module 71 | # because lightning already defines types 72 | [[tool.mypy.overrides]] 73 | module = [ 74 | "visionlab.core.module", 75 | "visionlab.cli.seed.core.module", 76 | "visionlab.cli.seed.core.trainer", 77 | "visionlab.cli.seed.pipeline.datamodule", 78 | "visionlab.cli.seed.pipeline.dataset", 79 | ] 80 | ignore_errors = true 81 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # this should only be used by CircleCI for installing base and testing requirements 2 | -r requirements/base.txt 3 | -r requirements/dev.txt 4 | -------------------------------------------------------------------------------- /requirements/base.txt: -------------------------------------------------------------------------------- 1 | pytorch-lightning 2 | torch 3 | torchvision 4 | wandb 5 | optuna 6 | torch-tb-profiler -------------------------------------------------------------------------------- /requirements/cli.txt: -------------------------------------------------------------------------------- 1 | click 2 | rich 3 | -------------------------------------------------------------------------------- /requirements/dev.txt: -------------------------------------------------------------------------------- 1 | # FORMMATING, LINTING, TESTING, STATIC TYPE CHECKING, COVERAGE 2 | black 3 | pytest 4 | mypy 5 | bandit 6 | coverage 7 | pre-commit 8 | isort 9 | ruff 10 | -------------------------------------------------------------------------------- /requirements/docs.txt: -------------------------------------------------------------------------------- 1 | # DOC TOOLS, MARKDOWN SUPPORT, THEMES 2 | mkdocs-material 3 | mkdocstrings[python] 4 | mkdocs-glightbox -------------------------------------------------------------------------------- /requirements/frontends.txt: -------------------------------------------------------------------------------- 1 | streamlit 2 | plotly 3 | 4 | -------------------------------------------------------------------------------- /requirements/packaging.txt: -------------------------------------------------------------------------------- 1 | setuptools 2 | build 3 | twine 4 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | # Copyright Justin R. Goheen. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | [options] 16 | zip_safe = False 17 | packages = find: 18 | install_requires = 19 | pytorch-lightning 20 | torch 21 | torchvision 22 | torch-tb-profiler 23 | wandb 24 | typer 25 | rich 26 | 27 | [options.extras_require] 28 | dev = 29 | ruff 30 | black 31 | pytest 32 | mypy 33 | bandit 34 | coverage 35 | isort 36 | pre-commit 37 | docs = 38 | mkdocs-material 39 | mkdocstrings[python] 40 | mkdocs-glightbox 41 | dev-all = 42 | visionlab[dev, docs] 43 | fe = 44 | plotly 45 | dash 46 | dash-bootstrap-components 47 | all = 48 | visionlab[dev-all, fe] 49 | 50 | 51 | [options.entry_points] 52 | console_scripts = 53 | lab = visionlab.cli:app 54 | 55 | [flake8] 56 | max-line-length = 120 57 | 58 | [tool:pytest] 59 | testpaths = 60 | /tests 61 | norecursedirs = 62 | .git 63 | .github 64 | *.egg-info 65 | addopts = 66 | --disable-pytest-warnings 67 | filterwarnings = 68 | # IGNORE THIRD PARTY LIBRARY WARNINGS 69 | # ignore tensorboard proto warnings 70 | ignore: Call to deprecated* 71 | # ignore torchvision transform warning 72 | ignore: .* and will be removed in Pillow 10 73 | # ignore torch distributed warning 74 | ignore: torch.distributed*. 75 | # ignore PL UserWarning 76 | ignore: You are trying to `self.log()`* 77 | 78 | [coverage:run] 79 | disable_warnings = ["couldnt-parse"] 80 | 81 | [coverage:report] 82 | ignore_errors = true 83 | exclude_lines = ["pragma: no cover"] 84 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright Justin R. Goheen. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from pathlib import Path 16 | 17 | from setuptools import setup 18 | 19 | rootdir = Path(__file__).parent 20 | long_description = (rootdir / "README.md").read_text() 21 | 22 | setup( 23 | name="visionlab", 24 | package_dir={"": "src"}, 25 | packages=["visionlab"], 26 | version="0.0.1", 27 | description="An End to End Vision Transformer Example", 28 | long_description=long_description, 29 | long_description_content_type="text/markdown", 30 | author="Justin Goheen", 31 | license="Apache 2.0", 32 | author_email="", 33 | classifiers=[ 34 | "Environment :: Console", 35 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 36 | "Topic :: Scientific/Engineering :: Information Analysis", 37 | "Operating System :: OS Independent", 38 | "Programming Language :: Python :: 3.8", 39 | "Programming Language :: Python :: 3.9", 40 | "Programming Language :: Python :: 3.10", 41 | "Programming Language :: Python :: 3.11", 42 | ], 43 | ) 44 | -------------------------------------------------------------------------------- /src/visionlab/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright Justin R. Goheen. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from visionlab.module import VisionTransformer # noqa: F401 16 | from visionlab.datamodule import CifarDataModule # noqa: F401 17 | -------------------------------------------------------------------------------- /src/visionlab/cli.py: -------------------------------------------------------------------------------- 1 | # Copyright Justin R. Goheen. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | from pathlib import Path 17 | 18 | import torch 19 | import typer 20 | from lightning.pytorch import Trainer 21 | from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint 22 | from lightning.pytorch.loggers import CSVLogger 23 | 24 | from visionlab import config, CifarDataModule, VisionTransformer 25 | 26 | this_file = Path(__file__) 27 | this_studio_idx = [i for i, j in enumerate(this_file.parents) if j.name.endswith("this_studio")][0] 28 | this_studio = this_file.parents[this_studio_idx] 29 | csvlogs = os.path.join(this_studio, "vision-lab", "logs", "csv") 30 | 31 | app = typer.Typer() 32 | docs_app = typer.Typer() 33 | run_app = typer.Typer() 34 | app.add_typer(docs_app, name="docs") 35 | app.add_typer(run_app, name="run") 36 | 37 | 38 | @app.callback() 39 | def callback() -> None: 40 | pass 41 | 42 | 43 | # Docs 44 | @docs_app.command("build") 45 | def build_docs() -> None: 46 | import shutil 47 | 48 | os.system("mkdocs build") 49 | shutil.copyfile(src="README.md", dst="docs/index.md") 50 | 51 | 52 | @docs_app.command("serve") 53 | def serve_docs() -> None: 54 | os.system("mkdocs serve") 55 | 56 | 57 | # Run 58 | @run_app.command("dev") 59 | def run_dev(): 60 | datamodule = CifarDataModule() 61 | model = VisionTransformer() 62 | trainer = Trainer(fast_dev_run=True) 63 | trainer.fit(model=model, datamodule=datamodule) 64 | 65 | 66 | @run_app.command("trainer", context_settings={"allow_extra_args": True, "ignore_unknown_options": True}) 67 | def run_trainer( 68 | devices: str = "auto", 69 | accelerator: str = "auto", 70 | strategy: str = "auto", 71 | max_epochs: int = 10, 72 | predict: bool = True, 73 | ): 74 | datamodule = CifarDataModule() 75 | model = VisionTransformer() 76 | trainer = Trainer( 77 | devices=devices, 78 | accelerator=accelerator, 79 | strategy=strategy, 80 | max_epochs=max_epochs, 81 | enable_checkpointing=True, 82 | callbacks=[ 83 | EarlyStopping(monitor="val-loss", mode="min"), 84 | ModelCheckpoint(dirpath=config.Paths.ckpts, filename="model"), 85 | ], 86 | logger=CSVLogger(save_dir=config.Paths.logs, name="csv"), 87 | log_every_n_steps=1, 88 | ) 89 | trainer.fit(model=model, datamodule=datamodule) 90 | 91 | if predict: 92 | trainer.test(ckpt_path="best", datamodule=datamodule) 93 | predictions = trainer.predict(model, datamodule.test_dataloader()) 94 | torch.save(predictions, config.Paths.predictions) 95 | -------------------------------------------------------------------------------- /src/visionlab/config.py: -------------------------------------------------------------------------------- 1 | # Copyright Justin R. Goheen. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | import sys 17 | from functools import partial 18 | from pathlib import Path 19 | 20 | import torch 21 | from pytorch_lightning.accelerators.mps import MPSAccelerator 22 | from pytorch_lightning.callbacks import EarlyStopping 23 | from torchvision import transforms 24 | 25 | 26 | class Settings: 27 | mps_available = MPSAccelerator.is_available() 28 | seed = 42 29 | projectname = "visionlab" 30 | data_version = "0" 31 | maybe_use_mps = dict(accelerator="mps", devices=1) if MPSAccelerator.is_available() else {} 32 | precision_dtype = "16-mixed" if mps_available else "32-true" 33 | platform = sys.platform 34 | 35 | 36 | class Paths: 37 | filepath = Path(__file__) 38 | project = filepath.parents[2] 39 | package = filepath.parent 40 | # logs 41 | logs = os.path.join(project, "logs") 42 | torch_profiler = os.path.join(logs, "torch_profiler") 43 | simple_profiler = os.path.join(logs, "simple_profiler") 44 | tuned_configs = os.path.join(logs, "tuned_configs") 45 | # models 46 | ckpts = os.path.join(project, "checkpoints") 47 | model = os.path.join(project, "checkpoints", "onnx", "model.onnx") 48 | predictions = os.path.join(project, "data", "predictions", "predictions.pt") 49 | # data 50 | dataset = os.path.join(project, "data", "cache") 51 | splits = os.path.join(project, "data", "training_split") 52 | train_split = os.path.join(splits, f"v{Settings.data_version}-train.pt") 53 | val_split = os.path.join(splits, f"v{Settings.data_version}-val.pt") 54 | test_split = os.path.join(splits, f"v{Settings.data_version}-test.pt") 55 | 56 | 57 | class Module: 58 | module_kwargs = dict( 59 | lr=1e-3, 60 | optimizer="Adam", 61 | ) 62 | model_kwargs = dict( 63 | image_size=32, 64 | num_classes=10, 65 | progress=False, 66 | weights=False, 67 | ) 68 | model_hyperameters = dict( 69 | dropout=0.25, 70 | attention_dropout=0.25, 71 | norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 72 | conv_stem_configs=None, 73 | ) 74 | 75 | 76 | class Trainer: 77 | train_flags = dict( 78 | max_epochs=100, 79 | precision=Settings.precision_dtype, 80 | callbacks=[EarlyStopping(monitor="val_loss", mode="min")], 81 | **Settings.maybe_use_mps, 82 | ) 83 | fast_flags = dict( 84 | max_epochs=2, 85 | precision=Settings.precision_dtype, 86 | **Settings.maybe_use_mps, 87 | ) 88 | 89 | 90 | class DataModule: 91 | batch_size = 128 92 | mean = [0.49139968, 0.48215841, 0.44653091] 93 | stddev = [0.24703223, 0.24348513, 0.26158784] 94 | inverse_mean = [-i for i in mean] 95 | inverse_stddev = [1 / i for i in stddev] 96 | cifar_norm = transforms.Normalize(mean=mean, std=stddev) 97 | test_transform = transforms.Compose( 98 | [ 99 | transforms.ToTensor(), 100 | ] 101 | ) 102 | train_transform = transforms.Compose( 103 | [ 104 | transforms.AutoAugment(policy=transforms.AutoAugmentPolicy.CIFAR10), 105 | transforms.ToTensor(), 106 | ] 107 | ) 108 | norm_train_transform = transforms.Compose( 109 | [ 110 | transforms.AutoAugment(policy=transforms.AutoAugmentPolicy.CIFAR10), 111 | transforms.ToTensor(), 112 | cifar_norm, 113 | ] 114 | ) 115 | norm_test_transform = transforms.Compose( 116 | [ 117 | transforms.ToTensor(), 118 | cifar_norm, 119 | ] 120 | ) 121 | # see https://discuss.pytorch.org/t/simple-way-to-inverse-transform-normalization/4821 122 | inverse_transform = transforms.Compose( 123 | [ 124 | transforms.Normalize(mean=[0.0, 0.0, 0.0], std=inverse_stddev), 125 | transforms.Normalize(mean=inverse_mean, std=[1.0, 1.0, 1.0]), 126 | ] 127 | ) 128 | 129 | -------------------------------------------------------------------------------- /src/visionlab/datamodule.py: -------------------------------------------------------------------------------- 1 | # Copyright Justin R. Goheen. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import multiprocessing 16 | import os 17 | from typing import Callable, Union 18 | 19 | import torch 20 | import torchvision 21 | import pytorch_lightning as pl 22 | from torchvision.datasets import CIFAR100 23 | from torchvision.datasets import VisionDataset 24 | from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS 25 | from torch.utils.data import DataLoader, random_split 26 | 27 | from visionlab import config 28 | 29 | NUMWORKERS = int(multiprocessing.cpu_count() // 2) 30 | 31 | 32 | class CifarDataModule(pl.LightningDataModule): 33 | """A custom LightningDataModule""" 34 | 35 | def __init__( 36 | self, 37 | dataset: VisionDataset = CIFAR100, 38 | data_cache: str = config.Paths.dataset, 39 | data_splits: str = config.Paths.splits, 40 | train_size: float = 0.8, 41 | num_workers: int = NUMWORKERS, 42 | train_transforms: Callable = config.DataModule.train_transform, 43 | test_transforms: Callable = config.DataModule.test_transform, 44 | batch_size: int = config.DataModule.batch_size, 45 | data_version: int = 0, 46 | reversion: bool = False, 47 | ): 48 | super().__init__() 49 | self.data_cache = data_cache 50 | self.data_splits = data_splits 51 | self.dataset = dataset 52 | self.train_size = train_size 53 | self.num_workers = num_workers 54 | self.train_transforms = train_transforms 55 | self.test_transforms = test_transforms 56 | self.batch_size = batch_size 57 | self.data_version = data_version 58 | self.reversion = reversion 59 | self.data_cache_exists = os.path.isdir(self.data_cache) 60 | 61 | def prepare_data(self) -> None: 62 | """prepares data for the dataloaders""" 63 | versioned_files = ( 64 | f"v{self.data_version}-train.pt", 65 | f"v{self.data_version}-val.pt", 66 | f"v{self.data_version}-test.pt", 67 | ) 68 | 69 | version_exists = any(v in os.listdir(self.data_splits) for v in versioned_files) 70 | 71 | if not self.data_cache_exists: 72 | self.dataset(self.data_cache, download=True) 73 | self._persist_splits() 74 | 75 | if not version_exists: 76 | self._persist_splits() 77 | 78 | if version_exists and not self.reversion: 79 | return 80 | 81 | if self.reversion: 82 | if version_exists: 83 | raise ValueError("a split version of the same version number already exists") 84 | self._persist_splits() 85 | 86 | def setup(self, stage: Union[str, None] = None) -> None: 87 | """used by trainer to setup the dataset for training and evaluation""" 88 | if stage == "fit" or stage is None: 89 | self.train_data = torch.load(os.path.join(self.data_splits, f"v{self.data_version}-train.pt")) 90 | self.val_data = torch.load(os.path.join(self.data_splits, f"v{self.data_version}-val.pt")) 91 | if stage == "test" or stage is None: 92 | self.test_data = torch.load(os.path.join(self.data_splits, f"v{self.data_version}-test.pt")) 93 | 94 | def _persist_splits(self): 95 | """saves all splits for reproducibility""" 96 | pl.seed_everything(config.Settings.seed) 97 | torchvision.disable_beta_transforms_warning() 98 | dataset = self.dataset(self.data_cache, train=True, transform=self.train_transforms) 99 | train_data, val_data = random_split(dataset, lengths=[self.train_size, 1 - self.train_size]) 100 | test_data = self.dataset(self.data_cache, train=False, transform=self.test_transforms) 101 | torch.save(train_data, os.path.join(self.data_splits, f"v{self.data_version}-train.pt")) 102 | torch.save(val_data, os.path.join(self.data_splits, f"v{self.data_version}-val.pt")) 103 | torch.save(test_data, os.path.join(self.data_splits, f"v{self.data_version}-test.pt")) 104 | 105 | def train_dataloader(self) -> TRAIN_DATALOADERS: 106 | """the dataloader used during training""" 107 | return DataLoader(self.train_data, shuffle=True, num_workers=self.num_workers, batch_size=self.batch_size) 108 | 109 | def test_dataloader(self) -> EVAL_DATALOADERS: 110 | """the dataloader used during testing""" 111 | return DataLoader(self.test_data, shuffle=False, num_workers=self.num_workers, batch_size=self.batch_size) 112 | 113 | def val_dataloader(self) -> EVAL_DATALOADERS: 114 | """the dataloader used during validation""" 115 | return DataLoader(self.val_data, shuffle=False, num_workers=self.num_workers, batch_size=self.batch_size) 116 | -------------------------------------------------------------------------------- /src/visionlab/module.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from typing import List, Optional 3 | 4 | import torch.nn.functional as F 5 | from torch import nn, optim 6 | from torchmetrics.functional import accuracy 7 | from torchvision import models 8 | from torchvision.models.vision_transformer import ConvStemConfig 9 | 10 | import pytorch_lightning as pl 11 | 12 | 13 | class VisionTransformer(pl.LightningModule): 14 | """A custom PyTorch Lightning LightningModule for torchvision VisionTransformers 15 | 16 | Args: 17 | optimizer: "Adam". A valid [torch.optim](https://pytorch.org/docs/stable/optim.html) name. 18 | lr: 1e-3 19 | accuracy_task: "multiclass". One of (binary, multiclass, multilabel). 20 | image_size: 32 21 | num_classes: 100 22 | dropout: 0.0 23 | attention_dropout: 0.0 24 | norm_layer: None 25 | conv_stem_configs: None 26 | progress: False 27 | weights: False 28 | vit_type: one of (b_16, b_32, l_16, l_32). Default is b_32. 29 | """ 30 | 31 | def __init__( 32 | self, 33 | optimizer: str = "Adam", 34 | lr: float = 1e-3, 35 | accuracy_task: str = "multiclass", 36 | image_size: int = 32, 37 | num_classes: int = 100, 38 | dropout: float = 0.0, 39 | attention_dropout: float = 0.0, 40 | norm_layer: Optional[nn.Module] = None, 41 | conv_stem_configs: Optional[List[ConvStemConfig]] = None, 42 | progress: bool = False, 43 | weights: bool = False, 44 | vit_type: str = "b_32", 45 | ): 46 | super().__init__() 47 | 48 | if vit_type not in ("b_16", "b_32", "l_16", "l_32"): 49 | raise ValueError("vit_type must be one of (b_16, b_32, l_16, l_32)") 50 | 51 | if not norm_layer: 52 | norm_layer = partial(nn.LayerNorm, eps=1e-6) 53 | 54 | if weights: 55 | weights_name = f"ViT_{vit_type.upper()}_Weights" 56 | weights = getattr(models, weights_name) 57 | 58 | vit_kwargs = dict( 59 | image_size=image_size, 60 | num_classes=num_classes, 61 | dropout=dropout, 62 | attention_dropout=attention_dropout, 63 | norm_layer=norm_layer, 64 | conv_stem_configs=conv_stem_configs, 65 | ) 66 | 67 | vision_transformer = getattr(models, f"vit_{vit_type}") 68 | 69 | self.model = vision_transformer( 70 | weights=weights, 71 | progress=progress, 72 | **vit_kwargs, 73 | ) 74 | self.optimizer = getattr(optim, optimizer) 75 | self.lr = lr 76 | self.accuracy_task = accuracy_task 77 | self.num_classes = num_classes 78 | self.save_hyperparameters() 79 | 80 | def forward(self, x): 81 | """calls .forward of a given model flow""" 82 | return self.model(x) 83 | 84 | def training_step(self, batch): 85 | """runs a training step sequence""" 86 | x, y = batch 87 | y_hat = self.model(x) 88 | loss = F.cross_entropy(y_hat, y) 89 | self.log("training-loss", loss) 90 | return loss 91 | 92 | def validation_step(self, batch, *args): 93 | """runs a validation step sequence""" 94 | x, y = batch 95 | y_hat = self.model(x) 96 | loss = F.cross_entropy(y_hat, y) 97 | self.log("val-loss", loss) 98 | acc = accuracy( 99 | y_hat.argmax(dim=-1), 100 | y, 101 | task=self.accuracy_task, 102 | num_classes=self.num_classes, 103 | ) 104 | self.log("val-acc", acc) 105 | 106 | def test_step(self, batch, *args): 107 | """runs a test step sequence""" 108 | x, y = batch 109 | y_hat = self.model(x) 110 | acc = accuracy( 111 | y_hat.argmax(dim=-1), 112 | y, 113 | task=self.accuracy_task, 114 | num_classes=self.num_classes, 115 | ) 116 | self.log("test-acc", acc) 117 | 118 | def predict_step(self, batch): 119 | """returns predicted logits from the trained model""" 120 | x, y = batch 121 | return self(x) 122 | 123 | def configure_optimizers(self): 124 | """configures the ``torch.optim`` used in training loop""" 125 | optimizer = self.optimizer(self.parameters(), lr=self.lr) 126 | return optimizer 127 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jxtngx/vision-lab/289395593ffe2a6b8128547badabb0bf884198ec/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_datamodule.py: -------------------------------------------------------------------------------- 1 | # Copyright Justin R. Goheen. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | from pathlib import Path 17 | 18 | import torch 19 | 20 | from visionlab import CifarDataModule 21 | from visionlab.config import Paths 22 | 23 | 24 | def test_module_not_abstract(): 25 | _ = CifarDataModule() 26 | 27 | 28 | def test_prepare_data(): 29 | data_module = CifarDataModule() 30 | data_module.prepare_data() 31 | assert "LabDataset" in os.listdir(Paths.dataset) 32 | 33 | 34 | def test_setup(): 35 | data_module = CifarDataModule() 36 | data_module.prepare_data() 37 | data_module.setup("fit") 38 | data_keys = ["train_data", "test_data", "val_data"] 39 | assert all(key in dir(data_module) for key in data_keys) 40 | 41 | 42 | def test_trainloader(): 43 | data_module = CifarDataModule() 44 | data_module.prepare_data() 45 | data_module.setup("fit") 46 | loader = data_module.train_dataloader() 47 | sample = loader.dataset[0][0] 48 | assert isinstance(sample, torch.Tensor) 49 | 50 | 51 | def test_testloader(): 52 | data_module = CifarDataModule() 53 | data_module.prepare_data() 54 | data_module.setup("fit") 55 | data_module.setup("test") 56 | loader = data_module.test_dataloader() 57 | sample = loader.dataset[0][0] 58 | assert isinstance(sample, torch.Tensor) 59 | 60 | 61 | def test_valloader(): 62 | data_module = CifarDataModule() 63 | data_module.prepare_data() 64 | data_module.setup() 65 | loader = data_module.val_dataloader() 66 | sample = loader.dataset[0][0] 67 | assert isinstance(sample, torch.Tensor) 68 | -------------------------------------------------------------------------------- /tests/test_module.py: -------------------------------------------------------------------------------- 1 | # Copyright Justin R. Goheen. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | 17 | from visionlab import VisionTransformer 18 | 19 | 20 | def test_module_not_abstract(): 21 | _ = VisionTransformer() 22 | 23 | 24 | def test_module_forward(): 25 | input_sample = torch.randn((1, 784)) 26 | model = VisionTransformer() 27 | preds, label = model.forward(input_sample) 28 | assert preds.shape == input_sample.shape 29 | 30 | 31 | def test_module_training_step(): 32 | input_sample = torch.randn((1, 784)), 1 33 | model = VisionTransformer() 34 | loss = model.training_step(input_sample) 35 | assert isinstance(loss, torch.Tensor) 36 | 37 | 38 | def test_optimizer(): 39 | model = VisionTransformer() 40 | optimizer = model.configure_optimizers() 41 | optimizer_base_class = optimizer.__class__.__base__.__name__ 42 | assert optimizer_base_class == "Optimizer" 43 | --------------------------------------------------------------------------------