├── .flake8 ├── .github └── workflows │ ├── docs.yml │ └── tests.yml ├── .gitignore ├── LICENSE ├── README.md ├── docs ├── CNAME ├── assets │ ├── api.png │ ├── cover.png │ ├── favicon.ico │ └── logo-white.png ├── index.md ├── models │ ├── image-classification.md │ ├── object-detection.md │ └── text-classification.md ├── requirements.txt ├── stylesheets │ ├── extra.css │ └── extra.js └── tutorial │ ├── deployment.md │ ├── experiment-tracking.md │ ├── intro.md │ ├── quickstart.ipynb │ ├── raw-models.md │ └── training.md ├── examples └── mnist-cnn-example.py ├── mkdocs.yml ├── requirements.txt ├── setup.py ├── tests ├── conftest.py ├── image_classification │ ├── test_preprocessing.py │ ├── test_sklearn_models.py │ └── test_torch_models.py ├── test_main.py └── test_model_wrapper.py └── traintool ├── __init__.py ├── _version.txt ├── image_classification ├── __init__.py ├── preprocessing.py ├── sklearn_models.py ├── torch_models.py └── visualization.py ├── main.py ├── model_wrapper.py └── utils.py /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 88 3 | extend-ignore = E203,W291,W293 4 | -------------------------------------------------------------------------------- /.github/workflows/docs.yml: -------------------------------------------------------------------------------- 1 | name: docs 2 | 3 | on: 4 | push: 5 | branches: 6 | - master 7 | 8 | jobs: 9 | deploy: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - uses: actions/checkout@v2 13 | - name: Set up Python 3.7 14 | uses: actions/setup-python@v2 15 | with: 16 | python-version: 3.7 17 | - name: Install dependencies for docs 18 | run: pip install -r docs/requirements.txt 19 | - name: Deploy to gh-pages 20 | run: mkdocs gh-deploy --force -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: tests 2 | 3 | on: [push] 4 | 5 | jobs: 6 | build: 7 | 8 | runs-on: ubuntu-latest 9 | strategy: 10 | matrix: 11 | python-version: [3.6, 3.7, 3.8] 12 | # <3.6 is not possible because we use f-strings and type annotations 13 | # 3.9 gives an error right now because pytorch doesn't support it yet apparently. 14 | 15 | steps: 16 | - uses: actions/checkout@v2 17 | - name: Set up Python 18 | uses: actions/setup-python@v2 19 | with: 20 | python-version: ${{ matrix.python-version }} 21 | - name: Install dependencies 22 | run: | 23 | python -m pip install --upgrade pip 24 | pip install pytest pytest-cov 25 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi 26 | # - name: Lint with flake8 27 | # run: | 28 | # # stop the build if there are Python syntax errors or undefined names 29 | # flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics 30 | # # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide 31 | # flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics 32 | - name: Test with pytest 33 | run: | 34 | pytest --cov=./ --cov-report=xml 35 | - name: Upload coverage to Codecov 36 | uses: codecov/codecov-action@v1 37 | with: 38 | token: ${{ secrets.CODECOV_TOKEN }} # not required for public repos 39 | file: ./coverage.xml # optional 40 | files: ./coverage1.xml,./coverage2.xml # optional 41 | flags: unittests # optional 42 | name: codecov-umbrella # optional 43 | fail_ci_if_error: true # optional (default = false) -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | data 2 | settings.json 3 | .vscode 4 | traintool-experiments 5 | .VSCodeCounter 6 | 7 | 8 | # Byte-compiled / optimized / DLL files 9 | __pycache__/ 10 | *.py[cod] 11 | *$py.class 12 | 13 | # C extensions 14 | *.so 15 | 16 | # Distribution / packaging 17 | .Python 18 | build/ 19 | develop-eggs/ 20 | dist/ 21 | downloads/ 22 | eggs/ 23 | .eggs/ 24 | lib/ 25 | lib64/ 26 | parts/ 27 | sdist/ 28 | var/ 29 | wheels/ 30 | pip-wheel-metadata/ 31 | share/python-wheels/ 32 | *.egg-info/ 33 | .installed.cfg 34 | *.egg 35 | MANIFEST 36 | 37 | # PyInstaller 38 | # Usually these files are written by a python script from a template 39 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 40 | *.manifest 41 | *.spec 42 | 43 | # Installer logs 44 | pip-log.txt 45 | pip-delete-this-directory.txt 46 | 47 | # Unit test / coverage reports 48 | htmlcov/ 49 | .tox/ 50 | .nox/ 51 | .coverage 52 | .coverage.* 53 | .cache 54 | nosetests.xml 55 | coverage.xml 56 | *.cover 57 | *.py,cover 58 | .hypothesis/ 59 | .pytest_cache/ 60 | 61 | # Translations 62 | *.mo 63 | *.pot 64 | 65 | # Django stuff: 66 | *.log 67 | local_settings.py 68 | db.sqlite3 69 | db.sqlite3-journal 70 | 71 | # Flask stuff: 72 | instance/ 73 | .webassets-cache 74 | 75 | # Scrapy stuff: 76 | .scrapy 77 | 78 | # Sphinx documentation 79 | docs/_build/ 80 | 81 | # PyBuilder 82 | target/ 83 | 84 | # Jupyter Notebook 85 | .ipynb_checkpoints 86 | 87 | # IPython 88 | profile_default/ 89 | ipython_config.py 90 | 91 | # pyenv 92 | .python-version 93 | 94 | # pipenv 95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 98 | # install all needed dependencies. 99 | #Pipfile.lock 100 | 101 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 102 | __pypackages__/ 103 | 104 | # Celery stuff 105 | celerybeat-schedule 106 | celerybeat.pid 107 | 108 | # SageMath parsed files 109 | *.sage.py 110 | 111 | # Environments 112 | .env 113 | .venv 114 | env/ 115 | venv/ 116 | ENV/ 117 | env.bak/ 118 | venv.bak/ 119 | 120 | # Spyder project settings 121 | .spyderproject 122 | .spyproject 123 | 124 | # Rope project settings 125 | .ropeproject 126 | 127 | # mkdocs documentation 128 | /site 129 | 130 | # mypy 131 | .mypy_cache/ 132 | .dmypy.json 133 | dmypy.json 134 | 135 | # Pyre type checker 136 | .pyre/ 137 | .pheasant_cache 138 | 139 | config.json 140 | *config.json 141 | */tmp/* 142 | 143 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2020 Johannes Rieke 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 |
3 |
14 | Train off-the-shelf machine learning models in one line of code 15 |
16 | 17 | 23 | 24 |25 | Try it out in Google Colab • Documentation 26 |
27 | 28 | --- 29 | 30 | traintool is the easiest Python library for **applied machine learning**. It allows you 31 | to train off-the-shelf models with minimum code: Just give your data 32 | and the model name, and traintool takes care of the rest. It combines **pre-implemented 33 | models** (built on top of sklearn & pytorch) with powerful **utilities** that get you 34 | started in seconds (automatic visualizations, experiment tracking, intelligent data 35 | preprocessing, API deployment). 36 | 37 | 38 | Alpha Release: traintool is in an early alpha release. The API can and will change 39 | without notice. If you find a bug, please file an issue on 40 | [Github](https://github.com/jrieke/traintool) or 41 | [write me](mailto:johannes.rieke@gmail.com). 42 | 43 | 44 | 45 | 50 | 51 | 69 | 70 | 71 | ## Installation 72 | 73 | ```bash 74 | pip install traintool 75 | ``` 76 | 77 | 78 | ## Features 79 | 80 | - **Minimum coding —** traintool is designed to require as few lines of code as 81 | possible. It offers a sleek and intuitive interface that gets you started in seconds. 82 | Training a model just takes a single line: 83 | 84 | ```python 85 | traintool.train("resnet18", train_data, test_data, config={"optimizer": "adam", "lr": 0.1}) 86 | ``` 87 | 88 | - **Pre-implemented models —** The heart of traintool are fully implemented and tested 89 | models – from simple classifiers to deep neural networks; built on sklearn, pytorch, 90 | or tensorflow. Here are only a few of the models you can use: 91 | 92 | ```python 93 | "svc", "random-forest", "alexnet", "resnet50", "inception_v3", ... 94 | ``` 95 | 96 | - **Automatic visualizations & experiment tracking —** traintool automatically 97 | calculates metrics, creates beautiful visualizations (in 98 | [tensorboard](https://www.tensorflow.org/tensorboard) or 99 | [comet.ml](https://www.comet.ml/)), and stores experiment data and 100 | model checkpoints – without needing a single additional line of code. 101 | 102 | - **Ready for your data —** traintool understands numpy arrays, pytorch datasets, 103 | and files. It automatically converts and preprocesses everything based on the model you 104 | use. 105 | 106 | - **Instant deployment —** In one line of code, you can deploy your model to a REST 107 | API that you can query from anywhere. Just call: 108 | 109 | ```python 110 | model.deploy() 111 | ``` 112 | 113 | 114 | 123 | 124 | 125 | 126 | 127 | ## Example: Image classification on MNIST 128 | 129 | Run this example interactively in Google Colab: 130 | 131 | [](https://colab.research.google.com/github/jrieke/traintool/blob/master/docs/tutorial/quickstart.ipynb) 132 | 133 | ```python 134 | import mnist 135 | import traintool 136 | 137 | # Load MNIST data as numpy 138 | train_data = [mnist.train_images(), mnist.train_labels()] 139 | test_data = [mnist.test_images(), mnist.test_labels()] 140 | 141 | # Train SVM classifier 142 | svc = traintool.train("svc", train_data=train_data, test_data=test_data) 143 | 144 | # Train ResNet with custom hyperparameters 145 | resnet = traintool.train("resnet", train_data=train_data, test_data=test_data, 146 | config={"lr": 0.1, "optimizer": "adam"}) 147 | 148 | # Make prediction 149 | result = resnet.predict(test_data[0][0]) 150 | print(result["predicted_class"]) 151 | 152 | # Deploy to REST API 153 | resnet.deploy() 154 | 155 | # Get underlying pytorch model (e.g. for custom analysis) 156 | pytorch_model = resnet.raw()["model"] 157 | ``` 158 | 159 | For more information, check out the 160 | [complete tutorial](https://traintool.jrieke.com/tutorial/quickstart/). 161 | 162 | 163 | ## Get in touch! 164 | 165 | You have a question on traintool, want to use it in production, or miss a feature? I'm 166 | happy to hear from you! Write me at [johannes.rieke@gmail.com](mailto:johannes.rieke@gmail.com). 167 | -------------------------------------------------------------------------------- /docs/CNAME: -------------------------------------------------------------------------------- 1 | traintool.jrieke.com -------------------------------------------------------------------------------- /docs/assets/api.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jrieke/traintool/0ada1ff20a9a570be5cfb4ac21f3f914604c9833/docs/assets/api.png -------------------------------------------------------------------------------- /docs/assets/cover.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jrieke/traintool/0ada1ff20a9a570be5cfb4ac21f3f914604c9833/docs/assets/cover.png -------------------------------------------------------------------------------- /docs/assets/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jrieke/traintool/0ada1ff20a9a570be5cfb4ac21f3f914604c9833/docs/assets/favicon.ico -------------------------------------------------------------------------------- /docs/assets/logo-white.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jrieke/traintool/0ada1ff20a9a570be5cfb4ac21f3f914604c9833/docs/assets/logo-white.png -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 |
2 |
3 |
14 | Train off-the-shelf machine learning models in one line of code 15 |
16 | 17 | 23 | 24 |25 | Try it out in Google Colab • Documentation 26 |
27 | 28 | --- 29 | 30 | traintool is the easiest Python library for **applied machine learning**. It allows you 31 | to train off-the-shelf models with minimum code: Just give your data 32 | and the model name, and traintool takes care of the rest. It combines **pre-implemented 33 | models** (built on top of sklearn & pytorch) with powerful **utilities** that get you 34 | started in seconds (automatic visualizations, experiment tracking, intelligent data 35 | preprocessing, API deployment). 36 | 37 | 38 | !!! warning "Alpha Release" 39 | traintool is in an early alpha release. The API can and will change 40 | without notice. If you find a bug, please file an issue on 41 | [Github](https://github.com/jrieke/traintool) or [write me](mailto:johannes.rieke@gmail.com). 42 | 43 | 44 | 45 | 50 | 51 | 69 | 70 | 71 | ## Installation 72 | 73 | ```bash 74 | pip install traintool 75 | ``` 76 | 77 | 78 | 79 | ## Features 80 | 81 | - **Minimum coding —** traintool is designed to require as few lines of code as 82 | possible. It offers a sleek and intuitive interface that gets you started in seconds. 83 | Training a model just takes a single line: 84 | 85 | traintool.train("resnet18", train_data, test_data, config={"optimizer": "adam", "lr": 0.1}) 86 | 87 | 88 | - **Pre-implemented models —** The heart of traintool are fully implemented and tested 89 | models – from simple classifiers to deep neural networks; built on sklearn, pytorch, 90 | or tensorflow. Here are only a few of the models you can use: 91 | 92 | "svc", "random-forest", "alexnet", "resnet50", "inception_v3", ... 93 | 94 | - **Automatic visualizations & experiment tracking —** traintool automatically 95 | calculates metrics, creates beautiful visualizations (in 96 | [tensorboard](https://www.tensorflow.org/tensorboard) or 97 | [comet.ml](https://www.comet.ml/)), and stores experiment data and 98 | model checkpoints – without needing a single additional line of code. 99 | 100 | - **Ready for your data —** traintool understands numpy arrays, pytorch datasets, 101 | and files. It automatically converts and preprocesses everything based on the model you 102 | use. 103 | 104 | - **Instant deployment —** In one line of code, you can deploy your model to a REST 105 | API that you can query from anywhere. Just call: 106 | 107 | model.deploy() 108 | 109 | 110 | 119 | 120 | 121 | ## Example: Image classification on MNIST 122 | 123 | Run this example interactively in Google Colab: 124 | 125 | [](https://colab.research.google.com/github/jrieke/traintool/blob/master/docs/tutorial/quickstart.ipynb) 126 | 127 | ```python 128 | import mnist 129 | import traintool 130 | 131 | # Load MNIST data as numpy 132 | train_data = [mnist.train_images(), mnist.train_labels()] 133 | test_data = [mnist.test_images(), mnist.test_labels()] 134 | 135 | # Train SVM classifier 136 | svc = traintool.train("svc", train_data=train_data, test_data=test_data) 137 | 138 | # Train ResNet with custom hyperparameters 139 | resnet = traintool.train("resnet", train_data=train_data, test_data=test_data, 140 | config={"lr": 0.1, "optimizer": "adam"}) 141 | 142 | # Make prediction 143 | result = resnet.predict(test_data[0][0]) 144 | print(result["predicted_class"]) 145 | 146 | # Deploy to REST API 147 | resnet.deploy() 148 | 149 | # Get underlying pytorch model (e.g. for custom analysis) 150 | pytorch_model = resnet.raw()["model"] 151 | ``` 152 | 153 | For more information, check out the 154 | [complete tutorial](https://traintool.jrieke.com/tutorial/quickstart/). 155 | 156 | 157 | ## Get in touch! 158 | 159 | You have a question on traintool, want to use it in production, or miss a feature? I'm 160 | happy to hear from you! Write me at [johannes.rieke@gmail.com](mailto:johannes.rieke@gmail.com). 161 | -------------------------------------------------------------------------------- /docs/models/image-classification.md: -------------------------------------------------------------------------------- 1 | # Image classification 2 | 3 |  4 | 5 | Image classification models classify an image into one out of several categories or 6 | classes, based on the image content (e.g. "cat" or "dog"). 7 | 8 | ## Input formats 9 | 10 | ### Numpy arrays 11 | 12 | Each data set should be a list of two elements: The first element is a numpy array 13 | of all images of shape `(number of images, color channels (1 or 3), height, width)`. The second 14 | element is an array of labels (as integer indices). 15 | 16 | Example: 17 | 18 | ```python 19 | train_images = np.zeros(32, 3, 256, 256) # 32 images with 3 color channels and size 256x256 20 | train_labels = np.zeros(32, dtype=int) 21 | 22 | traintool.train(..., train_data=[train_images, train_labels]) 23 | ``` 24 | 25 | 33 | 34 | ### Files 35 | 36 | Image files should be arranged in one folder per class, similar to this: 37 | 38 | ``` 39 | train 40 | +-- dogs 41 | | +-- funny-dog.jpg 42 | | +-- another-dog.png 43 | +-- cats 44 | | +-- brown-cat.png 45 | | +-- black-cat.png 46 | ... 47 | ``` 48 | 49 | Then simply pass the directory path to the `train` function: 50 | 51 | ```python 52 | traintool.train(..., train_data="./train") 53 | ``` 54 | 55 | 56 | ## Scikit-learn models 57 | 58 | These models implement simple classification algorithms that should train in a 59 | reasonable amount of time. Note that they are not GPU-accelerated so they might still 60 | take quite long with large datasets. 61 | 62 | **Preprocessing:** Image files are first loaded to a size of 28 x 28. All images (numpy 63 | or files) are then flattened and scaled to mean 0, standard deviation 1 (based on the 64 | train set). 65 | 66 | **Config parameters:** 67 | 68 | - `num_samples`: Set the number of samples to train on. This can be used to train on a 69 | subset of the data. Defaults to None (i.e. train on all data). 70 | - `num_samples_to_plot`: Set the number of samples to plot to tensorboard for each 71 | dataset. Defaults to 5. 72 | - All other config parameters are forwarded to the constructor of the sklearn object 73 | 74 | **Models:** 75 | 76 | - `random-forest`: A random forest classifier, from [sklearn.ensemble.RandomForestClassifier](https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.RandomForestClassifier.html) 77 | - `gradient-boosting`: Gradient boosting for classification, from [sklearn.ensemble.GradientBoostingClassifier](https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.GradientBoostingClassifier.html) 78 | - `gaussian-process`: Gaussian process classification based on Laplace approximation, from [sklearn.gaussian_process.GaussianProcessClassifier](https://scikit-learn.org/stable/modules/generated/sklearn.gaussian_process.GaussianProcessClassifier.html#sklearn.gaussian_process.GaussianProcessClassifier) 79 | - `logistic-regression`: Logistic Regression (aka logit, MaxEnt) classifier, from [sklearn.linear_model.LogisticRegression](https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LogisticRegression.html) 80 | - `sgd`: Linear classifiers (SVM, logistic regression, etc.) with SGD training, from [sklearn.linear_model.SGDClassifier](https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.SGDClassifier.html) 81 | - `perceptron`: A perceptron classifier, from [sklearn.linear_model.Perceptron](https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.Perceptron.html) 82 | - `passive-aggressive`: Passive aggressive classifier, from [sklearn.linear_model.PassiveAggressiveClassifier](https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.PassiveAggressiveClassifier.html) 83 | - `gaussian-nb`: Gaussian Naive Bayes, from [sklearn.naive_bayes.GaussianNB](https://scikit-learn.org/stable/modules/generated/sklearn.naive_bayes.GaussianNB.html) 84 | - `k-neighbors`: Classifier implementing the k-nearest neighbors vote, from [sklearn.neighbors.KNeighborsClassifier](https://scikit-learn.org/stable/modules/generated/sklearn.neighbors.KNeighborsClassifier.html) 85 | - `mlp`: Multi-layer Perceptron classifier, from [sklearn.neural_network.MLPClassifier](https://scikit-learn.org/stable/modules/generated/sklearn.neural_network.MLPClassifier.html) 86 | - `svc`: C-Support Vector Classification, from [sklearn.svm.SVC](https://scikit-learn.org/stable/modules/generated/sklearn.svm.SVC.html) 87 | - `linear-svc`: Linear Support Vector Classification, from [sklearn.svm.LinearSVC](https://scikit-learn.org/stable/modules/generated/sklearn.svm.LinearSVC.html) 88 | - `decision-tree`: A decision tree classifier, from [sklearn.tree.DecisionTreeClassifier](https://scikit-learn.org/stable/modules/generated/sklearn.tree.DecisionTreeClassifier.html) 89 | - `extra-tree`: An extra-trees classifier, from [sklearn.ensemble.ExtraTreesClassifier](https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.ExtraTreesClassifier.html) 90 | 91 | 92 | 93 | ## PyTorch models 94 | 95 | These models implement deep neural networks that can give better results on complex 96 | datasets. They are GPU-accelerated if run on a machine with a GPU. 97 | 98 | **Preprocessing:** All images (numpy or files) are rescaled to 256 x 256, then 99 | center-cropped to 224 x 224, MEAN STD 100 | 101 | **Config parameters:** 102 | 103 | - `num_classes`: The number of classes/different output labels (and therefore number of 104 | output neurons of the network). Defaults to None, in which case it will be automatically 105 | inferred from the data. 106 | - `num_samples`: Set the number of samples to train on. This can be used to train on a 107 | subset of the data. Defaults to None (i.e. train on all data). 108 | - `num_samples_to_plot`: Set the number of samples to plot to tensorboard for each 109 | dataset. Defaults to 5. 110 | - `pretrained`: Whether to use pretrained weights for the models (trained on ImageNet). 111 | Note that this requires that there are 1000 classes (the ImageNet classes). Defaults to 112 | False. 113 | 114 | **Models:** 115 | 116 | More information on the [torchvision docs](https://pytorch.org/docs/stable/torchvision/models.html). 117 | 118 | - `alexnet`: AlexNet model architecture from the [“One weird trick…”](https://arxiv.org/abs/1404.5997) paper 119 | - `vgg11`, `vgg11_bn`, `vgg13`, `vgg13_bn`, `vgg16`, `vgg16_bn`, `vgg19`, or `vgg19_bn`: VGG model variants from [“Very Deep Convolutional Networks For Large-Scale Image Recognition”](https://arxiv.org/pdf/1409.1556.pdf) 120 | - `resnet18`, `resnet34`, `resnet50`, `resnet101`, or `resnet152`: ResNet model variants from [“Deep Residual Learning for Image Recognition”](https://arxiv.org/pdf/1512.03385.pdf) 121 | - `squeezenet1_0`, or `squeezenet1_1`: SqueezeNet model variants from the [“SqueezeNet: AlexNet-level accuracy with 50x fewer parameters and <0.5MB model size”](https://arxiv.org/abs/1602.07360) paper. 122 | - `densenet121`, `densenet169`, `densenet161`, or `densenet201`: Densenet model variants from [“Densely Connected Convolutional Networks”](https://arxiv.org/pdf/1608.06993.pdf) 123 | - `inception_v3`: Inception v3 model architecture from [“Rethinking the Inception Architecture for Computer Vision”](http://arxiv.org/abs/1512.00567) 124 | - `googlenet`: GoogLeNet (Inception v1) model architecture from [“Going Deeper with Convolutions”](http://arxiv.org/abs/1409.4842) 125 | - `shufflenet_v2_x0_5`, `shufflenet_v2_x1_0`, `shufflenet_v2_x1_5`, or `shufflenet_v2_x2_0`: ShuffleNetV2 variants, as described in [“ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design”](https://arxiv.org/abs/1807.11164) 126 | - `mobilenet_v2`: MobileNetV2 architecture from [“MobileNetV2: Inverted Residuals and Linear Bottlenecks”](https://arxiv.org/abs/1801.04381) 127 | - `resnext50_32x4d` or `resnext101_32x8d`: ResNeXt model variants from [“Aggregated Residual Transformation for Deep Neural Networks”](https://arxiv.org/pdf/1611.05431.pdf) 128 | - `wide_resnet50_2` or `wide_resnet101_2`: Wide ResNet-50-2 model variants from [“Wide Residual Networks”](https://arxiv.org/pdf/1605.07146.pdf) 129 | - `mnasnet0_5`, `mnasnet0_75`, `mnasnet1_0`, or `mnasnet1_3`: MNASNet variants from [“MnasNet: Platform-Aware Neural Architecture Search for Mobile”](https://arxiv.org/pdf/1807.11626.pdf) 130 | 131 | 132 | -------------------------------------------------------------------------------- /docs/models/object-detection.md: -------------------------------------------------------------------------------- 1 | # Object detection 2 | 3 | Coming soon! -------------------------------------------------------------------------------- /docs/models/text-classification.md: -------------------------------------------------------------------------------- 1 | # Text classification 2 | 3 | Coming soon! -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | mkdocs 2 | mkdocs-material 3 | pygments 4 | pymdown-extensions 5 | mkdocs-autolinks-plugin 6 | mknotebooks 7 | #mkdocs-table-reader-plugin 8 | #mkdocs-jupyter # NOTE: this requires nbconvert==5.6.1 9 | #mkdocs-exclude 10 | #mkdocs-click 11 | #mkdocstrings 12 | -------------------------------------------------------------------------------- /docs/stylesheets/extra.css: -------------------------------------------------------------------------------- 1 | /* header different color*/ 2 | /* 3 | .md-header { 4 | background-color: #232f3e !important; 5 | border-bottom: 1px solid #1b2532 !important; 6 | } 7 | */ 8 | 9 | 10 | /* Active tab bold & color */ 11 | .md-tabs__link { 12 | font-size: .75rem !important; 13 | } 14 | .md-tabs__link--active { 15 | font-weight: bold !important; 16 | color: var(--md-accent-fg-color); 17 | } 18 | .md-tabs__link:hover { 19 | color: var(--md-accent-fg-color); 20 | } 21 | 22 | /* navigation bar active color*/ 23 | .md-nav__item .md-nav__link--active { 24 | font-weight: bold !important; 25 | color: var(--md-accent-fg-color); 26 | } 27 | .md-nav__item .md-nav__link--active { 28 | font-weight: bold !important; 29 | font-size: .75rem !important; 30 | color: var(--md-accent-fg-color); 31 | } 32 | 33 | /* frontpage elements */ 34 | .tx-hero h1 { 35 | font-size: 2.41rem !important; 36 | } 37 | a.md-button.md-button--primary { 38 | background-color: var(--md-accent-fg-color); 39 | border-color: var(--md-accent-fg-color); 40 | color: #ffffff; 41 | } 42 | a.md-button.md-button--primary:hover { 43 | color: #000000; 44 | } 45 | a.md-button.md-button:hover { 46 | color: #000000; 47 | } 48 | 49 | 50 | /* More visible headings*/ 51 | .md-main h1 { 52 | margin: 0.5em 0 1.0em 0; 53 | color: #333; 54 | font-weight: bold; 55 | font-size: 36px; 56 | line-height: 40px; 57 | /* 58 | counter-increment: section-1; 59 | counter-reset: section-2 section-3 section-4; 60 | */ 61 | } 62 | .md-main h2 { 63 | margin: 1.5em 0 0.4em 0; 64 | color: #595959; 65 | font-weight: normal; 66 | font-size: 30px; 67 | line-height: 36px; 68 | border-bottom: 1px solid #fff; 69 | box-shadow: 0 1px 0 rgba(0,0,0,0.1); 70 | padding-bottom: 10px; 71 | /* 72 | counter-increment: section-2; 73 | counter-reset: section-3 section-4; 74 | */ 75 | } 76 | .md-main h3 { 77 | margin: 1.2em 0 0.4em 0; 78 | color: #595959; 79 | font-weight: normal; 80 | font-size: 26px; 81 | line-height: 40px; 82 | /* 83 | counter-increment: section-3; 84 | counter-reset: section-4; 85 | */ 86 | } 87 | .md-main h4 { 88 | margin: 1.0em 0 0.4em 0; 89 | color: #333; 90 | font-weight: normal; 91 | font-size: 20px; 92 | line-height: 20px; 93 | /* counter-increment: section-4; */ 94 | } 95 | 96 | 97 | /* Define class for block image in block documentation trimmed*/ 98 | .blockimage { 99 | width: 200px; 100 | margin-right: 20px; 101 | margin-top: 7px; 102 | position:relative; 103 | float:left; 104 | } 105 | .blockimage img { 106 | width:100%; 107 | max-width:100%; 108 | float:left; 109 | } 110 | 111 | 112 | /*Table styling*/ 113 | table { 114 | padding: 0; } 115 | table tr { 116 | border-top: 1px solid #cccccc; 117 | background-color: white; 118 | margin: 0; 119 | padding: 0; } 120 | table tr:nth-child(2n) { 121 | background-color: #f8f8f8; } 122 | table tr th { 123 | font-weight: bold; 124 | border: 0px solid #cccccc; 125 | text-align: centre; 126 | margin: 0; 127 | padding: 6px 13px; } 128 | table tr td { 129 | border: 0px solid #cccccc; 130 | text-align: centre; 131 | margin: 0; 132 | padding: 6px 13px; } 133 | table tr th :first-child, table tr td :first-child { 134 | margin-top: 0; } 135 | table tr th :last-child, table tr td :last-child { 136 | margin-bottom: 0; } 137 | 138 | /*Neccessary css for tabbed extension*/ 139 | .tabbed-set { 140 | display: flex; 141 | position: relative; 142 | flex-wrap: wrap; 143 | } 144 | .tabbed-set .highlight { 145 | background: #ddd; 146 | } 147 | .tabbed-set .tabbed-content { 148 | display: none; 149 | order: 99; 150 | width: 100%; 151 | } 152 | .tabbed-set label { 153 | width: auto; 154 | margin: 0 0.5em; 155 | padding: 0.25em; 156 | font-size: 120%; 157 | cursor: pointer; 158 | } 159 | .tabbed-set input { 160 | position: absolute; 161 | opacity: 0; 162 | } 163 | .tabbed-set input:nth-child(n+1) { 164 | color: #333333; 165 | } 166 | .tabbed-set input:nth-child(n+1):checked + label { 167 | color: #FF5252; 168 | } 169 | .tabbed-set input:nth-child(n+1):checked + label + .tabbed-content { 170 | display: block; 171 | } 172 | 173 | 174 | /* Admonition settings option*/ 175 | .md-typeset .admonition.settings, .md-typeset details.settings { 176 | border-left: .22rem solid #448aff; 177 | } 178 | .md-typeset .admonition.settings>.admonition-title, .md-typeset details.settings>.admonition-title, .md-typeset details.settings>summary { 179 | border-bottom: .1rem solid rgba(236, 243, 255), .1); 180 | background-color: rgba(236, 243, 255), .1); 181 | } 182 | .md-typeset .admonition.settings>.admonition-title:before, .md-typeset details.settings>.admonition-title:before, .md-typeset details.settings>summary:before { 183 | color: #448aff; 184 | content: "settings"} 185 | 186 | 187 | /* Code block size, but then too small in blocks*/ 188 | /*code { font-size: 0.75em !important; }*/ 189 | 190 | 191 | /* Site width etc.*/ 192 | .md-grid { 193 | max-width: 64rem !important; 194 | } 195 | /* 196 | .framed-python{ 197 | margin-top:-70px; 198 | overflow:hidden; 199 | } 200 | .framed-r{ 201 | margin-top:0px; 202 | overflow:hidden; 203 | } 204 | .framed-r-api{ 205 | margin-top:-50px; 206 | overflow:hidden; 207 | } 208 | .framed-github{ 209 | height:100vh !important; 210 | width:100% !important; 211 | } 212 | */ 213 | 214 | 215 | /*Toc right margins*/ 216 | /* 217 | @media only screen and (min-width: 76.25em){ 218 | .framed-python{ 219 | margin-left:-45px; 220 | } 221 | .framed-r-api{ 222 | margin-left:-45px; 223 | } 224 | .md-sidebar--secondary { 225 | margin-left: 100% !important; 226 | } 227 | } 228 | */ 229 | 230 | 231 | /* mkdocstrings styling 232 | /* Indentation of function doc */ 233 | div.doc-contents:not(.first) { 234 | padding-left: 25px; 235 | border-left: 4px solid rgba(230, 230, 230); 236 | margin-bottom: 80px; 237 | } 238 | /* Don't capitalize names. */ 239 | h5.doc-heading { 240 | text-transform: none !important; 241 | } 242 | /* Don't use vertical space on hidden ToC entries. */ 243 | h6.hidden-toc { 244 | margin: 0 !important; 245 | position: relative; 246 | top: -70px; 247 | } 248 | h6.hidden-toc::before { 249 | margin-top: 0 !important; 250 | padding-top: 0 !important; 251 | } 252 | /* Don't show permalink of hidden ToC entries. */ 253 | h6.hidden-toc a.headerlink { 254 | display: none; 255 | } 256 | /* Avoid breaking parameters name, etc. in table cells. */ 257 | td code { 258 | word-break: normal !important; 259 | } 260 | /* For pieces of Markdown rendered in table cells. */ 261 | td p { 262 | margin-top: 0; 263 | margin-bottom: 0; 264 | } 265 | */ 266 | -------------------------------------------------------------------------------- /docs/stylesheets/extra.js: -------------------------------------------------------------------------------- 1 | 2 | /** 3 | 4 | // Table of contents always expanded 5 | document.addEventListener("DOMContentLoaded", function() { 6 | load_navpane(); 7 | }); 8 | 9 | function load_navpane() { 10 | var width = window.innerWidth; 11 | if (width <= 1200) { 12 | return; 13 | } 14 | 15 | var nav = document.getElementsByClassName("md-nav"); 16 | for(var i = 0; i < nav.length; i++) { 17 | if (typeof nav.item(i).style === "undefined") { 18 | continue; 19 | } 20 | 21 | if (nav.item(i).getAttribute("data-md-level") && nav.item(i).getAttribute("data-md-component")) { 22 | nav.item(i).style.display = 'block'; 23 | nav.item(i).style.overflow = 'visible'; 24 | } 25 | } 26 | 27 | var nav = document.getElementsByClassName("md-nav__toggle"); 28 | for(var i = 0; i < nav.length; i++) { 29 | nav.item(i).checked = true; 30 | } 31 | } 32 | **/ 33 | 34 | 35 | // Open links externally. 36 | var links = document.links; 37 | 38 | for (var i = 0, linksLength = links.length; i < linksLength; i++) { 39 | if (links[i].hostname != window.location.hostname) { 40 | links[i].target = '_blank'; 41 | } 42 | } -------------------------------------------------------------------------------- /docs/tutorial/deployment.md: -------------------------------------------------------------------------------- 1 | # Deployment 2 | 3 | traintool can easily deploy your model through a REST API. This allows you to access the model from a website or application without shipping it with your code. 4 | 5 | Deployment uses [FastAPI](https://fastapi.tiangolo.com/) under the hood, which makes the API fully compatible with [OpenAPI/Swagger](https://github.com/OAI/OpenAPI-Specification) and [JSON Schema](http://json-schema.org/). 6 | 7 | 8 | ## Deploying a model 9 | 10 | To deploy a model after training or loading, simply run: 11 | 12 | ```python 13 | model.deploy() 14 | ``` 15 | 16 | Note that the call to `deploy` is blocking, i.e. it should be run in a separate script. Also, it might not work well with Jupyter notebooks. 17 | 18 | !!! tip 19 | By default, the API will run on 127.0.0.1 at port 8000, but you can modify this, e.g. `model.deploy(host=0.0.0.0, port=8001)`. 20 | 21 | 22 | ## Accessing the API 23 | 24 | To access the API, navigate your browser to http://127.0.0.1:8000/. If everything worked out, you should see some basic information about the deployed model like below: 25 | 26 |  27 | 28 | To find out more about the API, check out the API docs at http://127.0.0.1:8000/docs. They contain information about all endpoints and required data types. 29 | 30 | 31 | ## Making predictions 32 | 33 | If you want to make a prediction with the API, you need to make a POST request to the `/predict` endpoint (http://127.0.0.1:8000/predict). The request body should look like this: 34 | 35 | ```json 36 | { 37 | "image": [[[0, 0.5, 0, 1], [0, 1, 0, 0.5]]] 38 | } 39 | ``` 40 | 41 | 42 | `"image"` is a list of lists with shape `color channels x height x width`(here: a grayscale 4x4 image). You can easily get this list format from a numpy array with [numpy.ndarray.tolist](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.tolist.html). Note that you cannot pass raw numpy arrays into the request because they are not JSON serializable. 43 | 44 | As in training, images can be RGB (3 color channels) or grayscale (1 color channel). They will be automatically preprocessed in the same way as the train data. If you used numpy images for training, make sure the image here has the same size and pixel range. If you used files, everything should be converted to the correct format automatically. 45 | 46 | !!! tip 47 | You can easily try out the `/predict` endpoint if you go to the API docs (http://127.0.0.1:8000/docs), click on `/predict` and then on the "Try it out" button on the right. 48 | 49 | The endpoint will return a JSON object which is very similar to the dictionary returned by `model.predict(...)`. Numpy arrays are again converted to lists of lists (convert back with [numpy.asarray](https://numpy.org/doc/stable/reference/generated/numpy.asarray.html)). The JSON should look like this: 50 | 51 | ```json 52 | { 53 | "predicted_class": 2, 54 | "probabilities": [ 55 | 0.1, 56 | 0.8, 57 | 0.1 58 | ], 59 | "runtime": "0:00:00.088831" 60 | } 61 | ``` 62 | 63 | 64 | 65 | 66 | -------------------------------------------------------------------------------- /docs/tutorial/experiment-tracking.md: -------------------------------------------------------------------------------- 1 | # Experiment tracking 2 | 3 | traintool tracks common metrics automatically (e.g. accuracy on train and test set) 4 | and has different options to store and visualize them. 5 | 6 | 7 | ## Tensorboard 8 | 9 | [Tensorboard](https://www.tensorflow.org/tensorboard) is a popular visualization toolkit from Google's tensorflow framework. By default, traintool automatically stores logs for tensorboard along with the model, so that you can visualize the metrics of your experiments. 10 | 11 | To start tensorboard, run on your terminal (from the project dir): 12 | 13 | ```bash 14 | tensorboard --logdir traintool-experiments 15 | ``` 16 | 17 | Navigate your browser to [http://localhost:6006/](http://localhost:6006/) and you should see the tensorboard window: 18 | 19 | INSERT IMAGE HERE 20 | 21 | On the bottom left, you can select all the different runs (same names as the directories in `traintool-experiments`), on the right side you can view the metrics. 22 | 23 | 24 | 25 | ## Comet.ml 26 | 27 | You can store these metrics in [comet.ml](https://www.comet.ml/), a popular platform 28 | for experiment tracking. They offer free accounts (you can sign up with your Github 29 | account), and free premium for students & academia. 30 | 31 | Once you have your account, log in to comet.ml, click on your profile in the upper 32 | right corner, go on settings and on "Generate API Key". Pass this API key along to the 33 | `train` function like this: 34 | 35 | ```python 36 | traintool.train("resnet", train_data=train_data, test_data=test_data, 37 | comet_config={"api_key": YOUR_API_KEY, "project_name": OPTIONAL_PROJECT_NAME}) 38 | ``` 39 | 40 | Now you can head on over to [comet.ml](https://www.comet.ml/) and follow the metrics in 41 | real time! 42 | 43 | -------------------------------------------------------------------------------- /docs/tutorial/intro.md: -------------------------------------------------------------------------------- 1 | # Intro 2 | 3 | This tutorial shows you everything that **traintool** can do. 4 | 5 | We will train a few different models on MNIST, use automated experiment tracking, deploy 6 | the models via REST APIs, and get access to the underlying, raw models. 7 | 8 | 9 | ## Installation 10 | 11 | If you haven't installed traintool yet, now is a good time: 12 | 13 | ```bash 14 | pip install git+https://github.com/jrieke/traintool 15 | ``` 16 | 17 | 18 | ## Dataset 19 | 20 | We will use the MNIST dataset throughout this tutorial. Just in case you never heard of 21 | it: MNIST is a popular dataset for image classification. It contains images of 22 | handwritten digits and the task is to predict which digit is shown on a given image. 23 | Below are some examples. 24 | 25 |  26 | -------------------------------------------------------------------------------- /docs/tutorial/quickstart.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "quickstart.ipynb", 7 | "provenance": [], 8 | "collapsed_sections": [] 9 | }, 10 | "kernelspec": { 11 | "name": "python3", 12 | "display_name": "Python 3" 13 | } 14 | }, 15 | "cells": [ 16 | { 17 | "cell_type": "markdown", 18 | "metadata": { 19 | "id": "ak4L7hONDZdx" 20 | }, 21 | "source": [ 22 | "# Quickstart\n", 23 | "\n", 24 | "Welcome to **traintool**!\n", 25 | "\n", 26 | "In this quickstart, we will train a few models on MNIST. This should give you a rough overview of what traintool can do. \n", 27 | "\n", 28 | "You can follow along interactively in **Google Colab** (a free Jupyter notebook service):\n", 29 | "\n", 30 | "[](https://colab.research.google.com/github/jrieke/traintool/blob/master/docs/tutorial/quickstart.ipynb)\n", 31 | "\n", 32 | "*We highly recommend to use Colab for this tutorial because it gives you free GPU access, which makes training much faster. Important: To enable GPU support, click on \"Runtime\" -> \"Change runtime type\", select \"GPU\" and hit \"Save\".*\n", 33 | "\n", 34 | "---\n" 35 | ] 36 | }, 37 | { 38 | "cell_type": "markdown", 39 | "metadata": { 40 | "id": "ZHmp7C-GEpLH" 41 | }, 42 | "source": [ 43 | "First, let's install traintool:" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "metadata": { 49 | "id": "sBitaippDXMG" 50 | }, 51 | "source": [ 52 | "!pip install -U git+https://github.com/jrieke/traintool" 53 | ], 54 | "execution_count": null, 55 | "outputs": [] 56 | }, 57 | { 58 | "cell_type": "markdown", 59 | "metadata": { 60 | "id": "IJunE3-NEuCW" 61 | }, 62 | "source": [ 63 | "Next, we import traintool and load the mnist dataset (installed with traintool):" 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "metadata": { 69 | "id": "1-dX_WnO1j08" 70 | }, 71 | "source": [ 72 | "import traintool\n", 73 | "import mnist" 74 | ], 75 | "execution_count": null, 76 | "outputs": [] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "metadata": { 81 | "id": "IagipCnO1wfi" 82 | }, 83 | "source": [ 84 | "train_images = mnist.train_images()[:, None] # add color dimension\n", 85 | "train_labels = mnist.train_labels()\n", 86 | "test_images = mnist.test_images()[:, None]\n", 87 | "test_labels = mnist.test_labels()\n", 88 | "\n", 89 | "print(\"Images shape:\", train_images.shape)\n", 90 | "print(\"Labels shape:\", train_labels.shape)" 91 | ], 92 | "execution_count": null, 93 | "outputs": [] 94 | }, 95 | { 96 | "cell_type": "markdown", 97 | "metadata": { 98 | "id": "NWfMdByHE5qM" 99 | }, 100 | "source": [ 101 | "As you can see, all data from the `mnist` package comes as numpy arrays. Images have the shape `num samples x color channels x height x width`. Note that traintool can handle numpy arrays like here as well as image files on your machine (see here)." 102 | ] 103 | }, 104 | { 105 | "cell_type": "markdown", 106 | "metadata": { 107 | "id": "BTyAJVPGMOgs" 108 | }, 109 | "source": [ 110 | "## Your first model" 111 | ] 112 | }, 113 | { 114 | "cell_type": "markdown", 115 | "metadata": { 116 | "id": "x6mlF9hTGaEr" 117 | }, 118 | "source": [ 119 | "Let's train our first model! We will use a very simple model, a support vector classifier (called `svc` in traintool). Training it requires only one line of code:\n", 120 | "\n", 121 | "*Note: We use the config parameter `num_samples` here to train only on a subset of the data to make it faster.*" 122 | ] 123 | }, 124 | { 125 | "cell_type": "code", 126 | "metadata": { 127 | "id": "1OZvJ2kT2CVO" 128 | }, 129 | "source": [ 130 | "svc = traintool.train(\"svc\", \n", 131 | " train_data=[train_images, train_labels], \n", 132 | " test_data=[test_images, test_labels], \n", 133 | " config={\"num_samples\": 500})" 134 | ], 135 | "execution_count": null, 136 | "outputs": [] 137 | }, 138 | { 139 | "cell_type": "markdown", 140 | "metadata": { 141 | "id": "rBm1LQJnGnHl" 142 | }, 143 | "source": [ 144 | "That looks very simple – but under the hood, a lot of stuff happened:\n", 145 | "\n", 146 | "1) traintool printed some **general information** about the experiment: Its ID, which model and configuration was used, where the model is saved and how you can load it later. \n", 147 | "\n", 148 | "2) Then, it **preprocessed** the data. It automatically converted all data to the correct format and applied some light preprocessing that makes sense with this model. \n", 149 | "\n", 150 | "3) It created and **trained** the model. Under the hood, traintool uses different frameworks for this step (e.g. scikit-learn or pytorch) but as a user, you don't have to worry about any of this. After training, traintool printed the resulting accuracies (should be 80-85 % here).\n", 151 | "\n", 152 | "4) traintool automatically **saved** the model, console output and tensorboard logs into a time-stamped folder (see below)." 153 | ] 154 | }, 155 | { 156 | "source": [], 157 | "cell_type": "markdown", 158 | "metadata": {} 159 | }, 160 | { 161 | "source": [ 162 | "## Making predictions" 163 | ], 164 | "cell_type": "markdown", 165 | "metadata": {} 166 | }, 167 | { 168 | "cell_type": "markdown", 169 | "metadata": { 170 | "id": "Mt2auldciv1f" 171 | }, 172 | "source": [ 173 | "To make a prediction with this model, simply use its `predict` function:" 174 | ] 175 | }, 176 | { 177 | "cell_type": "code", 178 | "metadata": { 179 | "id": "dQ6woUskinsi" 180 | }, 181 | "source": [ 182 | "svc.predict(test_images[0])" 183 | ], 184 | "execution_count": null, 185 | "outputs": [] 186 | }, 187 | { 188 | "cell_type": "markdown", 189 | "metadata": { 190 | "id": "XVhM_hkLjAUF" 191 | }, 192 | "source": [ 193 | "This gives you a dictionary with the predicted class and probabilities for each class. Note that for now, `predict` can only process a single image at a time. As the `train` method, it works with numpy arrays and image files (see here)." 194 | ] 195 | }, 196 | { 197 | "cell_type": "markdown", 198 | "metadata": { 199 | "id": "zNhKQxCdMQ9G" 200 | }, 201 | "source": [ 202 | "## Using other models" 203 | ] 204 | }, 205 | { 206 | "cell_type": "markdown", 207 | "metadata": { 208 | "id": "BPamL3DbJ72B" 209 | }, 210 | "source": [ 211 | "\n", 212 | "Ok, now what if you want to train a different model? traintool makes this very easy: You only have to call the `train` function with a different model name – no need to rewrite the implementation or change the data just because you use a model from a different framework!\n", 213 | "\n", 214 | "Let's train a residual network (`resnet18`), a deep neural network from pytorch (make sure to use a GPU!):" 215 | ] 216 | }, 217 | { 218 | "cell_type": "code", 219 | "metadata": { 220 | "id": "lskB7rhw2E37" 221 | }, 222 | "source": [ 223 | "resnet = traintool.train(\"resnet18\", \n", 224 | " train_data=[train_images, train_labels],\n", 225 | " test_data=[test_images, test_labels],\n", 226 | " config={\"batch_size\": 128, \"print_every\": 10, \"num_epochs\": 2, \"num_samples\": 10000})" 227 | ], 228 | "execution_count": null, 229 | "outputs": [] 230 | }, 231 | { 232 | "cell_type": "markdown", 233 | "metadata": { 234 | "id": "rYUNQIAEUSu3" 235 | }, 236 | "source": [ 237 | "And with this simple command, you can train all models supported by traintool! See [here](https://traintool.jrieke.com/models/image-classification/) for a list of models. \n", 238 | "\n", 239 | "As you may have noticed, we set some parameters with the `config` argument above. `config` is the central place to define hyperparameters for training. The supported hyperparameters vary from model to model – it's best to have a look at the overview page linked above. \n", 240 | "\n" 241 | ] 242 | }, 243 | { 244 | "cell_type": "markdown", 245 | "metadata": { 246 | "id": "j8Mdq6AXcomF" 247 | }, 248 | "source": [ 249 | "## Experiment tracking" 250 | ] 251 | }, 252 | { 253 | "cell_type": "markdown", 254 | "metadata": { 255 | "id": "BlkMmxlZctbs" 256 | }, 257 | "source": [ 258 | "traintool automatically keeps track of all experiments you run. Each experiment is stored in a time-stamped folder in `./traintool-experiments`. Have a look at this folder now to see the experiments you ran above! (If you are in Colab, click on the folder icon on the top left).\n", 259 | "\n", 260 | "*Tip: You can disable saving with `save=False`.*" 261 | ] 262 | }, 263 | { 264 | "cell_type": "markdown", 265 | "metadata": { 266 | "id": "cyngzaQteHQ0" 267 | }, 268 | "source": [ 269 | "Each experiment folder contains:\n", 270 | "\n", 271 | "- `info.yml`: General information about the experiment\n", 272 | "- `stdout.log`: The entire console output\n", 273 | "- model files and possibly checkpoints (e.g. the pytorch binary `model.pt` for resnet18)\n", 274 | "- tensorboard logs (see below)" 275 | ] 276 | }, 277 | { 278 | "source": [ 279 | "## Visualizations" 280 | ], 281 | "cell_type": "markdown", 282 | "metadata": {} 283 | }, 284 | { 285 | "cell_type": "markdown", 286 | "metadata": { 287 | "id": "tl41Bq9teL-Y" 288 | }, 289 | "source": [ 290 | "traintool writes all metrics and evaluations to [tensorboard](https://www.tensorflow.org/tensorboard), a powerful visualization platform from tensorflow. Let's start tensorboard now: If you are on a local machine, start a terminal in this dir and type `tensorboard --logdir traintool-experiments`. If you are on Colab, just run the cell below:" 291 | ] 292 | }, 293 | { 294 | "cell_type": "code", 295 | "metadata": { 296 | "id": "Wk2eD-9mCeRn" 297 | }, 298 | "source": [ 299 | "%load_ext tensorboard\n", 300 | "%tensorboard --logdir traintool-experiments/" 301 | ], 302 | "execution_count": null, 303 | "outputs": [] 304 | }, 305 | { 306 | "cell_type": "markdown", 307 | "metadata": { 308 | "id": "Ldx4CtPHe5jC" 309 | }, 310 | "source": [ 311 | "Let's see what's going on here: On the bottom left, you can select individual experiments. On the right, you should by default see scalar metrics: The loss and accuracy for train and test set. You can also click on `Images` at the top to see some sample images from both datasets along with classification results (use the sliders to look at different epochs!).\n", 312 | "\n", 313 | "*Tip: You can also store metrics in comet.ml, see here.*" 314 | ] 315 | }, 316 | { 317 | "cell_type": "markdown", 318 | "metadata": { 319 | "id": "1bzWx_7Lh3Zn" 320 | }, 321 | "source": [ 322 | "## Other functions" 323 | ] 324 | }, 325 | { 326 | "cell_type": "markdown", 327 | "metadata": { 328 | "id": "Z66dv-Rnh6-D" 329 | }, 330 | "source": [ 331 | "Before we end this quickstart, let's look at three other important functions:\n", 332 | "\n", 333 | "- **Loading:** To load a saved model, just pass its ID (or directory path) to `traintool.load(...)`. Check out the line starting with `Load via:` in the console output above – it shows you directly which command to call.\n", 334 | "- **Deployment:** traintool can easily deploy your trained model through a REST API. Simply call `model.deploy()` to start the server (note that this call is blocking!). More information here.\n", 335 | "- **Raw models:** traintool models are implemented in different frameworks, e.g. scikit-learn or pytorch. You can get access to the raw models by calling `model.raw()`. " 336 | ] 337 | }, 338 | { 339 | "cell_type": "markdown", 340 | "metadata": { 341 | "id": "abU5z1zVkWRS" 342 | }, 343 | "source": [ 344 | "\n", 345 | "---" 346 | ] 347 | }, 348 | { 349 | "cell_type": "markdown", 350 | "metadata": { 351 | "id": "D8FUSxzjj6pu" 352 | }, 353 | "source": [ 354 | "That's it! You should now be able to start using traintool. Make sure to read the complete tutorial and documentation to learn more! \n", 355 | "\n", 356 | "Please also consider leaving a ⭐ on our [Github](https://github.com/jrieke/traintool)." 357 | ] 358 | } 359 | ] 360 | } -------------------------------------------------------------------------------- /docs/tutorial/raw-models.md: -------------------------------------------------------------------------------- 1 | # Accessing raw models 2 | 3 | traintool is built on top of powerful machine learning libraries like scikit-learn or p 4 | ytorch. After training, it gives you full access to the raw models with: 5 | 6 | ```python 7 | model.raw() 8 | ``` 9 | 10 | This returns a dict of all underlying model objects. It usually contains the model 11 | itself (`model.raw()["model"]`) but might also contain some other 12 | objects, e.g. data scalers (`model.raw()["scaler"]`). 13 | -------------------------------------------------------------------------------- /docs/tutorial/training.md: -------------------------------------------------------------------------------- 1 | # Training and Prediction 2 | 3 | ## Your first model 4 | 5 | As a first example, we'll train a very simple model: A 6 | [Support Vector Machine](https://en.wikipedia.org/wiki/Support_vector_machine) or SVM. 7 | We will use the image classification dataset MNIST throughout this tutorial, so let's 8 | load it now (the `mnist` package was installed along with traintool): 9 | 10 | ```python 11 | import mnist 12 | train_data = [mnist.train_images(), mnist.train_labels()] 13 | test_data = [mnist.test_images(), mnist.test_labels()] 14 | ``` 15 | 16 | !!! tip 17 | The code above loads the data as numpy arrays but traintool can also deal with 18 | files and pytorch datasets (see here). More data formats will be added soon. 19 | 20 | Training the SVM classifier is very simple now: 21 | 22 | ```python 23 | import traintool 24 | svc = traintool.train("svc", train_data=train_data, test_data=test_data) 25 | ``` 26 | 27 | That's it! traintool will take care of reading and converting the data, applying some 28 | light preprocessing, training and saving the model, and tracking all metrics. It will 29 | also print out the final loss and accuracy (the test accuracy should be around XX % 30 | here). 31 | 32 | 33 | ## Making predictions 34 | 35 | Of course, you can do predictions with the trained model. Let's run it on an image of 36 | the test set: 37 | 38 | ```python 39 | pred = svc.predict(test_data[0][0]) 40 | print("Predicted:", pred["predicted_class"], " - Is:", test_data[1][0]) 41 | ``` 42 | 43 | This should print out the predicted class and the ground truth. Note that `pred` is a 44 | dictionary with the predicted class (`pred["predicted_class"]`) and the probabilities 45 | for each class (`pred["probabilities"]`). 46 | 47 | !!! tip 48 | Again, we use a numpy array for the test image here but traintool can also handle 49 | pytorch tensors and files. You can even pass in a whole batch of images 50 | (e.g. `test_data[0][0:2]`). 51 | 52 | 53 | ## Using other models 54 | 55 | Now, let's check a more advanced model. We will train a [Residual Network](https://arxiv.org/abs/1512.03385) 56 | (ResNet), a modern deep neural network. Usually, training this model instead of an SVM 57 | would require you to use an advanced framework like pytorch or tensorflow and rewrite 58 | most of your codebase. With traintool, it's as simple replacing the model name in the `train` method: 59 | 60 | ```python 61 | resnet = traintool.train("resnet", train_data=train_data, test_data=test_data) 62 | ``` 63 | 64 | And this syntax stays the same for every other model that traintool supports! This makes 65 | it really easy to compare a bunch of different models on your dataset and see what 66 | performs best. 67 | 68 | 69 | ## Custom hyperparameters 70 | 71 | In machine learning, most models have some hyperparameters that control the training 72 | process (e.g. the learning rate). traintool uses sensible defaults specific to each 73 | model, but gives you the flexibility to fully customize everything. 74 | 75 | First, let's find out which hyperparameters the model supports and what their defaults 76 | are: 77 | 78 | ```python 79 | print(traintool.default_hyperparameters("resnet")) 80 | ``` 81 | 82 | This should print out a dictionary of hyperparameters and defaults. Now, we want to 83 | change the learning rate and use a different optimizer. To do this, simply pass a 84 | `config` dict to the train method: 85 | 86 | ```python 87 | config = {"lr": 0.1, "optimizer": "adam"} 88 | better_resnet = traintool.train("resnet", config=config, train_data=train_data, test_data=test_data) 89 | ``` 90 | 91 | 92 | ## Saving and loading models 93 | 94 | There are two options to save a model to disk. Either use the `save` method after 95 | training like this: 96 | 97 | ```python 98 | model = traintool.train("...") 99 | model.save("path/to/dir") 100 | ``` 101 | 102 | Or you can specify an output directory directly during training. This makes sense for 103 | long-running processes, so you don't lose the whole progress in case your machine is 104 | interrupted: 105 | 106 | ```python 107 | model = traintool.train("...", save="path/to/dir") 108 | ``` 109 | 110 | In both cases, loading a model works via: 111 | 112 | ```python 113 | model = traintool.load("path/to/dir") 114 | ``` 115 | 116 | 117 |