├── .coveragerc ├── .deepsource.toml ├── .github ├── CODEOWNERS ├── FUNDING.yml ├── ISSUE_TEMPLATE │ ├── bug_report.md │ └── feature_request.md ├── labeler.yml ├── pull_request_template.md └── workflows │ ├── labeler.yml │ ├── latest-changes.yml.off │ ├── main.yml │ ├── markdown_links.yml │ ├── mkdocs_ci.yml.off │ ├── publish-to-pypi.yml │ ├── stale.yml │ └── welcome.yml ├── .gitignore ├── .pep8speaks.yml ├── .pre-commit-config.yaml ├── .readthedocs.yaml ├── CITATION.cff ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── MANIFEST.in ├── Makefile ├── README.md ├── docs ├── 404.md ├── CHANGELOG.md ├── CNAME ├── examples ├── gradsflow │ ├── autotasks │ │ ├── autotasks.md │ │ └── engine.md │ ├── callbacks.md │ ├── core.md │ ├── data.md │ ├── models │ │ ├── base.md │ │ ├── model.md │ │ ├── tracker.md │ │ └── utils.md │ ├── tuner.md │ └── utility.md ├── index.md ├── overrides │ └── main.html └── requirements.txt ├── examples ├── nbs │ ├── 01-ImageClassification.ipynb │ ├── 02-TextClassification.ipynb │ ├── 03-TextSummarization.ipynb │ ├── 04-RayDataset.ipynb │ ├── 05-model_fit.ipynb │ ├── 06-AutoModel_fit.ipynb │ ├── 2021-10-3-huggingface-training.ipynb │ └── Pix2Pix_explained_with_code.ipynb └── src │ ├── models │ └── hello_world.py │ └── tasks │ ├── image_classifier.py │ └── text_classification.py ├── gradsflow ├── __init__.py ├── autotasks │ ├── __init__.py │ ├── autoclassification │ │ ├── __init__.py │ │ ├── image.py │ │ └── text │ │ │ ├── __init__.py │ │ │ └── text.py │ ├── autosummarization.py │ ├── autotasks.py │ └── engine │ │ ├── __init__.py │ │ ├── autoclassifier.py │ │ ├── automodel.py │ │ └── backend.py ├── callbacks │ ├── __init__.py │ ├── base.py │ ├── comet.py │ ├── gpu.py │ ├── logger.py │ ├── progress.py │ ├── raytune.py │ ├── runner.py │ ├── tensorboard.py │ ├── training.py │ └── wandb.py ├── core │ ├── __init__.py │ ├── base.py │ └── metrics.py ├── data │ ├── __init__.py │ ├── autodata.py │ ├── base.py │ ├── common.py │ ├── image.py │ ├── mixins.py │ └── ray_dataset.py ├── models │ ├── __init__.py │ ├── base.py │ ├── constants.py │ ├── exceptions.py │ ├── model.py │ ├── tracker.py │ └── utils.py ├── tuner │ ├── __init__.py │ ├── automodel.py │ └── tuner.py └── utility │ ├── __init__.py │ ├── common.py │ ├── data.py │ └── imports.py ├── mkdocs.yml ├── pyproject.toml ├── setup.cfg ├── setup.py ├── sonar-project.properties └── tests ├── __init__.py ├── __main__.py ├── autotasks ├── test_autotasks.py ├── test_autotrainer.py ├── test_core_automodel.py ├── test_image.py ├── test_summarization.py └── test_text.py ├── callbacks ├── test_logger.py ├── test_runner.py └── test_wandb.py ├── conftest.py ├── core └── test_base.py ├── data ├── __init__.py ├── test_autodata.py ├── test_common.py ├── test_image_data.py ├── test_mixins.py └── test_ray_dataset.py ├── dummies.py ├── models ├── test_exceptions.py ├── test_model.py ├── test_tracker.py └── test_utils.py ├── tuner ├── test_automodel.py └── test_tuner.py └── utility ├── __init__.py ├── test_common.py └── test_data.py /.coveragerc: -------------------------------------------------------------------------------- 1 | [report] 2 | exclude_lines = 3 | pragma: no cover 4 | def __repr__ 5 | if _environ.get("GF_CI") 6 | if self.debug: 7 | if settings.DEBUG 8 | raise AssertionError 9 | raise NotImplementedError 10 | if 0: 11 | if __name__ == .__main__.: 12 | 13 | omit = 14 | tests/__init__.py 15 | -------------------------------------------------------------------------------- /.deepsource.toml: -------------------------------------------------------------------------------- 1 | version = 1 2 | 3 | test_patterns = ["tests/**"] 4 | 5 | exclude_patterns = [ 6 | "tests/**", 7 | "examples/**" 8 | ] 9 | 10 | [[analyzers]] 11 | name = "python" 12 | enabled = true 13 | 14 | [analyzers.meta] 15 | runtime_version = "3.x.x" 16 | 17 | [[transformers]] 18 | name = "black" 19 | enabled = true 20 | 21 | [[transformers]] 22 | name = "isort" 23 | enabled = true 24 | -------------------------------------------------------------------------------- /.github/CODEOWNERS: -------------------------------------------------------------------------------- 1 | # Lines starting with '#' are comments. 2 | # Each line is a file pattern followed by one or more owners. 3 | 4 | # More details are here: https://help.github.com/articles/about-codeowners/ 5 | 6 | # The '*' pattern is global owners. 7 | 8 | # Order is important. The last matching pattern has the most precedence. 9 | # The folders are ordered as follows: 10 | 11 | # In each subsection folders are ordered first by depth, then alphabetically. 12 | # This should make it easy to add new rules without breaking existing ones. 13 | 14 | # Global rule: 15 | * @aniketmaurya 16 | 17 | # tests 18 | /tests/** @aniketmaurya 19 | -------------------------------------------------------------------------------- /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | # These are supported funding model platforms 2 | 3 | github: [aniketmaurya] 4 | patreon: # Replace with a single Patreon username 5 | open_collective: # Replace with a single Open Collective username 6 | ko_fi: aniketmaurya 7 | tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel 8 | community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry 9 | liberapay: # Replace with a single Liberapay username 10 | issuehunt: # gradsflow 11 | otechie: # Replace with a single Otechie username 12 | custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2'] 13 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: 🐛 Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | 11 | #### Bug description 12 | 13 | 14 | #### Expected result 15 | 16 | 17 | #### Actual result 18 | 19 | 20 | #### Steps to reproduce 21 | 22 | 23 | 1. 24 | 2. 25 | 3. 26 | #### Context 27 | 28 | 29 | 30 | #### Your Environment 31 | 32 | 33 | * Version used: 34 | * Operating System and version: 35 | * Link to your fork: 36 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: 🚀 Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | #### Is your feature request related to a problem? Please describe. 11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 12 | 13 | #### Describe the solution you'd like 14 | A clear and concise description of what you want to happen. 15 | 16 | #### Describe alternatives you've considered 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | #### Additional context 20 | Add any other context or screenshots about the feature request here. 21 | -------------------------------------------------------------------------------- /.github/labeler.yml: -------------------------------------------------------------------------------- 1 | # Add 'docs' to any changes within 'docs' folder or any subfolders 2 | documentation: 3 | - docs/**/* 4 | 5 | example: 6 | - examples/**/* 7 | 8 | test: 9 | - tests/**/* 10 | 11 | CI: 12 | - .github/**/* 13 | - "*.yaml" 14 | - "*.yml" 15 | -------------------------------------------------------------------------------- /.github/pull_request_template.md: -------------------------------------------------------------------------------- 1 | #### Changes 2 | 3 | 4 | 5 | 6 | Fixes # (issue) 7 | 8 | 9 | #### Type of change 10 | 11 | - [ ] 📚 Documentation Update 12 | - [ ] 🧪 Tests Cases 13 | - [ ] 🐞 Bug fix (non-breaking change which fixes an issue) 14 | - [ ] 🔬 New feature (non-breaking change which adds functionality) 15 | - [ ] 🚨 Breaking change (fix or feature that would cause existing functionality to not work as expected) 16 | - [ ] 📝 This change requires a documentation update 17 | 18 | 19 | #### Checklist 20 | 21 | - [ ] My code follows the style guidelines of this project 22 | - [ ] I have performed a self-review of my own code 23 | - [ ] I have commented my code, particularly in hard-to-understand areas 24 | - [ ] I have made corresponding changes to the documentation 25 | - [ ] My changes generate no new warnings 26 | - [ ] Did you update CHANGELOG (docs/CHANGELOG.md) in case of a major change? 27 | -------------------------------------------------------------------------------- /.github/workflows/labeler.yml: -------------------------------------------------------------------------------- 1 | name: "Pull Request Labeler" 2 | on: 3 | - pull_request_target 4 | 5 | jobs: 6 | triage: 7 | runs-on: ubuntu-latest 8 | steps: 9 | - uses: actions/labeler@v3 10 | with: 11 | repo-token: "${{ secrets.GITHUB_TOKEN }}" 12 | -------------------------------------------------------------------------------- /.github/workflows/latest-changes.yml.off: -------------------------------------------------------------------------------- 1 | name: Latest Changes 2 | 3 | on: 4 | pull_request_target: 5 | branches: 6 | - main 7 | types: 8 | - closed 9 | # For manually triggering it 10 | workflow_dispatch: 11 | inputs: 12 | number: 13 | description: PR number 14 | required: true 15 | 16 | jobs: 17 | latest-changes: 18 | runs-on: ubuntu-latest 19 | steps: 20 | - uses: actions/checkout@v2 21 | with: 22 | token: ${{ secrets.ACTIONS_TOKEN }} 23 | - uses: docker://tiangolo/latest-changes:0.0.3 24 | with: 25 | token: ${{ secrets.GITHUB_TOKEN }} 26 | latest_changes_file: docs/CHANGELOG.md 27 | latest_changes_header: '## 0.0.3\n' 28 | debug_logs: true 29 | -------------------------------------------------------------------------------- /.github/workflows/main.yml: -------------------------------------------------------------------------------- 1 | name: pytest 2 | on: 3 | push: 4 | branches: [ main ] 5 | pull_request: 6 | branches: [ main ] 7 | 8 | 9 | jobs: 10 | pytest: 11 | runs-on: ${{ matrix.os }} 12 | timeout-minutes: 15 13 | strategy: 14 | matrix: 15 | os: [ ubuntu-latest, macos-latest ] 16 | python-version: ["3.9", "3.10"] 17 | include: 18 | - os: ubuntu-latest 19 | path: ~/.cache/pip 20 | - os: macos-latest 21 | path: ~/Library/Caches/pip 22 | env: 23 | OS: ${{ matrix.os }} 24 | PYTHON: '3.10' 25 | 26 | 27 | steps: 28 | - uses: actions/checkout@v2 29 | with: 30 | fetch-depth: 0 # Shallow clones should be disabled for a better relevancy of analysis 31 | 32 | - name: Set up Python ${{ matrix.python-version }} 33 | uses: actions/setup-python@v2 34 | with: 35 | python-version: ${{ matrix.python-version }} 36 | 37 | - name: Cache pip 38 | uses: actions/cache@v2 39 | with: 40 | path: ${{ matrix.path }} 41 | key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt') }} 42 | restore-keys: | 43 | ${{ runner.os }}-pip- 44 | ${{ runner.os }}- 45 | 46 | - name: Install dependencies 47 | run: | 48 | python --version 49 | pip --version 50 | python -m pip install --upgrade pip build 51 | pip install -e '.[dev,test]' 52 | pip list 53 | shell: bash 54 | 55 | - name: Prepare Test 56 | run: | 57 | python tests # download test data 58 | 59 | - name: Run Test with Coverage 60 | run: | 61 | coverage erase 62 | coverage run -m pytest 63 | 64 | - name: Generate Coverage Report 65 | run: | 66 | coverage report -m -i 67 | coverage xml -i 68 | 69 | - name: Upload Coverage to Codecov 70 | if: runner.os != 'macOS' 71 | uses: codecov/codecov-action@v1 72 | with: 73 | token: ${{ secrets.CODECOV_TOKEN }} 74 | file: ./coverage.xml 75 | flags: unittests 76 | env_vars: OS,PYTHON 77 | name: codecov-umbrella 78 | fail_ci_if_error: false 79 | 80 | # - name: SonarCloud Scan 81 | # if: runner.os != 'macOS' && env.SONAR_TOKEN != null 82 | # uses: SonarSource/sonarcloud-github-action@master 83 | # env: 84 | # GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} # Needed to get PR information, if any 85 | # SONAR_TOKEN: ${{ secrets.SONAR_TOKEN }} 86 | -------------------------------------------------------------------------------- /.github/workflows/markdown_links.yml: -------------------------------------------------------------------------------- 1 | name: Check Markdown links 2 | 3 | on: push 4 | 5 | jobs: 6 | markdown-link-check: 7 | runs-on: ubuntu-latest 8 | steps: 9 | - uses: actions/checkout@master 10 | - uses: gaurav-nelson/github-action-markdown-link-check@v1 11 | -------------------------------------------------------------------------------- /.github/workflows/mkdocs_ci.yml.off: -------------------------------------------------------------------------------- 1 | name: MkDocs 2 | on: 3 | push: 4 | branches: 5 | - master 6 | - main 7 | jobs: 8 | deploy: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: actions/checkout@v2 12 | - uses: actions/setup-python@v2 13 | with: 14 | python-version: 3.x 15 | - run: pip install -r docs/requirements.txt 16 | - run: mkdocs gh-deploy --force 17 | -------------------------------------------------------------------------------- /.github/workflows/publish-to-pypi.yml: -------------------------------------------------------------------------------- 1 | name: Publish Python 🐍 distributions 📦 to PyPI and TestPyPI 2 | 3 | on: push 4 | 5 | jobs: 6 | build-n-publish: 7 | name: Build and publish Python 🐍 distributions 📦 to PyPI and TestPyPI 8 | runs-on: ubuntu-latest 9 | 10 | steps: 11 | - uses: actions/checkout@master 12 | - name: Set up Python 3.10 13 | uses: actions/setup-python@v1 14 | with: 15 | python-version: "3.10" 16 | 17 | - name: Install pypa/build 18 | run: >- 19 | python -m 20 | pip install 21 | build 22 | --user 23 | - name: Build a binary wheel and a source tarball 24 | run: >- 25 | python -m build 26 | 27 | # - name: Publish distribution 📦 to Test PyPI 28 | # if: startsWith(github.ref, 'refs/tags') 29 | # uses: pypa/gh-action-pypi-publish@master 30 | # with: 31 | # password: ${{ secrets.TEST_PYPI_API_TOKEN }} 32 | # repository_url: https://test.pypi.org/legacy/ 33 | 34 | - name: Publish distribution 📦 to PyPI 35 | if: startsWith(github.ref, 'refs/tags') 36 | uses: pypa/gh-action-pypi-publish@master 37 | with: 38 | password: ${{ secrets.PYPI_API_TOKEN }} 39 | -------------------------------------------------------------------------------- /.github/workflows/stale.yml: -------------------------------------------------------------------------------- 1 | name: Mark stale issues and pull requests 2 | 3 | on: 4 | schedule: 5 | - cron: "30 1 * * *" 6 | 7 | jobs: 8 | stale: 9 | 10 | runs-on: ubuntu-latest 11 | 12 | steps: 13 | - uses: actions/stale@v3 14 | with: 15 | repo-token: ${{ secrets.GITHUB_TOKEN }} 16 | stale-issue-message: 'Stale issue message' 17 | stale-pr-message: 'Stale pull request message' 18 | stale-issue-label: 'no-issue-activity' 19 | stale-pr-label: 'no-pr-activity' 20 | -------------------------------------------------------------------------------- /.github/workflows/welcome.yml: -------------------------------------------------------------------------------- 1 | name: Greet New Contributors 2 | 3 | on: [pull_request_target, issues] 4 | 5 | jobs: 6 | greeting: 7 | runs-on: ubuntu-latest 8 | steps: 9 | - uses: actions/first-interaction@v1 10 | with: 11 | repo-token: ${{ secrets.GITHUB_TOKEN }} 12 | issue-message: "👋 @${{github.actor}}! Thank you for opening your first issue in this repo. We are so happy that you have decided to contribute and value your contribution. Please read these materials before proceeding: [Contributing Guide](https://github.com/gradsflow/gradsflow/blob/master/CONTRIBUTING.md) and [Code of Conduct](https://github.com/gradsflow/gradsflow/blob/master/CODE_OF_CONDUCT.md)." 13 | pr-message: "👋 @${{github.actor}}! Thank you for opening your first pull request in this repo. We are so happy that you have decided to contribute and value your contribution. Please read these materials before proceeding: [Contributing Guide](https://github.com/gradsflow/gradsflow/blob/master/CONTRIBUTING.md) and [Code of Conduct](https://github.com/gradsflow/gradsflow/blob/master/CODE_OF_CONDUCT.md)." 14 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | .idea/ 132 | lightning_logs/ 133 | *.zip 134 | *.jpg 135 | *.jpeg 136 | *.png 137 | *.gif 138 | 139 | /data 140 | .vscode/ 141 | -------------------------------------------------------------------------------- /.pep8speaks.yml: -------------------------------------------------------------------------------- 1 | # File : .pep8speaks.yml 2 | 3 | scanner: 4 | diff_only: True # If False, the entire file touched by the Pull Request is scanned for errors. If True, only the diff is scanned. 5 | linter: pycodestyle # Other option is flake8 6 | 7 | pycodestyle: # Same as scanner.linter value. Other option is flake8 8 | max-line-length: 100 # Default is 79 in PEP 8 9 | ignore: # Errors and warnings to ignore 10 | - W504 # line break after binary operator 11 | - E402 # module level import not at top of file 12 | - E731 # do not assign a lambda expression, use a def 13 | - C406 # Unnecessary list literal - rewrite as a dict literal. 14 | - E741 # ambiguous variable name 15 | 16 | no_blank_comment: True # If True, no comment is made on PR without any errors. 17 | descending_issues_order: False # If True, PEP 8 issues in message will be displayed in descending order of line numbers in the file 18 | 19 | message: # Customize the comment made by the bot 20 | opened: # Messages when a new PR is submitted 21 | header: "Hello @{name}! Thanks for opening this PR. " 22 | # The keyword {name} is converted into the author's username 23 | footer: "Do see the [Hitchhiker's guide to code style](https://goo.gl/hqbW4r)" 24 | # The messages can be written as they would over GitHub 25 | updated: # Messages when new commits are added to the PR 26 | header: "Hello @{name}! Thanks for updating this PR. " 27 | footer: "" # Why to comment the link to the style guide everytime? :) 28 | no_errors: "There are currently no PEP 8 issues detected in this Pull Request. Cheers! :beers: " 29 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # See https://pre-commit.com for more information 2 | # See https://pre-commit.com/hooks.html for more hooks 3 | default_language_version: 4 | python: python3.10 5 | 6 | ci: 7 | autofix_prs: true 8 | autoupdate_commit_msg: '[pre-commit.ci] pre-commit suggestions' 9 | autoupdate_schedule: quarterly 10 | # submodules: true 11 | 12 | repos: 13 | - repo: https://github.com/pre-commit/pre-commit-hooks 14 | rev: v5.0.0 15 | hooks: 16 | - id: trailing-whitespace 17 | - id: end-of-file-fixer 18 | - id: check-added-large-files 19 | 20 | - repo: https://github.com/psf/black 21 | rev: 24.8.0 22 | hooks: 23 | - id: black 24 | name: "Black: The uncompromising Python code formatter" 25 | 26 | - repo: https://github.com/PyCQA/isort 27 | rev: 5.13.2 28 | hooks: 29 | - id: isort 30 | name: "Sort Imports" 31 | args: [ "--profile black" ] 32 | 33 | - repo: https://github.com/codespell-project/codespell 34 | rev: v2.3.0 35 | hooks: 36 | - id: codespell 37 | args: 38 | - --ignore-words-list 39 | - "ans,hist" 40 | - --skip 41 | - "*.bib,*.ipynb" 42 | 43 | - repo: https://github.com/asottile/pyupgrade 44 | rev: v3.17.0 45 | hooks: 46 | - id: pyupgrade 47 | args: [ --py36-plus ] 48 | 49 | - repo: https://github.com/PyCQA/bandit 50 | rev: 1.7.10 51 | hooks: 52 | - id: bandit 53 | language_version: python3 54 | exclude: tests/ 55 | args: 56 | - -s 57 | - "B404,B602,B603,B607,B101" 58 | 59 | - repo: local 60 | hooks: 61 | - id: clean 62 | name: clean 63 | entry: make 64 | args: [ "clean" ] 65 | language: system 66 | pass_filenames: false 67 | 68 | 69 | - repo: https://github.com/kynan/nbstripout 70 | rev: 0.7.1 71 | hooks: 72 | - id: nbstripout 73 | args: 74 | - --max-size=500k 75 | - --drop-empty-cells 76 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # .readthedocs.yaml 2 | # Read the Docs configuration file 3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 4 | 5 | # Required 6 | version: 2 7 | 8 | mkdocs: 9 | configuration: mkdocs.yml 10 | 11 | # Optionally set the version of Python and requirements required to build your docs 12 | python: 13 | version: 3.8 14 | install: 15 | - requirements: docs/requirements.txt 16 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | message: "If you use GradsFlow, please cite it as below." 3 | authors: 4 | - family-names: "Maurya" 5 | given-names: "Aniket" 6 | orcid: "https://orcid.org/0000-0002-0202-4810" 7 | title: "gradsflow" 8 | doi: 10.5281/zenodo.5245150 9 | date-released: 2021-08-24 10 | url: "https://github.com/gradsflow/gradsflow" 11 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | We as members, contributors, and leaders pledge to make participation in our 6 | community a harassment-free experience for everyone, regardless of age, body 7 | size, visible or invisible disability, ethnicity, sex characteristics, gender 8 | identity and expression, level of experience, education, socio-economic status, 9 | nationality, personal appearance, race, religion, or sexual identity 10 | and orientation. 11 | 12 | We pledge to act and interact in ways that contribute to an open, welcoming, 13 | diverse, inclusive, and healthy community. 14 | 15 | ## Our Standards 16 | 17 | Examples of behavior that contributes to a positive environment for our 18 | community include: 19 | 20 | * Demonstrating empathy and kindness toward other people 21 | * Being respectful of differing opinions, viewpoints, and experiences 22 | * Giving and gracefully accepting constructive feedback 23 | * Accepting responsibility and apologizing to those affected by our mistakes, 24 | and learning from the experience 25 | * Focusing on what is best not just for us as individuals, but for the 26 | overall community 27 | 28 | Examples of unacceptable behavior include: 29 | 30 | * The use of sexualized language or imagery, and sexual attention or 31 | advances of any kind 32 | * Trolling, insulting or derogatory comments, and personal or political attacks 33 | * Public or private harassment 34 | * Publishing others' private information, such as a physical or email 35 | address, without their explicit permission 36 | * Other conduct which could reasonably be considered inappropriate in a 37 | professional setting 38 | 39 | ## Enforcement Responsibilities 40 | 41 | Community leaders are responsible for clarifying and enforcing our standards of 42 | acceptable behavior and will take appropriate and fair corrective action in 43 | response to any behavior that they deem inappropriate, threatening, offensive, 44 | or harmful. 45 | 46 | Community leaders have the right and responsibility to remove, edit, or reject 47 | comments, commits, code, wiki edits, issues, and other contributions that are 48 | not aligned to this Code of Conduct, and will communicate reasons for moderation 49 | decisions when appropriate. 50 | 51 | ## Scope 52 | 53 | This Code of Conduct applies within all community spaces, and also applies when 54 | an individual is officially representing the community in public spaces. 55 | Examples of representing our community include using an official e-mail address, 56 | posting via an official social media account, or acting as an appointed 57 | representative at an online or offline event. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported to the community leaders responsible for enforcement at 63 | hello@gradsflow.com. 64 | All complaints will be reviewed and investigated promptly and fairly. 65 | 66 | All community leaders are obligated to respect the privacy and security of the 67 | reporter of any incident. 68 | 69 | ## Enforcement Guidelines 70 | 71 | Community leaders will follow these Community Impact Guidelines in determining 72 | the consequences for any action they deem in violation of this Code of Conduct: 73 | 74 | ### 1. Correction 75 | 76 | **Community Impact**: Use of inappropriate language or other behavior deemed 77 | unprofessional or unwelcome in the community. 78 | 79 | **Consequence**: A private, written warning from community leaders, providing 80 | clarity around the nature of the violation and an explanation of why the 81 | behavior was inappropriate. A public apology may be requested. 82 | 83 | ### 2. Warning 84 | 85 | **Community Impact**: A violation through a single incident or series 86 | of actions. 87 | 88 | **Consequence**: A warning with consequences for continued behavior. No 89 | interaction with the people involved, including unsolicited interaction with 90 | those enforcing the Code of Conduct, for a specified period of time. This 91 | includes avoiding interactions in community spaces as well as external channels 92 | like social media. Violating these terms may lead to a temporary or 93 | permanent ban. 94 | 95 | ### 3. Temporary Ban 96 | 97 | **Community Impact**: A serious violation of community standards, including 98 | sustained inappropriate behavior. 99 | 100 | **Consequence**: A temporary ban from any sort of interaction or public 101 | communication with the community for a specified period of time. No public or 102 | private interaction with the people involved, including unsolicited interaction 103 | with those enforcing the Code of Conduct, is allowed during this period. 104 | Violating these terms may lead to a permanent ban. 105 | 106 | ### 4. Permanent Ban 107 | 108 | **Community Impact**: Demonstrating a pattern of violation of community 109 | standards, including sustained inappropriate behavior, harassment of an 110 | individual, or aggression toward or disparagement of classes of individuals. 111 | 112 | **Consequence**: A permanent ban from any sort of public interaction within 113 | the community. 114 | 115 | ## Attribution 116 | 117 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], 118 | version 2.0, available at 119 | https://www.contributor-covenant.org/version/2/0/code_of_conduct.html. 120 | 121 | Community Impact Guidelines were inspired by [Mozilla's code of conduct 122 | enforcement ladder](https://github.com/mozilla/diversity). 123 | 124 | [homepage]: https://www.contributor-covenant.org 125 | 126 | For answers to common questions about this code of conduct, see the FAQ at 127 | https://www.contributor-covenant.org/faq. Translations are available at 128 | https://www.contributor-covenant.org/translations. 129 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing guidelines 2 | 3 | 👍🎉 First off, thanks for taking the time to contribute! 🎉👍 4 | 5 | The following is a set of guidelines for contributing to Gradsflow and its packages, 6 | which are hosted in the GradsFlow Organization on GitHub. 7 | These are mostly guidelines, not rules. 8 | Use your best judgment, and feel free to propose changes to this document in a pull request. 9 | 10 | We welcome any kind of contribution to our software, from simple comment or question to a 11 | full fledged [pull request](https://help.github.com/articles/about-pull-requests/). 12 | Please read and follow our [Code of Conduct](CODE_OF_CONDUCT.md). 13 | 14 | A contribution can be one of the following cases: 15 | 16 | 1. you have a question; 17 | 1. you think you may have found a bug (including unexpected behavior); 18 | 1. you want to make some kind of change to the code base 19 | (e.g. to fix a bug, to add a new feature, to update documentation); 20 | 1. you want to make a new release of the code base. 21 | 22 | The sections below outline the steps in each case. 23 | 24 | ## You have a question 25 | 26 | 1. Use the search functionality [here](https://github.com/gradsflow/gradsflow/issues) to see if someone already 27 | filed the same issue or check out [Docs](https://docs.gradsflow.com). 28 | 2. If your issue search did not yield any relevant results, make a new issue. 29 | 3. Apply the "Question" label; apply other labels when relevant. 30 | 4. You can join our Slack group as well. 31 | 32 | ## You think you may have found a bug 33 | 34 | 1. use the search functionality [here](https://github.com/gradsflow/gradsflow/issues) to see if someone already filed the same issue; 35 | 1. if your issue search did not yield any relevant results, make a new issue, making sure to provide enough information to the rest of the community to understand the cause and context of the problem. Depending on the issue, you may want to include: 36 | - the [SHA hashcode](https://help.github.com/articles/autolinked-references-and-urls/#commit-shas) of the commit that is causing your problem; 37 | - some identifying information (name and version number) for dependencies you're using; 38 | - information about the operating system; 39 | 1. apply relevant labels to the newly created issue. 40 | 41 | ## You want to make some kind of change to the code base 42 | 43 | 1. (**important**) announce your plan to the rest of the community *before you start working*. This announcement should be in the form of a (new) issue; 44 | 1. (**important**) wait until some kind of consensus is reached about your idea being a good idea; 45 | 1. if needed, fork the repository to your own Github profile and create your own feature branch off of the latest master commit. While working on your feature branch, make sure to stay up to date with the master branch by pulling in changes, possibly from the 'upstream' repository (follow the instructions [here](https://help.github.com/articles/configuring-a-remote-for-a-fork/) and [here](https://help.github.com/articles/syncing-a-fork/)); 46 | 1. make sure the existing tests still work by running ``pytest``; 47 | 1. add your own tests (if necessary); 48 | 1. update or expand the documentation; 49 | 1. update the `docs/CHANGELOG.md` file with change; 50 | 1. push your feature branch to (your fork of) the https://github.com/gradsflow/gradsflow repository on GitHub; 51 | 1. create the pull request, e.g. following the instructions [here](https://help.github.com/articles/creating-a-pull-request/). 52 | 53 | In case you feel like you've made a valuable contribution, but you don't know how to write or run tests for it, or how to generate the documentation: don't let this discourage you from making the pull request; we can help you! Just go ahead and submit the pull request, but keep in mind that you might be asked to append additional commits to your pull request. 54 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include pyproject.toml 2 | include setup.cfg 3 | include LICENSE 4 | include CONTRIBUTING.md 5 | include README.md 6 | recursive-exclude ** __pycache__/ 7 | prune tests/ 8 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | build-docs: 2 | cp README.md docs/index.md 3 | 4 | docsserve: 5 | mkdocs serve --dirtyreload --livereload 6 | 7 | test: 8 | python tests/__init__.py 9 | pytest 10 | 11 | coverage: ## Run tests with coverage 12 | coverage erase 13 | coverage run -m pytest 14 | coverage report -m 15 | coverage xml 16 | 17 | clean: 18 | rm -rf dist 19 | find . -type f -name "*.DS_Store" -ls -delete 20 | find . | grep -E "(__pycache__|\.pyc|\.pyo)" | xargs rm -rf 21 | find . | grep -E ".pytest_cache" | xargs rm -rf 22 | find . | grep -E ".ipynb_checkpoints" | xargs rm -rf 23 | rm -f .coverage 24 | 25 | style: 26 | black . 27 | isort --profile black . 28 | 29 | build: clean 30 | python -m build 31 | 32 | test_pypi: build 33 | twine upload -r testpypi dist/* 34 | 35 | pypi: build 36 | twine upload dist/* 37 | 38 | push: 39 | git push && git push --tags 40 | 41 | install: style clean 42 | flit install --deps none 43 | -------------------------------------------------------------------------------- /docs/404.md: -------------------------------------------------------------------------------- 1 | # Oops! The page you are looking for does not exist. 2 | -------------------------------------------------------------------------------- /docs/CNAME: -------------------------------------------------------------------------------- 1 | docs.gradsflow.com 2 | -------------------------------------------------------------------------------- /docs/examples: -------------------------------------------------------------------------------- 1 | ../examples -------------------------------------------------------------------------------- /docs/gradsflow/autotasks/autotasks.md: -------------------------------------------------------------------------------- 1 | ::: gradsflow.autotasks.autotasks 2 | 3 | --- 4 | 5 | ::: gradsflow.autotasks.autoclassification 6 | 7 | --- 8 | 9 | ::: gradsflow.autotasks.autosummarization 10 | -------------------------------------------------------------------------------- /docs/gradsflow/autotasks/engine.md: -------------------------------------------------------------------------------- 1 | ::: gradsflow.autotasks.engine.backend 2 | 3 | --- 4 | 5 | ::: gradsflow.autotasks.engine.automodel 6 | 7 | --- 8 | 9 | ::: gradsflow.autotasks.engine.autoclassifier 10 | -------------------------------------------------------------------------------- /docs/gradsflow/callbacks.md: -------------------------------------------------------------------------------- 1 | ::: gradsflow.callbacks.wandb.WandbCallback 2 | 3 | --- 4 | 5 | ::: gradsflow.callbacks.gpu.EmissionTrackerCallback 6 | 7 | --- 8 | 9 | ::: gradsflow.callbacks.comet.CometCallback 10 | 11 | --- 12 | -------------------------------------------------------------------------------- /docs/gradsflow/core.md: -------------------------------------------------------------------------------- 1 | Core Building blocks 2 | 3 | ::: gradsflow.core.base.BaseAutoModel 4 | -------------------------------------------------------------------------------- /docs/gradsflow/data.md: -------------------------------------------------------------------------------- 1 | ::: gradsflow.data.ray_dataset 2 | 3 | --- 4 | 5 | ::: gradsflow.data.image 6 | 7 | --- 8 | 9 | ::: gradsflow.data.autodata.AutoDataset 10 | 11 | --- 12 | 13 | ::: gradsflow.data.common 14 | -------------------------------------------------------------------------------- /docs/gradsflow/models/base.md: -------------------------------------------------------------------------------- 1 | ::: gradsflow.models.base.BaseModel 2 | -------------------------------------------------------------------------------- /docs/gradsflow/models/model.md: -------------------------------------------------------------------------------- 1 | ::: gradsflow.models.model.Model 2 | -------------------------------------------------------------------------------- /docs/gradsflow/models/tracker.md: -------------------------------------------------------------------------------- 1 | ::: gradsflow.models.tracker.Tracker 2 | -------------------------------------------------------------------------------- /docs/gradsflow/models/utils.md: -------------------------------------------------------------------------------- 1 | ::: gradsflow.models.utils.available_losses 2 | 3 | --- 4 | 5 | ::: gradsflow.models.utils.available_metrics 6 | -------------------------------------------------------------------------------- /docs/gradsflow/tuner.md: -------------------------------------------------------------------------------- 1 | ::: gradsflow.tuner 2 | -------------------------------------------------------------------------------- /docs/gradsflow/utility.md: -------------------------------------------------------------------------------- 1 | ::: gradsflow.utility 2 | -------------------------------------------------------------------------------- /docs/overrides/main.html: -------------------------------------------------------------------------------- 1 | {% extends "base.html" %} 2 | 3 | {% block extrahead %} 4 | {% set title = config.site_name %} 5 | {% if page and page.meta and page.meta.title %} 6 | {% set title = title ~ " - " ~ page.meta.title %} 7 | {% elif page and page.title and not page.is_homepage %} 8 | {% set title = title ~ " - " ~ page.title | striptags %} 9 | {% endif %} 10 | 11 | 12 | An open-source AutoML Library based on PyTorch 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | {% endblock %} 30 | 31 | {% block outdated %} 32 | You're not viewing the latest version. 33 | 34 | Click here to go to latest. 35 | 36 | {% endblock %} 37 | 38 | 39 | {% set extracopyright %} 40 | | Apache License 2.0 41 | {% endset %} 42 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | # git+https://github.com/gradsflow/gradsflow@main 2 | mkdocs==1.4.2 3 | mkdocstrings==0.20.0 4 | mkdocs-material==8.5.11 5 | mkdocstrings-python==0.8.3 6 | mkdocs-material-extensions==1.1.1 7 | mkdocs-git-revision-date-localized-plugin==1.1.0 8 | mkdocs-macros-plugin==0.7.0 9 | mkdocs-autorefs==0.4.1 10 | mkdocs-jupyter==0.22.0 11 | tags-macros-plugin @ git+https://github.com/jldiaz/mkdocs-plugin-tags.git@d26e2f124e4f3471639d426459e281080988fe7a 12 | mkdocs-meta-descriptions-plugin==2.2.0 13 | jupyter_contrib_nbextensions 14 | comet_ml 15 | lightning-flash[image,text]>=0.5.1 16 | wandb 17 | tensorboard 18 | -------------------------------------------------------------------------------- /examples/nbs/03-TextSummarization.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "xYgiH2TkkX7x", 7 | "pycharm": { 8 | "name": "#%% md\n" 9 | } 10 | }, 11 | "source": [ 12 | "\"Open\n", 13 | "\n", 14 | "`!pip install lightning-flash`\n", 15 | "\n", 16 | "`!pip install -U gradsflow`" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": null, 22 | "metadata": { 23 | "id": "2_posF7Rj8sH", 24 | "pycharm": { 25 | "name": "#%%\n" 26 | } 27 | }, 28 | "outputs": [], 29 | "source": [ 30 | "from flash.core.data.utils import download_data\n", 31 | "from flash.text import SummarizationData, SummarizationTask\n", 32 | "\n", 33 | "# 1. Download the data\n", 34 | "download_data(\"https://pl-flash-data.s3.amazonaws.com/xsum.zip\", \"data/\")\n", 35 | "\n", 36 | "# 2. Load the data\n", 37 | "datamodule = SummarizationData.from_csv(\n", 38 | " \"input\",\n", 39 | " \"target\",\n", 40 | " train_file=\"data/xsum/train.csv\",\n", 41 | " val_file=\"data/xsum/valid.csv\",\n", 42 | " test_file=\"data/xsum/test.csv\",\n", 43 | " batch_size=4,\n", 44 | ")" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": null, 50 | "metadata": { 51 | "id": "YCYFfKhDkVVK", 52 | "pycharm": { 53 | "name": "#%%\n" 54 | } 55 | }, 56 | "outputs": [], 57 | "source": [ 58 | "from gradsflow import AutoSummarization\n", 59 | "\n", 60 | "suggested_conf = dict(\n", 61 | " optimizers=[\"adam\"],\n", 62 | " lr=(5e-4, 1e-3),\n", 63 | ")\n", 64 | "\n", 65 | "model = AutoSummarization(\n", 66 | " datamodule,\n", 67 | " suggested_backbones=\"sshleifer/distilbart-cnn-12-6\",\n", 68 | " suggested_conf=suggested_conf,\n", 69 | " max_epochs=1,\n", 70 | " optimization_metric=\"train_loss\",\n", 71 | " timeout=600,\n", 72 | ")\n", 73 | "\n", 74 | "print(\"AutoSummarization initialised!\")\n", 75 | "model.hp_tune()" 76 | ] 77 | } 78 | ], 79 | "metadata": { 80 | "colab": { 81 | "name": "03-TextSummarization.ipynb", 82 | "provenance": [] 83 | }, 84 | "kernelspec": { 85 | "display_name": "Python 3", 86 | "language": "python", 87 | "name": "python3" 88 | }, 89 | "language_info": { 90 | "codemirror_mode": { 91 | "name": "ipython", 92 | "version": 3 93 | }, 94 | "file_extension": ".py", 95 | "mimetype": "text/x-python", 96 | "name": "python", 97 | "nbconvert_exporter": "python", 98 | "pygments_lexer": "ipython3", 99 | "version": "3.7.10" 100 | } 101 | }, 102 | "nbformat": 4, 103 | "nbformat_minor": 4 104 | } 105 | -------------------------------------------------------------------------------- /examples/nbs/05-model_fit.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "0", 6 | "metadata": {}, 7 | "source": [ 8 | "\"Open" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": null, 14 | "id": "1", 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "import torchvision\n", 19 | "from timm import create_model\n", 20 | "from torch.utils.data import DataLoader\n", 21 | "from torchvision import transforms as T\n", 22 | "\n", 23 | "from gradsflow import AutoDataset, Model\n", 24 | "from gradsflow.callbacks import (\n", 25 | " CSVLogger,\n", 26 | " EmissionTrackerCallback,\n", 27 | " ModelCheckpoint,\n", 28 | ")\n", 29 | "from gradsflow.data.common import random_split_dataset" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": null, 35 | "id": "2", 36 | "metadata": {}, 37 | "outputs": [ 38 | { 39 | "name": "stdout", 40 | "output_type": "stream", 41 | "text": [ 42 | "Files already downloaded and verified\n" 43 | ] 44 | } 45 | ], 46 | "source": [ 47 | "# Replace dataloaders with your custom dataset and you are all set to train your model\n", 48 | "image_size = (64, 64)\n", 49 | "batch_size = 4\n", 50 | "\n", 51 | "to_rgb = lambda x: x.convert(\"RGB\")\n", 52 | "\n", 53 | "augs = T.Compose([to_rgb, T.AutoAugment(), T.Resize(image_size), T.ToTensor()])\n", 54 | "data = torchvision.datasets.Caltech101(\"~/\", download=True, transform=augs)\n", 55 | "train_data, val_data = random_split_dataset(data, 0.99)\n", 56 | "train_dl = DataLoader(train_data, batch_size=batch_size)\n", 57 | "val_dl = DataLoader(val_data, batch_size=batch_size)\n", 58 | "num_classes = len(data.categories)" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": null, 64 | "id": "3", 65 | "metadata": {}, 66 | "outputs": [ 67 | { 68 | "name": "stderr", 69 | "output_type": "stream", 70 | "text": [ 71 | "CODECARBON : No CPU tracking mode found. Falling back on CPU constant mode.\n", 72 | "/Users/aniket/miniconda3/envs/am/lib/python3.9/site-packages/apscheduler/util.py:95: PytzUsageWarning: The zone attribute is specific to pytz's interface; please migrate to a new time zone provider. For more details on how to do so, see https://pytz-deprecation-shim.readthedocs.io/en/latest/migration.html\n", 73 | " if obj.zone == 'local':\n", 74 | "/Users/aniket/miniconda3/envs/am/lib/python3.9/site-packages/apscheduler/triggers/interval.py:66: PytzUsageWarning: The normalize method is no longer necessary, as this time zone supports the fold attribute (PEP 495). For more details on migrating to a PEP 495-compliant implementation, see https://pytz-deprecation-shim.readthedocs.io/en/latest/migration.html\n", 75 | " return self.timezone.normalize(next_fire_time)\n" 76 | ] 77 | } 78 | ], 79 | "source": [ 80 | "cbs = [\n", 81 | " CSVLogger(\n", 82 | " verbose=True,\n", 83 | " ),\n", 84 | " ModelCheckpoint(),\n", 85 | " EmissionTrackerCallback(),\n", 86 | " # CometCallback(offline=True),\n", 87 | "]" 88 | ] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": null, 93 | "id": "4", 94 | "metadata": {}, 95 | "outputs": [ 96 | { 97 | "data": { 98 | "application/vnd.jupyter.widget-view+json": { 99 | "model_id": "e4eb63f7dc584cfcb22d92f15c96bbf3", 100 | "version_major": 2, 101 | "version_minor": 0 102 | }, 103 | "text/plain": [ 104 | "Output()" 105 | ] 106 | }, 107 | "metadata": {}, 108 | "output_type": "display_data" 109 | } 110 | ], 111 | "source": [ 112 | "autodataset = AutoDataset(train_dl, val_dl, num_classes=num_classes)\n", 113 | "cnn = create_model(\"resnet18\", pretrained=False, num_classes=num_classes)\n", 114 | "\n", 115 | "model = Model(cnn)\n", 116 | "\n", 117 | "model.compile(\"crossentropyloss\", \"adam\", metrics=[\"accuracy\"])\n", 118 | "model.fit(autodataset, max_epochs=10, steps_per_epoch=10, callbacks=cbs)" 119 | ] 120 | } 121 | ], 122 | "metadata": { 123 | "interpreter": { 124 | "hash": "e2d961b663a5ae03743cd178a74853be9b21def56a249d21ac1502fcfb05a9ce" 125 | }, 126 | "kernelspec": { 127 | "display_name": "Python 3 (ipykernel)", 128 | "language": "python", 129 | "name": "python3" 130 | }, 131 | "language_info": { 132 | "codemirror_mode": { 133 | "name": "ipython", 134 | "version": 3 135 | }, 136 | "file_extension": ".py", 137 | "mimetype": "text/x-python", 138 | "name": "python", 139 | "nbconvert_exporter": "python", 140 | "pygments_lexer": "ipython3", 141 | "version": "3.9.9" 142 | }, 143 | "toc": { 144 | "base_numbering": 1, 145 | "nav_menu": {}, 146 | "number_sections": true, 147 | "sideBar": true, 148 | "skip_h1_title": false, 149 | "title_cell": "Table of Contents", 150 | "title_sidebar": "Contents", 151 | "toc_cell": false, 152 | "toc_position": {}, 153 | "toc_section_display": true, 154 | "toc_window_display": false 155 | } 156 | }, 157 | "nbformat": 4, 158 | "nbformat_minor": 5 159 | } 160 | -------------------------------------------------------------------------------- /examples/src/models/hello_world.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 GradsFlow. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # Source code inspired from https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html 15 | import matplotlib.pyplot as plt 16 | import numpy as np 17 | import torch 18 | import torch.nn.functional as F 19 | import torch.optim as optim 20 | import torchvision 21 | import torchvision.transforms as transforms 22 | from torch import nn 23 | from torch.utils.data import DataLoader 24 | from torchmetrics.classification import MulticlassAccuracy 25 | 26 | from gradsflow import AutoDataset, Model 27 | from gradsflow.callbacks import CSVLogger, ModelCheckpoint 28 | 29 | # Replace dataloaders with your custom dataset, and you are all set to train your model 30 | image_size = (64, 64) 31 | batch_size = 4 32 | 33 | transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 34 | 35 | trainset = torchvision.datasets.CIFAR10(root="~/data", train=True, download=True, transform=transform) 36 | train_dl = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2) 37 | 38 | testset = torchvision.datasets.CIFAR10(root="~/data", train=False, download=True, transform=transform) 39 | val_dl = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2) 40 | num_classes = len(trainset.classes) 41 | cbs = [ 42 | CSVLogger( 43 | verbose=True, 44 | ), 45 | ModelCheckpoint(), 46 | # EmissionTrackerCallback(), 47 | # CometCallback(offline=True), 48 | # WandbCallback(), 49 | ] 50 | 51 | 52 | def imshow(img): 53 | img = img / 2 + 0.5 # unnormalize 54 | npimg = img.numpy() 55 | plt.imshow(np.transpose(npimg, (1, 2, 0))) 56 | plt.show() 57 | 58 | 59 | class Net(nn.Module): 60 | def __init__(self): 61 | super().__init__() 62 | self.conv1 = nn.Conv2d(3, 6, 5) 63 | self.pool = nn.MaxPool2d(2, 2) 64 | self.conv2 = nn.Conv2d(6, 16, 5) 65 | self.fc1 = nn.Linear(16 * 5 * 5, 120) 66 | self.fc2 = nn.Linear(120, 84) 67 | self.fc3 = nn.Linear(84, 10) 68 | 69 | def forward(self, x): 70 | x = self.pool(F.relu(self.conv1(x))) 71 | x = self.pool(F.relu(self.conv2(x))) 72 | x = torch.flatten(x, 1) # flatten all dimensions except batch 73 | x = F.relu(self.fc1(x)) 74 | x = F.relu(self.fc2(x)) 75 | x = self.fc3(x) 76 | return x 77 | 78 | 79 | if __name__ == "__main__": 80 | autodataset = AutoDataset(train_dl, val_dl, num_classes=num_classes) 81 | net = Net() 82 | model = Model(net) 83 | criterion = nn.CrossEntropyLoss() 84 | 85 | model.compile( 86 | criterion, 87 | optim.SGD, 88 | optimizer_config={"momentum": 0.9}, 89 | learning_rate=0.001, 90 | metrics=[MulticlassAccuracy(autodataset.num_classes)], 91 | ) 92 | model.fit(autodataset, max_epochs=2, callbacks=cbs) 93 | 94 | dataiter = iter(val_dl) 95 | images, labels = next(dataiter) 96 | 97 | # print images 98 | # imshow(torchvision.utils.make_grid(images)) 99 | print("GroundTruth: ", " ".join(f"{trainset.classes[labels[j]]:5s}" for j in range(4))) 100 | 101 | outputs = net(images) 102 | _, predicted = torch.max(outputs, 1) 103 | 104 | print("Predicted: ", " ".join(f"{trainset.classes[predicted[j]]:5s}" for j in range(4))) 105 | -------------------------------------------------------------------------------- /examples/src/tasks/image_classifier.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 GradsFlow. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torchvision 16 | from torch.utils.data import DataLoader 17 | from torchvision import transforms as T 18 | 19 | from gradsflow import AutoImageClassifier 20 | from gradsflow.data.common import random_split_dataset 21 | 22 | # Replace dataloaders with your custom dataset and you are all set to train your model 23 | image_size = (64, 64) 24 | batch_size = 4 25 | 26 | to_rgb = lambda x: x.convert("RGB") 27 | 28 | # TODO: Add argument parser 29 | if __name__ == "__main__": 30 | augs = T.Compose([to_rgb, T.AutoAugment(), T.Resize(image_size), T.ToTensor()]) 31 | data = torchvision.datasets.CIFAR10("~/data", download=True, transform=augs) 32 | train_data, val_data = random_split_dataset(data, 0.01) 33 | train_dl = DataLoader(train_data, batch_size=batch_size) 34 | val_dl = DataLoader(val_data, batch_size=batch_size) 35 | 36 | num_classes = len(data.classes) 37 | 38 | model = AutoImageClassifier( 39 | train_dataloader=train_dl, 40 | val_dataloader=val_dl, 41 | num_classes=num_classes, 42 | max_epochs=5, 43 | optimization_metric="train_loss", 44 | max_steps=1, 45 | n_trials=1, 46 | ) 47 | print("AutoImageClassifier initialised!") 48 | 49 | model.hp_tune() 50 | -------------------------------------------------------------------------------- /examples/src/tasks/text_classification.py: -------------------------------------------------------------------------------- 1 | from flash.core.data.utils import download_data 2 | from flash.text import TextClassificationData 3 | 4 | from gradsflow import AutoTextClassifier 5 | 6 | download_data("https://pl-flash-data.s3.amazonaws.com/imdb.zip", "./data/") 7 | 8 | print("Creating datamodule...") 9 | datamodule = TextClassificationData.from_csv( 10 | "review", "sentiment", train_file="data/imdb/train.csv", val_file="data/imdb/valid.csv", batch_size=4 11 | ) 12 | 13 | suggested_conf = dict( 14 | optimizer=["adam", "adamw", "sgd"], 15 | lr=(5e-4, 1e-3), 16 | ) 17 | 18 | model = AutoTextClassifier( 19 | datamodule, 20 | suggested_backbones=["prajjwal1/bert-tiny"], 21 | suggested_conf=suggested_conf, 22 | max_epochs=1, 23 | optimization_metric="val_accuracy", 24 | n_trials=1, 25 | ) 26 | 27 | print("AutoTextClassifier initialised!") 28 | model.hp_tune(finetune=True) 29 | -------------------------------------------------------------------------------- /gradsflow/__init__.py: -------------------------------------------------------------------------------- 1 | """An open-source AutoML Library based on PyTorch""" 2 | 3 | # Copyright (c) 2021 GradsFlow. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import logging 18 | from os import environ as _environ 19 | 20 | from gradsflow.autotasks.autoclassification.image import AutoImageClassifier 21 | from gradsflow.autotasks.autoclassification.text import AutoTextClassifier 22 | from gradsflow.autotasks.autosummarization import AutoSummarization 23 | from gradsflow.autotasks.autotasks import autotask, available_tasks 24 | from gradsflow.autotasks.engine.automodel import AutoModel 25 | from gradsflow.data import AutoDataset 26 | from gradsflow.models.model import Model 27 | from gradsflow.tuner.automodel import AutoModelV2 28 | from gradsflow.tuner.tuner import Tuner 29 | 30 | __version__ = "0.0.8.post1" 31 | logging.basicConfig(level=_environ.get("LOG_LEVEL", "WARNING")) 32 | -------------------------------------------------------------------------------- /gradsflow/autotasks/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 GradsFlow. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .autoclassification.image import AutoImageClassifier 16 | from .autoclassification.text import AutoTextClassifier 17 | from .autosummarization import AutoSummarization 18 | from .autotasks import autotask, available_tasks 19 | -------------------------------------------------------------------------------- /gradsflow/autotasks/autoclassification/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 GradsFlow. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /gradsflow/autotasks/autoclassification/image.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 GradsFlow. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import timm 16 | 17 | from gradsflow.autotasks.engine.autoclassifier import AutoClassifier 18 | from gradsflow.autotasks.engine.backend import BackendType 19 | from gradsflow.models.model import Model 20 | 21 | 22 | # noinspection PyTypeChecker 23 | class AutoImageClassifier(AutoClassifier): 24 | """ 25 | Automatically find Image Classification Model 26 | 27 | Args: 28 | datamodule Optional[DataModule]: PL Lightning DataModule with `num_classes` property. 29 | train_dataloader Optional[DataLoader]: torch dataloader 30 | val_dataloader Optional[DataLoader]: torch dataloader 31 | num_classes Optional[int]: number of classes 32 | max_epochs [int]: default=10. 33 | n_trials [int]: default=100. 34 | optimization_metric [Optional[str]]: defaults None 35 | suggested_backbones Union[List, str, None]: defaults None 36 | suggested_conf [Optional[dict] = None]: This sets Trial suggestions for optimizer, 37 | learning rate, and all the hyperparameters. 38 | timeout [int]: Hyperparameter search will stop after timeout. 39 | backend_type Optional[str]: Training loop code. Defaults to None. 40 | 41 | Examples: 42 | ```python 43 | from flash.core.data.utils import download_data 44 | from flash.image import ImageClassificationData 45 | 46 | from gradsflow import AutoImageClassifier 47 | 48 | # 1. Create the DataModule 49 | download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "./data") 50 | 51 | datamodule = ImageClassificationData.from_folders( 52 | train_folder="data/hymenoptera_data/train/", 53 | val_folder="data/hymenoptera_data/val/", 54 | ) 55 | 56 | model = AutoImageClassifier(datamodule, 57 | max_epochs=10, 58 | optimization_metric="val_accuracy", 59 | timeout=300) 60 | model.hp_tune() 61 | ``` 62 | """ 63 | 64 | _DEFAULT_BACKBONES = ["ssl_resnet18", "ssl_resnet50"] 65 | 66 | def __init__(self, *args, **kwargs): 67 | super().__init__(*args, **kwargs, backend=BackendType.gf.value) 68 | 69 | def build_model(self, config: dict) -> Model: 70 | """Build ImageClassifier model from `ray.tune` hyperparameter configs 71 | or via _search_space dictionary arguments. 72 | 73 | Arguments: 74 | backbone [str]: Image classification backbone name - resnet18, resnet50,... 75 | (Check Lightning-Flash for full model list) 76 | 77 | optimizer [str]: PyTorch Optimizers. Check `AutoImageClassification._OPTIMIZER_INDEX` 78 | learning_rate [float]: Learning rate for the model. 79 | """ 80 | backbone = config["backbone"] 81 | 82 | cnn = timm.create_model(backbone, pretrained=True, num_classes=self.num_classes) 83 | model = Model(cnn) 84 | model.compile( 85 | loss="crossentropyloss", optimizer=config["optimizer"], learning_rate=config["lr"], metrics="accuracy" 86 | ) 87 | return model 88 | -------------------------------------------------------------------------------- /gradsflow/autotasks/autoclassification/text/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 GradsFlow. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from .text import AutoTextClassifier 15 | -------------------------------------------------------------------------------- /gradsflow/autotasks/autoclassification/text/text.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 GradsFlow. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import logging 17 | 18 | import torch 19 | 20 | from gradsflow.autotasks.engine.autoclassifier import AutoClassifier 21 | 22 | logger = logging.getLogger(__name__) 23 | 24 | 25 | # noinspection PyTypeChecker 26 | class AutoTextClassifier(AutoClassifier): 27 | """ 28 | Automatically find Text Classification Model 29 | 30 | Examples: 31 | ```python 32 | from gradsflow import AutoTextClassifier 33 | 34 | from flash.core.data.utils import download_data 35 | from flash.text import TextClassificationData 36 | 37 | download_data("https://pl-flash-data.s3.amazonaws.com/imdb.zip", "./data/") 38 | datamodule = TextClassificationData.from_csv( 39 | "review", 40 | "sentiment", 41 | train_file="data/imdb/train.csv", 42 | val_file="data/imdb/valid.csv", 43 | batch_size=4, 44 | ) 45 | 46 | model = AutoTextClassifier(datamodule, 47 | suggested_backbones=['sgugger/tiny-distilbert-classification'], 48 | max_epochs=10, 49 | optimization_metric="val_accuracy", 50 | timeout=300) 51 | model.hp_tune() 52 | ``` 53 | 54 | Arguments: 55 | datamodule Optional[DataModule]: PL Lightning DataModule with `num_classes` property. 56 | train_dataloader Optional[DataLoader]: torch dataloader 57 | val_dataloader Optional[DataLoader]: torch dataloader 58 | num_classes Optional[int]: number of classes 59 | max_epochs [int]: default=10. 60 | n_trials [int]: default=100. 61 | optimization_metric [Optional[str]]: defaults None 62 | suggested_backbones Union[List, str, None]: defaults None 63 | suggested_conf [Optional[dict] = None]: This sets Trial suggestions for optimizer, 64 | learning rate, and all the hyperparameters. 65 | timeout [int]: Hyperparameter search will stop after timeout. 66 | """ 67 | 68 | _DEFAULT_BACKBONES = [ 69 | "distilbert-base-uncased-finetuned-sst-2-english", 70 | "sgugger/tiny-distilbert-classification", 71 | ] 72 | 73 | def __init__(self, *args, max_steps=-1, **kwargs): 74 | super().__init__(*args, max_steps=max_steps, **kwargs) 75 | meta = self.auto_dataset.meta 76 | self.num_classes = meta.get("num_labels") or meta.get("num_classes") 77 | logger.debug(f"num_classes = {self.num_classes}") 78 | 79 | def build_model(self, config: dict) -> torch.nn.Module: 80 | """Build TextClassifier model from `ray.tune` hyperparameter configs 81 | or via _search_space dictionary arguments 82 | 83 | Arguments: 84 | backbone [str]: Image classification backbone name - resnet18, resnet50,... 85 | (Check Lightning-Flash for full model list) 86 | 87 | optimizer [str]: PyTorch Optimizers. Check `AutoImageClassification._OPTIMIZER_INDEX` 88 | learning_rate [float]: Learning rate for the model. 89 | """ 90 | from flash.text.classification import TextClassifier 91 | from torchmetrics import Accuracy 92 | 93 | backbone = config["backbone"] 94 | optimizer = config["optimizer"] 95 | learning_rate = config["lr"] 96 | 97 | return TextClassifier( 98 | self.num_classes, 99 | backbone=backbone, 100 | optimizer=self._OPTIMIZER_INDEX[optimizer], 101 | learning_rate=learning_rate, 102 | metrics=Accuracy(num_classes=self.num_classes), 103 | ) 104 | -------------------------------------------------------------------------------- /gradsflow/autotasks/autosummarization.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 GradsFlow. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | 17 | from gradsflow.autotasks.engine.autoclassifier import AutoClassifier 18 | 19 | 20 | # noinspection PyTypeChecker 21 | class AutoSummarization(AutoClassifier): 22 | """ 23 | Automatically finds Text Summarization Model 24 | 25 | Args: 26 | datamodule Optional[DataModule]: PL Lightning DataModule with `num_classes` property. 27 | train_dataloader Optional[DataLoader]: torch dataloader 28 | val_dataloader Optional[DataLoader]: torch dataloader 29 | num_classes Optional[int]: number of classes 30 | max_epochs [int]: default=10. 31 | n_trials [int]: default=100. 32 | optimization_metric [Optional[str]]: defaults None 33 | suggested_backbones Union[List, str, None]: defaults None 34 | suggested_conf [Optional[dict] = None]: This sets Trial suggestions for optimizer, 35 | learning rate, and all the hyperparameters. 36 | timeout [int]: Hyperparameter search will stop after timeout. 37 | 38 | Examples: 39 | ```python 40 | from gradsflow import AutoSummarization 41 | 42 | from flash.core.data.utils import download_data 43 | from flash.text import SummarizationData, SummarizationTask 44 | 45 | # 1. Download the data 46 | download_data("https://pl-flash-data.s3.amazonaws.com/xsum.zip", "data/") 47 | # 2. Load the data 48 | datamodule = SummarizationData.from_csv( 49 | "input", 50 | "target", 51 | train_file="data/xsum/train.csv", 52 | val_file="data/xsum/valid.csv", 53 | test_file="data/xsum/test.csv", 54 | ) 55 | 56 | model = AutoSummarization(datamodule, 57 | max_epochs=10, 58 | optimization_metric="val_accuracy", 59 | timeout=300) 60 | model.hp_tune() 61 | ``` 62 | """ 63 | 64 | _DEFAULT_BACKBONES = [ 65 | "sshleifer/distilbart-cnn-12-6", 66 | "sshleifer/distilbart-xsum-12-3", 67 | ] 68 | 69 | def build_model(self, config: dict) -> torch.nn.Module: 70 | """Build SummarizationModel from `ray.tune` hyperparameter configs 71 | or via _search_space dictionary arguments 72 | 73 | Arguments: 74 | backbone [str]: Image classification backbone name - 75 | sshleifer/distilbart-cnn-12-6, sshleifer/distilbart-xsum-12-3,... 76 | (Check Lightning-Flash for full model list) 77 | 78 | optimizer [str]: PyTorch Optimizers. Check `AutoImageClassification._OPTIMIZER_INDEX` 79 | learning_rate [float]: Learning rate for the model. 80 | """ 81 | from flash.text.seq2seq import SummarizationTask 82 | 83 | backbone = config["backbone"] 84 | optimizer = config["optimizer"] 85 | learning_rate = config["lr"] 86 | 87 | return SummarizationTask( 88 | backbone=backbone, 89 | optimizer=self._OPTIMIZER_INDEX[optimizer], 90 | learning_rate=learning_rate, 91 | ) 92 | -------------------------------------------------------------------------------- /gradsflow/autotasks/autotasks.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 GradsFlow. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import List, Optional, Union 16 | 17 | from torch.utils.data import DataLoader 18 | 19 | from gradsflow.utility.imports import is_installed 20 | 21 | from .autoclassification.image import AutoImageClassifier 22 | from .autoclassification.text import AutoTextClassifier 23 | from .autosummarization import AutoSummarization 24 | 25 | if is_installed("pytorch_lightning"): 26 | import pytorch_lightning as pl 27 | 28 | SUPPORTED_TASKS = { 29 | "image-classification": AutoImageClassifier, 30 | "text-classification": AutoTextClassifier, 31 | "summarization": AutoSummarization, 32 | } 33 | 34 | 35 | def available_tasks() -> List[str]: 36 | """Get a list of all available autotasks.""" 37 | return list(SUPPORTED_TASKS.keys()) 38 | 39 | 40 | def autotask( 41 | datamodule: Optional["pl.LightningDataModule"] = None, 42 | train_dataloader: Optional[DataLoader] = None, 43 | val_dataloader: Optional[DataLoader] = None, 44 | num_classes: Optional[int] = None, 45 | task: Optional[str] = None, 46 | data_type: Optional[str] = None, 47 | max_epochs: int = 10, 48 | max_steps: int = 10, 49 | n_trials: int = 100, 50 | optimization_metric: Optional[str] = None, 51 | suggested_backbones: Union[List, str, None] = None, 52 | suggested_conf: Optional[dict] = None, 53 | timeout: int = 600, 54 | prune: bool = True, 55 | ): 56 | """ 57 | 58 | Args: 59 | datamodule Optional[DataModule]: PL Lightning DataModule with `num_classes` property. 60 | train_dataloader Optional[DataLoader]: torch dataloader 61 | val_dataloader Optional[DataLoader]: torch dataloader 62 | num_classes Optional[int]: number of classes 63 | task Optional[str]: type of task. Check available autotasks `availalbe_tasks() 64 | data_type Optional[str]: default=None. type of data - image, text or infer. 65 | max_epochs [int]: default=10. 66 | n_trials [int]: default=100. 67 | optimization_metric [Optional[str]]: defaults None 68 | suggested_backbones Union[List, str, None]: defaults None 69 | suggested_conf [Optional[dict] = None]: This sets Trial suggestions for optimizer, 70 | learning rate, and all the hyperparameters. 71 | timeout [int]: Hyperparameter search will stop after timeout. 72 | 73 | Returns: 74 | Implementation of `AutoModel` for the task type. 75 | """ 76 | if not (task or data_type): 77 | raise UserWarning("either task or data_type must be set!") 78 | 79 | if task not in SUPPORTED_TASKS: 80 | raise UserWarning(f"Unknown task {task}, available autotasks are {list(SUPPORTED_TASKS.keys())}") 81 | 82 | targeted_task = SUPPORTED_TASKS[task] 83 | 84 | return targeted_task( 85 | datamodule=datamodule, 86 | train_dataloader=train_dataloader, 87 | val_dataloader=val_dataloader, 88 | num_classes=num_classes, 89 | max_epochs=max_epochs, 90 | max_steps=max_steps, 91 | n_trials=n_trials, 92 | optimization_metric=optimization_metric, 93 | suggested_backbones=suggested_backbones, 94 | suggested_conf=suggested_conf, 95 | timeout=timeout, 96 | prune=prune, 97 | ) 98 | -------------------------------------------------------------------------------- /gradsflow/autotasks/engine/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 GradsFlow. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /gradsflow/autotasks/engine/autoclassifier.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 GradsFlow. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from abc import abstractmethod 16 | from typing import Dict, List, Optional, Union 17 | 18 | import torch 19 | from ray import tune 20 | from torch.utils.data import DataLoader 21 | 22 | from gradsflow.autotasks.engine.automodel import AutoModel 23 | from gradsflow.utility.common import listify 24 | from gradsflow.utility.imports import is_installed 25 | 26 | pl = None 27 | if is_installed("pytorch_lightning"): 28 | import pytorch_lightning as pl 29 | 30 | 31 | class AutoClassifier(AutoModel): 32 | """Implements `AutoModel` for classification autotasks.""" 33 | 34 | _DEFAULT_BACKBONES = [] 35 | 36 | def __init__( 37 | self, 38 | datamodule: Optional["pl.LightningDataModule"] = None, 39 | train_dataloader: Optional[DataLoader] = None, 40 | val_dataloader: Optional[DataLoader] = None, 41 | num_classes: Optional[int] = None, 42 | max_epochs: int = 10, 43 | max_steps: int = 10, 44 | n_trials: int = 100, 45 | optimization_metric: Optional[str] = None, 46 | suggested_backbones: Union[List, str, None] = None, 47 | suggested_conf: Optional[dict] = None, 48 | timeout: int = 600, 49 | prune: bool = True, 50 | backend: Optional[str] = None, 51 | ): 52 | super().__init__( 53 | datamodule=datamodule, 54 | train_dataloader=train_dataloader, 55 | val_dataloader=val_dataloader, 56 | num_classes=num_classes, 57 | max_epochs=max_epochs, 58 | max_steps=max_steps, 59 | optimization_metric=optimization_metric, 60 | n_trials=n_trials, 61 | suggested_conf=suggested_conf, 62 | timeout=timeout, 63 | prune=prune, 64 | backend=backend, 65 | ) 66 | 67 | if isinstance(suggested_backbones, (str, list, tuple)): 68 | self.suggested_backbones = listify(suggested_backbones) 69 | elif suggested_backbones is None: 70 | self.suggested_backbones = self._DEFAULT_BACKBONES 71 | else: 72 | raise UserWarning("Invalid suggested_backbone type!") 73 | 74 | self.num_classes = num_classes 75 | 76 | def forward(self, x): 77 | if not self.model: 78 | raise UserWarning("model not initialized yet, run `hp_tune()` first.") 79 | return self.model(x) 80 | 81 | # noinspection PyTypeChecker 82 | def _create_search_space(self) -> Dict[str, str]: 83 | """Create hyperparameter config from `ray.tune` 84 | 85 | Returns: 86 | key-value pair of `ray.tune` _search_space 87 | """ 88 | trial_backbone = tune.choice(self.suggested_backbones) 89 | trial_lr = tune.loguniform(*self.suggested_lr) 90 | trial_optimizer = tune.choice(self.suggested_optimizers) 91 | hparams = { 92 | "backbone": trial_backbone, 93 | "lr": trial_lr, 94 | "optimizer": trial_optimizer, 95 | } 96 | return hparams 97 | 98 | @abstractmethod 99 | def build_model(self, config: dict) -> torch.nn.Module: 100 | """Every Task implementing AutoClassifier has to implement a 101 | build model method that can build `torch.nn.Module` from dictionary config 102 | and return the model. 103 | """ 104 | -------------------------------------------------------------------------------- /gradsflow/autotasks/engine/backend.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 GradsFlow. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import logging 16 | import typing 17 | from enum import Enum 18 | from typing import Callable, Dict, Optional 19 | 20 | import torch 21 | 22 | from gradsflow.callbacks import report_checkpoint_callback 23 | from gradsflow.data import AutoDataset 24 | from gradsflow.utility.common import module_to_cls_index 25 | from gradsflow.utility.imports import is_installed 26 | 27 | if typing.TYPE_CHECKING: 28 | import lightning as L 29 | 30 | if is_installed("lightning-flash"): 31 | from flash import Task 32 | from flash import Trainer as FlashTrainer 33 | from lightning import Trainer as PLTrainer 34 | else: 35 | FlashTrainer = None 36 | PLTrainer = None 37 | 38 | logger = logging.getLogger("core.backend") 39 | 40 | 41 | class BackendType(Enum): 42 | # Remove torch 43 | lightning = "lightning" 44 | gf = "gf" 45 | torch = "gf" 46 | default = "lightning" 47 | 48 | 49 | class Backend: 50 | _OPTIMIZER_INDEX = module_to_cls_index(torch.optim, True) 51 | 52 | def __init__( 53 | self, 54 | autodataset: AutoDataset, 55 | model_builder: Callable, 56 | optimization_metric: Optional[str], 57 | max_epochs: int = 10, 58 | max_steps: Optional[int] = None, 59 | backend: Optional[str] = None, 60 | ): 61 | self.model_builder = model_builder 62 | self.backend_type = (backend or BackendType.default.value).lower() 63 | self.autodataset = autodataset 64 | self.optimization_metric = optimization_metric 65 | self.max_epochs = max_epochs 66 | self.max_steps = max_steps 67 | 68 | def _gf_objective(self, search_space: Dict, trainer_config: Dict, **_): 69 | autodataset = self.autodataset 70 | model = self.model_builder(search_space) 71 | tracker = model.fit( 72 | autodataset=autodataset, 73 | steps_per_epoch=self.max_steps, 74 | callbacks=trainer_config.get("callback_runner", ("tune_checkpoint", "tune_report")), 75 | show_progress=False, 76 | **trainer_config, 77 | ) 78 | return tracker 79 | 80 | # noinspection PyTypeChecker 81 | def _lightning_objective( 82 | self, config: Dict, trainer_config: Dict, gpu: Optional[float] = 0, finetune: bool = False 83 | ): 84 | val_check_interval = 1.0 85 | if self.max_steps: 86 | val_check_interval = max(self.max_steps - 1, 1.0) 87 | 88 | datamodule = self.autodataset.datamodule 89 | model = self.model_builder(config) 90 | 91 | trainer_cls = FlashTrainer if isinstance(model, Task) else PLTrainer 92 | 93 | trainer: "L.Trainer" = trainer_cls( 94 | logger=True, 95 | accelerator="auto", 96 | devices="auto", 97 | max_epochs=self.max_epochs, 98 | max_steps=self.max_steps, 99 | callbacks=[report_checkpoint_callback()], 100 | val_check_interval=val_check_interval, 101 | **trainer_config, 102 | ) 103 | 104 | hparams = dict(model=model.hparams) 105 | trainer.logger.log_hyperparams(hparams) 106 | if finetune: 107 | trainer.finetune(model, datamodule=datamodule) 108 | else: 109 | trainer.fit(model, datamodule=datamodule) 110 | 111 | logger.debug(trainer.callback_metrics) 112 | return trainer.callback_metrics[self.optimization_metric].item() 113 | 114 | def optimization_objective( 115 | self, config: dict, trainer_config: dict, finetune: bool = False, gpu: Optional[float] = 0.0 116 | ): 117 | """ 118 | Defines lightning_objective function which is used by tuner to minimize/maximize the metric. 119 | 120 | Args: 121 | config dict: key value pair of hyperparameters. 122 | trainer_config dict: configurations passed directly to Lightning Trainer. 123 | gpu Optional[float]: GPU per trial 124 | """ 125 | if self.backend_type == BackendType.lightning.value: 126 | return self._lightning_objective(config, trainer_config=trainer_config, gpu=gpu, finetune=finetune) 127 | 128 | if self.backend_type in (BackendType.gf.value,): 129 | return self._gf_objective(config, trainer_config=trainer_config, gpu=gpu, finetune=finetune) 130 | 131 | raise NotImplementedError(f"Trainer not implemented for backend_type: {self.backend_type}") 132 | -------------------------------------------------------------------------------- /gradsflow/callbacks/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 GradsFlow. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from .comet import CometCallback 15 | from .gpu import EmissionTrackerCallback 16 | from .logger import CSVLogger 17 | from .progress import ProgressCallback 18 | from .raytune import report_checkpoint_callback 19 | from .runner import CallbackRunner 20 | from .tensorboard import TensorboardCallback 21 | from .training import ModelCheckpoint, TrainEvalCallback 22 | from .wandb import WandbCallback 23 | -------------------------------------------------------------------------------- /gradsflow/callbacks/base.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 GradsFlow. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import typing 16 | from abc import ABC 17 | from typing import Callable, Optional 18 | 19 | if typing.TYPE_CHECKING: 20 | from gradsflow.models.model import Model 21 | 22 | 23 | def dummy(x=None, **__): 24 | return x 25 | 26 | 27 | # TODO: set self.MODE in each callback stage train | val 28 | class Callback(ABC): 29 | """Callback objects define events on which it will run during the model training cycle.""" 30 | 31 | _events = ("forward", "step", "train_epoch", "val_epoch", "epoch", "fit") 32 | _name: str = "Callback" 33 | 34 | def __init__(self, model: Optional["Model"] = None): 35 | self.model = model 36 | 37 | @property 38 | def name(self) -> str: 39 | return self._name 40 | 41 | def with_event(self, event_type: str, func: Callable, exception, final_fn: Callable = dummy): 42 | """Calls a function with event wrapped around. Inspired from FastAI. 43 | Ref: https://github.com/fastai/fastai/blob/6e44b354f4d12bdfa2c9530f38f851c54a05764d/fastai/learner.py#L162 44 | """ 45 | assert event_type in self._events, f"event_type is {event_type} but should be {self._events}" 46 | start_event = f"on_{event_type}_start" 47 | end_event = f"on_{event_type}_end" 48 | cancel_event = f"on_{event_type}_cancel" 49 | try: 50 | getattr(self, start_event)() 51 | func() 52 | except exception: 53 | getattr(self, cancel_event)() 54 | getattr(self, end_event)() 55 | final_fn() 56 | 57 | def on_fit_start(self): 58 | """Called on each `model.fit(...)`""" 59 | 60 | def on_fit_end( 61 | self, 62 | ): 63 | """Called after `model.fit(...)`""" 64 | 65 | def on_fit_cancel( 66 | self, 67 | ): 68 | """Called after `model.fit(...)`is cancelled""" 69 | 70 | def on_train_epoch_start( 71 | self, 72 | ): 73 | """Called on start of training epoch""" 74 | 75 | def on_train_epoch_end(self, *args, **kwargs): 76 | """Called after end of training epoch""" 77 | 78 | def on_train_epoch_cancel(self): 79 | """Called after training epoch is cancelled""" 80 | 81 | def on_val_epoch_start( 82 | self, 83 | ): 84 | """Called on start of validation epoch""" 85 | 86 | def on_val_epoch_end(self, *args, **kwargs): 87 | """called after validation epoch ends""" 88 | 89 | def on_val_epoch_cancel(self): 90 | """called after validation epoch cancelled""" 91 | 92 | def on_train_step_start(self): 93 | """called before `train_step`""" 94 | 95 | def on_train_step_end(self, *args, **kwargs): 96 | """Called after training step""" 97 | 98 | def on_train_step_cancel(self): 99 | """Called after training step is cancelled""" 100 | 101 | def on_val_step_start(self): 102 | """Called on validation step""" 103 | 104 | def on_val_step_end(self, *args, **kwargs): 105 | """Called after validation step""" 106 | 107 | def on_val_step_cancel(self): 108 | """Called after validation step is cancelled""" 109 | 110 | def on_epoch_start(self): 111 | """Called Before each Epoch""" 112 | 113 | def on_epoch_end(self): 114 | """Called after each epoch""" 115 | 116 | def on_epoch_cancel(self): 117 | """Called after epoch is cancelled""" 118 | 119 | def on_forward_start(self): 120 | """Called before model.forward(...)""" 121 | 122 | def on_forward_end(self): 123 | """Called after model.forward(...)""" 124 | 125 | def on_forward_cancel(self): 126 | """Called after model.forward(...) is cancelled""" 127 | 128 | def clean(self): 129 | """Clean up""" 130 | -------------------------------------------------------------------------------- /gradsflow/callbacks/comet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 GradsFlow. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | 17 | os.environ["COMET_DISABLE_AUTO_LOGGING"] = "1" 18 | from typing import TYPE_CHECKING, Optional 19 | 20 | BaseExperiment = None 21 | if TYPE_CHECKING: 22 | from comet_ml import BaseExperiment 23 | 24 | from gradsflow.callbacks.base import Callback 25 | from gradsflow.utility.imports import requires 26 | 27 | CURRENT_FILE = os.path.dirname(os.path.realpath(__file__)) 28 | 29 | 30 | class CometCallback(Callback): 31 | """ 32 | [Comet](https://www.comet.ml/) Logging callback. 33 | This callback requires `comet-ml` to be pre-installed (`pip install comet-ml`). 34 | Automatically log your Experiment to Comet logging platform. You need to provide API key either by setting 35 | environment variable `COMET_API_KEY` or directly pass as an argument to the callback. 36 | Checkout the documentation for more examples. 37 | 38 | Args: 39 | project_name: Name of the Project 40 | api_key: project API key 41 | offline: log experiment offline 42 | """ 43 | 44 | def __init__( 45 | self, 46 | project_name: str = "awesome-project", 47 | workspace: Optional[str] = None, 48 | experiment_id: Optional[str] = None, 49 | api_key: Optional[str] = os.environ.get("COMET_API_KEY"), 50 | code_file: str = CURRENT_FILE, 51 | offline: bool = False, 52 | **kwargs, 53 | ): 54 | super().__init__( 55 | model=None, 56 | ) 57 | self._code_file = code_file 58 | self._experiment_id = experiment_id 59 | self.experiment = self._create_experiment( 60 | project_name=project_name, 61 | workspace=workspace, 62 | api_key=api_key, 63 | offline=offline, 64 | experiment_id=experiment_id, 65 | **kwargs, 66 | ) 67 | self._train_prefix = "train" 68 | self._val_prefix = "val" 69 | 70 | @staticmethod 71 | @requires("comet_ml", "CometCallback requires comet_ml to be installed!") 72 | def _create_experiment( 73 | project_name: str, 74 | workspace: str, 75 | offline: bool = False, 76 | api_key: Optional[str] = None, 77 | experiment_id: Optional[str] = None, 78 | **kwargs, 79 | ) -> "BaseExperiment": 80 | from comet_ml import ( 81 | ExistingExperiment, 82 | ExistingOfflineExperiment, 83 | Experiment, 84 | OfflineExperiment, 85 | ) 86 | 87 | if offline: 88 | if experiment_id: 89 | experiment = ExistingOfflineExperiment( 90 | project_name=project_name, workspace=workspace, previous_experiment=experiment_id, **kwargs 91 | ) 92 | else: 93 | experiment = OfflineExperiment(project_name=project_name, workspace=workspace, **kwargs) 94 | else: 95 | if experiment_id: 96 | experiment = ExistingExperiment( 97 | project_name=project_name, 98 | workspace=workspace, 99 | api_key=api_key, 100 | previous_experiment=experiment_id, 101 | **kwargs, 102 | ) 103 | else: 104 | experiment = Experiment(project_name=project_name, workspace=workspace, api_key=api_key, **kwargs) 105 | return experiment 106 | 107 | def on_fit_start(self): 108 | self.experiment.set_model_graph(self.model.learner) 109 | self.experiment.log_code(self._code_file) 110 | 111 | def on_train_epoch_start( 112 | self, 113 | ): 114 | self.experiment.train() 115 | 116 | def on_val_epoch_start( 117 | self, 118 | ): 119 | self.experiment.validate() 120 | 121 | def _step(self, prefix: str, outputs: dict): 122 | step = self.model.tracker.mode(prefix).steps 123 | loss = outputs["loss"].item() 124 | self.experiment.log_metrics(outputs.get("metrics", {}), step=step, prefix=prefix) 125 | self.experiment.log_metric(f"{prefix}_step_loss", loss, step=step) 126 | 127 | def on_train_step_end(self, outputs: dict = None, **_): 128 | self._step(prefix=self._train_prefix, outputs=outputs) 129 | 130 | def on_val_step_end(self, outputs: dict = None, **_): 131 | self._step(prefix=self._val_prefix, outputs=outputs) 132 | 133 | def on_epoch_end(self): 134 | epoch = self.model.tracker.current_epoch 135 | train_loss = self.model.tracker.train_loss 136 | train_metrics = self.model.tracker.train_metrics 137 | val_loss = self.model.tracker.val_loss 138 | val_metrics = self.model.tracker.val_metrics 139 | 140 | self.experiment.train() 141 | self.experiment.log_metric("train_epoch_loss", train_loss, epoch=epoch) 142 | self.experiment.log_metrics(train_metrics, epoch=epoch, prefix=self._train_prefix) 143 | 144 | self.experiment.validate() 145 | self.experiment.log_metric("val_epoch_loss", val_loss, epoch=epoch) 146 | self.experiment.log_metrics(val_metrics, epoch=epoch, prefix=self._val_prefix) 147 | self.experiment.log_epoch_end(epoch) 148 | -------------------------------------------------------------------------------- /gradsflow/callbacks/gpu.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 GradsFlow. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import logging 17 | 18 | from gradsflow.callbacks.base import Callback 19 | from gradsflow.utility.imports import requires 20 | 21 | logger = logging.getLogger(__name__) 22 | 23 | 24 | class EmissionTrackerCallback(Callback): 25 | """ 26 | Tracks the carbon emissions produced by deep neural networks using 27 | [CodeCarbon](https://github.com/mlco2/codecarbon). To use this callback first install codecarbon using 28 | `pip install codecarbon`. 29 | For offline use, you must have to specify the [country code](https://github.com/mlco2/codecarbon#offline-mode). 30 | 31 | Args: 32 | offline: whether to use internet connection or not. You will have to provide the country code `country_iso_code` for offline use. 33 | **kwargs: passed directly to codecarbon class. 34 | """ 35 | 36 | _name = "EmissionTrackerCallback" 37 | 38 | @requires("codecarbon", "install codecarbon to use EmissionTrackerCallback") 39 | def __init__(self, offline: bool = False, **kwargs): 40 | from codecarbon import EmissionsTracker, OfflineEmissionsTracker 41 | 42 | if offline: 43 | self._emission_tracker = OfflineEmissionsTracker(**kwargs) 44 | else: 45 | self._emission_tracker = EmissionsTracker(**kwargs) 46 | self._emission_tracker.start() 47 | 48 | super().__init__(model=None) 49 | 50 | def on_fit_end(self): 51 | emissions: float = self._emission_tracker.stop() 52 | logger.info(f"Emissions: {emissions} kg") 53 | -------------------------------------------------------------------------------- /gradsflow/callbacks/logger.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 GradsFlow. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import logging 16 | import os 17 | from pathlib import Path 18 | 19 | import pandas as pd 20 | 21 | from gradsflow.callbacks.base import Callback 22 | from gradsflow.utility.common import to_item 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | 27 | class CSVLogger(Callback): 28 | """ 29 | Saves Model training metrics as CSV 30 | Args: 31 | filename: filename of the csv 32 | path: folder path location of the csv 33 | verbose: Whether to show output 34 | """ 35 | 36 | _name = "CSVLogger" 37 | 38 | def __init__(self, filename: str = "./experiment.csv", path: str = os.getcwd(), verbose: bool = False): 39 | super().__init__(model=None) 40 | self.filename = filename 41 | self.path = path 42 | self._dst = Path(path) / Path(filename) 43 | self._logs = [] 44 | self.verbose = verbose 45 | 46 | def on_epoch_end(self): 47 | epoch = self.model.tracker.current_epoch 48 | train_loss = self.model.tracker.train_loss 49 | val_loss = self.model.tracker.val_loss 50 | train_metrics = self.model.tracker.train_metrics 51 | val_metrics = self.model.tracker.val_metrics 52 | 53 | train_metrics = {"train/" + k: v.avg for k, v in train_metrics.items()} 54 | val_metrics = {"val/" + k: v.avg for k, v in val_metrics.items()} 55 | train_metrics = to_item(train_metrics) 56 | val_metrics = to_item(val_metrics) 57 | 58 | data = {"epoch": epoch, "train_loss": train_loss, "val_loss": val_loss, **train_metrics, **val_metrics} 59 | if self.verbose: 60 | logger.info(f"verbose csv_logger on_epoch_end: {data}") 61 | self._logs.append(data) 62 | df = pd.DataFrame(self._logs) 63 | df.to_csv(self._dst, index=False) 64 | -------------------------------------------------------------------------------- /gradsflow/callbacks/progress.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 GradsFlow. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from typing import Dict, Optional 15 | 16 | from rich.progress import BarColumn, Progress, RenderableColumn, TimeRemainingColumn 17 | 18 | from gradsflow.callbacks.base import Callback 19 | 20 | 21 | class ProgressCallback(Callback): 22 | _name = "ProgressCallback" 23 | 24 | def __init__(self, model, progress_kwargs: Optional[Dict] = None): 25 | super().__init__(model) 26 | progress_kwargs = progress_kwargs or {} 27 | tracker = self.model.tracker 28 | self.bar_column = BarColumn() 29 | self.table_column = RenderableColumn(tracker.create_table()) 30 | 31 | self.progress = Progress( 32 | "[progress.description]{task.description}", 33 | self.bar_column, 34 | "[progress.percentage]{task.percentage:>3.0f}%", 35 | TimeRemainingColumn(), 36 | self.table_column, 37 | **progress_kwargs, 38 | ) 39 | tracker.progress = self.progress 40 | self.fit_prog = None 41 | self.train_prog_bar = None 42 | self.val_prog_bar = None 43 | 44 | def on_fit_start(self): 45 | self.progress.start() 46 | epochs = self.model.tracker.max_epochs 47 | completed = self.model.tracker.current_epoch 48 | self.fit_prog = self.progress.add_task("[red]Progress", total=epochs, completed=completed) 49 | 50 | def on_fit_end(self): 51 | self.progress.stop() 52 | 53 | def on_epoch_end(self): 54 | self.progress.update(self.fit_prog, advance=1) 55 | 56 | def on_train_epoch_start(self): 57 | n = self.model.autodataset.dataloader_length["train"] 58 | self.train_prog_bar = self.progress.add_task("[green]Learning", total=n) 59 | 60 | def on_train_epoch_end(self, *args, **kwargs): 61 | self.progress.remove_task(self.train_prog_bar) 62 | self.table_column.renderable = self.model.tracker.create_table() 63 | 64 | def on_train_step_end(self, *args, **kwargs): 65 | self.progress.update(self.train_prog_bar, advance=1) 66 | self.table_column.renderable = self.model.tracker.create_table() 67 | 68 | def on_val_epoch_start(self): 69 | val_len = self.model.autodataset.dataloader_length["val"] 70 | if val_len is None: 71 | return 72 | self.val_prog_bar = self.progress.add_task("[blue]Validating...", total=val_len) 73 | 74 | def on_val_epoch_end(self, *args, **kwargs): 75 | val_dl = self.model.autodataset.val_dataloader 76 | if not val_dl: 77 | return 78 | self.table_column.renderable = self.model.tracker.create_table() 79 | self.progress.remove_task(self.val_prog_bar) 80 | 81 | def on_val_step_end(self, *args, **kwargs): 82 | self.progress.update(self.val_prog_bar, advance=1) 83 | self.table_column.renderable = self.model.tracker.create_table() 84 | 85 | def clean(self): 86 | self.progress.stop() 87 | -------------------------------------------------------------------------------- /gradsflow/callbacks/raytune.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 GradsFlow. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import os 15 | from typing import Optional 16 | 17 | from ray import tune 18 | 19 | from gradsflow.callbacks.base import Callback 20 | 21 | _METRICS = { 22 | "val_accuracy": "val_accuracy", 23 | "train_accuracy": "train_accuracy", 24 | } 25 | 26 | 27 | def report_checkpoint_callback(metrics: Optional[dict] = None, filename: Optional[str] = None): 28 | from ray.tune.integration.pytorch_lightning import TuneReportCheckpointCallback 29 | 30 | metrics = metrics or _METRICS 31 | filename = filename or "filename" 32 | callback = TuneReportCheckpointCallback(metrics=metrics, filename=filename, on="validation_end") 33 | 34 | return callback 35 | 36 | 37 | class TorchTuneCheckpointCallback(Callback): 38 | _name = "TorchTuneCheckpointCallback" 39 | 40 | def on_epoch_end(self): 41 | epoch = self.model.tracker.current_epoch 42 | 43 | with tune.checkpoint_dir(epoch) as checkpoint_dir: 44 | path = os.path.join(checkpoint_dir, "filename") 45 | self.model.save(path) 46 | 47 | 48 | class TorchTuneReport(Callback): 49 | _name = "TorchTuneReport" 50 | 51 | def on_epoch_end(self): 52 | val_loss = self.model.tracker.val_loss 53 | train_loss = self.model.tracker.train_loss 54 | val_tracker = self.model.tracker.train.metrics 55 | train_tracker = self.model.tracker.val.metrics 56 | 57 | train_metrics = {"train_" + k.lower(): v.avg for k, v in train_tracker.items()} 58 | val_metrics = {"val_" + k.lower(): v.avg for k, v in val_tracker.items()} 59 | 60 | tune.report(val_loss=val_loss, train_loss=train_loss, **train_metrics, **val_metrics) 61 | -------------------------------------------------------------------------------- /gradsflow/callbacks/runner.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 GradsFlow. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import typing 15 | from collections import OrderedDict 16 | from typing import Any, Dict, List, Optional, Union 17 | 18 | from gradsflow.callbacks.base import Callback 19 | from gradsflow.callbacks.progress import ProgressCallback 20 | from gradsflow.callbacks.raytune import TorchTuneCheckpointCallback, TorchTuneReport 21 | from gradsflow.callbacks.training import TrainEvalCallback 22 | from gradsflow.utility import listify 23 | 24 | if typing.TYPE_CHECKING: 25 | from gradsflow.models.model import Model 26 | 27 | 28 | class CallbackRunner(Callback): 29 | _name: str = "CallbackRunner" 30 | _AVAILABLE_CALLBACKS: Dict[str, Any] = { 31 | "training": TrainEvalCallback, 32 | "tune_checkpoint": TorchTuneCheckpointCallback, 33 | "tune_report": TorchTuneReport, 34 | "progress": ProgressCallback, 35 | } 36 | 37 | def __init__(self, model: "Model", *callbacks: Union[str, Callback]): 38 | super().__init__(model) 39 | self.callbacks = [] 40 | self.callbacks = OrderedDict() 41 | for callback in callbacks: 42 | self.append(callback) 43 | 44 | # skipcq: W0212 45 | def append(self, callback: Union[str, Callback]): 46 | try: 47 | if isinstance(callback, str): 48 | callback_fn: Callback = self._AVAILABLE_CALLBACKS[callback](model=self.model) 49 | self.callbacks[callback_fn._name] = callback_fn 50 | elif isinstance(callback, Callback): 51 | callback.model = self.model 52 | self.callbacks[callback._name] = callback 53 | except KeyError: 54 | raise NotImplementedError(f"callback is not implemented {callback}") 55 | 56 | def available_callbacks(self): 57 | return list(self._AVAILABLE_CALLBACKS.keys()) 58 | 59 | def on_train_epoch_end(self, *args, **kwargs): 60 | for _, callback in self.callbacks.items(): 61 | callback.on_train_epoch_end(*args, **kwargs) 62 | 63 | def on_train_epoch_start(self): 64 | for _, callback in self.callbacks.items(): 65 | callback.on_train_epoch_start() 66 | 67 | def on_fit_start(self): 68 | for _, callback in self.callbacks.items(): 69 | callback.on_fit_start() 70 | 71 | def on_fit_end( 72 | self, 73 | ): 74 | for _, callback in self.callbacks.items(): 75 | callback.on_fit_end() 76 | 77 | def on_val_epoch_start( 78 | self, 79 | ): 80 | for _, callback in self.callbacks.items(): 81 | callback.on_val_epoch_start() 82 | 83 | def on_val_epoch_end(self, *args, **kwargs): 84 | for _, callback in self.callbacks.items(): 85 | callback.on_val_epoch_end(*args, **kwargs) 86 | 87 | def on_train_step_start(self): 88 | for _, callback in self.callbacks.items(): 89 | callback.on_train_step_start() 90 | 91 | def on_train_step_end(self, *args, **kwargs): 92 | for _, callback in self.callbacks.items(): 93 | callback.on_train_step_end(*args, **kwargs) 94 | 95 | def on_val_step_start(self): 96 | for _, callback in self.callbacks.items(): 97 | callback.on_val_step_start() 98 | 99 | def on_val_step_end(self, *args, **kwargs): 100 | for _, callback in self.callbacks.items(): 101 | callback.on_val_step_end(*args, **kwargs) 102 | 103 | def on_epoch_start(self): 104 | for _, callback in self.callbacks.items(): 105 | callback.on_epoch_start() 106 | 107 | def on_epoch_end(self): 108 | for _, callback in self.callbacks.items(): 109 | callback.on_epoch_end() 110 | 111 | def on_forward_start(self): 112 | for _, callback in self.callbacks.items(): 113 | callback.on_forward_start() 114 | 115 | def on_forward_end(self): 116 | for _, callback in self.callbacks.items(): 117 | callback.on_forward_end() 118 | 119 | def clean(self, keep: Optional[Union[List[str], str]] = None): 120 | """Remove all the callbacks except callback names provided in keep""" 121 | for _, callback in self.callbacks.items(): 122 | callback.clean() 123 | not_keep = set(self.callbacks.keys()) - set(listify(keep)) 124 | for key in not_keep: 125 | self.callbacks.pop(key) 126 | # self.callbacks = OrderedDict(list(self.callbacks.items())[0:1]) 127 | -------------------------------------------------------------------------------- /gradsflow/callbacks/training.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 GradsFlow. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import os 15 | from pathlib import Path 16 | from typing import Optional 17 | 18 | from gradsflow.callbacks.base import Callback 19 | 20 | 21 | class TrainEvalCallback(Callback): 22 | _name = "TrainEvalCallback" 23 | 24 | def on_train_step_start(self): 25 | self.model.optimizer.zero_grad() 26 | 27 | def on_train_step_end(self, *args, outputs: dict = None, **kwargs): 28 | MODE = "train" 29 | # ----- AUTO OPTIMIZATION ----- 30 | if not self.model.disable_auto_optimization: 31 | self.model.backward(outputs["loss"]) 32 | self.model.optimizer.step() 33 | 34 | # ----- METRIC UPDATES ----- 35 | tracker = self.model.tracker 36 | loss = outputs["loss"] 37 | tracker.track_loss(loss, mode=MODE) 38 | tracker.track_metrics(outputs.get("metrics", {}), mode=MODE) 39 | 40 | def on_val_step_end(self, *args, outputs: dict = None, **kwargs): 41 | MODE = "val" 42 | # ----- METRIC UPDATES ----- 43 | tracker = self.model.tracker 44 | loss = outputs["loss"] 45 | tracker.track_loss(loss, mode=MODE) 46 | tracker.track_metrics(outputs.get("metrics", {}), mode=MODE) 47 | 48 | def on_train_epoch_start(self): 49 | self.model.train() 50 | self.model.metrics.reset() 51 | self.model.tracker.train.reset() 52 | 53 | def on_val_epoch_start(self): 54 | self.model.eval() 55 | self.model.metrics.reset() 56 | self.model.tracker.val.reset() 57 | 58 | 59 | class ModelCheckpoint(Callback): 60 | """ 61 | Saves Model checkpoint 62 | Args: 63 | filename: name of checkpoint 64 | path: folder path location of the model checkpoint. Will create a folder if does not exist. 65 | save_extra: whether to save extra details like tracker 66 | """ 67 | 68 | _name = "ModelCheckpoint" 69 | 70 | def __init__(self, filename: Optional[str] = None, path: str = os.getcwd(), save_extra: bool = False): 71 | super().__init__(model=None) 72 | filename = Path(filename or "model") 73 | path = Path(path) 74 | self.path = path 75 | self.path.mkdir(exist_ok=True) 76 | self._dst = path / filename 77 | self.save_extra = save_extra 78 | 79 | def on_epoch_end(self): 80 | epoch = self.model.tracker.current_epoch 81 | path = f"{self._dst}_epoch={epoch}_.pt" 82 | self.model.save(path, save_extra=self.save_extra) 83 | -------------------------------------------------------------------------------- /gradsflow/callbacks/wandb.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 GradsFlow. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import logging 16 | import os 17 | from typing import Dict, List, Optional 18 | 19 | from gradsflow.callbacks.base import Callback 20 | from gradsflow.utility.imports import is_installed, requires 21 | 22 | logger = logging.getLogger(__name__) 23 | 24 | if is_installed("wandb"): 25 | import wandb 26 | 27 | CURRENT_FILE = os.path.dirname(os.path.realpath(__file__)) 28 | 29 | 30 | def define_metrics(default_step_metric: str = "global_step"): 31 | min_max_def: Dict[str, List[str]] = { 32 | "min": ["train/step_loss", "train/epoch_loss", "val/epoch_loss"], 33 | "max": ["train/acc*", "val/acc*"], 34 | } 35 | for summary, metric_list in min_max_def.items(): 36 | for metric in metric_list: 37 | if "epoch" in metric or "val" in metric: 38 | wandb.define_metric(metric, summary=summary, step_metric="epoch") 39 | wandb.define_metric("*", step_metric=default_step_metric) 40 | 41 | 42 | class WandbCallback(Callback): 43 | """ 44 | [Weights & Biases](https://www.wandb.com/) Logging callback. To use this callback `pip install wandb`. 45 | Any metric that contains `epoch` will be plotted with `epoch` and all the other metrics will be plotted against 46 | `global_step` which is total training steps. You can change the default axis by providing `default_step_metric`. 47 | 48 | Args: 49 | log_model: Whether to upload model artifact to Wandb 50 | code_file: path of the code you want to upload as artifact to Wandb 51 | default_step_metric: Metrics will be plotted against the `default_step_metric`. Default value is `global_step`. 52 | 53 | ```python 54 | from gradsflow.callbacks import WandbCallback 55 | from timm import create_model 56 | 57 | cnn = create_model("resnet18", pretrained=False, num_classes=1) 58 | model = Model(cnn) 59 | model.compile() 60 | cb = WandbCallback() 61 | autodataset = None # create your dataset 62 | model.fit(autodataset, callbacks=cb) 63 | ``` 64 | """ 65 | 66 | @requires("wandb", "WandbCallback requires wandb to be installed!") 67 | def __init__(self, log_model: bool = False, code_file: Optional[str] = None, default_step_metric="global_step"): 68 | super().__init__() 69 | if wandb.run is None: 70 | logger.warning("wandb.init() was not called before initializing WandbCallback()" "Calling wandb.init()") 71 | wandb.init() 72 | self._code_file = code_file 73 | self._train_prefix = "train" 74 | self._val_prefix = "val" 75 | self._log_model = log_model 76 | self._setup(default_step_metric) 77 | 78 | def _setup(self, default_step_metric): 79 | define_metrics(default_step_metric) 80 | 81 | def on_fit_start(self): 82 | if self._log_model: 83 | wandb.log_artifact(self.model.learner) 84 | if self._code_file: 85 | wandb.log_artifact(self._code_file) 86 | 87 | def _apply_prefix(self, data: dict, prefix: str): 88 | data = {f"{prefix}/{k}": v for k, v in data.items()} 89 | return data 90 | 91 | def on_train_step_end(self, outputs: dict = None, **_): 92 | prefix = "train" 93 | global_step = self.model.tracker.global_step 94 | loss = outputs["loss"].item() 95 | # log train step loss 96 | wandb.log({f"{prefix}/step_loss": loss, "train_step": global_step}, commit=False) 97 | 98 | # log train step metrics 99 | metrics = outputs.get("metrics", {}) 100 | metrics = self._apply_prefix(metrics, prefix) 101 | wandb.log(metrics, commit=False) 102 | 103 | # https://docs.wandb.ai/guides/track/log#how-do-i-use-custom-x-axes 104 | wandb.log({"global_step": global_step}) 105 | 106 | def on_epoch_end(self): 107 | epoch = self.model.tracker.current_epoch 108 | train_loss = self.model.tracker.train_loss 109 | train_metrics = self.model.tracker.train_metrics.to_dict() 110 | val_loss = self.model.tracker.val_loss 111 | val_metrics = self.model.tracker.val_metrics.to_dict() 112 | 113 | train_metrics = self._apply_prefix(train_metrics, prefix=self._train_prefix) 114 | val_metrics = self._apply_prefix(val_metrics, prefix=self._val_prefix) 115 | train_metrics.update({"epoch": epoch}) 116 | val_metrics.update({"epoch": epoch}) 117 | 118 | wandb.log({"train/epoch_loss": train_loss, "epoch": epoch}, commit=False) 119 | wandb.log({"val/epoch_loss": val_loss, "epoch": epoch}, commit=False) 120 | wandb.log(train_metrics, commit=False) 121 | wandb.log(val_metrics, commit=False) 122 | wandb.log({}) 123 | -------------------------------------------------------------------------------- /gradsflow/core/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 GradsFlow. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Core Building blocks for Auto Tasks""" 16 | -------------------------------------------------------------------------------- /gradsflow/core/base.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 GradsFlow. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | from abc import ABC, abstractmethod 16 | from dataclasses import asdict, dataclass 17 | from typing import Dict, Optional 18 | 19 | import numpy as np 20 | import torch 21 | 22 | from gradsflow.utility.common import AverageMeter, GDict, module_to_cls_index 23 | 24 | 25 | class BaseAutoModel(ABC): 26 | """ 27 | The main class for AutoML which consists everything required for tranining a model - 28 | data, model and trainer. 29 | """ 30 | 31 | _OPTIMIZER_INDEX = module_to_cls_index(torch.optim, True) 32 | 33 | @abstractmethod 34 | def _create_search_space(self) -> Dict[str, str]: 35 | """creates search space""" 36 | raise NotImplementedError 37 | 38 | @abstractmethod 39 | def build_model(self, search_space: dict): 40 | """Build model from dictionary _search_space""" 41 | raise NotImplementedError 42 | 43 | 44 | @dataclass(init=False) 45 | class TrackingValues: 46 | loss: Optional[AverageMeter] = None # Average loss in a single Epoch 47 | steps: Optional[int] = None # Step per epoch 48 | step_loss: Optional[float] = None 49 | metrics: Optional[Dict[str, AverageMeter]] = None # Average value in a single Epoch 50 | 51 | def __init__(self): 52 | self.metrics = GDict() 53 | self.loss = AverageMeter(name="loss") 54 | 55 | def update_loss(self, loss: float): 56 | assert isinstance(loss, (int, float, np.ndarray)), f"loss must be int | float | np.ndarray but got {type(loss)}" 57 | self.step_loss = loss 58 | self.loss.update(loss) 59 | 60 | def update_metrics(self, metrics: Dict[str, float]): 61 | """Update `TrackingValues` metrics. mode can be train or val""" 62 | # Track values that averages with epoch 63 | for key, value in metrics.items(): 64 | try: 65 | self.metrics[key].update(value) 66 | except KeyError: 67 | self.metrics[key] = AverageMeter(name=key) 68 | self.metrics[key].update(value) 69 | 70 | def to_dict(self) -> dict: 71 | return asdict(self) 72 | 73 | def reset(self): 74 | """Values are Reset on start of each `on_*_epoch_start`""" 75 | self.loss.reset() 76 | for _, metric in self.metrics.items(): 77 | metric.reset() 78 | -------------------------------------------------------------------------------- /gradsflow/core/metrics.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 GradsFlow. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from typing import Dict, Union 15 | 16 | import torch 17 | import torchmetrics 18 | from torchmetrics import Metric, MetricCollection 19 | 20 | from gradsflow.utility.common import module_to_cls_index 21 | 22 | _tm_classes = module_to_cls_index(torchmetrics, lower_key=False) 23 | 24 | metrics_classes: Dict[str, Metric] = {k: v for k, v in _tm_classes.items() if 65 <= ord(k[0]) <= 90} 25 | metrics_classes = {k.lower(): v for k, v in metrics_classes.items()} 26 | 27 | 28 | class MetricsContainer: 29 | def __init__(self, device): 30 | self._device = device 31 | self._metrics: MetricCollection = MetricCollection([]) 32 | 33 | @property 34 | def metrics(self): 35 | return self._metrics 36 | 37 | def compile_metrics(self, *metrics: Union[str, Metric]) -> None: 38 | """Initialize metrics collection and add provided `*metrics` to the container.""" 39 | if len(self._metrics) > 0: 40 | self._metrics = MetricCollection([]) 41 | self.add_metrics(*metrics) 42 | 43 | def add_metrics(self, *metrics: Union[str, Metric]) -> None: 44 | for m in metrics: 45 | if isinstance(m, str): 46 | m_cls = metrics_classes.get(m) 47 | assert ( 48 | m_cls is not None 49 | ), f"metrics {m} is not available! Available metrics are {tuple(metrics_classes.keys())}" 50 | m_obj = m_cls() 51 | elif isinstance(m, Metric): 52 | m_obj = m 53 | else: 54 | raise NotImplementedError(f"metrics not implemented for {m}! Please see `torchmetrics`.") 55 | self._metrics.add_metrics(m_obj) 56 | self._metrics.to(self._device) 57 | 58 | def _update(self, preds, target): 59 | """Iteratively update all the `torchmetrics` value""" 60 | self._metrics.update(preds, target) 61 | 62 | def compute(self): 63 | return self._metrics.compute() 64 | 65 | def calculate_metrics(self, preds, target) -> Dict[str, torch.Tensor]: 66 | """Iteratively update the compiled metrics and return the new computed values""" 67 | self._update(preds, target) 68 | return self.compute() 69 | 70 | def reset(self): 71 | """Reset the values of each of the compiled metrics""" 72 | self._metrics.reset() 73 | -------------------------------------------------------------------------------- /gradsflow/data/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 GradsFlow. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from .autodata import AutoDataset 15 | from .common import random_split_dataset 16 | from .image import get_augmentations, get_fake_data, image_dataset_from_directory 17 | -------------------------------------------------------------------------------- /gradsflow/data/base.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 GradsFlow. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import dataclasses 15 | from typing import Union 16 | 17 | from torch.utils.data import DataLoader, Dataset 18 | 19 | from gradsflow.data.ray_dataset import RayDataset 20 | 21 | 22 | @dataclasses.dataclass(init=False) 23 | class Data: 24 | dataloader: DataLoader 25 | dataset: Union[RayDataset, Dataset] 26 | 27 | 28 | class BaseAutoDataset: 29 | def __init__(self): 30 | self.meta = {} 31 | self.datamodule = None 32 | self._train_dataloader = None 33 | self._val_dataloader = None 34 | self.train_dataset = None 35 | self.val_dataset = None 36 | self.num_classes = None 37 | self._val_dataloader_length = None 38 | self._train_dataloader_length = None 39 | -------------------------------------------------------------------------------- /gradsflow/data/common.py: -------------------------------------------------------------------------------- 1 | """Provide some common functionalities/utilities for Datasets""" 2 | 3 | # Copyright (c) 2021 GradsFlow. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | from typing import List 17 | 18 | from torch.utils.data import Dataset, random_split 19 | 20 | 21 | def random_split_dataset(data: Dataset, pct=0.9) -> List[Dataset]: 22 | """ 23 | Randomly splits dataset into two sets. Length of first split is len(data) * pct. 24 | Args: 25 | data: pytorch Dataset object with `__len__` implementation. 26 | pct: percentage of split. 27 | """ 28 | n = len(data) 29 | split_1 = int(n * pct) 30 | split_2 = n - split_1 31 | return random_split(data, (split_1, split_2)) 32 | -------------------------------------------------------------------------------- /gradsflow/data/image.py: -------------------------------------------------------------------------------- 1 | """Data loader for image dataset""" 2 | 3 | # Copyright (c) 2021 GradsFlow. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | import logging 17 | import os 18 | from pathlib import Path 19 | from typing import List, Optional, Tuple, Union 20 | 21 | from torch.utils.data import DataLoader 22 | from torchvision import transforms as T 23 | from torchvision.datasets import ImageFolder 24 | 25 | from gradsflow.data.base import Data 26 | from gradsflow.data.ray_dataset import RayImageFolder 27 | 28 | logger = logging.getLogger("data.image") 29 | 30 | 31 | def get_augmentations(image_size: tuple = (224, 224), auto_augment_policy: bool = True): 32 | if auto_augment_policy: 33 | augmentations = [T.Resize(image_size), T.AutoAugment(), T.ToTensor()] 34 | else: 35 | augmentations = [T.Resize(image_size), T.ToTensor()] 36 | return T.Compose(augmentations) 37 | 38 | 39 | def image_dataset_from_directory( 40 | directory: Union[List[str], Path, str], 41 | transform=None, 42 | image_size=(224, 224), 43 | batch_size: int = 1, 44 | shuffle: bool = False, 45 | pin_memory: bool = True, 46 | num_workers: Optional[int] = None, 47 | ray_data: bool = False, 48 | ) -> Data: 49 | """ 50 | Create Dataset and Dataloader for image folder dataset. 51 | Args: 52 | directory: 53 | transform: 54 | image_size: 55 | batch_size: 56 | shuffle: 57 | pin_memory: 58 | num_workers: 59 | 60 | Returns: 61 | A dictionary containing dataset and dataloader. 62 | """ 63 | data = Data() 64 | num_workers = num_workers or os.cpu_count() 65 | if transform is True: 66 | transform = get_augmentations(image_size) 67 | if ray_data: 68 | data.dataset = RayImageFolder(directory, transform=transform) 69 | else: 70 | data.dataset = ImageFolder(directory, transform=transform) 71 | logger.info("ds created") 72 | data.dataloader = DataLoader( 73 | data.dataset, 74 | batch_size=batch_size, 75 | pin_memory=pin_memory, 76 | shuffle=shuffle, 77 | num_workers=num_workers, 78 | ) 79 | return data 80 | 81 | 82 | def get_fake_data( 83 | image_size: Tuple[int, int], num_classes=10, batch_size=1, pin_memory=False, shuffle=True, num_workers=0 84 | ): 85 | from torchvision.datasets import FakeData 86 | 87 | data = Data() 88 | 89 | transform = get_augmentations( 90 | image_size=image_size, 91 | ) 92 | data.dataset = FakeData(size=100, image_size=[3, *image_size], num_classes=num_classes, transform=transform) 93 | data.dataloader = DataLoader( 94 | data.dataset, 95 | batch_size=batch_size, 96 | pin_memory=pin_memory, 97 | shuffle=shuffle, 98 | num_workers=num_workers, 99 | ) 100 | 101 | return data 102 | -------------------------------------------------------------------------------- /gradsflow/data/mixins.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 GradsFlow. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from typing import Dict, List, Tuple, Union 15 | 16 | import torch 17 | 18 | from gradsflow.utility import default_device 19 | 20 | 21 | class DataMixin: 22 | INPUT_KEY = 0 # other common value - inputs, images, text 23 | OUTPUT_KEY = 1 # other common values - target, ground 24 | device = default_device() 25 | 26 | def fetch_inputs(self, data: Union[List, Dict]): 27 | return data[self.INPUT_KEY] 28 | 29 | def fetch_target(self, data: Union[List, Dict]): 30 | return data[self.OUTPUT_KEY] 31 | 32 | @classmethod 33 | def send_to_device(cls, data: Union[List, Dict, Tuple, torch.Tensor, int, float]): 34 | """Send data to be device""" 35 | if isinstance(data, (int, float, str)): 36 | return data 37 | 38 | if isinstance(data, torch.Tensor): 39 | return data.to(cls.device) 40 | 41 | if isinstance(data, (list, tuple)): 42 | return list(map(cls.send_to_device, data)) 43 | if isinstance(data, dict): 44 | return {k: cls.send_to_device(v) for k, v in data.items()} 45 | raise NotImplementedError( 46 | f"send_to_device is not implemented for data of type {type(data)}! Please raise an issue/pr" 47 | ) 48 | -------------------------------------------------------------------------------- /gradsflow/data/ray_dataset.py: -------------------------------------------------------------------------------- 1 | """Mimics torch.data.Dataset for ray.data integration""" 2 | 3 | # Copyright (c) 2021 GradsFlow. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import io 18 | from typing import Callable, List, Union 19 | 20 | import ray 21 | from PIL import Image 22 | from torch.utils.data import IterableDataset 23 | 24 | 25 | class RayDataset(IterableDataset): 26 | def __init__(self, path: Union[List[str], str]): 27 | self.path = path 28 | self.ds = ray.data.read_binary_files(path, include_paths=True) 29 | 30 | def __iter__(self): 31 | return self.ds.iter_rows() 32 | 33 | def __len__(self): 34 | return len(self.input_files) 35 | 36 | def map_(self, func, *args, **kwargs) -> None: 37 | """Inplace Map for ray.data 38 | Time complexity: O(dataset size / parallelism) 39 | 40 | See https://docs.ray.io/en/latest/data/dataset.html#transforming-datasets""" 41 | self.ds = self.ds.map(func, *args, **kwargs) 42 | 43 | def map_batch_(self, func, batch_size: int = 2, **kwargs) -> None: 44 | """Inplace Map for ray.data 45 | Time complexity: O(dataset size / parallelism) 46 | See https://docs.ray.io/en/latest/data/dataset.html#transforming-datasets""" 47 | self.ds = self.ds.map_batches(func, batch_size=batch_size, **kwargs) 48 | 49 | @property 50 | def input_files(self): 51 | return self.ds.input_files() 52 | 53 | 54 | class RayImageFolder(RayDataset): 55 | """Read image datasets 56 | ``` 57 | root/dog/xxx.png 58 | root/dog/xxy.png 59 | root/dog/[...]/xxz.png 60 | 61 | root/cat/123.png 62 | root/cat/nsdf3.png 63 | root/cat/[...]/asd932_.png 64 | ``` 65 | """ 66 | 67 | def __init__(self, path, transform: Union[Callable, None] = None): 68 | super().__init__(path) 69 | self.transform = transform 70 | 71 | @staticmethod 72 | def file_to_class(files: Union[str, List]): 73 | file_list = [] 74 | if isinstance(files, (tuple, list)): 75 | for file in files: 76 | file_list.append(file.split("/")[-2]) 77 | return file_list 78 | return files.split("/")[-2] 79 | 80 | def find_classes(self) -> List[str]: 81 | files = self.input_files 82 | return sorted(list(set(map(self.file_to_class, files)))) 83 | 84 | def __iter__(self): 85 | for data in self.ds.iter_rows(): 86 | x = Image.open(io.BytesIO(data[1])) 87 | target = self.file_to_class(data[0]) 88 | if self.transform: 89 | x = self.transform(x) 90 | yield x, target 91 | -------------------------------------------------------------------------------- /gradsflow/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 GradsFlow. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from .model import Model 15 | from .utils import available_losses, available_metrics 16 | -------------------------------------------------------------------------------- /gradsflow/models/constants.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 GradsFlow. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | LEARNER = "learner" 15 | -------------------------------------------------------------------------------- /gradsflow/models/exceptions.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 GradsFlow. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import logging 15 | 16 | logger = logging.getLogger(__name__) 17 | 18 | 19 | class EpochCancel(Exception): 20 | def __init__(self): 21 | super().__init__() 22 | logger.info("epoch cancelled") 23 | 24 | 25 | class FitCancel(Exception): 26 | def __init__(self): 27 | super().__init__() 28 | logger.info("model.fit cancelled") 29 | -------------------------------------------------------------------------------- /gradsflow/models/tracker.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 GradsFlow. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import logging 15 | from dataclasses import dataclass 16 | from typing import Dict, List, Optional 17 | 18 | from rich import box 19 | from rich.table import Table 20 | 21 | from gradsflow.core.base import TrackingValues 22 | from gradsflow.utility.common import GDict, to_item 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | 27 | @dataclass(init=False) 28 | class BaseTracker: 29 | global_step: int = 0 # Global training steps 30 | max_epochs: int = 0 31 | current_epoch: int = 0 # current train current_epoch 32 | steps_per_epoch: Optional[int] = None 33 | train: TrackingValues = TrackingValues() 34 | val: TrackingValues = TrackingValues() 35 | 36 | 37 | class Tracker(BaseTracker): 38 | """ 39 | Tracks loss, accuracy and model weights during model.fit() 40 | """ 41 | 42 | def __init__(self): 43 | self.train.metrics = GDict() 44 | self.val.metrics = GDict() 45 | self.logs: List[Dict] = [] 46 | 47 | def __getitem__(self, key: str): # skipcq: PYL-R1705 48 | """ 49 | 1. key= `train | val` then return respective `TrackingValues` object 50 | 2. key=`metrics` then return a dictionary of metrics 51 | 3. key=`loss` then return a dictionary of losses 52 | Args: 53 | key: train, val, metrics or loss 54 | 55 | Returns: 56 | `TrackingValues` or a Dictionary 57 | """ 58 | if key in ("train", "val"): 59 | return self.mode(key) 60 | elif key == "metrics": 61 | return {"train": self.train_metrics, "val": self.val_metrics} 62 | elif key == "loss": 63 | return {"train": self.train_loss, "val": self.val_loss} 64 | 65 | raise KeyError(f"key {key} is not implemented!") 66 | 67 | @property 68 | def train_loss(self): 69 | return self.train.loss.avg 70 | 71 | @property 72 | def val_loss(self): 73 | return self.val.loss.avg 74 | 75 | @property 76 | def train_metrics(self) -> GDict: 77 | return self.train.metrics 78 | 79 | @property 80 | def val_metrics(self) -> GDict: 81 | return self.val.metrics 82 | 83 | def mode(self, mode) -> TrackingValues: 84 | if mode == "train": 85 | return self.train 86 | if mode == "val": 87 | return self.val 88 | 89 | raise KeyError(f"mode {mode} is not implemented!") 90 | 91 | def _append_logs(self, key, value): 92 | """Append Key Value pairs to `Tracker.logs`""" 93 | # TODO: accept a list of keys and values as well. 94 | epoch = self.current_epoch 95 | data = {"current_epoch": epoch, key: to_item(value)} 96 | self.logs.append(data) 97 | 98 | def track_loss(self, loss: float, mode: str): 99 | """Tracks loss by adding to `Tracker.logs` and maintaining average loss in a single Epoch with `TrackingValues`. 100 | Update loss with `TrackingValues.update_loss(loss)` which is called with `TrainEvalCallback` at `*_step_end`. 101 | Args: 102 | loss: Step Loss 103 | mode: can be train | val 104 | """ 105 | loss = to_item(loss) 106 | value_tracker = self.mode(mode) 107 | value_tracker.update_loss(loss) 108 | key = mode + "/" + "loss" 109 | self._append_logs(key, loss) 110 | 111 | def track_metrics(self, metric: Dict[str, float], mode: str): 112 | """Tracks metrics by adding to `Tracker.logs` and maintaining average metric in a single Epoch with `TrackingValues`. 113 | Update metrics with `TrackingValues.update_metrics(metrics)` which is called with `TrainEvalCallback` at `*_step_end`. 114 | Args: 115 | metric: Step metric 116 | mode: can be train | val 117 | """ 118 | value_tracker = self.mode(mode) 119 | 120 | # Track values that averages with epoch 121 | value_tracker.update_metrics(metric) 122 | 123 | # _append_logs value for each step in a dict 124 | for k, v in metric.items(): 125 | k = mode + "/" + k 126 | self._append_logs(k, v) 127 | 128 | def create_table(self) -> Table: 129 | headings = ["i", "train/loss"] 130 | row = [self.current_epoch, self.train_loss] 131 | 132 | if self.val.loss.computed: 133 | headings.append("val/loss") 134 | row.append(self.val_loss) 135 | 136 | for metric_name, value in self.train_metrics.items(): 137 | headings.append("train/" + metric_name) 138 | row.append(value.avg) 139 | 140 | for metric_name, value in self.val_metrics.items(): 141 | headings.append("val/" + metric_name) 142 | row.append(value.avg) 143 | 144 | row = list(map(lambda x: f"{x: .3f}" if isinstance(x, float) else str(x), row)) 145 | table = Table(*headings, expand=True, box=box.SIMPLE) 146 | table.add_row(*row) 147 | return table 148 | 149 | def reset(self): 150 | """Resets epochs, logs and train & val `TrackingValues`.""" 151 | logger.debug("Reset Tracker") 152 | self.max_epochs = 0 153 | self.current_epoch = 0 154 | self.steps_per_epoch = None 155 | self.train = TrackingValues() 156 | self.val = TrackingValues() 157 | self.logs = [] 158 | -------------------------------------------------------------------------------- /gradsflow/models/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 GradsFlow. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from typing import Callable, Dict, List, Optional, Union 15 | 16 | import numpy as np 17 | import torch 18 | import torchmetrics 19 | from torch import nn 20 | from torchmetrics import Metric 21 | 22 | from gradsflow.utility.common import filter_list, module_to_cls_index 23 | 24 | SCALAR = Union[torch.Tensor, np.float64, float, int] 25 | _nn_classes = module_to_cls_index(nn) 26 | _tm_classes = module_to_cls_index(torchmetrics, lower_key=False) 27 | 28 | losses: Dict[str, Callable] = {k: v for k, v in _nn_classes.items() if "loss" in k} 29 | 30 | metrics: Dict[str, Metric] = {k: v for k, v in _tm_classes.items() if 65 <= ord(k[0]) <= 90} 31 | metrics = {k.lower(): v for k, v in metrics.items()} 32 | 33 | 34 | def available_losses(pattern: Optional[str] = None) -> List[str]: 35 | """Get available loss functions 36 | ```python 37 | >> available_losses() 38 | >> # crossentropy, binarycrossentropy, mae, ... 39 | 40 | # Filter available losses with regex pattern 41 | >> available_losses("m.e) 42 | >> # ["mae", "mse"] 43 | ``` 44 | """ 45 | loss_keys = list(losses.keys()) 46 | return filter_list(loss_keys, pattern) 47 | 48 | 49 | def available_metrics(pattern: Optional[str] = None) -> List[str]: 50 | """Get available Metrics 51 | ```python 52 | >> available_metrics() 53 | >> # accuracy, F1, RMSE, ... 54 | 55 | # Filter available metrics with regex pattern 56 | >> available_metrics("acc.*") 57 | >> # ["accuracy"] 58 | ``` 59 | """ 60 | metric_keys = list(metrics.keys()) 61 | return filter_list(metric_keys, pattern) 62 | -------------------------------------------------------------------------------- /gradsflow/tuner/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 GradsFlow. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .automodel import AutoModelV2 16 | from .tuner import Tuner 17 | -------------------------------------------------------------------------------- /gradsflow/tuner/tuner.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 GradsFlow. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from typing import Any, Dict, List, Sequence, Union 15 | 16 | import ray 17 | from ray import tune 18 | from ray.tune.sample import Domain 19 | 20 | 21 | class ComplexObject: 22 | """Class to store and retrieve large size objects and convert it to `ray.tune.Domain`. 23 | Objects will be stored with `ray.put` and retrieved with `ray.get`. 24 | """ 25 | 26 | def __init__(self): 27 | self.values: List[Any] = [] 28 | 29 | def __len__(self): 30 | return len(self.values) 31 | 32 | def append(self, value: Any): 33 | self.values.append(ray.put(value)) 34 | 35 | def get_complex_object(self, idx): 36 | return ray.get(self.values[idx]) 37 | 38 | def to_choice(self) -> Domain: 39 | """converts to ray.tune Domain""" 40 | indices = tuple(range(len(self.values))) 41 | return tune.choice(indices) 42 | 43 | 44 | class Tuner: 45 | """ 46 | Supports `ray.tune` methods and provide an easy way for tuning large size complex objects like Models. 47 | """ 48 | 49 | def __init__(self): 50 | self._search_space: Dict[str, Domain] = {} 51 | self._complex_objects: Dict[str, ComplexObject] = {} 52 | 53 | def update_search_space(self, k: str, v: Union[Domain, ComplexObject]): 54 | """ 55 | Update search space with value `ray.tune(...)` or `gradsflow.tuner.ComplexObject` 56 | Args: 57 | k: hyperparameter name 58 | v: hyperparameter value - `ray.tune(...)` or `gradsflow.tuner.ComplexObject` object. 59 | """ 60 | if isinstance(v, Domain): 61 | self._search_space[k] = v 62 | elif isinstance(v, ComplexObject): 63 | assert isinstance(v, ComplexObject), f"Selected is_complex but object is of type {type(v)}" 64 | self._search_space[k] = v.to_choice() 65 | self._complex_objects[k] = v 66 | else: 67 | raise UserWarning(f"Tuner Search space doesn't support {type(v)}") 68 | assert isinstance( 69 | self._search_space[k], Domain 70 | ), "search space should only contain object of type `tune.Domain`" 71 | 72 | def suggest_complex(self, key: str, *values: Sequence) -> ComplexObject: 73 | """ 74 | Use this method when you want to search models or any large object. 75 | It will also update search space with the provided key. 76 | Args: 77 | key: hyperparameter name 78 | *values: values for the hyperparameter 79 | 80 | Returns: 81 | `ComplexObject` 82 | """ 83 | complex_object = ComplexObject() 84 | for _, v in enumerate(values): 85 | complex_object.append(v) 86 | 87 | object_choice = complex_object.to_choice() 88 | self._search_space[key] = object_choice 89 | self._complex_objects[key] = complex_object 90 | return complex_object 91 | 92 | def scalar(self, key: str, value): 93 | """This sets a scalar value and will not be used for tuning""" 94 | self._search_space[key] = value 95 | 96 | def choice(self, key: str, *values) -> Domain: 97 | """Tune for categorical values""" 98 | x = tune.choice(values) 99 | self._search_space[key] = x 100 | return x 101 | 102 | def loguniform(self, key: str, lower: float, upper: float, base: float = 10) -> Domain: 103 | x = tune.loguniform(lower, upper, base) 104 | self._search_space[key] = x 105 | return x 106 | 107 | def union(self, tuner: "Tuner") -> "Tuner": 108 | """Inplace Merge of two Tuners""" 109 | self._search_space.update(tuner._search_space) 110 | self._complex_objects.update(tuner._complex_objects) 111 | return self 112 | 113 | @staticmethod 114 | def merge(*tuners: "Tuner") -> "Tuner": 115 | final_tuner = Tuner() 116 | for t in tuners: 117 | final_tuner.union(t) 118 | return final_tuner 119 | 120 | @property 121 | def value(self) -> Dict[str, Domain]: 122 | return self._search_space 123 | 124 | def get_complex_object(self, key: str, idx: int): 125 | """Get registered complex object value from key at given index""" 126 | return self._complex_objects[key].get_complex_object(idx) 127 | 128 | def get(self, key: str) -> Union[Domain, ComplexObject]: 129 | if key in self._complex_objects: 130 | return self._complex_objects[key] 131 | if key in self._search_space: 132 | return self._search_space[key] 133 | raise KeyError(f"key={key} is not available in tuner! Available={tuple(self._search_space.keys())}") 134 | -------------------------------------------------------------------------------- /gradsflow/utility/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 GradsFlow. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from .common import AverageMeter, default_device, listify 15 | from .data import download 16 | -------------------------------------------------------------------------------- /gradsflow/utility/common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 GradsFlow. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import dataclasses 15 | import inspect 16 | import logging 17 | import os 18 | import re 19 | import sys 20 | import warnings 21 | from glob import glob 22 | from pathlib import Path 23 | from typing import Any, Dict, List, Optional, Union 24 | 25 | import numpy as np 26 | import torch 27 | 28 | 29 | def get_file_extension(path: str) -> str: 30 | """Returns extension of the file""" 31 | return os.path.basename(path).split(".")[-1] 32 | 33 | 34 | def get_files(folder: str): 35 | """Fetch every file from given folder recursively.""" 36 | folder = str(Path(folder) / "**" / "*") 37 | return glob(folder, recursive=True) 38 | 39 | 40 | def default_device(): 41 | return "cuda" if torch.cuda.is_available() else "cpu" 42 | 43 | 44 | def module_to_cls_index(module, lower_key: bool = True) -> dict: 45 | """Fetch classes from module and create a Dictionary with key as class name and value as Class""" 46 | class_members = inspect.getmembers(sys.modules[module.__name__], inspect.isclass) 47 | mapping = {} 48 | for k, v in class_members: 49 | if lower_key: 50 | k = k.lower() 51 | mapping[k] = v 52 | 53 | return mapping 54 | 55 | 56 | def listify(item: Any) -> List: 57 | """Convert any scalar value into list.""" 58 | if not item: 59 | return [] 60 | if isinstance(item, list): 61 | return item 62 | if isinstance(item, (tuple, set)): 63 | return list(item) 64 | if isinstance(item, (int, float, str)): 65 | return [item] 66 | try: 67 | return list(item) 68 | except TypeError: 69 | return [item] 70 | 71 | 72 | # ref: https://github.com/rwightman/pytorch-image-models/blob/b544ad4d3fcd02057ab9f43b118290f2a089566f/timm/utils/metrics.py#L7 73 | @dataclasses.dataclass(init=False) 74 | class AverageMeter: 75 | """Computes and stores the average and current value. 76 | `val` is the running value, `avg` is the average value over an epoch. 77 | """ 78 | 79 | avg: Optional[float] = 0 80 | 81 | def __init__(self, name=None): 82 | self.name = name 83 | self.computed = False 84 | self.val: Optional[float] = None 85 | self.sum: Optional[float] = None 86 | self.count: Optional[int] = None 87 | self.reset() 88 | 89 | def reset(self): 90 | self.val = 0 91 | self.avg = 0 92 | self.sum = 0 93 | self.count = 0 94 | 95 | def update(self, val, n=1): 96 | """Updates the average meter value with new data. It also converts `torch.Tensor` to primitive datatype.""" 97 | self.computed = True 98 | val = to_item(val) 99 | self.val = val 100 | self.sum += val * n 101 | self.count += n 102 | self.avg = self.sum / self.count 103 | 104 | 105 | def to_item(data: Any) -> Union[int, float, str, np.ndarray, Dict]: 106 | """ 107 | Converts torch.Tensor into cpu numpy format. 108 | Args: 109 | data: torch.Tensor contained in any Iterable or Dictionary. 110 | """ 111 | 112 | if isinstance(data, (int, float, str)): 113 | return data 114 | if isinstance(data, (list, tuple)): 115 | return type(data)(map(to_item, data)) 116 | if isinstance(data, dict): 117 | return {k: to_item(v) for k, v in data.items()} 118 | 119 | if torch.is_tensor(data): 120 | if data.requires_grad: 121 | data = data.detach() 122 | data = data.cpu().numpy() 123 | 124 | logging.info("to_item didn't convert any value.") 125 | return data 126 | 127 | 128 | def filter_list(arr: List[str], pattern: Optional[str] = None) -> List[str]: 129 | """Filter a list of strings with given pattern 130 | ```python 131 | >> arr = ['crossentropy', 'binarycrossentropy', 'softmax', 'mae',] 132 | >> filter_list(arr, ".*entropy*") 133 | >> # ["crossentropy", "binarycrossentropy"] 134 | ``` 135 | """ 136 | if pattern is None: 137 | return arr 138 | 139 | p = re.compile(pattern) 140 | return [s for s in arr if p.match(s)] 141 | 142 | 143 | class GDict(dict): 144 | def to_dict(self): 145 | clone = self.copy() 146 | for k in clone.keys(): 147 | value = clone[k] 148 | try: 149 | clone[k] = dataclasses.asdict(value) 150 | except TypeError: 151 | continue 152 | return clone 153 | -------------------------------------------------------------------------------- /gradsflow/utility/data.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 GradsFlow. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from smart_open import smart_open 16 | 17 | 18 | def download(path): 19 | """Read any filesystem or cloud file""" 20 | with smart_open(path) as fr: 21 | return fr.read() 22 | -------------------------------------------------------------------------------- /gradsflow/utility/imports.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 GradsFlow. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import functools 15 | import importlib 16 | from typing import Optional 17 | 18 | 19 | def is_installed(module_name: str) -> bool: 20 | try: 21 | return importlib.util.find_spec(module_name) is not None 22 | except AttributeError: 23 | return False 24 | 25 | 26 | def requires(package_name: str, err_msg: Optional[str] = None): 27 | def inner_fn(func): 28 | @functools.wraps(func) 29 | def wrapper(*args, **kwargs): 30 | msg = err_msg or f"{package_name} Module must be installed to use!" 31 | if not is_installed(package_name): 32 | raise ModuleNotFoundError(msg) 33 | return func(*args, **kwargs) 34 | 35 | return wrapper 36 | 37 | return inner_fn 38 | -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: GradsFlow 2 | site_description: "An open-source AutoML Library based on PyTorch" 3 | site_author: Aniket Maurya 4 | copyright: 'Copyright © 2023 GradsFlow' 5 | 6 | #banner_url: https://ik.imagekit.io/gradsflow/logo/v2/gf-logo-geadsflow-orange-white-bg_y5v0LNPvdeM.png 7 | repo_url: https://github.com/gradsflow/gradsflow/ 8 | repo_name: gradsflow/gradsflow 9 | 10 | theme: 11 | name: material 12 | custom_dir: docs/overrides 13 | palette: 14 | - scheme: default 15 | primary: deep orange 16 | accent: indigo 17 | toggle: 18 | icon: material/weather-sunny 19 | name: Switch to dark mode 20 | 21 | - scheme: slate 22 | primary: deep orange 23 | accent: indigo 24 | toggle: 25 | icon: material/weather-night 26 | name: Switch to light mode 27 | 28 | logo: https://ik.imagekit.io/gradsflow/logo/v2/gf-logo-gflow-black_9ZE8jOuKXTm.svg?updatedAt=1633488021249 29 | favicon: https://ik.imagekit.io/gradsflow/logo/v2/gf-logo-gflow-white_vCxfpINvg.svg 30 | features: 31 | - search.suggest 32 | - search.highlight 33 | 34 | # Necessary for search to work properly 35 | include_search_page: false 36 | search_index_only: true 37 | 38 | markdown_extensions: 39 | - meta 40 | - pymdownx.highlight 41 | - pymdownx.superfences 42 | - pymdownx.details 43 | - pymdownx.superfences 44 | - admonition 45 | - pymdownx.emoji: 46 | emoji_index: !!python/name:materialx.emoji.twemoji 47 | emoji_generator: !!python/name:materialx.emoji.to_svg 48 | - toc: 49 | permalink: true 50 | 51 | plugins: 52 | - git-revision-date-localized 53 | - search 54 | - autorefs 55 | - mkdocs-jupyter 56 | - mkdocstrings: 57 | default_handler: python 58 | handlers: 59 | python: 60 | rendering: 61 | show_source: false 62 | 63 | extra: 64 | homepage: https://docs.gradsflow.com 65 | 66 | nav: 67 | - Intro: 'index.md' 68 | - Examples: 69 | - Auto Image Classification: 'examples/nbs/01-ImageClassification.ipynb' 70 | - Auto Text Classification: 'examples/nbs/02-TextClassification.ipynb' 71 | - Auto Summarization: 'examples/nbs/03-TextSummarization.ipynb' 72 | - Remote Dataset Loading: 'examples/nbs/04-RayDataset.ipynb' 73 | - Model Training: 'examples/nbs/05-model_fit.ipynb' 74 | - AutoModel - HyperParameter Search: 'examples/nbs/06-AutoModel_fit.ipynb' 75 | - Pix2Pix GAN Code Explanation: 'examples/nbs/Pix2Pix_explained_with_code.ipynb' 76 | - 🤗 HuggingFace Training Example: 'examples/nbs/2021-10-3-huggingface-training.ipynb' 77 | - API References: 78 | - Model: 79 | - gradsflow/models/base.md 80 | - gradsflow/models/model.md 81 | - gradsflow/models/tracker.md 82 | - gradsflow/models/utils.md 83 | - Tuner: gradsflow/tuner 84 | - AutoTasks: 85 | - gradsflow/autotasks/autotasks.md 86 | - gradsflow/autotasks/engine.md 87 | - Data: gradsflow/data 88 | - Callbacks: gradsflow/callbacks 89 | - Core: gradsflow/core 90 | - utility: gradsflow/utils.md 91 | - Release Notes: 'CHANGELOG.md' 92 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = [ 3 | "setuptools >= 40.9.0", 4 | "wheel", 5 | ] 6 | build-backend = "setuptools.build_meta" 7 | 8 | 9 | [tool.isort] 10 | profile = "black" 11 | 12 | [tool.black] 13 | line_length = 120 14 | 15 | 16 | [tool.pytest.ini_options] 17 | norecursedirs = ["tests/autotasks", "tests/tuner"] 18 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | name = gradsflow 3 | version = attr: gradsflow.__version__ 4 | author = Aniket Maurya 5 | author_email = aniket@gradsflow.com 6 | description = An open-source AutoML Library based on PyTorch 7 | long_description = file: README.md, LICENSE.md 8 | long_description_content_type = text/markdown 9 | url = https://github.com/gradsflow/gradsflow 10 | project_urls = 11 | Bug Tracker = https://github.com/gradsflow/gradsflow/issues 12 | Documentation = https://docs.gradsflow.com 13 | Source Code = https://github.com/gradsflow/gradsflow 14 | 15 | classifiers = 16 | ; How mature is this project? Common values are 17 | ; 3 - Alpha, 4 - Beta, 5 - Production/Stable 18 | Development Status :: 4 - Beta 19 | Intended Audience :: Developers 20 | Programming Language :: Python :: 3 21 | License :: OSI Approved :: Apache Software License 22 | Operating System :: OS Independent 23 | 24 | keywords = AutoML, Pytorch, Deep Learning 25 | 26 | [options] 27 | packages = find: 28 | python_requires = >=3.8 29 | install_requires = 30 | torch >=1.13.1 31 | torchvision 32 | ray[default,tune] >=2.2.0 33 | timm>=0.6.12 34 | rich>=13.3.1 35 | smart_open >=5.1,<=5.2.1 36 | torchmetrics >=0.11.1 37 | lightning >=1.9.2 38 | 39 | [options.extras_require] 40 | dev = codecarbon >=1.2.0; wandb; tensorboard 41 | test = pytest; coverage; pytest-sugar; pytest-randomly 42 | 43 | [options.packages.find] #optional 44 | exclude=tests, docs, examples 45 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 GradsFlow. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from setuptools import setup 15 | 16 | if __name__ == "__main__": 17 | setup() 18 | -------------------------------------------------------------------------------- /sonar-project.properties: -------------------------------------------------------------------------------- 1 | sonar.projectKey=gradsflow_gradsflow 2 | sonar.organization=gflow 3 | 4 | # This is the name and version displayed in the SonarCloud UI. 5 | #sonar.projectName=gradsflow 6 | #sonar.projectVersion=1.0 7 | 8 | # Path is relative to the sonar-project.properties file. Replace "\" by "/" on Windows. 9 | sonar.sources=gradsflow 10 | 11 | # Encoding of the source code. Default is default system encoding 12 | sonar.sourceEncoding=UTF-8 13 | 14 | #---- Language properties ---- 15 | sonar.language=py 16 | sonar.python.version=3 17 | sonar.python.coverage.reportPaths=coverage.xml 18 | sonar.tests=tests 19 | sonar.exclusions=tests/** 20 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 GradsFlow. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | import sys 17 | import warnings 18 | 19 | try: 20 | import gradsflow 21 | except ImportError: 22 | sys.path.append("./") 23 | 24 | os.environ["GF_CI"] = "true" 25 | 26 | warnings.filterwarnings("ignore") 27 | -------------------------------------------------------------------------------- /tests/__main__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 GradsFlow. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import urllib.request 16 | import zipfile 17 | from pathlib import Path 18 | 19 | cwd = Path.cwd() 20 | (Path.cwd() / "data").mkdir(exist_ok=True) 21 | 22 | urllib.request.urlretrieve( 23 | "https://github.com/gradsflow/test-data/archive/refs/tags/cat-dog-v0.zip", 24 | f"{cwd}/data/test-cat-dog-v0.zip", 25 | ) 26 | 27 | with zipfile.ZipFile(f"{cwd}/data/test-cat-dog-v0.zip", "r") as zip_ref: 28 | zip_ref.extractall(f"{cwd}/data/") 29 | -------------------------------------------------------------------------------- /tests/autotasks/test_autotasks.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | os.environ["GF_CI"] = "true" 4 | 5 | import warnings 6 | from pathlib import Path 7 | 8 | import ray 9 | from flash.image import ImageClassificationData 10 | 11 | from gradsflow.autotasks import autotask 12 | from gradsflow.data.image import image_dataset_from_directory 13 | from gradsflow.models.model import Model 14 | 15 | warnings.filterwarnings("ignore") 16 | 17 | ray.init(local_mode=True, ignore_reinit_error=True) 18 | 19 | data_dir = Path.cwd() 20 | datamodule = ImageClassificationData.from_folders( 21 | train_folder=f"{data_dir}/data/hymenoptera_data/train/", 22 | val_folder=f"{data_dir}/data/hymenoptera_data/val/", 23 | batch_size=1, 24 | ) 25 | data_dir = Path.cwd() / "data" 26 | 27 | train_data = image_dataset_from_directory(f"{data_dir}/hymenoptera_data/train/", transform=True) 28 | train_dl = train_data.dataloader 29 | 30 | val_data = image_dataset_from_directory(f"{data_dir}/hymenoptera_data/val/", transform=True) 31 | val_dl = val_data.dataloader 32 | 33 | 34 | def test_build_model(): 35 | model = autotask( 36 | task="image-classification", 37 | train_dataloader=train_dl, 38 | val_dataloader=val_dl, 39 | num_classes=2, 40 | timeout=5, 41 | suggested_backbones="ssl_resnet18", 42 | n_trials=1, 43 | ) 44 | kwargs = {"backbone": "ssl_resnet18", "optimizer": "adam", "lr": 1e-1} 45 | model.model = model.build_model(kwargs) 46 | assert isinstance(model.model, Model) 47 | 48 | 49 | def test_hp_tune(): 50 | model = autotask( 51 | task="image-classification", 52 | train_dataloader=train_dl, 53 | val_dataloader=val_dl, 54 | num_classes=2, 55 | max_epochs=1, 56 | max_steps=2, 57 | timeout=10, 58 | suggested_backbones="ssl_resnet18", 59 | optimization_metric="val_loss", 60 | n_trials=1, 61 | ) 62 | model.hp_tune(name="pytest-experiment", mode="max", gpu=0) 63 | -------------------------------------------------------------------------------- /tests/autotasks/test_autotrainer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 GradsFlow. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from unittest.mock import MagicMock, Mock, patch 15 | 16 | import pytest 17 | 18 | from gradsflow.autotasks.engine.backend import Backend 19 | 20 | trainer_config = {"show_progress": False} 21 | 22 | 23 | @patch("gradsflow.autotasks.engine.backend.FlashTrainer") 24 | @patch("gradsflow.autotasks.engine.backend.PLTrainer") 25 | def test_optimization_objective(mock_pl_trainer: Mock, mock_fl_trainer: Mock): 26 | dm = MagicMock() 27 | model_builder = MagicMock() 28 | 29 | # backend_type is pl 30 | autotrainer = Backend(dm, model_builder, optimization_metric="val_accuracy", backend="pl") 31 | autotrainer.optimization_objective({}, trainer_config) 32 | assert mock_pl_trainer.called or mock_fl_trainer.called 33 | 34 | # wrong backend_type is passed 35 | with pytest.raises(NotImplementedError): 36 | autotrainer = Backend( 37 | dm, 38 | model_builder, 39 | optimization_metric="val_accuracy", 40 | backend="error", 41 | ) 42 | autotrainer.optimization_objective({}, trainer_config) 43 | -------------------------------------------------------------------------------- /tests/autotasks/test_core_automodel.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 GradsFlow. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from pathlib import Path 16 | from unittest.mock import MagicMock, patch 17 | 18 | import pytest 19 | import torch 20 | from flash.image import ImageClassificationData 21 | 22 | from gradsflow.autotasks.engine.automodel import AutoModel 23 | from gradsflow.autotasks.engine.backend import BackendType 24 | 25 | cwd = Path.cwd() 26 | datamodule = ImageClassificationData.from_folders( 27 | train_folder=f"{cwd}/data/hymenoptera_data/train/", val_folder=f"{cwd}/data/hymenoptera_data/val/", batch_size=1 28 | ) 29 | 30 | 31 | @patch.multiple(AutoModel, __abstractmethods__=set()) 32 | def test_auto_model(): 33 | assert AutoModel(datamodule) 34 | 35 | 36 | @patch.multiple(AutoModel, __abstractmethods__=set()) 37 | def test_build_model(): 38 | model = AutoModel(datamodule) 39 | with pytest.raises(NotImplementedError): 40 | model.build_model({"lr": 1}) 41 | 42 | 43 | @patch.multiple(AutoModel, __abstractmethods__=set()) 44 | def test_create_search_space(): 45 | model = AutoModel(datamodule) 46 | with pytest.raises(NotImplementedError): 47 | model._create_search_space() 48 | 49 | 50 | @patch.multiple(AutoModel, __abstractmethods__=set()) 51 | @patch("gradsflow.autotasks.engine.backend.FlashTrainer") 52 | @patch("gradsflow.autotasks.engine.backend.PLTrainer") 53 | def test_objective(mock_pl_trainer, mock_fl_trainer): 54 | optimization_metric = "val_accuracy" 55 | model = AutoModel( 56 | datamodule, 57 | optimization_metric=optimization_metric, 58 | backend=BackendType.lightning.value, 59 | ) 60 | 61 | model.backend.model_builder = MagicMock() 62 | 63 | mock_pl_trainer.callback_metrics = mock_fl_trainer.callback_metrics = {optimization_metric: torch.as_tensor([1])} 64 | 65 | model.backend.optimization_objective({}, {}) 66 | 67 | 68 | @patch.multiple(AutoModel, __abstractmethods__=set()) 69 | @patch("gradsflow.autotasks.engine.automodel.tune.run") 70 | def test_hp_tune( 71 | mock_tune, 72 | ): 73 | automodel = AutoModel(datamodule) 74 | automodel._create_search_space = MagicMock() 75 | automodel.build_model = MagicMock() 76 | automodel.hp_tune(gpu=0, cpu=2) 77 | 78 | mock_tune.assert_called() 79 | -------------------------------------------------------------------------------- /tests/autotasks/test_image.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 GradsFlow. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | 17 | os.environ["GF_CI"] = "true" 18 | 19 | import warnings 20 | from pathlib import Path 21 | 22 | import pytest 23 | import torch 24 | 25 | from gradsflow.autotasks import AutoImageClassifier 26 | from gradsflow.data.image import image_dataset_from_directory 27 | from gradsflow.models.model import Model 28 | 29 | warnings.filterwarnings("ignore") 30 | 31 | data_dir = Path.cwd() / "data" 32 | 33 | train_data = image_dataset_from_directory(f"{data_dir}/hymenoptera_data/train/", transform=True) 34 | train_dl = train_data.dataloader 35 | 36 | val_data = image_dataset_from_directory(f"{data_dir}/hymenoptera_data/val/", transform=True) 37 | val_dl = val_data.dataloader 38 | 39 | 40 | def test_forward(): 41 | model = AutoImageClassifier( 42 | train_dataloader=train_dl, 43 | val_dataloader=val_dl, 44 | num_classes=2, 45 | ) 46 | 47 | with pytest.raises(UserWarning): 48 | model.forward(torch.rand(1, 3, 8, 8)) 49 | 50 | 51 | def test_build_model(): 52 | automodel = AutoImageClassifier( 53 | train_dataloader=train_dl, 54 | val_dataloader=val_dl, 55 | num_classes=2, 56 | max_epochs=1, 57 | max_steps=5, 58 | timeout=10, 59 | suggested_backbones="ssl_resnet18", 60 | n_trials=1, 61 | ) 62 | kwargs = {"backbone": "ssl_resnet18", "optimizer": "adam", "lr": 1e-1} 63 | automodel.model = automodel.build_model(kwargs) 64 | assert isinstance(automodel.model, Model) 65 | 66 | 67 | def test_hp_tune(): 68 | model = AutoImageClassifier( 69 | train_dataloader=train_dl, 70 | val_dataloader=val_dl, 71 | num_classes=2, 72 | max_epochs=1, 73 | max_steps=5, 74 | timeout=10, 75 | suggested_backbones="ssl_resnet18", 76 | optimization_metric="train_loss", 77 | n_trials=1, 78 | ) 79 | model.hp_tune(name="pytest-experiment", mode="min", gpu=0) 80 | -------------------------------------------------------------------------------- /tests/autotasks/test_summarization.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 GradsFlow. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | 17 | os.environ["GF_CI"] = "true" 18 | 19 | from unittest.mock import MagicMock 20 | 21 | from gradsflow.autotasks import AutoSummarization 22 | 23 | 24 | def test_build_model(): 25 | datamodule = MagicMock() 26 | model = AutoSummarization( 27 | datamodule, 28 | max_epochs=1, 29 | timeout=5, 30 | suggested_backbones="sshleifer/distilbart-cnn-12-6", 31 | n_trials=1, 32 | ) 33 | model_confs = { 34 | "backbone": model._DEFAULT_BACKBONES[-1], 35 | "optimizer": "adam", 36 | "lr": 1e-3, 37 | } 38 | model.build_model(model_confs) 39 | -------------------------------------------------------------------------------- /tests/autotasks/test_text.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 GradsFlow. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | 17 | os.environ["GF_CI"] = "true" 18 | 19 | from unittest.mock import MagicMock 20 | 21 | from gradsflow.autotasks import AutoTextClassifier 22 | 23 | 24 | def test_build_model(): 25 | suggested_conf = dict( 26 | optimizers=["adam"], 27 | lr=(5e-4, 1e-3), 28 | ) 29 | datamodule = MagicMock() 30 | datamodule.num_classes = 2 31 | model = AutoTextClassifier( 32 | datamodule, 33 | num_classes=datamodule.num_classes, 34 | suggested_backbones=["sgugger/tiny-distilbert-classification"], 35 | suggested_conf=suggested_conf, 36 | max_epochs=1, 37 | optimization_metric="val_loss", 38 | timeout=5, 39 | n_trials=1, 40 | ) 41 | 42 | model_confs = { 43 | "backbone": model._DEFAULT_BACKBONES[-1], 44 | "optimizer": "adam", 45 | "lr": 1e-3, 46 | } 47 | model.build_model(model_confs) 48 | -------------------------------------------------------------------------------- /tests/callbacks/test_logger.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 GradsFlow. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import os 15 | from pathlib import Path 16 | 17 | import pytest 18 | 19 | from gradsflow import AutoDataset 20 | from gradsflow.callbacks import ( 21 | CometCallback, 22 | CSVLogger, 23 | EmissionTrackerCallback, 24 | ModelCheckpoint, 25 | ) 26 | from gradsflow.data.image import image_dataset_from_directory 27 | from gradsflow.utility.imports import is_installed 28 | from tests.dummies import DummyModel 29 | 30 | data_dir = Path.cwd() 31 | folder = f"{data_dir}/data/test-data-cat-dog-v0/cat-dog/" 32 | data = image_dataset_from_directory(folder, transform=True, ray_data=False) 33 | 34 | 35 | @pytest.fixture 36 | def dummy_model(): 37 | return DummyModel() 38 | 39 | 40 | @pytest.fixture 41 | def auto_dataset(): 42 | return AutoDataset(train_dataloader=data.dataloader, val_dataloader=data.dataloader) 43 | 44 | 45 | def test_csv_logger(dummy_model, auto_dataset): 46 | csv_logger = CSVLogger(filename="test_csv_logger.csv") 47 | dummy_model.compile() 48 | dummy_model.fit(auto_dataset, callbacks=csv_logger) 49 | assert os.path.isfile("test_csv_logger.csv") 50 | 51 | 52 | def test_model_checkpoint(dummy_model, auto_dataset): 53 | ckpt_cb = ModelCheckpoint(filename="model_ckpt", path="test_model_checkpoint_folder") 54 | dummy_model.compile() 55 | dummy_model.fit(auto_dataset, callbacks=ckpt_cb) 56 | assert os.path.exists("test_model_checkpoint_folder") 57 | 58 | 59 | @pytest.mark.skipif(not is_installed("comet_ml"), reason="requires `comet_ml` installed") 60 | def test_comet(dummy_model, auto_dataset): 61 | with pytest.raises(ValueError): 62 | CometCallback() 63 | 64 | comet = CometCallback(offline=True) 65 | dummy_model.compile() 66 | dummy_model.fit(auto_dataset, callbacks=[comet]) 67 | 68 | 69 | @pytest.mark.skipif(not is_installed("codecarbon"), reason="requires `codecarbon` installed") 70 | def test_emission_tracker(dummy_model, auto_dataset): 71 | emission_tracker = EmissionTrackerCallback() 72 | dummy_model.compile() 73 | dummy_model.fit(auto_dataset, callbacks=[emission_tracker]) 74 | -------------------------------------------------------------------------------- /tests/callbacks/test_runner.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 GradsFlow. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import pytest 15 | 16 | from gradsflow.callbacks import CallbackRunner, TrainEvalCallback 17 | from gradsflow.callbacks.base import Callback 18 | 19 | 20 | def test_init(dummy_model): 21 | assert isinstance(CallbackRunner(dummy_model, "training").callbacks["TrainEvalCallback"], TrainEvalCallback) 22 | with pytest.raises(NotImplementedError): 23 | CallbackRunner(dummy_model, "random") 24 | 25 | 26 | def test_append(dummy_model): 27 | cb = CallbackRunner(dummy_model) 28 | with pytest.raises(NotImplementedError): 29 | cb.append("random") 30 | cb.append("tune_checkpoint") 31 | cb.append(TrainEvalCallback(cb.model)) 32 | assert len(cb.callbacks) == 2 33 | 34 | for cb_name, cb in cb.callbacks.items(): 35 | assert isinstance(cb_name, str) 36 | assert isinstance(cb, Callback) 37 | 38 | 39 | def test_clean(dummy_model): 40 | cb = CallbackRunner(dummy_model, TrainEvalCallback()) 41 | cb.clean(keep="TrainEvalCallback") 42 | assert cb.callbacks.get("TrainEvalCallback") is not None 43 | -------------------------------------------------------------------------------- /tests/callbacks/test_wandb.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 GradsFlow. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from unittest.mock import Mock, patch 15 | 16 | import pytest 17 | 18 | from gradsflow.callbacks.wandb import WandbCallback 19 | from gradsflow.utility.imports import is_installed 20 | 21 | 22 | @pytest.mark.skipif(not is_installed("wandb"), reason="requires `wandb` installed") 23 | @patch("gradsflow.callbacks.wandb.wandb") 24 | def test_wandbcallback(mock_wandb: Mock, cnn_model, auto_dataset): 25 | model = cnn_model 26 | cb = WandbCallback() 27 | model.compile() 28 | model.fit(auto_dataset, callbacks=cb) 29 | mock_wandb.log.assert_called() 30 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 GradsFlow. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # Arrange 15 | from pathlib import Path 16 | 17 | import pytest 18 | import timm 19 | from torch import nn 20 | 21 | from gradsflow import AutoDataset, Model 22 | from gradsflow.data import image_dataset_from_directory 23 | from gradsflow.models.tracker import Tracker 24 | 25 | data_dir = Path.cwd() 26 | folder = f"{data_dir}/data/test-data-cat-dog-v0/cat-dog/" 27 | data = image_dataset_from_directory(folder, transform=True, ray_data=False) 28 | 29 | 30 | @pytest.fixture 31 | def auto_dataset(): 32 | return AutoDataset(train_dataloader=data.dataloader, val_dataloader=data.dataloader) 33 | 34 | 35 | @pytest.fixture 36 | def resnet18(): 37 | cnn = timm.create_model("ssl_resnet18", pretrained=False, num_classes=10).eval() 38 | 39 | return cnn 40 | 41 | 42 | @pytest.fixture 43 | def cnn_model(resnet18): 44 | model = Model(resnet18) 45 | model.TEST = True 46 | 47 | return model 48 | 49 | 50 | @pytest.fixture 51 | def tracker(): 52 | return Tracker() 53 | 54 | 55 | @pytest.fixture 56 | def dummy_model(): 57 | """A dummy torch.nn model that adds 1 to the forward input value.""" 58 | 59 | class DummyModel(nn.Module): 60 | def __init__(self): 61 | super().__init__() 62 | 63 | def forward(self, x): 64 | return x + 1 65 | 66 | return DummyModel() 67 | -------------------------------------------------------------------------------- /tests/core/test_base.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 GradsFlow. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /tests/data/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 GradsFlow. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /tests/data/test_autodata.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 GradsFlow. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from pathlib import Path 15 | 16 | import pytest 17 | import torch 18 | from lightning import fabric 19 | from lightning.fabric import Fabric 20 | from torch.utils.data import DataLoader, TensorDataset 21 | 22 | from gradsflow.data import AutoDataset 23 | 24 | dataset = TensorDataset(torch.randn(8, 1, 32, 32)) 25 | dataloader = DataLoader(dataset) 26 | 27 | from gradsflow.data.image import image_dataset_from_directory 28 | 29 | data_dir = Path.cwd() 30 | folder = f"{data_dir}/data/test-data-cat-dog-v0/cat-dog/" 31 | data = image_dataset_from_directory(folder, transform=True, ray_data=False) 32 | 33 | 34 | def test_auto_dataset(): 35 | with pytest.raises(UserWarning): 36 | AutoDataset() 37 | 38 | 39 | def test_sent_to_device(): 40 | accelerator = Fabric() 41 | autodata = AutoDataset(dataloader) 42 | assert autodata.device_setup_status is None 43 | autodata.setup_data(accelerator) 44 | assert autodata.device_setup_status 45 | 46 | 47 | def test_dataset(): 48 | accelerator = Fabric() 49 | autodata = AutoDataset(train_dataset=data.dataset, val_dataset=data.dataset) 50 | autodata.setup_data(accelerator) 51 | assert isinstance(autodata.train_dataloader, fabric.fabric._FabricDataLoader) 52 | -------------------------------------------------------------------------------- /tests/data/test_common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 GradsFlow. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from gradsflow.data.common import random_split_dataset 15 | from gradsflow.data.image import get_fake_data 16 | 17 | fake_data = get_fake_data((32, 32)) 18 | 19 | 20 | def test_random_split_dataset(): 21 | d1, d2 = random_split_dataset(fake_data.dataset, 0.9) 22 | assert len(d1) > len(d2) 23 | assert len(d1) == int(len(fake_data.dataset) * 0.9) 24 | -------------------------------------------------------------------------------- /tests/data/test_image_data.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 GradsFlow. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from pathlib import Path 15 | 16 | from gradsflow.data.base import Data 17 | from gradsflow.data.image import image_dataset_from_directory 18 | 19 | data_dir = Path.cwd() 20 | 21 | 22 | def test_image_dataset_from_directory(): 23 | folder = f"{data_dir}/data/test-data-cat-dog-v0/cat-dog/" 24 | res = image_dataset_from_directory(folder, transform=True, ray_data=False) 25 | assert isinstance(res, Data) 26 | -------------------------------------------------------------------------------- /tests/data/test_mixins.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 GradsFlow. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import pytest 15 | import torch 16 | 17 | from gradsflow.data.mixins import DataMixin 18 | from gradsflow.utility import default_device 19 | 20 | 21 | class DataTest(DataMixin): 22 | device = default_device() 23 | 24 | 25 | datamixin = DataTest() 26 | 27 | 28 | def test_send_to_device(): 29 | # data as primitive 30 | assert datamixin.send_to_device(1) == 1 31 | assert datamixin.send_to_device(1.5) == 1.5 32 | 33 | # data as Tensor 34 | x = torch.randn(4, 1) 35 | assert isinstance(datamixin.send_to_device(x), torch.Tensor) 36 | 37 | # data as list 38 | batch = torch.randn(4, 16), [1] * 4 39 | assert datamixin.send_to_device(batch) 40 | 41 | # data as dict 42 | batch = {"inputs": torch.randn(4, 16), "targets": [1] * 4} 43 | assert datamixin.send_to_device(batch) 44 | 45 | # catch error 46 | with pytest.raises(NotImplementedError): 47 | datamixin.send_to_device(set(batch)) 48 | -------------------------------------------------------------------------------- /tests/data/test_ray_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 GradsFlow. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from pathlib import Path 15 | 16 | import pytest 17 | from PIL import Image 18 | 19 | from gradsflow.data.ray_dataset import RayDataset, RayImageFolder 20 | 21 | data_dir = Path.cwd() 22 | 23 | 24 | # TODO: remote dataset test 25 | @pytest.mark.skip 26 | def test_ray_dataset(): 27 | folder = f"{data_dir}/data/test-data-cat-dog-v0/cat-dog/" 28 | 29 | dataset = RayDataset(folder) 30 | 31 | assert len(dataset) == 8 32 | assert next(iter(dataset)) 33 | 34 | assert dataset 35 | 36 | 37 | @pytest.mark.skip 38 | def test_ray_image_folder(): 39 | folder = f"{data_dir}/data/test-data-cat-dog-v0/cat-dog/" 40 | 41 | dataset = RayImageFolder(folder) 42 | 43 | # test_find_classes 44 | assert dataset.find_classes() == ["cat", "dog"] 45 | 46 | # test_iter 47 | item = next(iter(dataset)) 48 | assert isinstance(item[0], Image.Image) 49 | assert isinstance(item[1], str) 50 | -------------------------------------------------------------------------------- /tests/dummies.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 GradsFlow. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | 17 | from gradsflow.models import Model 18 | 19 | 20 | class DummyModel(Model): 21 | def __init__(self): 22 | learner = torch.nn.Linear(1, 4) 23 | super().__init__(learner) 24 | 25 | def backward(self, loss: torch.Tensor): 26 | return None 27 | 28 | def train_step(self, batch): 29 | return {"loss": torch.as_tensor(1), "metrics": {"accuracy": 1}} 30 | 31 | def val_step(self, batch): 32 | return {"loss": torch.as_tensor(1), "metrics": {"accuracy": 1}} 33 | -------------------------------------------------------------------------------- /tests/models/test_exceptions.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 GradsFlow. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import pytest 15 | 16 | from gradsflow.models.exceptions import EpochCancel, FitCancel 17 | 18 | 19 | def test_exceptions(): 20 | with pytest.raises(EpochCancel): 21 | raise EpochCancel() 22 | 23 | with pytest.raises(FitCancel): 24 | raise FitCancel() 25 | -------------------------------------------------------------------------------- /tests/models/test_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 GradsFlow. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import pytest 15 | import timm 16 | import torch 17 | from torchmetrics.classification import BinaryAccuracy 18 | 19 | from gradsflow.callbacks import ModelCheckpoint 20 | from gradsflow.data import AutoDataset 21 | from gradsflow.data.image import get_fake_data 22 | from gradsflow.models.model import Model 23 | from gradsflow.models.tracker import Tracker 24 | 25 | image_size = (64, 64) 26 | train_data = get_fake_data(image_size) 27 | val_data = get_fake_data(image_size) 28 | 29 | num_classes = train_data.dataset.num_classes 30 | autodataset = AutoDataset(train_data.dataloader, val_data.dataloader, num_classes=num_classes) 31 | 32 | cnn = timm.create_model("ssl_resnet18", pretrained=False, num_classes=num_classes).eval() 33 | model = Model(cnn) 34 | model.compile("crossentropyloss", "adam") 35 | model.TEST = True 36 | 37 | 38 | def test_predict(cnn_model): 39 | x = torch.randn(1, 3, 64, 64) 40 | r1 = cnn_model.forward(x) 41 | r2 = cnn_model(x) 42 | r3 = cnn_model.predict(x) 43 | assert torch.all(torch.isclose(r1, r2)) 44 | assert torch.all(torch.isclose(r2, r3)) 45 | assert isinstance(model.predict(torch.randn(1, 3, 64, 64)), torch.Tensor) 46 | 47 | 48 | def test_fit(cnn_model): 49 | cnn_model.compile() 50 | assert autodataset 51 | tracker = cnn_model.fit(autodataset, max_epochs=1, steps_per_epoch=1, show_progress=True) 52 | assert isinstance(tracker, Tracker) 53 | 54 | autodataset2 = AutoDataset(train_data.dataloader, num_classes=num_classes) 55 | cnn_model.TEST = False 56 | ckpt_cb = ModelCheckpoint(save_extra=False) 57 | tracker2 = cnn_model.fit( 58 | autodataset2, 59 | max_epochs=1, 60 | steps_per_epoch=1, 61 | show_progress=False, 62 | resume=False, 63 | callbacks=[ckpt_cb], 64 | ) 65 | assert isinstance(tracker2, Tracker) 66 | 67 | 68 | def test_compile(): 69 | def compute_accuracy(*_, **__): 70 | return 1 71 | 72 | with pytest.raises(NotImplementedError): 73 | model1 = Model(cnn) 74 | model1.compile("crossentropyloss", "adam", metrics=compute_accuracy) 75 | 76 | with pytest.raises(AssertionError): 77 | model2 = Model(cnn) 78 | model2.compile("crossentropyloss", "adam", metrics="random_val") 79 | 80 | model3 = Model(cnn) 81 | model3.compile("crossentropyloss", "adam", metrics=[BinaryAccuracy()]) 82 | 83 | model4 = Model(cnn) 84 | model4.compile("crossentropyloss", torch.optim.Adam) 85 | 86 | model5 = Model(cnn) 87 | model5.compile(torch.nn.CrossEntropyLoss, torch.optim.Adam, learning_rate=0.01) 88 | assert model5.optimizer.param_groups[0]["lr"] == 0.01 89 | 90 | 91 | def test_set_accelerator(resnet18): 92 | model = Model(resnet18, precision=16) 93 | model.compile() 94 | assert model._accelerator 95 | 96 | 97 | def test_save_model(tmp_path, resnet18, cnn_model): 98 | path = f"{tmp_path}/dummy_model.pth" 99 | cnn_model.save(path, save_extra=True) 100 | assert isinstance(torch.load(path), dict) 101 | 102 | cnn_model.save(path, save_extra=False) 103 | assert isinstance(torch.load(path), type(resnet18)) 104 | 105 | 106 | def test_load_from_checkpoint(tmp_path, cnn_model): 107 | path = f"{tmp_path}/dummy_model.pth" 108 | cnn_model.save(path, save_extra=True) 109 | assert isinstance(torch.load(path), dict) 110 | 111 | offline_model = cnn_model.load_from_checkpoint(path) 112 | assert isinstance(offline_model, Model) 113 | -------------------------------------------------------------------------------- /tests/models/test_tracker.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 GradsFlow. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import pytest 15 | from rich.table import Table 16 | 17 | 18 | def test_reset(tracker): 19 | tracker.max_epochs = 5 20 | tracker.reset() 21 | assert tracker.max_epochs == 0 22 | 23 | 24 | def test_mode(tracker): 25 | tracker.mode("train") 26 | tracker.mode("val") 27 | with pytest.raises(KeyError): 28 | tracker.mode("test") 29 | 30 | 31 | def test_track(tracker): 32 | tracker._append_logs("val", 0.9) 33 | tracker._append_logs("score", 0.5) 34 | 35 | 36 | def test_create_table(tracker): 37 | tracker.track_loss(0.1, "train") 38 | tracker.track_loss(0.2, "val") 39 | tracker.track_metrics({"accuracy": 0.9}, mode="train") 40 | table = tracker.create_table() 41 | assert isinstance(table, Table) 42 | 43 | 44 | def test_get_item(tracker): 45 | assert tracker["train"] == tracker.mode("train") 46 | assert isinstance(tracker["metrics"], dict) 47 | assert "train" in tracker["loss"] 48 | assert "val" in tracker["loss"] 49 | -------------------------------------------------------------------------------- /tests/models/test_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 GradsFlow. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # Licensed under the Apache License, Version 2.0 (the "License"); 16 | # you may not use this file except in compliance with the License. 17 | # You may obtain a copy of the License at 18 | # 19 | # http://www.apache.org/licenses/LICENSE-2.0 20 | # 21 | # Unless required by applicable law or agreed to in writing, software 22 | # distributed under the License is distributed on an "AS IS" BASIS, 23 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 24 | # See the License for the specific language governing permissions and 25 | # limitations under the License. 26 | 27 | from gradsflow.models.utils import available_losses, available_metrics 28 | 29 | 30 | def test_available_losses(): 31 | assert isinstance(available_losses()[0], str) 32 | assert isinstance(available_losses(), list) 33 | 34 | 35 | def test_available_metrics(): 36 | assert isinstance(available_metrics()[0], str) 37 | assert isinstance(available_metrics(), list) 38 | -------------------------------------------------------------------------------- /tests/tuner/test_automodel.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 GradsFlow. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import os 15 | 16 | os.environ["GF_CI"] = "true" 17 | 18 | import torch.nn 19 | from ray import tune 20 | from timm import create_model 21 | 22 | from gradsflow.data import AutoDataset 23 | from gradsflow.data.image import get_fake_data 24 | from gradsflow.models.constants import LEARNER 25 | from gradsflow.tuner import AutoModelV2 as AutoModel 26 | from gradsflow.tuner import Tuner 27 | 28 | image_size = (64, 64) 29 | train_data = get_fake_data(image_size, num_classes=2) 30 | val_data = get_fake_data(image_size, num_classes=2) 31 | 32 | num_classes = train_data.dataset.num_classes 33 | autodataset = AutoDataset(train_data.dataloader, val_data.dataloader, num_classes=num_classes) 34 | 35 | 36 | def test_hp_tune(): 37 | tuner = Tuner() 38 | cnn = create_model("resnet18", pretrained=False, num_classes=num_classes) 39 | 40 | model = AutoModel(cnn, optimization_metric="val_loss") 41 | model.compile( 42 | loss="crossentropyloss", optimizer=tune.choice(("adam", "sgd")), learning_rate=tune.loguniform(1e-5, 1e-3) 43 | ) 44 | 45 | model.hp_tune( 46 | tuner, 47 | autodataset, 48 | n_trials=1, 49 | epochs=1, 50 | cpu=0.05, 51 | gpu=0, 52 | trainer_config={"steps_per_epoch": 2}, 53 | ) 54 | 55 | 56 | def test_get_learner(): 57 | tuner = Tuner() 58 | cnn = create_model("resnet18", pretrained=False, num_classes=num_classes) 59 | complex_cnn = tuner.suggest_complex("learner", cnn) 60 | automodel = AutoModel(complex_cnn, optimization_metric="val_loss") 61 | hparams = {LEARNER: 0} 62 | model = automodel._get_learner(hparams, tuner) 63 | assert isinstance(model, torch.nn.Module) 64 | 65 | 66 | def test_compile(): 67 | tuner = Tuner() 68 | cnn = create_model("resnet18", pretrained=False, num_classes=num_classes) 69 | complex_cnn = tuner.suggest_complex("learner", cnn) 70 | 71 | model = AutoModel(complex_cnn, optimization_metric="val_loss") 72 | model.compile( 73 | loss="crossentropyloss", optimizer=tune.choice(("adam", "sgd")), learning_rate=tune.loguniform(1e-5, 1e-3) 74 | ) 75 | -------------------------------------------------------------------------------- /tests/tuner/test_tuner.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 GradsFlow. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import pytest 15 | import ray 16 | from ray.tune.sample import Domain 17 | 18 | from gradsflow.tuner.tuner import ComplexObject, Tuner 19 | 20 | complex_object = ComplexObject() 21 | 22 | 23 | def test_append(): 24 | complex_object.append("test_append") 25 | assert "test_append" in ray.get(complex_object.values) 26 | 27 | 28 | def test_get_complex_object(): 29 | assert complex_object.get_complex_object(0) == "test_append" 30 | 31 | 32 | def test_to_choice(): 33 | assert isinstance(complex_object.to_choice(), Domain) 34 | 35 | 36 | def test_update_search_space(): 37 | tuner = Tuner() 38 | tuner.update_search_space("test_update_search_space", complex_object) 39 | assert isinstance(tuner.get_complex_object("test_update_search_space", 0), str) 40 | 41 | with pytest.raises(UserWarning): 42 | tuner.update_search_space("hello", "world") 43 | 44 | 45 | def test_union(): 46 | tuner = Tuner() 47 | tuner1 = Tuner() 48 | tuner1.choice("dropout", 0.1, 0.2, 0.3) 49 | tuner2 = tuner.union(tuner1) 50 | assert tuner2.value.get("dropout") is not None 51 | 52 | 53 | def test_merge(): 54 | tuner1 = Tuner() 55 | tuner1.choice("dropout", 0.1, 0.2, 0.3) 56 | tuner2 = Tuner() 57 | tuner2.choice("layers", 1, 2, 3) 58 | tuner3 = Tuner.merge(tuner1, tuner2) 59 | assert "layers" in tuner3.value 60 | 61 | 62 | def test_suggest_complex(): 63 | tuner = Tuner() 64 | tuner.suggest_complex("test_complex", "val1", "val2") 65 | assert "test_complex" in tuner.value 66 | 67 | 68 | def test_scalar(): 69 | tuner = Tuner() 70 | tuner.scalar("a", "b") 71 | assert tuner.get("a") == "b" 72 | 73 | 74 | def test_get(): 75 | tuner = Tuner() 76 | tuner.choice("optimizer", "val1", "val2") 77 | assert isinstance(tuner.get("optimizer"), Domain) 78 | 79 | tuner.suggest_complex("complex_opt", "val1") 80 | assert tuner.get("complex_opt").get_complex_object(0) == "val1" 81 | 82 | with pytest.raises(KeyError): 83 | tuner.get("random_key") 84 | -------------------------------------------------------------------------------- /tests/utility/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 GradsFlow. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /tests/utility/test_common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 GradsFlow. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import numpy as np 15 | import torch 16 | 17 | from gradsflow.utility.common import ( 18 | GDict, 19 | default_device, 20 | filter_list, 21 | get_file_extension, 22 | get_files, 23 | listify, 24 | module_to_cls_index, 25 | to_item, 26 | ) 27 | 28 | 29 | def test_create_module_index(): 30 | assert isinstance(module_to_cls_index(torch.optim), dict) 31 | 32 | 33 | def test_get_files(): 34 | assert len(get_files("./")) != 0 35 | 36 | 37 | def test_get_file_extension(): 38 | assert get_file_extension("image.1.png") == "png" 39 | 40 | 41 | def test_listify(): 42 | assert listify(None) == [] 43 | assert listify(1) == [1] 44 | assert listify((1, 2)) == [1, 2] 45 | assert listify([1]) == [1] 46 | assert listify({"a": 1}) == ["a"] 47 | 48 | 49 | def test_get_device(): 50 | assert default_device() in ("cpu", "cuda") 51 | 52 | 53 | def test_to_item(): 54 | x = torch.rand(1, 1, requires_grad=True) 55 | assert isinstance(to_item(x), np.ndarray) 56 | 57 | x = [torch.rand(10)] 58 | assert isinstance(to_item(x), list) 59 | 60 | x = (torch.rand(10),) 61 | assert isinstance(to_item(x), tuple) 62 | 63 | x = {"input": torch.rand(10)} 64 | assert isinstance(to_item(x), dict) 65 | assert isinstance(to_item(x)["input"], np.ndarray) 66 | 67 | 68 | def test_filter_list(): 69 | arr = [ 70 | "crossentropy", 71 | "binarycrossentropy", 72 | "softmax", 73 | "mae", 74 | ] 75 | assert filter_list(arr, ".*entropy") == arr[:2] 76 | assert filter_list(arr) == arr 77 | 78 | 79 | def test_gdict(): 80 | gdict: GDict[str, str] = GDict() 81 | gdict["hi"] = "hello" 82 | assert gdict["hi"] == "hello" 83 | assert list(gdict.items())[0][0] == "hi" 84 | assert list(gdict.items())[0][1] == "hello" 85 | assert gdict.to_dict() == {"hi": "hello"} 86 | -------------------------------------------------------------------------------- /tests/utility/test_data.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 GradsFlow. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from gradsflow.utility import download 15 | 16 | 17 | def test_download(): 18 | with open("test_download.txt", "w") as fw: 19 | fw.write("Hello") 20 | assert b"hello" == (download("test_download.txt")).lower() 21 | --------------------------------------------------------------------------------