├── minai ├── version.py ├── __init__.py ├── _modidx.py └── core.py ├── .github └── workflows │ └── pypi.yml ├── settings.ini ├── README.md ├── CONTRIBUTING.md ├── setup.py ├── CODE_OF_CONDUCT.md ├── .gitignore ├── LICENSE ├── tutorial_01.ipynb └── llm_example.ipynb /minai/version.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.0.5" -------------------------------------------------------------------------------- /minai/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.0.1" 2 | from .core import * 3 | -------------------------------------------------------------------------------- /.github/workflows/pypi.yml: -------------------------------------------------------------------------------- 1 | # This workflows will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | name: Upload Python Package 5 | 6 | on: 7 | release: 8 | types: [created] 9 | 10 | jobs: 11 | deploy: 12 | 13 | runs-on: ubuntu-latest 14 | 15 | steps: 16 | - uses: actions/checkout@v2 17 | - name: Set up Python 18 | uses: actions/setup-python@v2 19 | with: 20 | python-version: '3.x' 21 | - name: Install dependencies 22 | run: | 23 | python -m pip install --upgrade pip 24 | pip install setuptools wheel twine 25 | - name: Build and publish 26 | env: 27 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} 28 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} 29 | run: | 30 | python setup.py sdist bdist_wheel 31 | twine upload dist/* -------------------------------------------------------------------------------- /settings.ini: -------------------------------------------------------------------------------- 1 | [DEFAULT] 2 | lib_name = minai 3 | repo = minai 4 | user = fastai 5 | description = A miniture AI training framework for PyTorch 6 | keywords = python 7 | author = Jeremy Howard and Jonathan Whitaker 8 | author_email = infos@fast.ai 9 | copyright = fast.ai 10 | branch = master 11 | version = 0.0.1 12 | min_python = 3.8 13 | audience = Developers 14 | language = English 15 | custom_sidebar = False 16 | license = apache2 17 | status = 2 18 | nbs_path = . 19 | doc_path = _docs 20 | requirements = torch fastcore fastprogress torcheval matplotlib numpy 21 | dev_requirements = nbdev accelerate datasets torchvision transformers 22 | git_url = https://github.com/fastai/minai/ 23 | lib_path = minai 24 | title = minai 25 | doc_host = https://minai.fast.ai 26 | doc_baseurl = / 27 | host = github 28 | tst_flags = 29 | recursive = True 30 | black_formatting = False 31 | readme_nb = index.ipynb 32 | allowed_metadata_keys = 33 | allowed_cell_metadata_keys = 34 | jupyter_hooks = True 35 | clean_ids = False 36 | conda_user = fastai 37 | clear_all = False 38 | put_version_in_init = True 39 | 40 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # minai 2 | 3 | The mini version of fastai's miniai PyTorch framework created during the fastai course 2022-2023. 4 | 5 | ## Installation 6 | 7 | ```bash 8 | pip install minai 9 | ``` 10 | 11 | or to install from source clone this repo and run: 12 | 13 | ```bash 14 | pip install -e . 15 | ``` 16 | 17 | ## Usage 18 | 19 | This is still a work in progress - I'll add example usage soon. But in general, for examples from the course where you have `from miniai.something import X` you should be able to do `from minai import X`. You can do `import minai as mi` or even `from minai import *` for quick access to all the functions and things, if you're so inclined. 20 | 21 | Tutorial 1 has a minimal example of fitting a model using minai - open it in Google colab [here](https://colab.research.google.com/github/AnswerDotAI/minai/blob/main/tutorial_01.ipynb). 22 | 23 | Tutorial 2 shows callbacks in action on a slightly more complex task - open it in Google colab [here](https://colab.research.google.com/github/AnswerDotAI/minai/blob/main/tutorial_02.ipynb). 24 | 25 | 26 | An example of the library in action: [this notebook](https://colab.research.google.com/drive/1b3CeZB2FfRGr5NPYDVvk34hyZFBtgub5?usp=sharing) shows how to train a diffusion model on spectrograms to generate birdcalls, using minai. It is covered in the final lesson of Part 2 of the FastAI course. 27 | 28 | And a lovely demo of use in the wild is [this report by Thomas Capelle](https://wandb.ai/capecape/miniai_ddpm/reports/Next-Frame-Prediction-Using-Diffusion-The-fastai-Approach--VmlldzozMzcyMTYy) where he uses diffusion models to predict the next frame of an image sequence. 29 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to contribute 2 | 3 | ## How to get started 4 | 5 | Clone the `fastcore` repository. 6 | 7 | Inside the resulting local repository, run the following command to do an 8 | [editable 9 | install](https://stackoverflow.com/questions/35064426/when-would-the-e-editable-option-be-useful-with-pip-install) of the `fastcore` library with tools we use to support development: 10 | 11 | ``` 12 | pip install -e '.[dev]' 13 | ``` 14 | 15 | Run the following command to install git hooks that run automatic scripts during each commit and merge. These scripts strip the notebooks of superfluous metadata (and avoid merge conflicts): 16 | 17 | ``` 18 | nbdev_install_hooks 19 | ``` 20 | 21 | ## Did you find a bug? 22 | 23 | * Ensure the bug was not already reported by searching on GitHub under Issues. 24 | * If you're unable to find an open issue addressing the problem, open a new one. Be sure to include a title and clear description, as much relevant information as possible, and a code sample or an executable test case demonstrating the expected behavior that is not occurring. 25 | * Be sure to add the complete error messages. 26 | 27 | #### Did you write a patch that fixes a bug? 28 | 29 | * Open a new GitHub pull request with the patch. 30 | * Ensure that your PR includes a test that fails without your patch, and pass with it. 31 | * Ensure the PR description clearly describes the problem and solution. Include the relevant issue number if applicable. 32 | 33 | ## PR submission guidelines 34 | 35 | * Keep each PR focused. While it's more convenient, do not combine several unrelated fixes together. Create as many branches as needing to keep each PR focused. 36 | * Do not mix style changes/fixes with "functional" changes. It's very difficult to review such PRs and it most likely get rejected. 37 | * Do not add/remove vertical whitespace. Preserve the original style of the file you edit as much as you can. 38 | * Do not turn an already submitted PR into your development playground. If after you submitted PR, you discovered that more work is needed - close the PR, do the required work and then submit a new PR. Otherwise each of your commits requires attention from maintainers of the project. 39 | * If, however, you submitted a PR and received a request for changes, you should proceed with commits inside that PR, so that the maintainer can see the incremental fixes and won't need to review the whole PR again. In the exception case where you realize it'll take many many commits to complete the requests, then it's probably best to close the PR, do the work and then submit it again. Use common sense where you'd choose one way over another. 40 | 41 | ## Do you want to contribute to the documentation? 42 | 43 | * Docs are automatically created from the notebooks in the nbs folder. 44 | 45 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from pkg_resources import parse_version 2 | from configparser import ConfigParser 3 | import setuptools,re,sys 4 | assert parse_version(setuptools.__version__)>=parse_version('36.2') 5 | 6 | # note: all settings are in settings.ini; edit there, not here 7 | config = ConfigParser(delimiters=['=']) 8 | config.read('settings.ini') 9 | cfg = config['DEFAULT'] 10 | 11 | cfg_keys = 'version description keywords author author_email'.split() 12 | expected = cfg_keys + "lib_name user branch license status min_python audience language".split() 13 | for o in expected: assert o in cfg, "missing expected setting: {}".format(o) 14 | setup_cfg = {o:cfg[o] for o in cfg_keys} 15 | 16 | if len(sys.argv)>1 and sys.argv[1]=='version': 17 | print(setup_cfg['version']) 18 | exit() 19 | 20 | licenses = { 21 | 'apache2': ('Apache Software License 2.0','OSI Approved :: Apache Software License'), 22 | } 23 | statuses = [ '1 - Planning', '2 - Pre-Alpha', '3 - Alpha', 24 | '4 - Beta', '5 - Production/Stable', '6 - Mature', '7 - Inactive' ] 25 | py_versions = '2.0 2.1 2.2 2.3 2.4 2.5 2.6 2.7 3.0 3.1 3.2 3.3 3.4 3.5 3.6 3.7 3.8 3.9 3.10'.split() 26 | min_python = cfg['min_python'] 27 | lic = licenses[cfg['license']] 28 | 29 | requirements = ['pip', 'packaging'] 30 | if cfg.get('requirements'): requirements += cfg.get('requirements','').split() 31 | if cfg.get('pip_requirements'): requirements += cfg.get('pip_requirements','').split() 32 | dev_requirements = (cfg.get('dev_requirements') or '').split() 33 | 34 | long_description = open('README.md', encoding="utf8").read() 35 | # ![png](docs/images/output_13_0.png) 36 | for ext in ['png', 'svg']: 37 | long_description = re.sub(r'!\['+ext+'\]\((.*)\)', '!['+ext+']('+'https://raw.githubusercontent.com/{}/{}'.format(cfg['user'],cfg['lib_name'])+'/'+cfg['branch']+'/\\1)', long_description) 38 | long_description = re.sub(r'src=\"(.*)\.'+ext+'\"', 'src=\"https://raw.githubusercontent.com/{}/{}'.format(cfg['user'],cfg['lib_name'])+'/'+cfg['branch']+'/\\1.'+ext+'\"', long_description) 39 | 40 | setuptools.setup( 41 | name = 'minai', 42 | license = lic[0], 43 | classifiers = [ 44 | 'Development Status :: ' + statuses[int(cfg['status'])], 45 | 'Intended Audience :: ' + cfg['audience'].title(), 46 | 'License :: ' + lic[1], 47 | 'Natural Language :: ' + cfg['language'].title(), 48 | ] + ['Programming Language :: Python :: '+o for o in py_versions[py_versions.index(min_python):]], 49 | url = cfg['git_url'], 50 | packages = setuptools.find_packages(), 51 | include_package_data = True, 52 | install_requires = requirements, 53 | extras_require={ 'dev': dev_requirements }, 54 | python_requires = '>=' + cfg['min_python'], 55 | long_description = long_description, 56 | long_description_content_type = 'text/markdown', 57 | zip_safe = False, 58 | entry_points = { 59 | 'console_scripts': cfg.get('console_scripts','').split(), 60 | 'nbdev': [f'{cfg.get("lib_path")}={cfg.get("lib_path")}._modidx:d'] 61 | }, 62 | **setup_cfg) 63 | 64 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, gender identity and expression, level of experience, 9 | education, socio-economic status, nationality, personal appearance, race, 10 | religion, or sexual identity and orientation. 11 | 12 | Examples of unacceptable behavior by participants include: 13 | 14 | * The use of sexualized language or imagery and unwelcome sexual attention or 15 | advances 16 | * Trolling, insulting/derogatory comments, and personal or political attacks 17 | * Public or private harassment 18 | * Publishing others' private information, such as a physical or electronic 19 | address, without explicit permission 20 | 21 | These examples of unacceptable behaviour are requirements; we will not allow them 22 | in any fast.ai project, including this one. 23 | 24 | ## Our Standards 25 | 26 | Examples of behavior that contributes to creating a positive environment 27 | include: 28 | 29 | * Using welcoming and inclusive language 30 | * Being respectful of differing viewpoints and experiences 31 | * Gracefully accepting constructive criticism 32 | * Focusing on what is best for the community 33 | * Showing empathy towards other community members 34 | 35 | These examples are shown only to help you participate effectively -- they are not 36 | requirements, just requests and guidance. 37 | 38 | ## Our Responsibilities 39 | 40 | Project maintainers are responsible for clarifying the standards of acceptable 41 | behavior and are expected to take appropriate and fair corrective action in 42 | response to any instances of unacceptable behavior. 43 | 44 | Project maintainers have the right and responsibility to remove, edit, or 45 | reject comments, commits, code, wiki edits, issues, and other contributions 46 | that are not aligned to this Code of Conduct, or to ban temporarily or 47 | permanently any contributor for other behaviors that they deem inappropriate, 48 | threatening, offensive, or harmful. 49 | 50 | ## Scope 51 | 52 | This Code of Conduct applies both within project spaces and in public spaces 53 | when an individual is representing the project or its community. Examples of 54 | representing a project or community include using an official project e-mail 55 | address, posting via an official social media account or acting as an appointed 56 | representative at an online or offline event. Representation of a project may be 57 | further defined and clarified by project maintainers. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing or otherwise unacceptable behavior may be 62 | reported by contacting the project team at info@fast.ai. All 63 | complaints will be reviewed and investigated and will result in a response that 64 | is deemed necessary and appropriate to the circumstances. The project team is 65 | obligated to maintain confidentiality with regard to the reporter of an incident. 66 | Further details of specific enforcement policies may be posted separately. 67 | 68 | Project maintainers who do not follow or enforce the Code of Conduct in good 69 | faith may face temporary or permanent repercussions as determined by other 70 | members of the project's leadership. 71 | 72 | ## Attribution 73 | 74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 76 | 77 | [homepage]: https://www.contributor-covenant.org 78 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | models/ 2 | 3 | ### Python ### 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | cover/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | .pybuilder/ 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | # For a library or package, you might want to ignore these files since the code is 90 | # intended to run in multiple environments; otherwise, check them in: 91 | # .python-version 92 | 93 | # pipenv 94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 97 | # install all needed dependencies. 98 | #Pipfile.lock 99 | 100 | # poetry 101 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 102 | # This is especially recommended for binary packages to ensure reproducibility, and is more 103 | # commonly ignored for libraries. 104 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 105 | #poetry.lock 106 | 107 | # pdm 108 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 109 | #pdm.lock 110 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 111 | # in version control. 112 | # https://pdm.fming.dev/#use-with-ide 113 | .pdm.toml 114 | 115 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 116 | __pypackages__/ 117 | 118 | # Celery stuff 119 | celerybeat-schedule 120 | celerybeat.pid 121 | 122 | # SageMath parsed files 123 | *.sage.py 124 | 125 | # Environments 126 | .env 127 | .venv 128 | env/ 129 | venv/ 130 | ENV/ 131 | env.bak/ 132 | venv.bak/ 133 | 134 | # Spyder project settings 135 | .spyderproject 136 | .spyproject 137 | 138 | # Rope project settings 139 | .ropeproject 140 | 141 | # mkdocs documentation 142 | /site 143 | 144 | # mypy 145 | .mypy_cache/ 146 | .dmypy.json 147 | dmypy.json 148 | 149 | # Pyre type checker 150 | .pyre/ 151 | 152 | # pytype static type analyzer 153 | .pytype/ 154 | 155 | # Cython debug symbols 156 | cython_debug/ 157 | 158 | # PyCharm 159 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 160 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 161 | # and can be added to the global gitignore or merged into this file. For a more nuclear 162 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 163 | #.idea/ 164 | 165 | ### Python Patch ### 166 | # Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration 167 | poetry.toml 168 | 169 | # ruff 170 | .ruff_cache/ 171 | 172 | # LSP config files 173 | pyrightconfig.json 174 | 175 | # End of https://www.toptal.com/developers/gitignore/api/python 176 | 177 | 178 | ## wandb 179 | wandb/ 180 | artifacts/ 181 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /minai/_modidx.py: -------------------------------------------------------------------------------- 1 | # Autogenerated by nbdev 2 | 3 | d = { 'settings': { 'branch': 'master', 4 | 'doc_baseurl': '/', 5 | 'doc_host': 'https://minai.fast.ai', 6 | 'git_url': 'https://github.com/fastai/minai/', 7 | 'lib_path': 'minai'}, 8 | 'syms': { 'minai.core': { 'minai.core.AccelerateCB': ('core.html#acceleratecb', 'minai/core.py'), 9 | 'minai.core.AccelerateCB.__init__': ('core.html#acceleratecb.__init__', 'minai/core.py'), 10 | 'minai.core.AccelerateCB.after_fit': ('core.html#acceleratecb.after_fit', 'minai/core.py'), 11 | 'minai.core.AccelerateCB.backward': ('core.html#acceleratecb.backward', 'minai/core.py'), 12 | 'minai.core.AccelerateCB.before_fit': ('core.html#acceleratecb.before_fit', 'minai/core.py'), 13 | 'minai.core.ActivationStats': ('core.html#activationstats', 'minai/core.py'), 14 | 'minai.core.ActivationStats.__init__': ('core.html#activationstats.__init__', 'minai/core.py'), 15 | 'minai.core.ActivationStats.color_dim': ('core.html#activationstats.color_dim', 'minai/core.py'), 16 | 'minai.core.ActivationStats.dead_chart': ('core.html#activationstats.dead_chart', 'minai/core.py'), 17 | 'minai.core.ActivationStats.plot_stats': ('core.html#activationstats.plot_stats', 'minai/core.py'), 18 | 'minai.core.BaseSchedCB': ('core.html#baseschedcb', 'minai/core.py'), 19 | 'minai.core.BaseSchedCB.__init__': ('core.html#baseschedcb.__init__', 'minai/core.py'), 20 | 'minai.core.BaseSchedCB._step': ('core.html#baseschedcb._step', 'minai/core.py'), 21 | 'minai.core.BaseSchedCB.before_fit': ('core.html#baseschedcb.before_fit', 'minai/core.py'), 22 | 'minai.core.BatchSchedCB': ('core.html#batchschedcb', 'minai/core.py'), 23 | 'minai.core.BatchSchedCB.after_batch': ('core.html#batchschedcb.after_batch', 'minai/core.py'), 24 | 'minai.core.BatchTransformCB': ('core.html#batchtransformcb', 'minai/core.py'), 25 | 'minai.core.BatchTransformCB.__init__': ('core.html#batchtransformcb.__init__', 'minai/core.py'), 26 | 'minai.core.BatchTransformCB.before_batch': ('core.html#batchtransformcb.before_batch', 'minai/core.py'), 27 | 'minai.core.Callback': ('core.html#callback', 'minai/core.py'), 28 | 'minai.core.CancelBatchException': ('core.html#cancelbatchexception', 'minai/core.py'), 29 | 'minai.core.CancelEpochException': ('core.html#cancelepochexception', 'minai/core.py'), 30 | 'minai.core.CancelFitException': ('core.html#cancelfitexception', 'minai/core.py'), 31 | 'minai.core.CapturePreds': ('core.html#capturepreds', 'minai/core.py'), 32 | 'minai.core.CapturePreds.after_batch': ('core.html#capturepreds.after_batch', 'minai/core.py'), 33 | 'minai.core.CapturePreds.after_fit': ('core.html#capturepreds.after_fit', 'minai/core.py'), 34 | 'minai.core.CapturePreds.before_fit': ('core.html#capturepreds.before_fit', 'minai/core.py'), 35 | 'minai.core.CycleDL': ('core.html#cycledl', 'minai/core.py'), 36 | 'minai.core.CycleDL.__init__': ('core.html#cycledl.__init__', 'minai/core.py'), 37 | 'minai.core.CycleDL.__iter__': ('core.html#cycledl.__iter__', 'minai/core.py'), 38 | 'minai.core.CycleDL.__len__': ('core.html#cycledl.__len__', 'minai/core.py'), 39 | 'minai.core.DataLoaders': ('core.html#dataloaders', 'minai/core.py'), 40 | 'minai.core.DataLoaders.__init__': ('core.html#dataloaders.__init__', 'minai/core.py'), 41 | 'minai.core.DataLoaders.from_dd': ('core.html#dataloaders.from_dd', 'minai/core.py'), 42 | 'minai.core.Dataset': ('core.html#dataset', 'minai/core.py'), 43 | 'minai.core.Dataset.__getitem__': ('core.html#dataset.__getitem__', 'minai/core.py'), 44 | 'minai.core.Dataset.__init__': ('core.html#dataset.__init__', 'minai/core.py'), 45 | 'minai.core.Dataset.__len__': ('core.html#dataset.__len__', 'minai/core.py'), 46 | 'minai.core.DeviceCB': ('core.html#devicecb', 'minai/core.py'), 47 | 'minai.core.DeviceCB.__init__': ('core.html#devicecb.__init__', 'minai/core.py'), 48 | 'minai.core.DeviceCB.before_batch': ('core.html#devicecb.before_batch', 'minai/core.py'), 49 | 'minai.core.DeviceCB.before_fit': ('core.html#devicecb.before_fit', 'minai/core.py'), 50 | 'minai.core.EpochSchedCB': ('core.html#epochschedcb', 'minai/core.py'), 51 | 'minai.core.EpochSchedCB.after_epoch': ('core.html#epochschedcb.after_epoch', 'minai/core.py'), 52 | 'minai.core.GeneralRelu': ('core.html#generalrelu', 'minai/core.py'), 53 | 'minai.core.GeneralRelu.__init__': ('core.html#generalrelu.__init__', 'minai/core.py'), 54 | 'minai.core.GeneralRelu.forward': ('core.html#generalrelu.forward', 'minai/core.py'), 55 | 'minai.core.HasLearnCB': ('core.html#haslearncb', 'minai/core.py'), 56 | 'minai.core.HasLearnCB.after_fit': ('core.html#haslearncb.after_fit', 'minai/core.py'), 57 | 'minai.core.HasLearnCB.before_fit': ('core.html#haslearncb.before_fit', 'minai/core.py'), 58 | 'minai.core.Hook': ('core.html#hook', 'minai/core.py'), 59 | 'minai.core.Hook.__del__': ('core.html#hook.__del__', 'minai/core.py'), 60 | 'minai.core.Hook.__init__': ('core.html#hook.__init__', 'minai/core.py'), 61 | 'minai.core.Hook.remove': ('core.html#hook.remove', 'minai/core.py'), 62 | 'minai.core.Hooks': ('core.html#hooks', 'minai/core.py'), 63 | 'minai.core.Hooks.__del__': ('core.html#hooks.__del__', 'minai/core.py'), 64 | 'minai.core.Hooks.__delitem__': ('core.html#hooks.__delitem__', 'minai/core.py'), 65 | 'minai.core.Hooks.__enter__': ('core.html#hooks.__enter__', 'minai/core.py'), 66 | 'minai.core.Hooks.__exit__': ('core.html#hooks.__exit__', 'minai/core.py'), 67 | 'minai.core.Hooks.__init__': ('core.html#hooks.__init__', 'minai/core.py'), 68 | 'minai.core.Hooks.remove': ('core.html#hooks.remove', 'minai/core.py'), 69 | 'minai.core.HooksCallback': ('core.html#hookscallback', 'minai/core.py'), 70 | 'minai.core.HooksCallback.__init__': ('core.html#hookscallback.__init__', 'minai/core.py'), 71 | 'minai.core.HooksCallback.__iter__': ('core.html#hookscallback.__iter__', 'minai/core.py'), 72 | 'minai.core.HooksCallback.__len__': ('core.html#hookscallback.__len__', 'minai/core.py'), 73 | 'minai.core.HooksCallback._hookfunc': ('core.html#hookscallback._hookfunc', 'minai/core.py'), 74 | 'minai.core.HooksCallback.after_fit': ('core.html#hookscallback.after_fit', 'minai/core.py'), 75 | 'minai.core.HooksCallback.before_fit': ('core.html#hookscallback.before_fit', 'minai/core.py'), 76 | 'minai.core.LRFinderCB': ('core.html#lrfindercb', 'minai/core.py'), 77 | 'minai.core.LRFinderCB.__init__': ('core.html#lrfindercb.__init__', 'minai/core.py'), 78 | 'minai.core.LRFinderCB.after_batch': ('core.html#lrfindercb.after_batch', 'minai/core.py'), 79 | 'minai.core.LRFinderCB.before_fit': ('core.html#lrfindercb.before_fit', 'minai/core.py'), 80 | 'minai.core.LRFinderCB.cleanup_fit': ('core.html#lrfindercb.cleanup_fit', 'minai/core.py'), 81 | 'minai.core.Learner': ('core.html#learner', 'minai/core.py'), 82 | 'minai.core.Learner.__getattr__': ('core.html#learner.__getattr__', 'minai/core.py'), 83 | 'minai.core.Learner.__init__': ('core.html#learner.__init__', 'minai/core.py'), 84 | 'minai.core.Learner._fit': ('core.html#learner._fit', 'minai/core.py'), 85 | 'minai.core.Learner._one_batch': ('core.html#learner._one_batch', 'minai/core.py'), 86 | 'minai.core.Learner._one_epoch': ('core.html#learner._one_epoch', 'minai/core.py'), 87 | 'minai.core.Learner.callback': ('core.html#learner.callback', 'minai/core.py'), 88 | 'minai.core.Learner.fit': ('core.html#learner.fit', 'minai/core.py'), 89 | 'minai.core.Learner.one_epoch': ('core.html#learner.one_epoch', 'minai/core.py'), 90 | 'minai.core.Learner.training': ('core.html#learner.training', 'minai/core.py'), 91 | 'minai.core.MetricsCB': ('core.html#metricscb', 'minai/core.py'), 92 | 'minai.core.MetricsCB.__init__': ('core.html#metricscb.__init__', 'minai/core.py'), 93 | 'minai.core.MetricsCB._log': ('core.html#metricscb._log', 'minai/core.py'), 94 | 'minai.core.MetricsCB.after_batch': ('core.html#metricscb.after_batch', 'minai/core.py'), 95 | 'minai.core.MetricsCB.after_epoch': ('core.html#metricscb.after_epoch', 'minai/core.py'), 96 | 'minai.core.MetricsCB.before_epoch': ('core.html#metricscb.before_epoch', 'minai/core.py'), 97 | 'minai.core.MetricsCB.before_fit': ('core.html#metricscb.before_fit', 'minai/core.py'), 98 | 'minai.core.MixedPrecision': ('core.html#mixedprecision', 'minai/core.py'), 99 | 'minai.core.MixedPrecision.__init__': ('core.html#mixedprecision.__init__', 'minai/core.py'), 100 | 'minai.core.MixedPrecision.after_loss': ('core.html#mixedprecision.after_loss', 'minai/core.py'), 101 | 'minai.core.MixedPrecision.backward': ('core.html#mixedprecision.backward', 'minai/core.py'), 102 | 'minai.core.MixedPrecision.before_batch': ('core.html#mixedprecision.before_batch', 'minai/core.py'), 103 | 'minai.core.MixedPrecision.before_fit': ('core.html#mixedprecision.before_fit', 'minai/core.py'), 104 | 'minai.core.MixedPrecision.step': ('core.html#mixedprecision.step', 'minai/core.py'), 105 | 'minai.core.MomentumLearner': ('core.html#momentumlearner', 'minai/core.py'), 106 | 'minai.core.MomentumLearner.__init__': ('core.html#momentumlearner.__init__', 'minai/core.py'), 107 | 'minai.core.MomentumLearner.zero_grad': ('core.html#momentumlearner.zero_grad', 'minai/core.py'), 108 | 'minai.core.ProgressCB': ('core.html#progresscb', 'minai/core.py'), 109 | 'minai.core.ProgressCB.__init__': ('core.html#progresscb.__init__', 'minai/core.py'), 110 | 'minai.core.ProgressCB._log': ('core.html#progresscb._log', 'minai/core.py'), 111 | 'minai.core.ProgressCB.after_batch': ('core.html#progresscb.after_batch', 'minai/core.py'), 112 | 'minai.core.ProgressCB.before_epoch': ('core.html#progresscb.before_epoch', 'minai/core.py'), 113 | 'minai.core.ProgressCB.before_fit': ('core.html#progresscb.before_fit', 'minai/core.py'), 114 | 'minai.core.RandCopy': ('core.html#randcopy', 'minai/core.py'), 115 | 'minai.core.RandCopy.__init__': ('core.html#randcopy.__init__', 'minai/core.py'), 116 | 'minai.core.RandCopy.forward': ('core.html#randcopy.forward', 'minai/core.py'), 117 | 'minai.core.RandErase': ('core.html#randerase', 'minai/core.py'), 118 | 'minai.core.RandErase.__init__': ('core.html#randerase.__init__', 'minai/core.py'), 119 | 'minai.core.RandErase.forward': ('core.html#randerase.forward', 'minai/core.py'), 120 | 'minai.core.RecorderCB': ('core.html#recordercb', 'minai/core.py'), 121 | 'minai.core.RecorderCB.__init__': ('core.html#recordercb.__init__', 'minai/core.py'), 122 | 'minai.core.RecorderCB.after_batch': ('core.html#recordercb.after_batch', 'minai/core.py'), 123 | 'minai.core.RecorderCB.before_fit': ('core.html#recordercb.before_fit', 'minai/core.py'), 124 | 'minai.core.RecorderCB.plot': ('core.html#recordercb.plot', 'minai/core.py'), 125 | 'minai.core.SingleBatchCB': ('core.html#singlebatchcb', 'minai/core.py'), 126 | 'minai.core.SingleBatchCB.after_batch': ('core.html#singlebatchcb.after_batch', 'minai/core.py'), 127 | 'minai.core.TfmDataset': ('core.html#tfmdataset', 'minai/core.py'), 128 | 'minai.core.TfmDataset.__getitem__': ('core.html#tfmdataset.__getitem__', 'minai/core.py'), 129 | 'minai.core.TfmDataset.__init__': ('core.html#tfmdataset.__init__', 'minai/core.py'), 130 | 'minai.core.TrainCB': ('core.html#traincb', 'minai/core.py'), 131 | 'minai.core.TrainCB.__init__': ('core.html#traincb.__init__', 'minai/core.py'), 132 | 'minai.core.TrainCB.backward': ('core.html#traincb.backward', 'minai/core.py'), 133 | 'minai.core.TrainCB.get_loss': ('core.html#traincb.get_loss', 'minai/core.py'), 134 | 'minai.core.TrainCB.predict': ('core.html#traincb.predict', 'minai/core.py'), 135 | 'minai.core.TrainCB.step': ('core.html#traincb.step', 'minai/core.py'), 136 | 'minai.core.TrainCB.zero_grad': ('core.html#traincb.zero_grad', 'minai/core.py'), 137 | 'minai.core.TrainLearner': ('core.html#trainlearner', 'minai/core.py'), 138 | 'minai.core.TrainLearner.__init__': ('core.html#trainlearner.__init__', 'minai/core.py'), 139 | 'minai.core.TrainLearner.backward': ('core.html#trainlearner.backward', 'minai/core.py'), 140 | 'minai.core.TrainLearner.get_loss': ('core.html#trainlearner.get_loss', 'minai/core.py'), 141 | 'minai.core.TrainLearner.predict': ('core.html#trainlearner.predict', 'minai/core.py'), 142 | 'minai.core.TrainLearner.step': ('core.html#trainlearner.step', 'minai/core.py'), 143 | 'minai.core.TrainLearner.zero_grad': ('core.html#trainlearner.zero_grad', 'minai/core.py'), 144 | 'minai.core._flops': ('core.html#_flops', 'minai/core.py'), 145 | 'minai.core._get_inp': ('core.html#_get_inp', 'minai/core.py'), 146 | 'minai.core._get_lbl': ('core.html#_get_lbl', 'minai/core.py'), 147 | 'minai.core._get_preds': ('core.html#_get_preds', 'minai/core.py'), 148 | 'minai.core._rand_copy1': ('core.html#_rand_copy1', 'minai/core.py'), 149 | 'minai.core._rand_erase1': ('core.html#_rand_erase1', 'minai/core.py'), 150 | 'minai.core.append_stats': ('core.html#append_stats', 'minai/core.py'), 151 | 'minai.core.capture_preds': ('core.html#capture_preds', 'minai/core.py'), 152 | 'minai.core.clean_ipython_hist': ('core.html#clean_ipython_hist', 'minai/core.py'), 153 | 'minai.core.clean_mem': ('core.html#clean_mem', 'minai/core.py'), 154 | 'minai.core.clean_tb': ('core.html#clean_tb', 'minai/core.py'), 155 | 'minai.core.collate_device': ('core.html#collate_device', 'minai/core.py'), 156 | 'minai.core.collate_dict': ('core.html#collate_dict', 'minai/core.py'), 157 | 'minai.core.get_dls': ('core.html#get_dls', 'minai/core.py'), 158 | 'minai.core.get_grid': ('core.html#get_grid', 'minai/core.py'), 159 | 'minai.core.get_hist': ('core.html#get_hist', 'minai/core.py'), 160 | 'minai.core.get_min': ('core.html#get_min', 'minai/core.py'), 161 | 'minai.core.lr_find': ('core.html#lr_find', 'minai/core.py'), 162 | 'minai.core.rand_copy': ('core.html#rand_copy', 'minai/core.py'), 163 | 'minai.core.rand_erase': ('core.html#rand_erase', 'minai/core.py'), 164 | 'minai.core.run_cbs': ('core.html#run_cbs', 'minai/core.py'), 165 | 'minai.core.set_seed': ('core.html#set_seed', 'minai/core.py'), 166 | 'minai.core.show_image': ('core.html#show_image', 'minai/core.py'), 167 | 'minai.core.show_image_batch': ('core.html#show_image_batch', 'minai/core.py'), 168 | 'minai.core.show_images': ('core.html#show_images', 'minai/core.py'), 169 | 'minai.core.subplots': ('core.html#subplots', 'minai/core.py'), 170 | 'minai.core.summary': ('core.html#summary', 'minai/core.py'), 171 | 'minai.core.to_cpu': ('core.html#to_cpu', 'minai/core.py'), 172 | 'minai.core.to_device': ('core.html#to_device', 'minai/core.py'), 173 | 'minai.core.with_cbs': ('core.html#with_cbs', 'minai/core.py'), 174 | 'minai.core.with_cbs.__call__': ('core.html#with_cbs.__call__', 'minai/core.py'), 175 | 'minai.core.with_cbs.__init__': ('core.html#with_cbs.__init__', 'minai/core.py')}, 176 | 'minai.version': {}}} 177 | -------------------------------------------------------------------------------- /minai/core.py: -------------------------------------------------------------------------------- 1 | # AUTOGENERATED! DO NOT EDIT! File to edit: ../core.ipynb. 2 | 3 | # %% auto 0 4 | __all__ = ['def_device', 'set_seed', 'Dataset', 'TfmDataset', 'get_dls', 'collate_dict', 'DataLoaders', 'show_image', 'subplots', 5 | 'get_grid', 'show_images', 'to_device', 'to_cpu', 'collate_device', 'CancelFitException', 6 | 'CancelBatchException', 'CancelEpochException', 'Callback', 'run_cbs', 'with_cbs', 'CycleDL', 'Learner', 7 | 'TrainLearner', 'TrainCB', 'MomentumLearner', 'DeviceCB', 'SingleBatchCB', 'MetricsCB', 'ProgressCB', 8 | 'CapturePreds', 'capture_preds', 'show_image_batch', 'LRFinderCB', 'lr_find', 'RecorderCB', 'BaseSchedCB', 9 | 'BatchSchedCB', 'EpochSchedCB', 'HasLearnCB', 'MixedPrecision', 'AccelerateCB', 'append_stats', 'get_min', 10 | 'Hook', 'Hooks', 'HooksCallback', 'get_hist', 'ActivationStats', 'summary', 'BatchTransformCB', 11 | 'GeneralRelu', 'rand_erase', 'RandErase', 'rand_copy', 'RandCopy', 'clean_ipython_hist', 'clean_tb', 12 | 'clean_mem'] 13 | 14 | # %% ../core.ipynb 1 15 | import sys, gc, traceback, math, typing, random, numpy as np 16 | from collections.abc import Mapping 17 | from copy import copy 18 | from itertools import zip_longest 19 | from functools import partial, wraps 20 | from operator import attrgetter, itemgetter 21 | 22 | import matplotlib.pyplot as plt 23 | import fastcore.all as fc 24 | from fastprogress import progress_bar, master_bar 25 | 26 | import torch, torch.nn.functional as F 27 | from torch import nn, optim 28 | from torch.utils.data import DataLoader 29 | from torch.optim.lr_scheduler import ExponentialLR 30 | from torch.utils.data import default_collate 31 | 32 | from torcheval.metrics import Mean 33 | 34 | # %% ../core.ipynb 2 35 | try: from accelerate import Accelerator 36 | except: Accelerator=None 37 | 38 | # %% ../core.ipynb 6 39 | def set_seed(seed, deterministic=False): 40 | torch.use_deterministic_algorithms(deterministic) 41 | torch.manual_seed(seed) 42 | random.seed(seed) 43 | np.random.seed(seed) 44 | 45 | # %% ../core.ipynb 14 46 | class Dataset(): 47 | "Simple dataset that combines two collections" 48 | def __init__(self, x, y): self.x,self.y = x,y 49 | def __len__(self): return len(self.x) 50 | def __getitem__(self, i): return self.x[i],self.y[i] 51 | 52 | # %% ../core.ipynb 16 53 | class TfmDataset(Dataset): 54 | "Dataset subclass that transforms items" 55 | def __init__(self, x, y, tfm_x=None, tfm_y=None): 56 | super().__init__(x,y) 57 | self.tfm_x,self.tfm_y = tfm_x,tfm_y 58 | 59 | def __getitem__(self, i): 60 | x,y = self.x[i],self.y[i] 61 | return self.tfm_x(x) if self.tfm_x else x, self.tfm_y(y) if self.tfm_y else y 62 | 63 | # %% ../core.ipynb 18 64 | def get_dls(train_ds, valid_ds, bs, **kwargs): 65 | "Convert train and validation datasets to data loaders" 66 | return (DataLoader(train_ds, batch_size=bs, shuffle=True, **kwargs), 67 | DataLoader(valid_ds, batch_size=bs*2, **kwargs)) 68 | 69 | # %% ../core.ipynb 22 70 | def collate_dict(ds): 71 | get = itemgetter(*ds.features) 72 | def _f(b): return get(default_collate(b)) 73 | return _f 74 | 75 | # %% ../core.ipynb 26 76 | class DataLoaders: 77 | "Convert a `DatasetDict` into a pair of `DataLoader`s" 78 | def __init__(self, *dls): self.train,self.valid = dls[:2] 79 | 80 | @classmethod 81 | def from_dd(cls, dd, batch_size, as_tuple=True, **kwargs): 82 | f = collate_dict(dd['train']) 83 | return cls(*get_dls(*dd.values(), bs=batch_size, collate_fn=f)) 84 | 85 | # %% ../core.ipynb 32 86 | @fc.delegates(plt.Axes.imshow) 87 | def show_image(im, ax=None, figsize=None, title=None, noframe=True, **kwargs): 88 | "Show a PIL or PyTorch image on `ax`." 89 | if fc.hasattrs(im, ('cpu','permute','detach')): 90 | im = im.detach().cpu() 91 | if len(im.shape)==3 and im.shape[0]<5: im=im.permute(1,2,0) 92 | elif not isinstance(im,np.ndarray): im=np.array(im) 93 | if im.shape[-1]==1: im=im[...,0] 94 | if ax is None: _,ax = plt.subplots(figsize=figsize) 95 | ax.imshow(im, **kwargs) 96 | if title is not None: ax.set_title(title) 97 | ax.set_xticks([]) 98 | ax.set_yticks([]) 99 | if noframe: ax.axis('off') 100 | return ax 101 | 102 | # %% ../core.ipynb 33 103 | @fc.delegates(plt.subplots, keep=True) 104 | def subplots( 105 | nrows:int=1, # Number of rows in returned axes grid 106 | ncols:int=1, # Number of columns in returned axes grid 107 | figsize:tuple=None, # Width, height in inches of the returned figure 108 | imsize:int=3, # Size (in inches) of images that will be displayed in the returned figure 109 | suptitle:str=None, # Title to be set to returned figure 110 | **kwargs 111 | ): # fig and axs 112 | "A figure and set of subplots to display images of `imsize` inches" 113 | if figsize is None: figsize=(ncols*imsize, nrows*imsize) 114 | fig,ax = plt.subplots(nrows, ncols, figsize=figsize, **kwargs) 115 | if suptitle is not None: fig.suptitle(suptitle) 116 | if nrows*ncols==1: ax = np.array([ax]) 117 | return fig,ax 118 | 119 | # %% ../core.ipynb 34 120 | @fc.delegates(subplots) 121 | def get_grid( 122 | n:int, # Number of axes 123 | nrows:int=None, # Number of rows, defaulting to `int(math.sqrt(n))` 124 | ncols:int=None, # Number of columns, defaulting to `ceil(n/rows)` 125 | title:str=None, # If passed, title set to the figure 126 | weight:str='bold', # Title font weight 127 | size:int=14, # Title font size 128 | **kwargs, 129 | ): # fig and axs 130 | "Return a grid of `n` axes, `rows` by `cols`" 131 | if nrows: ncols = ncols or int(np.floor(n/nrows)) 132 | elif ncols: nrows = nrows or int(np.ceil(n/ncols)) 133 | else: 134 | nrows = int(math.sqrt(n)) 135 | ncols = int(np.floor(n/nrows)) 136 | fig,axs = subplots(nrows, ncols, **kwargs) 137 | for i in range(n, nrows*ncols): axs.flat[i].set_axis_off() 138 | if title is not None: fig.suptitle(title, weight=weight, size=size) 139 | return fig,axs 140 | 141 | # %% ../core.ipynb 35 142 | @fc.delegates(subplots) 143 | def show_images(ims:list, # Images to show 144 | nrows:typing.Union[int, None]=None, # Number of rows in grid 145 | ncols:typing.Union[int, None]=None, # Number of columns in grid (auto-calculated if None) 146 | titles:typing.Union[list, None]=None, # Optional list of titles for each image 147 | **kwargs): 148 | "Show all images `ims` as subplots with `rows` using `titles`" 149 | axs = get_grid(len(ims), nrows, ncols, **kwargs)[1].flat 150 | for im,t,ax in zip_longest(ims, [] if titles is None else titles, axs): show_image(im, ax=ax, title=t) 151 | 152 | # %% ../core.ipynb 41 153 | def_device = 'mps' if torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu' 154 | 155 | # %% ../core.ipynb 42 156 | def to_device(x, device=def_device): 157 | if isinstance(x, torch.Tensor): return x.to(device) 158 | if isinstance(x, Mapping): return {k:v.to(device) for k,v in x.items()} 159 | return type(x)(to_device(o, device) for o in x) 160 | 161 | # %% ../core.ipynb 43 162 | def to_cpu(x): 163 | if isinstance(x, Mapping): return {k:to_cpu(v) for k,v in x.items()} 164 | if isinstance(x, list): return [to_cpu(o) for o in x] 165 | if isinstance(x, tuple): return tuple(to_cpu(list(x))) 166 | return x.detach().cpu() 167 | 168 | # %% ../core.ipynb 44 169 | def collate_device(b): return to_device(default_collate(b)) 170 | 171 | # %% ../core.ipynb 49 172 | class CancelFitException(Exception): pass 173 | class CancelBatchException(Exception): pass 174 | class CancelEpochException(Exception): pass 175 | 176 | # %% ../core.ipynb 50 177 | class Callback(): order = 0 178 | 179 | # %% ../core.ipynb 51 180 | def run_cbs(cbs, method_nm, learn=None): 181 | for cb in sorted(cbs, key=attrgetter('order')): 182 | method = getattr(cb, method_nm, None) 183 | if method is not None: method(learn) 184 | 185 | # %% ../core.ipynb 52 186 | class with_cbs: 187 | def __init__(self, nm): self.nm = nm 188 | def __call__(self, f): 189 | def _f(o, *args, **kwargs): 190 | try: 191 | o.callback(f'before_{self.nm}') 192 | f(o, *args, **kwargs) 193 | o.callback(f'after_{self.nm}') 194 | except globals()[f'Cancel{self.nm.title()}Exception']: pass 195 | finally: o.callback(f'cleanup_{self.nm}') 196 | return _f 197 | 198 | # %% ../core.ipynb 54 199 | from itertools import cycle 200 | 201 | # %% ../core.ipynb 55 202 | class CycleDL(): 203 | def __init__(self, items, sz=None): 204 | self.items = items 205 | self.sz = len(items) if sz is None else sz 206 | self.it = None 207 | 208 | def __len__(self): return len(self.items) if self.sz is None else self.sz 209 | def __iter__(self): 210 | if self.it is None: self.it = cycle(iter(self.items)) 211 | for i in range(self.sz): yield next(self.it) 212 | 213 | # %% ../core.ipynb 57 214 | class Learner(): 215 | def __init__(self, model, dls=(0,), loss_func=F.mse_loss, lr=0.1, cbs=None, opt_func=optim.SGD, epoch_sz=None): 216 | cbs = fc.L(cbs) 217 | fc.store_attr() 218 | 219 | @with_cbs('batch') 220 | def _one_batch(self): 221 | self.predict() 222 | self.callback('after_predict') 223 | self.get_loss() 224 | self.callback('after_loss') 225 | if self.training: 226 | self.backward() 227 | self.callback('after_backward') 228 | self.step() 229 | self.callback('after_step') 230 | self.zero_grad() 231 | 232 | @with_cbs('epoch') 233 | def _one_epoch(self): 234 | for self.iter,self.batch in enumerate(self.dl): self._one_batch() 235 | 236 | def one_epoch(self, training): 237 | self.model.train(training) 238 | self.dl = self.train_dl if training else self.dls.valid 239 | self._one_epoch() 240 | 241 | @with_cbs('fit') 242 | def _fit(self, train, valid): 243 | self.train_dl = self.dls.train 244 | if self.epoch_sz is not None: self.train_dl = CycleDL(self.train_dl, self.epoch_sz) 245 | for self.epoch in self.epochs: 246 | if train: self.one_epoch(True) 247 | if valid: 248 | with torch.inference_mode(): self.one_epoch(False) 249 | 250 | def fit(self, n_epochs=1, train=True, valid=True, cbs=None, lr=None): 251 | cbs = fc.L(cbs) 252 | self.cbs += cbs 253 | try: 254 | self.n_epochs = n_epochs 255 | self.epochs = range(n_epochs) 256 | if lr is None: lr = self.lr 257 | if self.opt_func: self.opt = self.opt_func(self.model.parameters(), lr) 258 | self._fit(train, valid) 259 | finally: 260 | for cb in cbs: self.cbs.remove(cb) 261 | 262 | def __getattr__(self, name): 263 | if name in ('predict','get_loss','backward','step','zero_grad'): return partial(self.callback, name) 264 | raise AttributeError(name) 265 | 266 | def callback(self, method_nm): run_cbs(self.cbs, method_nm, self) 267 | 268 | @property 269 | def training(self): return self.model.training 270 | 271 | # %% ../core.ipynb 58 272 | def _get_inp(b, n_inp, inp_nm): 273 | if inp_nm is not None: return [b[inp_nm]] 274 | return b[:n_inp] 275 | 276 | def _get_lbl(b, n_inp, lbl_nm): 277 | if lbl_nm is not None: return [b[lbl_nm]] 278 | return b[n_inp:] 279 | 280 | def _get_preds(b, preds_nm): 281 | return b if preds_nm is None else getattr(b, preds_nm) 282 | 283 | # %% ../core.ipynb 59 284 | class TrainLearner(Learner): 285 | def __init__(self, model, dls, loss_func, lr=None, cbs=None, opt_func=torch.optim.SGD, epoch_sz=None, 286 | n_inp=1, inp_nm=None, lbl_nm=None, preds_nm=None): 287 | super().__init__(model, dls, loss_func, lr, cbs, opt_func=opt_func, epoch_sz=epoch_sz) 288 | self.n_inp,self.inp_nm,self.lbl_nm,self.preds_nm = n_inp,inp_nm,lbl_nm,preds_nm 289 | 290 | def predict(self): 291 | inps = _get_inp(self.batch, self.n_inp, self.inp_nm) 292 | self.preds = self.model(*inps) 293 | 294 | def get_loss(self): 295 | lbls = _get_lbl(self.batch, self.n_inp, self.lbl_nm) 296 | preds = _get_preds(self.preds, self.preds_nm) 297 | self.loss = self.loss_func(preds, *lbls) 298 | 299 | def backward(self): self.loss.backward() 300 | def step(self): self.opt.step() 301 | def zero_grad(self): self.opt.zero_grad() 302 | 303 | # %% ../core.ipynb 60 304 | class TrainCB(Callback): 305 | def __init__(self, n_inp=1, inp_nm=None, lbl_nm=None, preds_nm=None): 306 | self.n_inp = n_inp 307 | self.n_inp,self.inp_nm,self.lbl_nm,self.preds_nm = n_inp,inp_nm,lbl_nm,preds_nm 308 | 309 | def predict(self, learn): 310 | inps = _get_inp(learn.batch, self.n_inp, self.inp_nm) 311 | learn.preds = learn.model(*inps) 312 | 313 | def get_loss(self, learn): 314 | lbls = _get_lbl(learn.batch, self.n_inp, self.lbl_nm) 315 | preds = _get_preds(learn.preds, self.preds_nm) 316 | learn.loss = learn.loss_func(preds, *lbls) 317 | 318 | def backward(self, learn): learn.loss.backward() 319 | def step(self, learn): learn.opt.step() 320 | def zero_grad(self, learn): learn.opt.zero_grad() 321 | 322 | # %% ../core.ipynb 61 323 | class MomentumLearner(TrainLearner): 324 | def __init__(self, model, dls, loss_func, lr=None, cbs=None, opt_func=torch.optim.SGD, epoch_sz=None, 325 | n_inp=1, inp_nm=None, lbl_nm=None, preds_nm=None, mom=0.85): 326 | self.mom = mom 327 | super().__init__(model, dls, loss_func, lr, cbs, opt_func=opt_func, epoch_sz=epoch_sz, n_inp=n_inp, 328 | inp_nm=inp_nm, lbl_nm=lbl_nm, preds_nm=preds_nm) 329 | 330 | def zero_grad(self): 331 | with torch.no_grad(): 332 | for p in self.model.parameters(): 333 | if p.grad is not None: 334 | p.grad.detach_() 335 | p.grad *= self.mom 336 | 337 | # %% ../core.ipynb 64 338 | class DeviceCB(Callback): 339 | def __init__(self, device=def_device): fc.store_attr() 340 | def before_fit(self, learn): 341 | if hasattr(learn.model, 'to'): learn.model.to(self.device) 342 | def before_batch(self, learn): learn.batch = to_device(learn.batch, device=self.device) 343 | 344 | # %% ../core.ipynb 66 345 | class SingleBatchCB(Callback): 346 | order = 1 347 | def after_batch(self, learn): raise CancelFitException() 348 | 349 | # %% ../core.ipynb 68 350 | class MetricsCB(Callback): 351 | def __init__(self, *ms, **metrics): 352 | for o in ms: metrics[type(o).__name__] = o 353 | self.metrics = metrics 354 | self.all_metrics = copy(metrics) 355 | self.all_metrics['loss'] = self.loss = Mean() 356 | 357 | def _log(self, d): print(d) 358 | def before_fit(self, learn): learn.metrics = self 359 | def before_epoch(self, learn): [o.reset() for o in self.all_metrics.values()] 360 | 361 | def after_epoch(self, learn): 362 | log = {k:f'{v.compute():.3f}' for k,v in self.all_metrics.items()} 363 | log['epoch'] = learn.epoch 364 | log['train'] = 'train' if learn.model.training else 'eval' 365 | self._log(log) 366 | 367 | def after_batch(self, learn): 368 | x,y,*_ = to_cpu(learn.batch) 369 | for m in self.metrics.values(): m.update(to_cpu(learn.preds), y) 370 | self.loss.update(to_cpu(learn.loss), weight=len(x)) 371 | 372 | # %% ../core.ipynb 70 373 | class ProgressCB(Callback): 374 | order = MetricsCB.order+1 375 | def __init__(self, plot=False): self.plot = plot 376 | def before_fit(self, learn): 377 | learn.epochs = self.mbar = master_bar(learn.epochs) 378 | self.first = True 379 | if hasattr(learn, 'metrics'): learn.metrics._log = self._log 380 | self.losses = [] 381 | 382 | def _log(self, d): 383 | if self.first: 384 | self.mbar.write(list(d), table=True) 385 | self.first = False 386 | self.mbar.write(list(d.values()), table=True) 387 | 388 | def before_epoch(self, learn): learn.dl = progress_bar(learn.dl, leave=False, parent=self.mbar) 389 | def after_batch(self, learn): 390 | learn.dl.comment = f'{learn.loss:.3f}' 391 | if self.plot and hasattr(learn, 'metrics') and learn.training: 392 | self.losses.append(learn.loss.item()) 393 | self.mbar.update_graph([[fc.L.range(self.losses), self.losses]]) 394 | 395 | # %% ../core.ipynb 77 396 | class CapturePreds(Callback): 397 | def before_fit(self, learn): self.all_inps,self.all_preds,self.all_targs = [],[],[] 398 | def after_batch(self, learn): 399 | self.all_inps. append(to_cpu(learn.batch[0])) 400 | self.all_preds.append(to_cpu(learn.preds)) 401 | self.all_targs.append(to_cpu(learn.batch[1])) 402 | def after_fit(self, learn): 403 | self.all_preds,self.all_targs,self.all_inps = map(torch.cat, [self.all_preds,self.all_targs,self.all_inps]) 404 | 405 | # %% ../core.ipynb 78 406 | @fc.patch 407 | def capture_preds(self: Learner, cbs=None, inps=False): 408 | cp = CapturePreds() 409 | with torch.inference_mode(): self.fit(1, train=False, cbs=[cp]+fc.L(cbs)) 410 | res = cp.all_preds,cp.all_targs 411 | if inps: res = res+(cp.all_inps,) 412 | return res 413 | 414 | # %% ../core.ipynb 83 415 | @fc.patch 416 | @fc.delegates(show_images) 417 | def show_image_batch(self:Learner, max_n=9, cbs=None, **kwargs): 418 | self.fit(1, cbs=[SingleBatchCB()]+fc.L(cbs)) 419 | xb,yb = self.batch 420 | feat = fc.nested_attr(self.dls, 'train.dataset.features') 421 | if feat is None: titles = np.array(yb) 422 | else: 423 | names = feat['label'].names 424 | titles = [names[i] for i in yb] 425 | show_images(xb[:max_n], titles=titles[:max_n], **kwargs) 426 | 427 | # %% ../core.ipynb 87 428 | class LRFinderCB(Callback): 429 | def __init__(self, gamma=1.3, max_mult=3): fc.store_attr() 430 | 431 | def before_fit(self, learn): 432 | self.sched = ExponentialLR(learn.opt, self.gamma) 433 | self.lrs,self.losses = [],[] 434 | self.min = math.inf 435 | 436 | def after_batch(self, learn): 437 | if not learn.training: raise CancelEpochException() 438 | self.lrs.append(learn.opt.param_groups[0]['lr']) 439 | loss = to_cpu(learn.loss) 440 | self.losses.append(loss) 441 | if loss < self.min: self.min = loss 442 | if loss > self.min*self.max_mult: 443 | raise CancelFitException() 444 | self.sched.step() 445 | 446 | def cleanup_fit(self, learn): 447 | plt.plot(self.lrs, self.losses) 448 | plt.xscale('log') 449 | 450 | # %% ../core.ipynb 88 451 | @fc.patch 452 | def lr_find(self:Learner, gamma=1.3, max_mult=3, start_lr=1e-5, max_epochs=10): 453 | self.fit(max_epochs, lr=start_lr, cbs=LRFinderCB(gamma=gamma, max_mult=max_mult)) 454 | 455 | # %% ../core.ipynb 91 456 | class RecorderCB(Callback): 457 | def __init__(self, **d): self.d = d 458 | def before_fit(self, learn): 459 | self.recs = {k:[] for k in self.d} 460 | self.pg = learn.opt.param_groups[0] 461 | 462 | def after_batch(self, learn): 463 | if not learn.training: return 464 | for k,v in self.d.items(): 465 | self.recs[k].append(v(self)) 466 | 467 | def plot(self): 468 | for k,v in self.recs.items(): 469 | plt.plot(v, label=k) 470 | plt.legend() 471 | plt.show() 472 | 473 | # %% ../core.ipynb 92 474 | class BaseSchedCB(Callback): 475 | def __init__(self, sched): self.sched = sched 476 | def before_fit(self, learn): self.schedo = self.sched(learn.opt) 477 | def _step(self, learn): 478 | if learn.training: self.schedo.step() 479 | 480 | # %% ../core.ipynb 93 481 | class BatchSchedCB(BaseSchedCB): 482 | def after_batch(self, learn): self._step(learn) 483 | 484 | # %% ../core.ipynb 94 485 | class EpochSchedCB(BaseSchedCB): 486 | def after_epoch(self, learn): self._step(learn) 487 | 488 | # %% ../core.ipynb 95 489 | class HasLearnCB(Callback): 490 | def before_fit(self, learn): self.learn = learn 491 | def after_fit(self, learn): self.learn = None 492 | 493 | # %% ../core.ipynb 97 494 | class MixedPrecision(TrainCB): 495 | order = DeviceCB.order+10 496 | def __init__(self, n_inp=1, dtype=torch.bfloat16): 497 | super().__init__(n_inp=n_inp) 498 | self.dtype=dtype 499 | 500 | def before_fit(self, learn): self.scaler = torch.cuda.amp.GradScaler() 501 | 502 | def before_batch(self, learn): 503 | self.autocast = torch.autocast("cuda", dtype=self.dtype) 504 | self.autocast.__enter__() 505 | 506 | def after_loss(self, learn): self.autocast.__exit__(None, None, None) 507 | 508 | def backward(self, learn): self.scaler.scale(learn.loss).backward() 509 | 510 | def step(self, learn): 511 | self.scaler.step(learn.opt) 512 | self.scaler.update() 513 | 514 | # %% ../core.ipynb 99 515 | class AccelerateCB(TrainCB): 516 | order = DeviceCB.order+10 517 | def __init__(self, n_inp=1, mixed_precision="fp16"): 518 | super().__init__(n_inp=n_inp) 519 | self.acc = Accelerator(mixed_precision=mixed_precision) 520 | 521 | def before_fit(self, learn): 522 | learn.model,learn.opt,learn.dls.train,learn.dls.valid = self.acc.prepare( 523 | learn.model, learn.opt, learn.dls.train, learn.dls.valid) 524 | 525 | def after_fit(self, learn): learn.model = self.acc.unwrap_model(learn.model) 526 | def backward(self, learn): self.acc.backward(learn.loss) 527 | 528 | # %% ../core.ipynb 101 529 | def append_stats(hook, mod, inp, outp): 530 | if not hasattr(hook,'stats'): hook.stats = ([],[],[]) 531 | acts = to_cpu(outp).float() 532 | hook.stats[0].append(acts.mean()) 533 | hook.stats[1].append(acts.std()) 534 | hook.stats[2].append(acts.abs().histc(40,0,10)) 535 | 536 | # %% ../core.ipynb 102 537 | def get_min(h): 538 | h1 = torch.stack(h.stats[2]).t().float() 539 | return h1[0]/h1.sum(0) 540 | 541 | # %% ../core.ipynb 103 542 | class Hook(): 543 | def __init__(self, m, f): self.hook = m.register_forward_hook(partial(f, self)) 544 | def remove(self): self.hook.remove() 545 | def __del__(self): self.remove() 546 | 547 | # %% ../core.ipynb 104 548 | class Hooks(list): 549 | def __init__(self, ms, f): super().__init__([Hook(m, f) for m in ms]) 550 | def __enter__(self, *args): return self 551 | def __exit__ (self, *args): self.remove() 552 | def __del__(self): self.remove() 553 | def __delitem__(self, i): 554 | self[i].remove() 555 | super().__delitem__(i) 556 | def remove(self): 557 | for h in self: h.remove() 558 | 559 | # %% ../core.ipynb 105 560 | class HooksCallback(Callback): 561 | def __init__(self, hookfunc, mod_filter=fc.noop, on_train=True, on_valid=False, mods=None): 562 | fc.store_attr() 563 | super().__init__() 564 | 565 | def before_fit(self, learn): 566 | if self.mods: mods=self.mods 567 | else: mods = fc.filter_ex(learn.model.modules(), self.mod_filter) 568 | self.hooks = Hooks(mods, partial(self._hookfunc, learn)) 569 | 570 | def _hookfunc(self, learn, *args, **kwargs): 571 | if (self.on_train and learn.training) or (self.on_valid and not learn.training): self.hookfunc(*args, **kwargs) 572 | 573 | def after_fit(self, learn): self.hooks.remove() 574 | def __iter__(self): return iter(self.hooks) 575 | def __len__(self): return len(self.hooks) 576 | 577 | # %% ../core.ipynb 106 578 | # Thanks to @ste for initial version of histgram plotting code 579 | def get_hist(h): return torch.stack(h.stats[2]).t().float().log1p() 580 | 581 | # %% ../core.ipynb 107 582 | class ActivationStats(HooksCallback): 583 | def __init__(self, mod_filter=fc.noop): super().__init__(append_stats, mod_filter) 584 | 585 | def color_dim(self, figsize=(11,5)): 586 | fig,axes = get_grid(len(self), figsize=figsize) 587 | for ax,h in zip(axes.flat, self): 588 | show_image(get_hist(h), ax, origin='lower') 589 | 590 | def dead_chart(self, figsize=(11,5)): 591 | fig,axes = get_grid(len(self), figsize=figsize) 592 | for ax,h in zip(axes.flatten(), self): 593 | ax.plot(get_min(h)) 594 | ax.set_ylim(0,1) 595 | 596 | def plot_stats(self, figsize=(10,4)): 597 | fig,axs = plt.subplots(1,2, figsize=figsize) 598 | for h in self: 599 | for i in 0,1: axs[i].plot(h.stats[i]) 600 | axs[0].set_title('Means') 601 | axs[1].set_title('Stdevs') 602 | plt.legend(fc.L.range(self)) 603 | 604 | # %% ../core.ipynb 114 605 | def _flops(x, h, w): 606 | if x.dim()<3: return x.numel() 607 | if x.dim()==4: return x.numel()*h*w 608 | 609 | # %% ../core.ipynb 115 610 | @fc.patch 611 | def summary(self:Learner): 612 | res = '|Module|Input|Output|Num params|MFLOPS|\n|--|--|--|--|--|\n' 613 | totp,totf = 0,0 614 | def _f(hook, mod, inp, outp): 615 | nonlocal res,totp,totf 616 | nparms = sum(o.numel() for o in mod.parameters()) 617 | totp += nparms 618 | *_,h,w = outp.shape 619 | flops = sum(_flops(o, h, w) for o in mod.parameters())/1e6 620 | totf += flops 621 | res += f'|{type(mod).__name__}|{tuple(inp[0].shape)}|{tuple(outp.shape)}|{nparms}|{flops:.1f}|\n' 622 | with Hooks(self.model, _f) as hooks: self.fit(1, lr=1, cbs=SingleBatchCB()) 623 | print(f"Tot params: {totp}; MFLOPS: {totf:.1f}") 624 | if fc.IN_NOTEBOOK: 625 | from IPython.display import Markdown 626 | return Markdown(res) 627 | else: print(res) 628 | 629 | # %% ../core.ipynb 117 630 | class BatchTransformCB(Callback): 631 | def __init__(self, tfm, on_train=True, on_val=True): fc.store_attr() 632 | 633 | def before_batch(self, learn): 634 | if (self.on_train and learn.training) or (self.on_val and not learn.training): 635 | learn.batch = self.tfm(learn.batch) 636 | 637 | # %% ../core.ipynb 119 638 | class GeneralRelu(nn.Module): 639 | def __init__(self, leak=None, sub=None, maxv=None): 640 | super().__init__() 641 | self.leak,self.sub,self.maxv = leak,sub,maxv 642 | 643 | def forward(self, x): 644 | x = F.leaky_relu(x,self.leak) if self.leak is not None else F.relu(x) 645 | if self.sub is not None: x -= self.sub 646 | if self.maxv is not None: x.clamp_max_(self.maxv) 647 | return x 648 | 649 | # %% ../core.ipynb 122 650 | def _rand_erase1(x, pct, xm, xs, mn, mx): 651 | szx = int(pct*x.shape[-2]) 652 | szy = int(pct*x.shape[-1]) 653 | stx = int(random.random()*(1-pct)*x.shape[-2]) 654 | sty = int(random.random()*(1-pct)*x.shape[-1]) 655 | nn.init.normal_(x[:,:,stx:stx+szx,sty:sty+szy], mean=xm, std=xs) 656 | x.clamp_(mn, mx) 657 | 658 | # %% ../core.ipynb 123 659 | def rand_erase(x, pct=0.2, min_num=0, max_num = 4): 660 | xm,xs,mn,mx = x.mean(),x.std(),x.min(),x.max() 661 | num = random.randint(min_num, max_num) 662 | for i in range(num): _rand_erase1(x, pct, xm, xs, mn, mx) 663 | return x 664 | 665 | # %% ../core.ipynb 125 666 | class RandErase(nn.Module): 667 | def __init__(self, pct=0.2, max_num=4): 668 | super().__init__() 669 | self.pct,self.max_num = pct,max_num 670 | def forward(self, x): return rand_erase(x, self.pct, self.max_num) 671 | 672 | # %% ../core.ipynb 126 673 | def _rand_copy1(x, pct): 674 | szx = int(pct*x.shape[-2]) 675 | szy = int(pct*x.shape[-1]) 676 | stx1 = int(random.random()*(1-pct)*x.shape[-2]) 677 | sty1 = int(random.random()*(1-pct)*x.shape[-1]) 678 | stx2 = int(random.random()*(1-pct)*x.shape[-2]) 679 | sty2 = int(random.random()*(1-pct)*x.shape[-1]) 680 | x[:,:,stx1:stx1+szx,sty1:sty1+szy] = x[:,:,stx2:stx2+szx,sty2:sty2+szy] 681 | 682 | # %% ../core.ipynb 127 683 | def rand_copy(x, pct=0.2, min_num=0, max_num=4): 684 | num = random.randint(min_num, max_num) 685 | for i in range(num): _rand_copy1(x, pct) 686 | return x 687 | 688 | # %% ../core.ipynb 129 689 | class RandCopy(nn.Module): 690 | def __init__(self, pct=0.2, max_num=4): 691 | super().__init__() 692 | self.pct,self.max_num = pct,max_num 693 | def forward(self, x): return rand_copy(x, self.pct, self.max_num) 694 | 695 | # %% ../core.ipynb 131 696 | def clean_ipython_hist(): 697 | # Code in this function mainly copied from IPython source 698 | if not 'get_ipython' in globals(): return 699 | ip = get_ipython() 700 | user_ns = ip.user_ns 701 | ip.displayhook.flush() 702 | pc = ip.displayhook.prompt_count + 1 703 | for n in range(1, pc): user_ns.pop('_i'+repr(n),None) 704 | user_ns.update(dict(_i='',_ii='',_iii='')) 705 | hm = ip.history_manager 706 | hm.input_hist_parsed[:] = [''] * pc 707 | hm.input_hist_raw[:] = [''] * pc 708 | hm._i = hm._ii = hm._iii = hm._i00 = '' 709 | 710 | # %% ../core.ipynb 132 711 | def clean_tb(): 712 | # h/t Piotr Czapla 713 | if hasattr(sys, 'last_traceback'): 714 | traceback.clear_frames(sys.last_traceback) 715 | delattr(sys, 'last_traceback') 716 | if hasattr(sys, 'last_type'): delattr(sys, 'last_type') 717 | if hasattr(sys, 'last_value'): delattr(sys, 'last_value') 718 | 719 | # %% ../core.ipynb 133 720 | def clean_mem(): 721 | clean_tb() 722 | clean_ipython_hist() 723 | gc.collect() 724 | torch.cuda.empty_cache() 725 | -------------------------------------------------------------------------------- /tutorial_01.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Miniminiai Tutorial 1: Introduction\n", 8 | "\n", 9 | "This is a minimal example to get you started, showing the basic flow of training a model using (mini)miniai.\n", 10 | "\n" 11 | ] 12 | }, 13 | { 14 | "cell_type": "markdown", 15 | "metadata": {}, 16 | "source": [ 17 | "# Setup\n", 18 | "\n", 19 | "Installing the library and importing a few useful things" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": null, 25 | "metadata": {}, 26 | "outputs": [], 27 | "source": [ 28 | "# Imports\n", 29 | "import torch.nn as nn\n", 30 | "import minai as mi # So we can see what is from minai in this tutorial\n", 31 | "import torchvision.transforms.functional as TF\n", 32 | "from datasets import load_dataset\n", 33 | "from torcheval.metrics import MulticlassAccuracy" 34 | ] 35 | }, 36 | { 37 | "cell_type": "markdown", 38 | "metadata": {}, 39 | "source": [ 40 | "# Preparing the DataLoaders\n", 41 | "\n", 42 | "The dataloaders is just a tiny wrapper around two pytorch dataloaders, dls.train and dls.valid. You can create your dataloaders with `dls=DataLoaders(train_dl, valid_dl)` or use the `from_dd` method like we do here to load them from a DatasetDict (for datasets from huggingface with the datasets library):\n", 43 | "\n" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": null, 49 | "metadata": {}, 50 | "outputs": [ 51 | { 52 | "name": "stderr", 53 | "output_type": "stream", 54 | "text": [ 55 | "WARNING:datasets.builder:Found cached dataset mnist (/root/.cache/huggingface/datasets/mnist/mnist/1.0.0/fda16c03c4ecfb13f165ba7e29cf38129ce035011519968cdaf74894ce91c9d4)\n" 56 | ] 57 | }, 58 | { 59 | "data": { 60 | "application/vnd.jupyter.widget-view+json": { 61 | "model_id": "65be2a364ec8486a9bd46d20deed89a1", 62 | "version_major": 2, 63 | "version_minor": 0 64 | }, 65 | "text/plain": [ 66 | " 0%| | 0/2 [00:00" 111 | ] 112 | }, 113 | "metadata": {}, 114 | "output_type": "display_data" 115 | } 116 | ], 117 | "source": [ 118 | "# The library has some useful utility functions such as:\n", 119 | "mi.show_images(xb[:5], ncols=5, titles=list(yb[:5].numpy()))" 120 | ] 121 | }, 122 | { 123 | "cell_type": "markdown", 124 | "metadata": {}, 125 | "source": [ 126 | "You can do a lot of fancy stuff with your collate function if your data requires more processing or augmentation." 127 | ] 128 | }, 129 | { 130 | "cell_type": "markdown", 131 | "metadata": {}, 132 | "source": [ 133 | "# Prepare the Model\n", 134 | "\n", 135 | "The model can be pretty much any PyTorch model, no changes needed here :)" 136 | ] 137 | }, 138 | { 139 | "cell_type": "code", 140 | "execution_count": null, 141 | "metadata": {}, 142 | "outputs": [], 143 | "source": [ 144 | "model = nn.Sequential(\n", 145 | " nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=1),\n", 146 | " nn.ReLU(),\n", 147 | " nn.Conv2d(16, 16, kernel_size=3, stride=2, padding=1),\n", 148 | " nn.ReLU(),\n", 149 | " nn.Conv2d(16, 10, kernel_size=3, stride=2, padding=1),\n", 150 | " nn.ReLU(),\n", 151 | " nn.AdaptiveAvgPool2d(1),\n", 152 | " nn.Flatten()\n", 153 | ")" 154 | ] 155 | }, 156 | { 157 | "cell_type": "markdown", 158 | "metadata": {}, 159 | "source": [ 160 | "# Create and Fit the Learner\n", 161 | "\n", 162 | "The heart of (mini)miniai is the Learner class. It pulls together the data, model and loss function, and can be extended in all sorts of cool ways using callbacks. Here's a somewhat minimal example, training our model on this classification task and plotting some stats as we do so:" 163 | ] 164 | }, 165 | { 166 | "cell_type": "code", 167 | "execution_count": null, 168 | "metadata": {}, 169 | "outputs": [ 170 | { 171 | "data": { 172 | "text/html": [ 173 | "\n", 174 | "\n" 189 | ], 190 | "text/plain": [ 191 | "" 192 | ] 193 | }, 194 | "metadata": {}, 195 | "output_type": "display_data" 196 | }, 197 | { 198 | "data": { 199 | "text/html": [ 200 | "\n", 201 | " \n", 202 | " \n", 203 | " \n", 204 | " \n", 205 | " \n", 206 | " \n", 207 | " \n", 208 | " \n", 209 | " \n", 210 | " \n", 211 | " \n", 212 | " \n", 213 | " \n", 214 | " \n", 215 | " \n", 216 | " \n", 217 | " \n", 218 | " \n", 219 | " \n", 220 | " \n", 221 | " \n", 222 | " \n", 223 | " \n", 224 | " \n", 225 | " \n", 226 | " \n", 227 | " \n", 228 | " \n", 229 | " \n", 230 | " \n", 231 | " \n", 232 | " \n", 233 | " \n", 234 | " \n", 235 | " \n", 236 | " \n", 237 | " \n", 238 | " \n", 239 | " \n", 240 | " \n", 241 | " \n", 242 | " \n", 243 | " \n", 244 | " \n", 245 | " \n", 246 | " \n", 247 | "
accuracylossepochtrain
0.3371.8530train
0.5871.2610eval
0.7070.9061train
0.8010.6481eval
0.8190.5862train
0.8380.5222eval
" 248 | ], 249 | "text/plain": [ 250 | "" 251 | ] 252 | }, 253 | "metadata": {}, 254 | "output_type": "display_data" 255 | } 256 | ], 257 | "source": [ 258 | "# There are callbacks for all sorts of things, here are some common ones:\n", 259 | "cbs = [\n", 260 | " mi.TrainCB(), # Handles the core steps in the training loop. Can be left out if using TrainLearner\n", 261 | " mi.DeviceCB(), # Handles making sure data and model are on the right device\n", 262 | " mi.MetricsCB(accuracy=MulticlassAccuracy()), # Keep track of any relevant metrics\n", 263 | " mi.ProgressCB(), # Displays metrics and loss during training, optionally plot=True for a pretty graph\n", 264 | "]\n", 265 | "\n", 266 | "# Nothing fancy for the loss function\n", 267 | "loss_fn = nn.CrossEntropyLoss()\n", 268 | "\n", 269 | "# The learner takes a model, dataloaders and loss function, plus some optional extras like a list of callbacks\n", 270 | "learn = mi.Learner(model, dls, loss_fn, lr=0.1, cbs=cbs)\n", 271 | "\n", 272 | "# And fit does the magic :)\n", 273 | "learn.fit(3)" 274 | ] 275 | }, 276 | { 277 | "attachments": {}, 278 | "cell_type": "markdown", 279 | "metadata": {}, 280 | "source": [ 281 | "When I get around to making more complex tutorials I'll try to show some of the other existing callbacks in action, but for most tasks this is pretty much all you need! The model (`learn.model`) is just a regular PyTorch model, so you can save it and load it later somewhere else without needing any minai code at all. " 282 | ] 283 | } 284 | ], 285 | "metadata": { 286 | "kernelspec": { 287 | "display_name": "python3", 288 | "language": "python", 289 | "name": "python3" 290 | }, 291 | "widgets": { 292 | "application/vnd.jupyter.widget-state+json": { 293 | "243a6081666e4c1497d740e8c895b52a": { 294 | "model_module": "@jupyter-widgets/base", 295 | "model_module_version": "1.2.0", 296 | "model_name": "LayoutModel", 297 | "state": { 298 | "_model_module": "@jupyter-widgets/base", 299 | "_model_module_version": "1.2.0", 300 | "_model_name": "LayoutModel", 301 | "_view_count": null, 302 | "_view_module": "@jupyter-widgets/base", 303 | "_view_module_version": "1.2.0", 304 | "_view_name": "LayoutView", 305 | "align_content": null, 306 | "align_items": null, 307 | "align_self": null, 308 | "border": null, 309 | "bottom": null, 310 | "display": null, 311 | "flex": null, 312 | "flex_flow": null, 313 | "grid_area": null, 314 | "grid_auto_columns": null, 315 | "grid_auto_flow": null, 316 | "grid_auto_rows": null, 317 | "grid_column": null, 318 | "grid_gap": null, 319 | "grid_row": null, 320 | "grid_template_areas": null, 321 | "grid_template_columns": null, 322 | "grid_template_rows": null, 323 | "height": null, 324 | "justify_content": null, 325 | "justify_items": null, 326 | "left": null, 327 | "margin": null, 328 | "max_height": null, 329 | "max_width": null, 330 | "min_height": null, 331 | "min_width": null, 332 | "object_fit": null, 333 | "object_position": null, 334 | "order": null, 335 | "overflow": null, 336 | "overflow_x": null, 337 | "overflow_y": null, 338 | "padding": null, 339 | "right": null, 340 | "top": null, 341 | "visibility": null, 342 | "width": null 343 | } 344 | }, 345 | "2e415547b49048fb8a7fd9c0cf93fd8c": { 346 | "model_module": "@jupyter-widgets/base", 347 | "model_module_version": "1.2.0", 348 | "model_name": "LayoutModel", 349 | "state": { 350 | "_model_module": "@jupyter-widgets/base", 351 | "_model_module_version": "1.2.0", 352 | "_model_name": "LayoutModel", 353 | "_view_count": null, 354 | "_view_module": "@jupyter-widgets/base", 355 | "_view_module_version": "1.2.0", 356 | "_view_name": "LayoutView", 357 | "align_content": null, 358 | "align_items": null, 359 | "align_self": null, 360 | "border": null, 361 | "bottom": null, 362 | "display": null, 363 | "flex": null, 364 | "flex_flow": null, 365 | "grid_area": null, 366 | "grid_auto_columns": null, 367 | "grid_auto_flow": null, 368 | "grid_auto_rows": null, 369 | "grid_column": null, 370 | "grid_gap": null, 371 | "grid_row": null, 372 | "grid_template_areas": null, 373 | "grid_template_columns": null, 374 | "grid_template_rows": null, 375 | "height": null, 376 | "justify_content": null, 377 | "justify_items": null, 378 | "left": null, 379 | "margin": null, 380 | "max_height": null, 381 | "max_width": null, 382 | "min_height": null, 383 | "min_width": null, 384 | "object_fit": null, 385 | "object_position": null, 386 | "order": null, 387 | "overflow": null, 388 | "overflow_x": null, 389 | "overflow_y": null, 390 | "padding": null, 391 | "right": null, 392 | "top": null, 393 | "visibility": null, 394 | "width": null 395 | } 396 | }, 397 | "2fa7aece2e854579b6856dc227d90383": { 398 | "model_module": "@jupyter-widgets/controls", 399 | "model_module_version": "1.5.0", 400 | "model_name": "HTMLModel", 401 | "state": { 402 | "_dom_classes": [], 403 | "_model_module": "@jupyter-widgets/controls", 404 | "_model_module_version": "1.5.0", 405 | "_model_name": "HTMLModel", 406 | "_view_count": null, 407 | "_view_module": "@jupyter-widgets/controls", 408 | "_view_module_version": "1.5.0", 409 | "_view_name": "HTMLView", 410 | "description": "", 411 | "description_tooltip": null, 412 | "layout": "IPY_MODEL_f19e76960da346debdde6e0b89f291cf", 413 | "placeholder": "​", 414 | "style": "IPY_MODEL_b94f987ac374487da181bf3c73d3dec4", 415 | "value": " 2/2 [00:00<00:00, 92.52it/s]" 416 | } 417 | }, 418 | "44e64afc3b3f49f9b81b05a88fc924c8": { 419 | "model_module": "@jupyter-widgets/controls", 420 | "model_module_version": "1.5.0", 421 | "model_name": "ProgressStyleModel", 422 | "state": { 423 | "_model_module": "@jupyter-widgets/controls", 424 | "_model_module_version": "1.5.0", 425 | "_model_name": "ProgressStyleModel", 426 | "_view_count": null, 427 | "_view_module": "@jupyter-widgets/base", 428 | "_view_module_version": "1.2.0", 429 | "_view_name": "StyleView", 430 | "bar_color": null, 431 | "description_width": "" 432 | } 433 | }, 434 | "65be2a364ec8486a9bd46d20deed89a1": { 435 | "model_module": "@jupyter-widgets/controls", 436 | "model_module_version": "1.5.0", 437 | "model_name": "HBoxModel", 438 | "state": { 439 | "_dom_classes": [], 440 | "_model_module": "@jupyter-widgets/controls", 441 | "_model_module_version": "1.5.0", 442 | "_model_name": "HBoxModel", 443 | "_view_count": null, 444 | "_view_module": "@jupyter-widgets/controls", 445 | "_view_module_version": "1.5.0", 446 | "_view_name": "HBoxView", 447 | "box_style": "", 448 | "children": [ 449 | "IPY_MODEL_862e35aad3494ea1837743ae186ec3b9", 450 | "IPY_MODEL_a3c7c7de120549ecaabff8615a6ecc10", 451 | "IPY_MODEL_2fa7aece2e854579b6856dc227d90383" 452 | ], 453 | "layout": "IPY_MODEL_243a6081666e4c1497d740e8c895b52a" 454 | } 455 | }, 456 | "862e35aad3494ea1837743ae186ec3b9": { 457 | "model_module": "@jupyter-widgets/controls", 458 | "model_module_version": "1.5.0", 459 | "model_name": "HTMLModel", 460 | "state": { 461 | "_dom_classes": [], 462 | "_model_module": "@jupyter-widgets/controls", 463 | "_model_module_version": "1.5.0", 464 | "_model_name": "HTMLModel", 465 | "_view_count": null, 466 | "_view_module": "@jupyter-widgets/controls", 467 | "_view_module_version": "1.5.0", 468 | "_view_name": "HTMLView", 469 | "description": "", 470 | "description_tooltip": null, 471 | "layout": "IPY_MODEL_2e415547b49048fb8a7fd9c0cf93fd8c", 472 | "placeholder": "​", 473 | "style": "IPY_MODEL_eec16c8f02f147c98ce4d21f40f3782f", 474 | "value": "100%" 475 | } 476 | }, 477 | "a3c7c7de120549ecaabff8615a6ecc10": { 478 | "model_module": "@jupyter-widgets/controls", 479 | "model_module_version": "1.5.0", 480 | "model_name": "FloatProgressModel", 481 | "state": { 482 | "_dom_classes": [], 483 | "_model_module": "@jupyter-widgets/controls", 484 | "_model_module_version": "1.5.0", 485 | "_model_name": "FloatProgressModel", 486 | "_view_count": null, 487 | "_view_module": "@jupyter-widgets/controls", 488 | "_view_module_version": "1.5.0", 489 | "_view_name": "ProgressView", 490 | "bar_style": "success", 491 | "description": "", 492 | "description_tooltip": null, 493 | "layout": "IPY_MODEL_afd0138634664df38f89b59f6949970a", 494 | "max": 2, 495 | "min": 0, 496 | "orientation": "horizontal", 497 | "style": "IPY_MODEL_44e64afc3b3f49f9b81b05a88fc924c8", 498 | "value": 2 499 | } 500 | }, 501 | "afd0138634664df38f89b59f6949970a": { 502 | "model_module": "@jupyter-widgets/base", 503 | "model_module_version": "1.2.0", 504 | "model_name": "LayoutModel", 505 | "state": { 506 | "_model_module": "@jupyter-widgets/base", 507 | "_model_module_version": "1.2.0", 508 | "_model_name": "LayoutModel", 509 | "_view_count": null, 510 | "_view_module": "@jupyter-widgets/base", 511 | "_view_module_version": "1.2.0", 512 | "_view_name": "LayoutView", 513 | "align_content": null, 514 | "align_items": null, 515 | "align_self": null, 516 | "border": null, 517 | "bottom": null, 518 | "display": null, 519 | "flex": null, 520 | "flex_flow": null, 521 | "grid_area": null, 522 | "grid_auto_columns": null, 523 | "grid_auto_flow": null, 524 | "grid_auto_rows": null, 525 | "grid_column": null, 526 | "grid_gap": null, 527 | "grid_row": null, 528 | "grid_template_areas": null, 529 | "grid_template_columns": null, 530 | "grid_template_rows": null, 531 | "height": null, 532 | "justify_content": null, 533 | "justify_items": null, 534 | "left": null, 535 | "margin": null, 536 | "max_height": null, 537 | "max_width": null, 538 | "min_height": null, 539 | "min_width": null, 540 | "object_fit": null, 541 | "object_position": null, 542 | "order": null, 543 | "overflow": null, 544 | "overflow_x": null, 545 | "overflow_y": null, 546 | "padding": null, 547 | "right": null, 548 | "top": null, 549 | "visibility": null, 550 | "width": null 551 | } 552 | }, 553 | "b94f987ac374487da181bf3c73d3dec4": { 554 | "model_module": "@jupyter-widgets/controls", 555 | "model_module_version": "1.5.0", 556 | "model_name": "DescriptionStyleModel", 557 | "state": { 558 | "_model_module": "@jupyter-widgets/controls", 559 | "_model_module_version": "1.5.0", 560 | "_model_name": "DescriptionStyleModel", 561 | "_view_count": null, 562 | "_view_module": "@jupyter-widgets/base", 563 | "_view_module_version": "1.2.0", 564 | "_view_name": "StyleView", 565 | "description_width": "" 566 | } 567 | }, 568 | "eec16c8f02f147c98ce4d21f40f3782f": { 569 | "model_module": "@jupyter-widgets/controls", 570 | "model_module_version": "1.5.0", 571 | "model_name": "DescriptionStyleModel", 572 | "state": { 573 | "_model_module": "@jupyter-widgets/controls", 574 | "_model_module_version": "1.5.0", 575 | "_model_name": "DescriptionStyleModel", 576 | "_view_count": null, 577 | "_view_module": "@jupyter-widgets/base", 578 | "_view_module_version": "1.2.0", 579 | "_view_name": "StyleView", 580 | "description_width": "" 581 | } 582 | }, 583 | "f19e76960da346debdde6e0b89f291cf": { 584 | "model_module": "@jupyter-widgets/base", 585 | "model_module_version": "1.2.0", 586 | "model_name": "LayoutModel", 587 | "state": { 588 | "_model_module": "@jupyter-widgets/base", 589 | "_model_module_version": "1.2.0", 590 | "_model_name": "LayoutModel", 591 | "_view_count": null, 592 | "_view_module": "@jupyter-widgets/base", 593 | "_view_module_version": "1.2.0", 594 | "_view_name": "LayoutView", 595 | "align_content": null, 596 | "align_items": null, 597 | "align_self": null, 598 | "border": null, 599 | "bottom": null, 600 | "display": null, 601 | "flex": null, 602 | "flex_flow": null, 603 | "grid_area": null, 604 | "grid_auto_columns": null, 605 | "grid_auto_flow": null, 606 | "grid_auto_rows": null, 607 | "grid_column": null, 608 | "grid_gap": null, 609 | "grid_row": null, 610 | "grid_template_areas": null, 611 | "grid_template_columns": null, 612 | "grid_template_rows": null, 613 | "height": null, 614 | "justify_content": null, 615 | "justify_items": null, 616 | "left": null, 617 | "margin": null, 618 | "max_height": null, 619 | "max_width": null, 620 | "min_height": null, 621 | "min_width": null, 622 | "object_fit": null, 623 | "object_position": null, 624 | "order": null, 625 | "overflow": null, 626 | "overflow_x": null, 627 | "overflow_y": null, 628 | "padding": null, 629 | "right": null, 630 | "top": null, 631 | "visibility": null, 632 | "width": null 633 | } 634 | } 635 | } 636 | } 637 | }, 638 | "nbformat": 4, 639 | "nbformat_minor": 4 640 | } 641 | -------------------------------------------------------------------------------- /llm_example.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import os\n", 10 | "# os.environ['CUDA_VISIBLE_DEVICES'] = '1'" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "import torch, numpy as np\n", 20 | "from minai.core import *\n", 21 | "from datasets import load_dataset\n", 22 | "from torch.utils.data import DataLoader\n", 23 | "from torch import nn, tensor\n", 24 | "from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": null, 30 | "metadata": {}, 31 | "outputs": [], 32 | "source": [ 33 | "set_seed(42)" 34 | ] 35 | }, 36 | { 37 | "cell_type": "markdown", 38 | "metadata": {}, 39 | "source": [ 40 | "## Prepare" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": null, 46 | "metadata": {}, 47 | "outputs": [ 48 | { 49 | "name": "stderr", 50 | "output_type": "stream", 51 | "text": [ 52 | "The model was loaded with use_flash_attention_2=True, which is deprecated and may be removed in a future release. Please use `attn_implementation=\"flash_attention_2\"` instead.\n" 53 | ] 54 | }, 55 | { 56 | "data": { 57 | "application/vnd.jupyter.widget-view+json": { 58 | "model_id": "3983fccc4a974a2d9d6fc0fd89e94913", 59 | "version_major": 2, 60 | "version_minor": 0 61 | }, 62 | "text/plain": [ 63 | "Loading checkpoint shards: 0%| | 0/2 [00:00 Context: CREATE TABLE table_name_74 (week VARCHAR, attendance INTEGER)\\nQuestion: How many weeks had an attendance of over 68,000?\\nAnswer: SELECT COUNT(week) FROM table_name_74 WHERE attendance > 68 OFFSET 000 68 OFFSET 000` beginning-of-sequence token, to help the model recognize the beginning of the input sequence, in the input `xb[0]`.\n", 373 | "- the label `yb[0]` is equal to `xb[0]`, except by not having the beginning of sequence token, by having one more token at the end. That last token is the next token to predict.\n" 374 | ] 375 | }, 376 | { 377 | "cell_type": "markdown", 378 | "metadata": {}, 379 | "source": [ 380 | "## Training" 381 | ] 382 | }, 383 | { 384 | "cell_type": "code", 385 | "execution_count": null, 386 | "metadata": {}, 387 | "outputs": [], 388 | "source": [ 389 | "dls = DataLoaders(train_dataloader, eval_dataloader)" 390 | ] 391 | }, 392 | { 393 | "cell_type": "code", 394 | "execution_count": null, 395 | "metadata": {}, 396 | "outputs": [], 397 | "source": [ 398 | "def loss_fn(x, y):\n", 399 | " return torch.nn.functional.cross_entropy(x.view(-1, x.shape[-1]), y.view(-1))" 400 | ] 401 | }, 402 | { 403 | "cell_type": "code", 404 | "execution_count": null, 405 | "metadata": {}, 406 | "outputs": [], 407 | "source": [ 408 | "# from peft import get_peft_config, get_peft_model, LoraConfig, TaskType\n", 409 | "\n", 410 | "# peft_config = LoraConfig(\n", 411 | "# task_type=TaskType.CAUSAL_LM, inference_mode=False, r=32, lora_alpha=16, lora_dropout=0.1,\n", 412 | "# target_modules=[l+\"_proj\" for l in [\"k\", 'v', \"q\", \"o\", \"gate\", \"up\", \"down\"]]\n", 413 | "# )\n", 414 | "# m = get_peft_model(m, peft_config)" 415 | ] 416 | }, 417 | { 418 | "cell_type": "code", 419 | "execution_count": null, 420 | "metadata": {}, 421 | "outputs": [], 422 | "source": [ 423 | "from torch import optim" 424 | ] 425 | }, 426 | { 427 | "cell_type": "code", 428 | "execution_count": null, 429 | "metadata": {}, 430 | "outputs": [], 431 | "source": [ 432 | "prog = ProgressCB(plot=True)\n", 433 | "cbs = [DeviceCB(), MetricsCB()]" 434 | ] 435 | }, 436 | { 437 | "cell_type": "code", 438 | "execution_count": null, 439 | "metadata": {}, 440 | "outputs": [], 441 | "source": [ 442 | "# Just freeze embeddings for small memory decrease\n", 443 | "m.model.embed_tokens.weight.requires_grad_(False);" 444 | ] 445 | }, 446 | { 447 | "cell_type": "markdown", 448 | "metadata": {}, 449 | "source": [ 450 | "`requires_grad == False` tells PyTorch not to track gradient values. Gradient values are used to update weights in training. So setting requires_grad to False turns off training of those weights, the weights used to define initial embedding layer which maps a token value to an embedding vector. Training fewer weights requires less memory.\n", 451 | "\n", 452 | "But why can turn off _these_ weights? Embedding layers are relatively expensive in terms of parameter count.\n", 453 | "\n", 454 | "But why can we get away with it? Empirically, it turns out (✨) that freezing them often has only a minimal impact on downstream task performance." 455 | ] 456 | }, 457 | { 458 | "cell_type": "code", 459 | "execution_count": null, 460 | "metadata": {}, 461 | "outputs": [ 462 | { 463 | "data": { 464 | "text/plain": [ 465 | "32" 466 | ] 467 | }, 468 | "execution_count": null, 469 | "metadata": {}, 470 | "output_type": "execute_result" 471 | } 472 | ], 473 | "source": [ 474 | "len(m.model.layers)" 475 | ] 476 | }, 477 | { 478 | "cell_type": "code", 479 | "execution_count": null, 480 | "metadata": {}, 481 | "outputs": [], 482 | "source": [ 483 | "# Or freeze first n layers for larger decrease (in this case, 24). (Can freeze up to len(m.model.layers)-1)\n", 484 | "n_freeze = 24\n", 485 | "for param in m.parameters(): param.requires_grad = False\n", 486 | "for param in m.lm_head.parameters(): param.requires_grad = True\n", 487 | "for param in m.model.layers[n_freeze:].parameters(): param.requires_grad = True" 488 | ] 489 | }, 490 | { 491 | "cell_type": "code", 492 | "execution_count": null, 493 | "metadata": {}, 494 | "outputs": [], 495 | "source": [ 496 | "from functools import partial\n", 497 | "optim = partial(torch.optim.Adam, betas=(0.9,0.99), eps=1e-5)" 498 | ] 499 | }, 500 | { 501 | "cell_type": "code", 502 | "execution_count": null, 503 | "metadata": {}, 504 | "outputs": [], 505 | "source": [ 506 | "lr = 1e-3\n", 507 | "sz = len(dls.train)//50" 508 | ] 509 | }, 510 | { 511 | "cell_type": "markdown", 512 | "metadata": {}, 513 | "source": [ 514 | "`epoch_sz` controls how many batches are seen before validation metrics are reported. It lets us define a pseudo \"epoch\" which is smaller than a true epoch" 515 | ] 516 | }, 517 | { 518 | "cell_type": "code", 519 | "execution_count": null, 520 | "metadata": {}, 521 | "outputs": [], 522 | "source": [ 523 | "cb_trn = TrainCB(preds_nm='logits')\n", 524 | "learn = MomentumLearner(m, dls, loss_func=loss_fn, lr=lr, cbs=cbs, preds_nm='logits', epoch_sz=sz, mom=0.9)\n", 525 | "# learn = TrainLearner(m, dls, loss_func=loss_fn, lr=lr, cbs=cbs, preds_nm='logits', epoch_sz=sz)\n", 526 | "# learn = Learner(m, dls, loss_func=loss_fn, lr=lr, cbs=cbs+[cb_trn], epoch_sz=sz) #, opt_func=optim)" 527 | ] 528 | }, 529 | { 530 | "cell_type": "code", 531 | "execution_count": null, 532 | "metadata": {}, 533 | "outputs": [], 534 | "source": [ 535 | "m.gradient_checkpointing_enable()" 536 | ] 537 | }, 538 | { 539 | "cell_type": "code", 540 | "execution_count": null, 541 | "metadata": {}, 542 | "outputs": [], 543 | "source": [ 544 | "# NB lr_find does *not* reset model, so recreate it afterwards\n", 545 | "# learn.lr_find(max_mult=10)" 546 | ] 547 | }, 548 | { 549 | "cell_type": "markdown", 550 | "metadata": {}, 551 | "source": [ 552 | "Train for 1 \"epoch\"" 553 | ] 554 | }, 555 | { 556 | "cell_type": "code", 557 | "execution_count": null, 558 | "metadata": {}, 559 | "outputs": [ 560 | { 561 | "data": { 562 | "text/html": [ 563 | "\n", 564 | "\n" 579 | ], 580 | "text/plain": [ 581 | "" 582 | ] 583 | }, 584 | "metadata": {}, 585 | "output_type": "display_data" 586 | }, 587 | { 588 | "data": { 589 | "text/html": [ 590 | "\n", 591 | " \n", 592 | " \n", 593 | " \n", 594 | " \n", 595 | " \n", 596 | " \n", 597 | " \n", 598 | " \n", 599 | " \n", 600 | " \n", 601 | " \n", 602 | " \n", 603 | " \n", 604 | " \n", 605 | " \n", 606 | " \n", 607 | " \n", 608 | " \n", 609 | " \n", 610 | "
lossepochtrain
2.0840train
1.0520eval
" 611 | ], 612 | "text/plain": [ 613 | "" 614 | ] 615 | }, 616 | "metadata": {}, 617 | "output_type": "display_data" 618 | }, 619 | { 620 | "name": "stderr", 621 | "output_type": "stream", 622 | "text": [ 623 | "/home/algal/miniconda3/envs/mbert/lib/python3.11/site-packages/torch/utils/checkpoint.py:87: UserWarning: None of the inputs have requires_grad=True. Gradients will be None\n", 624 | " warnings.warn(\n" 625 | ] 626 | }, 627 | { 628 | "data": { 629 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAfcAAAFfCAYAAABTOoWkAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjAsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvlHJYcgAAAAlwSFlzAAAPYQAAD2EBqD+naQAAMrdJREFUeJzt3Xl8VOWhPvDnzJptZrJnsieEnYSwBJBFQItYVMSFimtBLa0t2iql12JvK1Ivsep1K4qKtwo/F7QtuCtFhYiyJSg7JCwJ2ck+k3UmM3N+fyQZCBDIJDM5M2ee7+czH5LJyczDkPDMe95z3iOIoiiCiIiIZEMhdQAiIiJyL5Y7ERGRzLDciYiIZIblTkREJDMsdyIiIplhuRMREckMy52IiEhmVAP9hA6HA+Xl5dDpdBAEYaCfnoiIyGeJoojGxkbExcVBoeh5fD7g5V5eXo7ExMSBfloiIiLZKCkpQUJCQo9fH/By1+l0ADqC6fX6gX56IiIin2U2m5GYmOjs0p4MeLl37YrX6/UsdyIioj643LQ2D6gjIiKSGZY7ERGRzLDciYiIZGbA59yJiEje7HY72tvbpY7hk9RqNZRKZb8fh+VORERuIYoiKisr0dDQIHUUnxYaGgqj0divtWBY7kRE5BZdxR4dHY2goCAuVOYiURTR0tKCqqoqAEBsbGyfH4vlTkRE/Wa3253FHhERIXUcnxUYGAgAqKqqQnR0dJ930fOAOiIi6reuOfagoCCJk/i+rtewP8ctsNyJiMhtuCu+/9zxGrLciYiIZMbny/274zW4c+0u/OWjQ1JHISIi8go+X+7tdgd2nKzFtwXVUkchIiI/l5KSghdeeEHqGL5/tPy4pDAIAlBU24LqRguidFqpIxERkQ+ZOXMmxowZ45ZSzs3NRXBwcP9D9ZPPj9wNQWoMi+m49N3e03USpyEiIrkRRRE2m61X20ZFRXnFGQM+X+4AkJUSBgDILaqXOAkREXURRREtVtuA30RR7HXGRYsWIScnBy+++CIEQYAgCHjrrbcgCAI2b96MrKwsaLVabN++HSdPnsS8efMQExODkJAQTJgwAV999VW3xzt/t7wgCHjjjTdw8803IygoCEOGDMHHH3/srpe4Rz6/Wx4AspLD8fauYuQVceROROQtWtvtGPmXzQP+vEdWXosgTe/q7cUXX0RBQQHS09OxcuVKAMDhw4cBAP/1X/+FZ599FoMGDUJoaChKS0tx3XXX4cknn0RAQADWrVuHuXPnIj8/H0lJST0+xxNPPIGnn34azzzzDP7+97/jrrvuwunTpxEeHt7/v2wPZDVyP1RuRou1d7tOiIiIDAYDNBoNgoKCYDQaYTQanavCrVy5Etdccw3S0tIQERGBzMxM/OpXv0JGRgaGDBmCJ598EoMGDbrsSHzRokW44447MHjwYKxatQrNzc3Ys2ePR/9eshi5x4cGItYQgApTG/aVNGBKWqTUkYiI/F6gWokjK6+V5HndISsrq9vnzc3NeOKJJ/Dpp5+ivLwcNpsNra2tKC4uvuTjjB492vlxcHAwdDqdc/14T5FFuQuCgKyUcHyyvxx5RfUsdyIiLyAIQq93j3uj8496/8Mf/oDNmzfj2WefxeDBgxEYGIj58+fDarVe8nHUanW3zwVBgMPhcHvec7m8W76srAx33303IiIiEBQUhDFjxmDv3r2eyOaSrOSug+o4705ERL2n0Whgt9svu9327duxaNEi3HzzzcjIyIDRaERRUZHnA/aBS2+p6uvrMXXqVFx11VX44osvEB0djZMnTyI0NNRD8Xqva979x+IG2B0ilAqub0xERJeXkpKC3bt3o6ioCCEhIT2OqgcPHoyNGzdi7ty5EAQBf/7znz0+Au8rl0buf/vb35CYmIg333wTEydOREpKCn7yk58gLS3NU/l6bbhRjxCtCk0WG45VmqWOQ0REPmLZsmVQKpUYOXIkoqKiepxDf/755xEWFoYpU6Zg7ty5uPbaazFu3LgBTts7gujCCYEjR47Etddei9LSUuTk5CA+Ph6/+c1vsHjx4h6/x2KxwGKxOD83m81ITEyEyWSCXq/vX/rz/Pwfe/BtQTWeuHEUFk5JcetjExFRz9ra2lBYWIjU1FQEBARIHcenXeq1NJvNMBgMl+1Ql0bup06dwpo1azBkyBBs3rwZDzzwAH77299i/fr1PX5PdnY2DAaD85aYmOjKU7pkAufdiYiIXCt3h8OBcePGYdWqVRg7dix+9atfYfHixVizZk2P37N8+XKYTCbnraSkpN+hezI+5Wy5u7JCERERkZy4VO6xsbEYOXJkt/tGjBhxyXP8tFot9Hp9t5unjEkMhUoh4IzZgtL6Vo89DxERkTdzqdynTp2K/Pz8bvcVFBQgOTnZraH6Kkijwqh4AwBg72muM09ERP7JpXJ/5JFHsGvXLqxatQonTpzAu+++i9dffx1LlizxVD6Xcd6diEg63npqmC9xx2vo0nnuEyZMwKZNm7B8+XKsXLkSqampeOGFF3DXXXf1O4i7ZKWE4Y3vCpHHK8QREQ0YjUYDhUKB8vJyREVFQaPRQBC43ogrRFGE1WpFdXU1FAoFNBpNnx/L5XUBb7jhBtxwww19fkJPG5/ccZWd/DONMLW0wxCkvsx3EBFRfykUCqSmpqKiogLl5eVSx/FpQUFBSEpKgkLR92u7+e6ivz2I0mmRGhmMwppm7C2uw9XDY6SORETkFzQaDZKSkmCz2Xq1nCtdSKlUQqVS9Xuvh+zKHehYZ76wphl5RfUsdyKiASQIAtRq9QUXS6GBJYvruZ9vQkrHrnnOuxMRkT+SZbl3LWazr7QBFht3DRERkX+RZbkPigxGeLAGVpsDh8pMUschIiIaULIsd0EQnNd35655IiLyN7Isd+DsvHsuy52IiPyMbMs9q3Pefe/pOjgcvIgMERH5D9mW+6g4A7QqBepb2nGqpknqOERERANGtuWuUSkwJjEUAHfNExGRf5FtuQM8352IiPyTrMu9a9497zSvEEdERP5D1uU+LjkMggCcrm1BlblN6jhEREQDQtblrg9QY1iMDgCQd5q75omIyD/IutwBzrsTEZH/kX25c96diIj8jezLvWvkfrjcjGaLTeI0REREnif7co8LDUScIQB2h4h9JQ1SxyEiIvI42Zc7AGQ515nnrnkiIpI/vyj3Cc515nlQHRERyZ9flHvXyP2H0/Ww2R0SpyEiIvIsvyj3oTE66AJUaLbacayyUeo4REREHuUX5a5UCBiX1LFrnvPuREQkd35R7sDZeXcuZkNERHLnN+XeNe+ed7oOoihKnIaIiMhz/KbcMxNCoVYKOGO2oLS+Veo4REREHuM35R6oUWJUnAEA592JiEje/KbcgbPz7rmcdyciIhnzq3Lvmnffy4vIEBGRjPlXuSd3jNwLzjShocUqcRoiIiLP8KtyjwjRYlBUMAAuRUtERPLlV+UOnB29c96diIjkyv/Kvet8dx4xT0REMuV35T6hs9wPlJrQ1m6XOA0REZH7+V25p0QEITJEA6vdgUNlJqnjEBERuZ3flbsgCBjPeXciIpIxvyt34Oyuec67ExGRHPlluTsXsymuh8PBi8gQEZG8+GW5j4rTI0CtQENLO05WN0kdh4iIyK38stzVSgXGJnLenYiI5Mkvyx0AsjovIsN5dyIikhs/LveOefdcXkSGiIhkxm/LfVxSKBQCUFLXijPmNqnjEBERuY1L5b5ixQoIgtDtZjQaPZXNo3QBagw36gEAeZx3JyIiGXF55D5q1ChUVFQ4bwcPHvRErgExIaXroDrumiciIvlQufwNKpXPjtbPNz4lHOt2nkYe592JiEhGXB65Hz9+HHFxcUhNTcXtt9+OU6dOXXJ7i8UCs9nc7eYtukbuR8rNaLLYJE5DRETkHi6V+6RJk7B+/Xps3rwZa9euRWVlJaZMmYLa2toevyc7OxsGg8F5S0xM7Hdod4k1BCI+NBAOEdhX3CB1HCIiIrdwqdznzJmDW2+9FRkZGZg1axY+++wzAMC6det6/J7ly5fDZDI5byUlJf1L7GacdyciIrnp16lwwcHByMjIwPHjx3vcRqvVQq/Xd7t5k/FdF5HhvDsREclEv8rdYrHg6NGjiI2NdVeeAdc1cv+xuAE2u0PiNERERP3nUrkvW7YMOTk5KCwsxO7duzF//nyYzWYsXLjQU/k8bmi0DroAFVqsdhytaJQ6DhERUb+5VO6lpaW44447MGzYMNxyyy3QaDTYtWsXkpOTPZXP4xQKAVnJnHcnIiL5cOk89w0bNngqh6SyUsKxNb8aeafrcN+0VKnjEBER9Yvfri1/rrMj93qIoihxGiIiov5huQPITAyFWimgutGC4roWqeMQERH1C8sdQIBaiYx4AwBeRIaIiHwfy73TBJ7vTkREMsFy7zT+nHl3IiIiX8Zy79RV7ieqmlDfbJU4DRERUd+x3DtFhGiRFhUMANh7mqN3IiLyXSz3c3TNu+dy3p2IiHwYy/0cWV0H1XHenYiIfBjL/Rxdi9kcKG1AW7td4jRERER9w3I/R3JEEKJ1WrTbRfxQzNE7ERH5Jpb7OQRBwKRBEQCA3ac4705ERL6J5X6eSakd8+67C2slTkJERNQ3LPfzXDGoo9x/LG6AxcZ5dyIi8j0s9/OkRYUgMkQDi82B/SUmqeMQERG5jOV+HkEQMLFr1/wp7ponIiLfw3K/iEmpnQfVFfKgOiIi8j0s94uY1Dnvvvd0PdrtDonTEBERuYblfhFDo3UIDVKjtd2OA6WcdyciIt/Ccr8IhULAxBSeEkdERL6J5d4DLmZDRES+iuXeg67FbPKK6mDjvDsREfkQlnsPRsTqoQtQodlqx+Fys9RxiIiIeo3l3gMl592JiMhHsdwvoeuUOM67ExGRL2G5X0LXYjZ7iupgd4gSpyEiIuodlvsljIrTI0SrQmObDUcrOO9ORES+geV+CSqlAuOTwwBwKVoiIvIdLPfLODvvzoPqiIjIN7DcL+PceXcH592JiMgHsNwvY3SCAYFqJRpa2lFQ1Sh1HCIiostiuV+G+tx5d54SR0REPoDl3gtdS9FyMRsiIvIFLPde6LqIzJ7COogi592JiMi7sdx7ITPRAK1KgZomK05WN0kdh4iI6JJY7r2gVSkxNikUALCL8+5EROTlWO691HVKHBezISIib8dy76VzF7PhvDsREXkzlnsvjUsKg0apQFWjBUW1LVLHISIi6hHLvZcC1EpkJhoAcClaIiLybix3F3DenYiIfAHL3QWcdyciIl/AcnfB+OQwqBQCyk1tKK1vlToOERHRRfWr3LOzsyEIAh5++GE3xfFuQRoVMhI65t13cd6diIi8VJ/LPTc3F6+//jpGjx7tzjxej/PuRETk7fpU7k1NTbjrrruwdu1ahIWFXXJbi8UCs9nc7ebLnPPuvIgMERF5qT6V+5IlS3D99ddj1qxZl902OzsbBoPBeUtMTOzLU3qNrOQwKASgpK4V5Q2cdyciIu/jcrlv2LABP/zwA7Kzs3u1/fLly2EymZy3kpISl0N6E12AGunxnee7c/ROREReyKVyLykpwe9+9zu8/fbbCAgI6NX3aLVa6PX6bjdf57y+Oy8iQ0REXsilct+7dy+qqqowfvx4qFQqqFQq5OTk4KWXXoJKpYLdbvdUTq/Cg+qIiMibqVzZ+Cc/+QkOHjzY7b57770Xw4cPx6OPPgqlUunWcN5qQmo4BAEorGlGlbkN0fre7cUgIiIaCC6Vu06nQ3p6erf7goODERERccH9cmYIVGOEUY8jFWbsKqzDjZlxUkciIiJy4gp1fXTuUrRERETexKWR+8Vs27bNDTF8z6TUCLz5fRHn3YmIyOtw5N5HEzuPmD9R1YSaJovEaYiIiM5iufdReLAGw2J0AIA9HL0TEZEXYbn3A+fdiYjIG7Hc+4HnuxMRkTdiufdD17z7scpG1DdbJU5DRETUgeXeD1E6LdKiggEAe4o4eiciIu/Acu+nSYM6d81znXkiIvISLPd+cl5EhleIIyIiL8Fy76crOkfuRyrMMLW2S5yGiIiI5d5vMfoApEQEQRSBPM67ExGRF2C5uwFPiSMiIm/CcncDLmZDRETehOXuBl1HzB8qN6PJYpM4DRER+TuWuxvEhwYiISwQdofIeXciIpIcy91NOO9ORETeguXuJldw3p2IiLwEy91Nus53P1BqQouV8+5ERCQdlrubJIQFIs4QAJtDxA+nG6SOQ0REfozl7iaCIJxdZ55L0RIRkYRY7m7kXGeeF5EhIiIJsdzdqGvkvq+kAW3tdonTEBGRv2K5u1FKRBCidVpY7Q78WNwgdRwiIvJTLHc34rw7ERF5A5a7m3HenYiIpMZyd7OuxWx+KK6HxcZ5dyIiGngsdzdLiwpBZIgGFpsDB0pNUschIiI/xHJ3M0EQMDGVS9ESEZF0WO4ewIvIEBGRlFjuHjCpc9597+l6tNsdEqchIiJ/w3L3gKHROoQGqdFiteNgGefdiYhoYLHcPUChEDAxhafEERGRNFjuHsLFbIiISCosdw/pWswmr6geNs67ExHRAGK5e8iIWD30ASo0WWxY9GYuSupapI5ERER+guXuIUqFgJXz0qFVKfDdiRpc+8K3WLejCA6HKHU0IiKSOZa7B900Nh5fPjwdE1PD0WK14/GPD2PB6ztxqrpJ6mheTRT5BoiIqD9Y7h6WGhmMDYuvwF/njUKwRonconrMeXE7Xss5ybn485Q1tGLe6u9w0ys70GrluvxERH3Fch8ACoWAeyanYPMj03HlkEhYbA5kf3EMt67ZgfzKRqnjeYXTtc247dWd2F9qwv6SBqzfWSR1JCIin8VyH0AJYUFYf99EPD1/NHQBKuwvNeGGv2/HS18f9+uV7E5UNeG213airKEVugAVAGBNzkk0trVLnIyIyDex3AeYIAi4LSsRXy2dgVkjYtBuF/HclgLcuPp7HPLD1eyOVpix4LWdOGO2YGhMCLY8MgNpUcFoaGnH/31XKHU8IiKfxHKXSIw+AGt/Ph4v3TEWYUFqHK0wY97L3+PpL4+hrd0/5psPlDbg9td3obbZilFxemz45WQYDQFYes0wAMAb2wtR32yVOCURke9huUtIEATcmBmHLUtn4IbRsbA7RLyy7SSuf2k79p6ulzqeR+UV1eGutbtham3H2KRQvLv4CoQHawAAc9KNGBmrR5PFhldzTkqclIjI97hU7mvWrMHo0aOh1+uh1+sxefJkfPHFF57K5jciQ7RYfec4vHr3eETptDhZ3Yz5r+7Ayk+OoMVqkzqe2+04UYN7/m8PGi02TEoNx/+7fxIMgWrn1xUKAX+4tmP0vm5nEarMbVJFJSLySS6Ve0JCAp566ink5eUhLy8PV199NebNm4fDhw97Kp9f+Wm6EV89MgPzxydAFIF/fF+In76wHTtO1kgdzW225lfh3rdy0dpux5VDIvHWvRMRolVdsN3MYVEYnxyGtnYHVm89IUFSIiLfJYj9XDEkPDwczzzzDO6///5ebW82m2EwGGAymaDX6/vz1LK2Lb8Kj208iHJTx6j1rklJ+OOc4dAFqC/znd7ry0OVeOi9H9BuFzFrRAxevmsstCplj9vvPFmLO9buglop4Jvfz0RieNAApiUi8j697dA+z7nb7XZs2LABzc3NmDx5co/bWSwWmM3mbje6vJnDorH5kem4+4okAMA7u4sx+/lvsfbbU6htskicznUf7SvDknc7iv36jFisuXvcJYsdACanRWDa4Ei020W89PXxAUpKROT7XB65Hzx4EJMnT0ZbWxtCQkLw7rvv4rrrrutx+xUrVuCJJ5644H6O3Htv58laPPrvAyjuvPiMWilg9kgjbp+YiKlpkVAoBIkTXtoHuSV4dOMBiCJwy7h4PH3raKiUvXtfua+kATe9/D0UAvCfR2ZgcHSIh9MSEXmv3o7cXS53q9WK4uJiNDQ04N///jfeeOMN5OTkYOTIkRfd3mKxwGI5O9I0m81ITExkubuo1WrHph/L8H5uMfaXnj0fPiEsEAuyEvGzrEQYDQESJry49TuL8JePOo7JuGtSEv46L93lNyO/WJeHr46ewfWjY/HyneM8EZOIyCd4rNzPN2vWLKSlpeG1115zazDq2ZFyMzbkFmPTj2VobOs4ml4hAFcNi8btE5Nw1bCoXo+MPem1nJPI/uIYAOC+qan48w0jIAiu72U4WmHGdS9thygCn/12GkbFGdwdlYjIJ3h8zr2LKIrdRubkeSPj9Fg5Lx17HpuF527LxMSUcDhE4OtjVVi8Pg9TnvoGz2w+huJaaa4hL4oiXviqwFnsD141uM/FDgAjYvWYOzoOAPDcfwrclpOISK5cGrk/9thjmDNnDhITE9HY2IgNGzbgqaeewpdffolrrrmmV4/BkbtnnKhqwgd5JfjX3lLUnbOq27TBkVgwIRGzR8Vc9gA2dxBFEX/7Mt+5+Myy2UPx4NVD+v24hTXNmPVcDuwOERt/MwXjksL6/ZhERL7GI7vl77//fnz99deoqKiAwWDA6NGj8eijj/a62F0JRn1jtTmw5cgZbMgtxncnatD1rxserMEtY+Nx+8REDI7WeeS5HQ4RKz89grd2FAEA/vv6EfjFlYPc9viP/usA3s8rwZS0CLy7+Aq3PS4Rka8YsDl3V7HcB05JXQs+yCvBB3klOGM+O3WSlRyGBRMSMSrOgEidBuFBmn7P0dsdIv606SA25JYAAJ68KR13X5Hcr8c8X1lDK656Zhusdgfe+cUkTB0c6dbHJyLydix3crLZHcgpqMZ7e0qwNb8Kdkf3f3JBAEID1YgM0SIiRIPIEG3Hx8EaROo6/owI0SKq8+vB560oZ7M7sOyf+/HhvnIoBODp+ZmYPz7BI3+XFR8fxls7ijAmMRSbfjOlz/P4RES+iOVOF3XG3IZ/7S3F5wcrcMbchrpmKxwu/gQEqpWICOkqfA3qW9qx93Q9VAoBL9w+Bjd0HvzmCVWNbZj+9Fa0tTvwxs+zMGtkjMeei4jI27DcqVfsDhH1LVbUNllR02RBTZPF+XFtkxW1zRZUN1lR2/m1tnbHRR9Ho1Rg9Z1jMXuU0eOZn/riGF7NOYnhRh0+/+2VXr+IDxGRu/S2Qy+8Ygf5FaVCcO6GH4ZLH2gniiJarHbUNllR3WRBbZMFtc1W1LdYMX1IFNLjB+b88wdmDMI7u07jWGUjPjtYgbmZnttTQETki1ju1GuCICBYq0KwVoWkCOku4hIapMHi6YPw3JYCPL+lAHPSjV6xaA8Rkbfg/4jkk+6blorwYA1O1TRj4w9lUschIvIqLHfySSFaFX49Iw0A8OLXx2Gx2SVORETkPVju5LPumZyMGL0WZQ2t2LCnROo4REReg+VOPitArXQubbt66wm0Wjl6JyICWO7k4xZkJSIhLBDVjRas21kkdRwiIq/AciefplEp8PCsoQCAV3NOwtzWLnEiIiLpsdzJ5908Nh6Do0PQ0NKON7YXSh2HiEhyLHfyeUqFgKXXdIze/2/7qW6XvCUi8kcsd5KFn44yYlScHs1Wu/Na8kRE/orlTrKgUAhYNnsYAGDdjiKcMbdJnIiISDosd5KNmcOiMD45DBabA6u/OSF1HCIiybDcSTYE4ezofUNuMUrqWiROREQkDZY7ycrktAhMGxyJdruIF746LnUcIiJJsNxJdpZd2zF63/RjKU5UNUqchoho4LHcSXbGJIbimpExcIjA81s4eici/8NyJ1n6/eyhEATgs4MV+GgfLwlLRP6F5U6yNNyox50TkwAAv9uwD69sOwFRFCVORUQ0MFjuJFt/nZeOX0xLBQA8/WU+Htt0CDa7Q+JURESex3In2VIoBPz3DSPx+NyREATgvT3FWLw+D80Wm9TRiIg8iuVOsnfv1FS8evd4BKgV2JpfjQWv70QVV7AjIhljuZNfuHaUEe8tvgIRwRocKjPj5ld2oOAMT5MjInliuZPfGJsUho2/mYLUyGCUNbTi1jU7sONkjdSxiIjcjuVOfiU5Ihgbfz0FWclhaGyzYeE/9uDDH3mqHBHJC8ud/E5YsAZv/2ISrs+IRbtdxMPv78PLW3mqHBHJB8ud/FKAWom/3zEWv5w+CADwzOZ8PLbpIE+VIyJZYLmT31IoBDx23QisnDcKCgF4b08J7l+XhyaeKkdEPo7lTn7v55NT8No9WQhQK5BTUI3bXt2JMzxVjoh8GMudCMA1I2Pw/i8nIzJEgyMVZtz88vfIr+SpckTkm1juRJ0yE0Ox6TdTMSgqGOWmNsxfswM7TvBUOSLyPSx3onMkhgdh46+nYGJKOBotNix8cw82/lAqdSwiIpew3InOExqkwfr7J2JuZhza7SKWfrAfL319nKfKEZHPYLkTXUSAWokXF4zBAzPSAADPbSnAo/8+gHaeKkdEPoDlTtQDhULAH+cMx5M3pUMhAB/klSL782NSxyIiuiyWO9Fl3H1FMlbfOQ4A8I/vC7HzZK3EiYiILo3lTtQL12XE4o6JiQCAZf/cj8a2dokTERH1jOVO1Et/un4kEsMDUdbQiv/57KjUcYiIesRyJ+qlEK0Kz87PhCAAG3JL8M2xM1JHIiK6KJY7kQsmDYrA/VNTAQCP/vsg6putEiciIrqQS+WenZ2NCRMmQKfTITo6GjfddBPy8/M9lY3IKy27dhgGR4egutGCP390SOo4REQXcKncc3JysGTJEuzatQtbtmyBzWbD7Nmz0dzc7Kl8RF4nQK3Ec7dlQqkQ8OmBCnyyv1zqSERE3QhiP5bdqq6uRnR0NHJycjB9+vRefY/ZbIbBYIDJZIJer+/rUxNJ7rktBXjp6+MIDVLjPw9PR7Q+QOpIRCRzve3Qfs25m0wmAEB4eHiP21gsFpjN5m43Ijl46OrBSI/Xo6GlHY/++wCXpyUir9HnchdFEUuXLsW0adOQnp7e43bZ2dkwGAzOW2JiYl+fksirqJUKPHfbGGhUCmzNr8b7uSVSRyIiAtCPcn/wwQdx4MABvPfee5fcbvny5TCZTM5bSQn/AyT5GBqjw7LZQwEAf/30CErqWiRORETUx3J/6KGH8PHHH2Pr1q1ISEi45LZarRZ6vb7bjUhO7p82CBNTwtFstWPZP/fD4eDueSKSlkvlLooiHnzwQWzcuBHffPMNUlNTPZWLyGcoFQKe/VkmgjRK7C6sw5s7iqSORER+zqVyX7JkCd5++228++670Ol0qKysRGVlJVpbWz2Vj8gnJEUE4U/XjwAAPP3lMZyoapI4ERH5M5fKfc2aNTCZTJg5cyZiY2Odt/fff99T+Yh8xp0TkzBjaBQsNgd+/8E+2HjtdyKSiMu75S92W7RokYfiEfkOQRDwt1tHQx+gwv5SE17ZdlLqSETkp7i2PJEbGQ0BWDmv49TQl74+jkNlJokTEZE/YrkTudm8MXGYk26EzSFi6Qf70NZulzoSEfkZljuRmwmCgCdvSkdkiAYFZ5rw/JYCqSMRkZ9huRN5QESIFtm3jAYAvL79FHKL6iRORET+hOVO5CHXjIzB/PEJEEXg9x/sR7PFJnUkIvITLHciD/rL3JGIDw1EcV0LVn1+VOo4ROQnWO5EHqQPUOOZ+R2759/ZXYycgmqJExGRP2C5E3nYlMGRWDQlBQDw6L8OwNTSLm0gIpI9ljvRAHj0p8MxKDIYleY2rPjksNRxiEjmWO5EAyBQo8Szt2VCIQCbfizDl4cqpI5ERDLGcicaIOOSwvDAjDQAwGObDqGqsU3iREQkVyx3ogH0u1lDMNyoQ12zFdOf3opFb+7BW98XoqimWepoRCQjgiiK4kA+odlshsFggMlkgl6vH8inJvIKBWca8cv1eSiqbel2f3JEEGYOjcLMYdG4YlAEAjVKiRISkbfqbYey3IkkIIoiCs40YVt+FbblVyPvdB3a7Wd/FTUqBSalhmPmsGjMGBqFtKhgCIIgYWIi8gYsdyIf0mSxYceJGmwrqEZOfjXKGlq7fT0hLBAzh0VhxtBoTEmLQLBWJVFSIpISy53IR4miiJPVTdiWX41t+dXYU1gHq93h/LpGqcCE1DDMHBqNGcOiMCQ6hKN6Ij/BcieSiRarDTtP1naUfUEVSuq6j+rjDAGY0Tmqnzo4AroAtURJicjTWO5EMiSKIgprmrEtvxo5BdXYdaoWFtvZUb1KIWB8chhmDovGzGFRGG7UcVRPJCMsdyI/0Gq1Y1dhLXI6y77wvFPqYvRazOg8An/q4EgYAjmqJ/JlLHciP3S6trlzrr4KO0/Voq397KheqRAwLinUeQT+qDg9R/VEPoblTuTn2trt2FNY55yrP1XdfVQfpdNi+pAozBwWhSuHRCI0SCNRUiLqLZY7EXVTUtfiPNVux8katFjtzq8pBGBMYseofk66EUNidBImJaKesNyJqEcWmx15RfXYll+FnIJqFJxpcn5NIQCPzBqK31w1GEoFd9sTeROWOxH1WllDK3Lyq7H5cCVyCqoBAFMHR+D5BWMQrQuQOB0RdWG5E1Gf/GtvKf784SG0ttsRGaLBCwvGYtqQSKljERF636G8KhwRdTN/fAI+eWgahht1qGmy4p5/7Mazm/NhO2eVPCLybix3IrrA4OgQfLhkKu6clARRBFZvPYE71u5Chan18t9MRJJjuRPRRQWolVh1cwb+fsdYhGhVyC2qx3Uvbsc3x85IHY2ILoPlTkSXNDczDp8+NA0Z8QbUt7Tjvrfy8D+fHYHVxt30RN6K5U5El5USGYx//XoyFk1JAQCs3V6I217biZK6FmmDEdFFsdyJqFe0KiVW3DgKr90zHvoAFfaVNOC6l7bjy0MVUkcbcKIoosrchp0na/HenmJ8ffQM7I4BPfGI6JJ4KhwRuay0vgUPvfcjfixuAAAsnJyM5deNQIBa6dbnEUURJ6qasLuwDscqzTAEqmE0BCJWHwCjIQCxhgCEB2s8tkZ+W7sdhTXNOFXdjFPVTThV04yT1U0orG5Go8XWbduEsEAsmpKC2yYkQs/L7pKH8Dx3IvKodrsDz27Ox2vfngIAjIrTY/Wd45AaGdznx7TZHThSYcaewjrsKaxDblEd6lvaL/k9GpUCxnPK3mgI6Cz/QMR23hcRou1xtT1RFFFpbnMW+MnqjgI/Vd2MclMrevofUiEAieFBSI4IxoHSBjR05gzWKPGzrEQsmpKClH68FkQXw3InogGxNb8Kv/9gP+qarQjWKLHqlgzMGxPfq+9ta7djf0kDcovqsLuwDj+crkfzOWveA0CAWoGxiWHITAxFi9WGClMbKk1tqDC1oabJ0qvnUSkExOgDEKPXItYQiCidFrXNVpyqbkJhTXO3dfbPpw9QIS06BIMiQzAoKhhpUcEYFBWC5IggaFUdeyparXZ8uK8M//iuEMerOpbyFQTgJ8Ojcd/UVExOi/DpK/C1tdvRarUjNEjt038POWC5E9GAqTS14bfv/Yg9RXUAgNsnJOLxuaMQqOm+m76xrR17T9cjt6hjZL6/xATreYvj6AJUmJASjomp4ZiQEo6MeAM0qosfHmS1OXDG3IZKc0fZn+ks/Upzq/NNwBlzGy43Ha5UCEgKD3IW96DIzj+jghHhwm5/URTx3YkavPl9Eb45VuW8f7hRh/umpuLGMXFun7roba4Wqx2m1vZuN/N5f5pa22Fus12wXdeZEYMigzE3Mw43jolDWlTIgP89iOVORAPMZnfgpa+P4+9bT0AUgWExOqy6JQPVjW3YU1iPPUW1OFJuvqBoo3RaTDynzIcZdW69YI3N7kBNkxUVplbniP9MYxvCgjTOEk8KD+rxDURfnaxuwrodRfhnXila2zv2DIQHa3DXpCTcc0UyovXuX7O/ytyGg2UmHCoz41C5CSermtDQWd42Nx/wlx6vx7zMeNyQGYtYQ6BbH5t6xnInIkl8f6IGv9uwr8dd5knhQZiQEo5JqeGYkBqOlIggWe/qNbW04/28YqzbcRplDR0r/KmVAm4YHYf7pqYiI8Hg8mOKoohyUxsOlZlwuMzUUejlZlQ3XnqaQq0UYAhUQx+ohj5ADUNg95s+UHXOx+dsE6SGAOCro2fw0b5ybD9e4zw7QBCAiSnhuHFMHK5Lj0VYsMblvw/1HsudiCRT3WjBH/61H98WVGNwdAgmpoZjYmoEJqaEw2jwz6vM2ewO/OfIGfzju0Lkna533j8hJQz3TU3FNSNjoFJeuPdAFEWU1LXiUHlniZeZcLjcjLpm6wXbKgQgLSoEGfEGjIo3YIRRh4gQrbO0A9VKt7yRqm2y4PNDlfhkX7lzKgboOLZh+tAozBsTh1kjYhCsVfX7uag7ljsRSc5md1y0sPzdgdIGvPl9ET7ZX+7cXR4fGoiFU5IxfWgU8isbcbjcjEOdZW5us13wGCqFgCExOqTH6ZEeb0B6vAEjYnUI0gxsoZY1tOLT/eX4aF85jlSYnfcHqpWYNTIGN2bGYcbQKLdPe/grljsRkZc7Y27D27tO453dxRcdiXfRKBUYZtQhPb6zyOMMGGbUSXJw3qWcqGrEx/vK8fH+chTVnl290BCoxpx0I24cE4dJqRFuPabC37DciYh8RFu7HR/tK8NbO06jqKYZw4w6ZMQbkB6vx6g4A4bG6Hxq5CuKIg6UmvDx/nJ8sr8cVeccCxCt0+KG0XG4angU0uMMnKN3EcudiMgHiaIoqwMM7Q4Ruwtr8fG+cnx+sOKCKYb40ECkx+udxwmkxxkQpdNKlNb7sdyJiMirWGx2fFtQg88OlGNfSUO3XffnMuoDnHstMjqPJ4jRa2X1pqevPFbu3377LZ555hns3bsXFRUV2LRpE2666Sa3ByMiInkzt7XjcJkZh8s7Dhw8WGbCqZrmiy75GxmicR5v0HXsQXxooN8Vfm871OXDKpubm5GZmYl7770Xt956a79CEhGR/9IHqDE5LQKT0yKc9zVbbDhaYXYuxnO43ITjVU2oabJiW341tuVXO7cNC1KfU/LuySQIAoLUSgRplAjUqDr/VCL4nI+DNOd8Xa1EkFYJjVLhVW80XC73OXPmYM6cOZ7IQkREfi5Yq0JWSjiyUsKd97Va7ThWacahcjMOlZpwqNyEgjONqG9px/bjNRKmPUup6HhTEKjp/sbgZ+MTcPvEpAHP4/ETIi0WCyyWs0dKms3mS2xNRETUXaBGibFJYRibFOa8z2Kzo6CyCYfKTajt5QWEesPmENHaeaGcZosdre02tFjtaLF23Ndi7f5517UR7A4RjRbbBZcCvnJIpNuyucLj5Z6dnY0nnnjC009DRER+RKtSIiPB0Kfle93JZnegxflmoKP4W9u7yt8m2QV2PF7uy5cvx9KlS52fm81mJCYmevppiYiIPE6lVECvVEAfoJY6SjceL3etVgutlucsEhERDRTfWfKIiIiIesXlkXtTUxNOnDjh/LywsBD79u1DeHg4kpIG/ohAIiIi6s7lcs/Ly8NVV13l/LxrPn3hwoV466233BaMiIiI+sblcp85cyYGeMVaIiIicgHn3ImIiGSG5U5ERCQzLHciIiKZYbkTERHJDMudiIhIZjy+Qt35uo605wVkiIiIXNPVnZc7a23Ay72xsREAuL48ERFRHzU2NsJg6PmiOYI4wCetOxwOlJeXQ6fTue3C9l0XoykpKYFer3fLY1Lv8fWXFl9/afH1l5a/vf6iKKKxsRFxcXFQKHqeWR/wkbtCoUBCQoJHHluv1/vFP6634usvLb7+0uLrLy1/ev0vNWLvwgPqiIiIZIblTkREJDOyKHetVovHH3+c142XCF9/afH1lxZff2nx9b+4AT+gjoiIiDxLFiN3IiIiOovlTkREJDMsdyIiIplhuRMREckMy52IiEhmfL7cX3nlFaSmpiIgIADjx4/H9u3bpY7kF1asWAFBELrdjEaj1LFk7dtvv8XcuXMRFxcHQRDw4Ycfdvu6KIpYsWIF4uLiEBgYiJkzZ+Lw4cPShJWhy73+ixYtuuB34oorrpAmrMxkZ2djwoQJ0Ol0iI6Oxk033YT8/Pxu2/DnvzufLvf3338fDz/8MP70pz/hxx9/xJVXXok5c+aguLhY6mh+YdSoUaioqHDeDh48KHUkWWtubkZmZiZWr1590a8//fTTeO6557B69Wrk5ubCaDTimmuucV6sifrncq8/APz0pz/t9jvx+eefD2BC+crJycGSJUuwa9cubNmyBTabDbNnz0Zzc7NzG/78n0f0YRMnThQfeOCBbvcNHz5c/OMf/yhRIv/x+OOPi5mZmVLH8FsAxE2bNjk/dzgcotFoFJ966innfW1tbaLBYBBfffVVCRLK2/mvvyiK4sKFC8V58+ZJksffVFVViQDEnJwcURT5838xPjtyt1qt2Lt3L2bPnt3t/tmzZ2PHjh0SpfIvx48fR1xcHFJTU3H77bfj1KlTUkfyW4WFhaisrOz2+6DVajFjxgz+Pgygbdu2ITo6GkOHDsXixYtRVVUldSRZMplMAIDw8HAA/Pm/GJ8t95qaGtjtdsTExHS7PyYmBpWVlRKl8h+TJk3C+vXrsXnzZqxduxaVlZWYMmUKamtrpY7ml7p+5vn7IJ05c+bgnXfewTfffIP//d//RW5uLq6++mpYLBapo8mKKIpYunQppk2bhvT0dAD8+b+YAb/kq7udf014URTddp146tmcOXOcH2dkZGDy5MlIS0vDunXrsHTpUgmT+Tf+PkhnwYIFzo/T09ORlZWF5ORkfPbZZ7jlllskTCYvDz74IA4cOIDvvvvugq/x5/8snx25R0ZGQqlUXvCurKqq6oJ3b+R5wcHByMjIwPHjx6WO4pe6zlTg74P3iI2NRXJyMn8n3Oihhx7Cxx9/jK1btyIhIcF5P3/+L+Sz5a7RaDB+/Hhs2bKl2/1btmzBlClTJErlvywWC44ePYrY2Fipo/il1NRUGI3Gbr8PVqsVOTk5/H2QSG1tLUpKSvg74QaiKOLBBx/Exo0b8c033yA1NbXb1/nzfyGf3i2/dOlS3HPPPcjKysLkyZPx+uuvo7i4GA888IDU0WRv2bJlmDt3LpKSklBVVYUnn3wSZrMZCxculDqabDU1NeHEiRPOzwsLC7Fv3z6Eh4cjKSkJDz/8MFatWoUhQ4ZgyJAhWLVqFYKCgnDnnXdKmFo+LvX6h4eHY8WKFbj11lsRGxuLoqIiPPbYY4iMjMTNN98sYWp5WLJkCd5991189NFH0Ol0zhG6wWBAYGAgBEHgz//5JD1W3w1efvllMTk5WdRoNOK4ceOcp0aQZy1YsECMjY0V1Wq1GBcXJ95yyy3i4cOHpY4la1u3bhUBXHBbuHChKIodpwM9/vjjotFoFLVarTh9+nTx4MGD0oaWkUu9/i0tLeLs2bPFqKgoUa1Wi0lJSeLChQvF4uJiqWPLwsVedwDim2++6dyGP//d8XruREREMuOzc+5ERER0cSx3IiIimWG5ExERyQzLnYiISGZY7kRERDLDciciIpIZljsREZHMsNyJiIhkhuVOREQkMyx3IiIimWG5ExERycz/Byj2Cg9lH7ClAAAAAElFTkSuQmCC", 630 | "text/plain": [ 631 | "
" 632 | ] 633 | }, 634 | "metadata": {}, 635 | "output_type": "display_data" 636 | } 637 | ], 638 | "source": [ 639 | "learn.fit(1, cbs=prog)" 640 | ] 641 | }, 642 | { 643 | "cell_type": "code", 644 | "execution_count": null, 645 | "metadata": {}, 646 | "outputs": [], 647 | "source": [ 648 | "# learn.model.save_pretrained('models/sql_1ep_636')" 649 | ] 650 | }, 651 | { 652 | "cell_type": "code", 653 | "execution_count": null, 654 | "metadata": {}, 655 | "outputs": [ 656 | { 657 | "data": { 658 | "text/plain": [ 659 | "16.106410496" 660 | ] 661 | }, 662 | "execution_count": null, 663 | "metadata": {}, 664 | "output_type": "execute_result" 665 | } 666 | ], 667 | "source": [ 668 | "#SGD\n", 669 | "torch.cuda.max_memory_allocated()/1_000_000_000" 670 | ] 671 | }, 672 | { 673 | "cell_type": "code", 674 | "execution_count": null, 675 | "metadata": {}, 676 | "outputs": [ 677 | { 678 | "data": { 679 | "text/plain": [ 680 | "16.106410496" 681 | ] 682 | }, 683 | "execution_count": null, 684 | "metadata": {}, 685 | "output_type": "execute_result" 686 | } 687 | ], 688 | "source": [ 689 | "#adam\n", 690 | "torch.cuda.max_memory_allocated()/1_000_000_000" 691 | ] 692 | }, 693 | { 694 | "cell_type": "markdown", 695 | "metadata": {}, 696 | "source": [ 697 | "## Testing" 698 | ] 699 | }, 700 | { 701 | "cell_type": "code", 702 | "execution_count": null, 703 | "metadata": {}, 704 | "outputs": [], 705 | "source": [ 706 | "prompt = \"Context:\" + eval_dataset[0]['context'] + \"\\nQuestion:\" + eval_dataset[0]['question'] + \"\\nAnswer:\"\n", 707 | "tokenized_prompt = tokenizer(prompt, return_tensors='pt')['input_ids'].cuda()" 708 | ] 709 | }, 710 | { 711 | "cell_type": "code", 712 | "execution_count": null, 713 | "metadata": {}, 714 | "outputs": [], 715 | "source": [ 716 | "with torch.inference_mode():\n", 717 | " output = m.generate(tokenized_prompt, max_new_tokens=90)" 718 | ] 719 | }, 720 | { 721 | "cell_type": "code", 722 | "execution_count": null, 723 | "metadata": {}, 724 | "outputs": [ 725 | { 726 | "name": "stdout", 727 | "output_type": "stream", 728 | "text": [ 729 | "Context:CREATE TABLE table_name_95 (tournament VARCHAR, score VARCHAR, outcome VARCHAR, surface VARCHAR)\n", 730 | "Question:Which tournament has an Outcome of runner-up, a Surface of hard, and a Score of 6–4, 6–2?\n", 731 | "Answer:Answer:\n" 732 | ] 733 | } 734 | ], 735 | "source": [ 736 | "print(prompt + tokenizer.decode(output[0][len(tokenized_prompt[0]):], skip_special_tokens=True))" 737 | ] 738 | }, 739 | { 740 | "cell_type": "markdown", 741 | "metadata": {}, 742 | "source": [ 743 | "To produce better inference results, it would be better to use an instruct model and to ensure the dataset is in the format expected by the model." 744 | ] 745 | }, 746 | { 747 | "cell_type": "markdown", 748 | "metadata": {}, 749 | "source": [ 750 | "## fin -" 751 | ] 752 | }, 753 | { 754 | "cell_type": "code", 755 | "execution_count": null, 756 | "metadata": {}, 757 | "outputs": [], 758 | "source": [] 759 | } 760 | ], 761 | "metadata": { 762 | "kernelspec": { 763 | "display_name": "python3", 764 | "language": "python", 765 | "name": "python3" 766 | } 767 | }, 768 | "nbformat": 4, 769 | "nbformat_minor": 4 770 | } 771 | --------------------------------------------------------------------------------