├── .coveragerc
├── .deepsource.toml
├── .github
├── CODEOWNERS
├── FUNDING.yml
├── ISSUE_TEMPLATE
│ ├── bug_report.md
│ └── feature_request.md
├── labeler.yml
├── pull_request_template.md
└── workflows
│ ├── labeler.yml
│ ├── latest-changes.yml.off
│ ├── main.yml
│ ├── markdown_links.yml
│ ├── mkdocs_ci.yml.off
│ ├── publish-to-pypi.yml
│ ├── stale.yml
│ └── welcome.yml
├── .gitignore
├── .pep8speaks.yml
├── .pre-commit-config.yaml
├── .readthedocs.yaml
├── CITATION.cff
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── LICENSE
├── MANIFEST.in
├── Makefile
├── README.md
├── docs
├── 404.md
├── CHANGELOG.md
├── CNAME
├── examples
├── gradsflow
│ ├── autotasks
│ │ ├── autotasks.md
│ │ └── engine.md
│ ├── callbacks.md
│ ├── core.md
│ ├── data.md
│ ├── models
│ │ ├── base.md
│ │ ├── model.md
│ │ ├── tracker.md
│ │ └── utils.md
│ ├── tuner.md
│ └── utility.md
├── index.md
├── overrides
│ └── main.html
└── requirements.txt
├── examples
├── nbs
│ ├── 01-ImageClassification.ipynb
│ ├── 02-TextClassification.ipynb
│ ├── 03-TextSummarization.ipynb
│ ├── 04-RayDataset.ipynb
│ ├── 05-model_fit.ipynb
│ ├── 06-AutoModel_fit.ipynb
│ ├── 2021-10-3-huggingface-training.ipynb
│ └── Pix2Pix_explained_with_code.ipynb
└── src
│ ├── models
│ └── hello_world.py
│ └── tasks
│ ├── image_classifier.py
│ └── text_classification.py
├── gradsflow
├── __init__.py
├── autotasks
│ ├── __init__.py
│ ├── autoclassification
│ │ ├── __init__.py
│ │ ├── image.py
│ │ └── text
│ │ │ ├── __init__.py
│ │ │ └── text.py
│ ├── autosummarization.py
│ ├── autotasks.py
│ └── engine
│ │ ├── __init__.py
│ │ ├── autoclassifier.py
│ │ ├── automodel.py
│ │ └── backend.py
├── callbacks
│ ├── __init__.py
│ ├── base.py
│ ├── comet.py
│ ├── gpu.py
│ ├── logger.py
│ ├── progress.py
│ ├── raytune.py
│ ├── runner.py
│ ├── tensorboard.py
│ ├── training.py
│ └── wandb.py
├── core
│ ├── __init__.py
│ ├── base.py
│ └── metrics.py
├── data
│ ├── __init__.py
│ ├── autodata.py
│ ├── base.py
│ ├── common.py
│ ├── image.py
│ ├── mixins.py
│ └── ray_dataset.py
├── models
│ ├── __init__.py
│ ├── base.py
│ ├── constants.py
│ ├── exceptions.py
│ ├── model.py
│ ├── tracker.py
│ └── utils.py
├── tuner
│ ├── __init__.py
│ ├── automodel.py
│ └── tuner.py
└── utility
│ ├── __init__.py
│ ├── common.py
│ ├── data.py
│ └── imports.py
├── mkdocs.yml
├── pyproject.toml
├── setup.cfg
├── setup.py
├── sonar-project.properties
└── tests
├── __init__.py
├── __main__.py
├── autotasks
├── test_autotasks.py
├── test_autotrainer.py
├── test_core_automodel.py
├── test_image.py
├── test_summarization.py
└── test_text.py
├── callbacks
├── test_logger.py
├── test_runner.py
└── test_wandb.py
├── conftest.py
├── core
└── test_base.py
├── data
├── __init__.py
├── test_autodata.py
├── test_common.py
├── test_image_data.py
├── test_mixins.py
└── test_ray_dataset.py
├── dummies.py
├── models
├── test_exceptions.py
├── test_model.py
├── test_tracker.py
└── test_utils.py
├── tuner
├── test_automodel.py
└── test_tuner.py
└── utility
├── __init__.py
├── test_common.py
└── test_data.py
/.coveragerc:
--------------------------------------------------------------------------------
1 | [report]
2 | exclude_lines =
3 | pragma: no cover
4 | def __repr__
5 | if _environ.get("GF_CI")
6 | if self.debug:
7 | if settings.DEBUG
8 | raise AssertionError
9 | raise NotImplementedError
10 | if 0:
11 | if __name__ == .__main__.:
12 |
13 | omit =
14 | tests/__init__.py
15 |
--------------------------------------------------------------------------------
/.deepsource.toml:
--------------------------------------------------------------------------------
1 | version = 1
2 |
3 | test_patterns = ["tests/**"]
4 |
5 | exclude_patterns = [
6 | "tests/**",
7 | "examples/**"
8 | ]
9 |
10 | [[analyzers]]
11 | name = "python"
12 | enabled = true
13 |
14 | [analyzers.meta]
15 | runtime_version = "3.x.x"
16 |
17 | [[transformers]]
18 | name = "black"
19 | enabled = true
20 |
21 | [[transformers]]
22 | name = "isort"
23 | enabled = true
24 |
--------------------------------------------------------------------------------
/.github/CODEOWNERS:
--------------------------------------------------------------------------------
1 | # Lines starting with '#' are comments.
2 | # Each line is a file pattern followed by one or more owners.
3 |
4 | # More details are here: https://help.github.com/articles/about-codeowners/
5 |
6 | # The '*' pattern is global owners.
7 |
8 | # Order is important. The last matching pattern has the most precedence.
9 | # The folders are ordered as follows:
10 |
11 | # In each subsection folders are ordered first by depth, then alphabetically.
12 | # This should make it easy to add new rules without breaking existing ones.
13 |
14 | # Global rule:
15 | * @aniketmaurya
16 |
17 | # tests
18 | /tests/** @aniketmaurya
19 |
--------------------------------------------------------------------------------
/.github/FUNDING.yml:
--------------------------------------------------------------------------------
1 | # These are supported funding model platforms
2 |
3 | github: [aniketmaurya]
4 | patreon: # Replace with a single Patreon username
5 | open_collective: # Replace with a single Open Collective username
6 | ko_fi: aniketmaurya
7 | tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel
8 | community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry
9 | liberapay: # Replace with a single Liberapay username
10 | issuehunt: # gradsflow
11 | otechie: # Replace with a single Otechie username
12 | custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2']
13 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/bug_report.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: 🐛 Bug report
3 | about: Create a report to help us improve
4 | title: ''
5 | labels: ''
6 | assignees: ''
7 |
8 | ---
9 |
10 |
11 | #### Bug description
12 |
13 |
14 | #### Expected result
15 |
16 |
17 | #### Actual result
18 |
19 |
20 | #### Steps to reproduce
21 |
22 |
23 | 1.
24 | 2.
25 | 3.
26 | #### Context
27 |
28 |
29 |
30 | #### Your Environment
31 |
32 |
33 | * Version used:
34 | * Operating System and version:
35 | * Link to your fork:
36 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/feature_request.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: 🚀 Feature request
3 | about: Suggest an idea for this project
4 | title: ''
5 | labels: ''
6 | assignees: ''
7 |
8 | ---
9 |
10 | #### Is your feature request related to a problem? Please describe.
11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
12 |
13 | #### Describe the solution you'd like
14 | A clear and concise description of what you want to happen.
15 |
16 | #### Describe alternatives you've considered
17 | A clear and concise description of any alternative solutions or features you've considered.
18 |
19 | #### Additional context
20 | Add any other context or screenshots about the feature request here.
21 |
--------------------------------------------------------------------------------
/.github/labeler.yml:
--------------------------------------------------------------------------------
1 | # Add 'docs' to any changes within 'docs' folder or any subfolders
2 | documentation:
3 | - docs/**/*
4 |
5 | example:
6 | - examples/**/*
7 |
8 | test:
9 | - tests/**/*
10 |
11 | CI:
12 | - .github/**/*
13 | - "*.yaml"
14 | - "*.yml"
15 |
--------------------------------------------------------------------------------
/.github/pull_request_template.md:
--------------------------------------------------------------------------------
1 | #### Changes
2 |
3 |
4 |
5 |
6 | Fixes # (issue)
7 |
8 |
9 | #### Type of change
10 |
11 | - [ ] 📚 Documentation Update
12 | - [ ] 🧪 Tests Cases
13 | - [ ] 🐞 Bug fix (non-breaking change which fixes an issue)
14 | - [ ] 🔬 New feature (non-breaking change which adds functionality)
15 | - [ ] 🚨 Breaking change (fix or feature that would cause existing functionality to not work as expected)
16 | - [ ] 📝 This change requires a documentation update
17 |
18 |
19 | #### Checklist
20 |
21 | - [ ] My code follows the style guidelines of this project
22 | - [ ] I have performed a self-review of my own code
23 | - [ ] I have commented my code, particularly in hard-to-understand areas
24 | - [ ] I have made corresponding changes to the documentation
25 | - [ ] My changes generate no new warnings
26 | - [ ] Did you update CHANGELOG (docs/CHANGELOG.md) in case of a major change?
27 |
--------------------------------------------------------------------------------
/.github/workflows/labeler.yml:
--------------------------------------------------------------------------------
1 | name: "Pull Request Labeler"
2 | on:
3 | - pull_request_target
4 |
5 | jobs:
6 | triage:
7 | runs-on: ubuntu-latest
8 | steps:
9 | - uses: actions/labeler@v3
10 | with:
11 | repo-token: "${{ secrets.GITHUB_TOKEN }}"
12 |
--------------------------------------------------------------------------------
/.github/workflows/latest-changes.yml.off:
--------------------------------------------------------------------------------
1 | name: Latest Changes
2 |
3 | on:
4 | pull_request_target:
5 | branches:
6 | - main
7 | types:
8 | - closed
9 | # For manually triggering it
10 | workflow_dispatch:
11 | inputs:
12 | number:
13 | description: PR number
14 | required: true
15 |
16 | jobs:
17 | latest-changes:
18 | runs-on: ubuntu-latest
19 | steps:
20 | - uses: actions/checkout@v2
21 | with:
22 | token: ${{ secrets.ACTIONS_TOKEN }}
23 | - uses: docker://tiangolo/latest-changes:0.0.3
24 | with:
25 | token: ${{ secrets.GITHUB_TOKEN }}
26 | latest_changes_file: docs/CHANGELOG.md
27 | latest_changes_header: '## 0.0.3\n'
28 | debug_logs: true
29 |
--------------------------------------------------------------------------------
/.github/workflows/main.yml:
--------------------------------------------------------------------------------
1 | name: pytest
2 | on:
3 | push:
4 | branches: [ main ]
5 | pull_request:
6 | branches: [ main ]
7 |
8 |
9 | jobs:
10 | pytest:
11 | runs-on: ${{ matrix.os }}
12 | timeout-minutes: 15
13 | strategy:
14 | matrix:
15 | os: [ ubuntu-latest, macos-latest ]
16 | python-version: ["3.9", "3.10"]
17 | include:
18 | - os: ubuntu-latest
19 | path: ~/.cache/pip
20 | - os: macos-latest
21 | path: ~/Library/Caches/pip
22 | env:
23 | OS: ${{ matrix.os }}
24 | PYTHON: '3.10'
25 |
26 |
27 | steps:
28 | - uses: actions/checkout@v2
29 | with:
30 | fetch-depth: 0 # Shallow clones should be disabled for a better relevancy of analysis
31 |
32 | - name: Set up Python ${{ matrix.python-version }}
33 | uses: actions/setup-python@v2
34 | with:
35 | python-version: ${{ matrix.python-version }}
36 |
37 | - name: Cache pip
38 | uses: actions/cache@v2
39 | with:
40 | path: ${{ matrix.path }}
41 | key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt') }}
42 | restore-keys: |
43 | ${{ runner.os }}-pip-
44 | ${{ runner.os }}-
45 |
46 | - name: Install dependencies
47 | run: |
48 | python --version
49 | pip --version
50 | python -m pip install --upgrade pip build
51 | pip install -e '.[dev,test]'
52 | pip list
53 | shell: bash
54 |
55 | - name: Prepare Test
56 | run: |
57 | python tests # download test data
58 |
59 | - name: Run Test with Coverage
60 | run: |
61 | coverage erase
62 | coverage run -m pytest
63 |
64 | - name: Generate Coverage Report
65 | run: |
66 | coverage report -m -i
67 | coverage xml -i
68 |
69 | - name: Upload Coverage to Codecov
70 | if: runner.os != 'macOS'
71 | uses: codecov/codecov-action@v1
72 | with:
73 | token: ${{ secrets.CODECOV_TOKEN }}
74 | file: ./coverage.xml
75 | flags: unittests
76 | env_vars: OS,PYTHON
77 | name: codecov-umbrella
78 | fail_ci_if_error: false
79 |
80 | # - name: SonarCloud Scan
81 | # if: runner.os != 'macOS' && env.SONAR_TOKEN != null
82 | # uses: SonarSource/sonarcloud-github-action@master
83 | # env:
84 | # GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} # Needed to get PR information, if any
85 | # SONAR_TOKEN: ${{ secrets.SONAR_TOKEN }}
86 |
--------------------------------------------------------------------------------
/.github/workflows/markdown_links.yml:
--------------------------------------------------------------------------------
1 | name: Check Markdown links
2 |
3 | on: push
4 |
5 | jobs:
6 | markdown-link-check:
7 | runs-on: ubuntu-latest
8 | steps:
9 | - uses: actions/checkout@master
10 | - uses: gaurav-nelson/github-action-markdown-link-check@v1
11 |
--------------------------------------------------------------------------------
/.github/workflows/mkdocs_ci.yml.off:
--------------------------------------------------------------------------------
1 | name: MkDocs
2 | on:
3 | push:
4 | branches:
5 | - master
6 | - main
7 | jobs:
8 | deploy:
9 | runs-on: ubuntu-latest
10 | steps:
11 | - uses: actions/checkout@v2
12 | - uses: actions/setup-python@v2
13 | with:
14 | python-version: 3.x
15 | - run: pip install -r docs/requirements.txt
16 | - run: mkdocs gh-deploy --force
17 |
--------------------------------------------------------------------------------
/.github/workflows/publish-to-pypi.yml:
--------------------------------------------------------------------------------
1 | name: Publish Python 🐍 distributions 📦 to PyPI and TestPyPI
2 |
3 | on: push
4 |
5 | jobs:
6 | build-n-publish:
7 | name: Build and publish Python 🐍 distributions 📦 to PyPI and TestPyPI
8 | runs-on: ubuntu-latest
9 |
10 | steps:
11 | - uses: actions/checkout@master
12 | - name: Set up Python 3.10
13 | uses: actions/setup-python@v1
14 | with:
15 | python-version: "3.10"
16 |
17 | - name: Install pypa/build
18 | run: >-
19 | python -m
20 | pip install
21 | build
22 | --user
23 | - name: Build a binary wheel and a source tarball
24 | run: >-
25 | python -m build
26 |
27 | # - name: Publish distribution 📦 to Test PyPI
28 | # if: startsWith(github.ref, 'refs/tags')
29 | # uses: pypa/gh-action-pypi-publish@master
30 | # with:
31 | # password: ${{ secrets.TEST_PYPI_API_TOKEN }}
32 | # repository_url: https://test.pypi.org/legacy/
33 |
34 | - name: Publish distribution 📦 to PyPI
35 | if: startsWith(github.ref, 'refs/tags')
36 | uses: pypa/gh-action-pypi-publish@master
37 | with:
38 | password: ${{ secrets.PYPI_API_TOKEN }}
39 |
--------------------------------------------------------------------------------
/.github/workflows/stale.yml:
--------------------------------------------------------------------------------
1 | name: Mark stale issues and pull requests
2 |
3 | on:
4 | schedule:
5 | - cron: "30 1 * * *"
6 |
7 | jobs:
8 | stale:
9 |
10 | runs-on: ubuntu-latest
11 |
12 | steps:
13 | - uses: actions/stale@v3
14 | with:
15 | repo-token: ${{ secrets.GITHUB_TOKEN }}
16 | stale-issue-message: 'Stale issue message'
17 | stale-pr-message: 'Stale pull request message'
18 | stale-issue-label: 'no-issue-activity'
19 | stale-pr-label: 'no-pr-activity'
20 |
--------------------------------------------------------------------------------
/.github/workflows/welcome.yml:
--------------------------------------------------------------------------------
1 | name: Greet New Contributors
2 |
3 | on: [pull_request_target, issues]
4 |
5 | jobs:
6 | greeting:
7 | runs-on: ubuntu-latest
8 | steps:
9 | - uses: actions/first-interaction@v1
10 | with:
11 | repo-token: ${{ secrets.GITHUB_TOKEN }}
12 | issue-message: "👋 @${{github.actor}}! Thank you for opening your first issue in this repo. We are so happy that you have decided to contribute and value your contribution. Please read these materials before proceeding: [Contributing Guide](https://github.com/gradsflow/gradsflow/blob/master/CONTRIBUTING.md) and [Code of Conduct](https://github.com/gradsflow/gradsflow/blob/master/CODE_OF_CONDUCT.md)."
13 | pr-message: "👋 @${{github.actor}}! Thank you for opening your first pull request in this repo. We are so happy that you have decided to contribute and value your contribution. Please read these materials before proceeding: [Contributing Guide](https://github.com/gradsflow/gradsflow/blob/master/CONTRIBUTING.md) and [Code of Conduct](https://github.com/gradsflow/gradsflow/blob/master/CODE_OF_CONDUCT.md)."
14 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
131 | .idea/
132 | lightning_logs/
133 | *.zip
134 | *.jpg
135 | *.jpeg
136 | *.png
137 | *.gif
138 |
139 | /data
140 | .vscode/
141 |
--------------------------------------------------------------------------------
/.pep8speaks.yml:
--------------------------------------------------------------------------------
1 | # File : .pep8speaks.yml
2 |
3 | scanner:
4 | diff_only: True # If False, the entire file touched by the Pull Request is scanned for errors. If True, only the diff is scanned.
5 | linter: pycodestyle # Other option is flake8
6 |
7 | pycodestyle: # Same as scanner.linter value. Other option is flake8
8 | max-line-length: 100 # Default is 79 in PEP 8
9 | ignore: # Errors and warnings to ignore
10 | - W504 # line break after binary operator
11 | - E402 # module level import not at top of file
12 | - E731 # do not assign a lambda expression, use a def
13 | - C406 # Unnecessary list literal - rewrite as a dict literal.
14 | - E741 # ambiguous variable name
15 |
16 | no_blank_comment: True # If True, no comment is made on PR without any errors.
17 | descending_issues_order: False # If True, PEP 8 issues in message will be displayed in descending order of line numbers in the file
18 |
19 | message: # Customize the comment made by the bot
20 | opened: # Messages when a new PR is submitted
21 | header: "Hello @{name}! Thanks for opening this PR. "
22 | # The keyword {name} is converted into the author's username
23 | footer: "Do see the [Hitchhiker's guide to code style](https://goo.gl/hqbW4r)"
24 | # The messages can be written as they would over GitHub
25 | updated: # Messages when new commits are added to the PR
26 | header: "Hello @{name}! Thanks for updating this PR. "
27 | footer: "" # Why to comment the link to the style guide everytime? :)
28 | no_errors: "There are currently no PEP 8 issues detected in this Pull Request. Cheers! :beers: "
29 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | # See https://pre-commit.com for more information
2 | # See https://pre-commit.com/hooks.html for more hooks
3 | default_language_version:
4 | python: python3.10
5 |
6 | ci:
7 | autofix_prs: true
8 | autoupdate_commit_msg: '[pre-commit.ci] pre-commit suggestions'
9 | autoupdate_schedule: quarterly
10 | # submodules: true
11 |
12 | repos:
13 | - repo: https://github.com/pre-commit/pre-commit-hooks
14 | rev: v5.0.0
15 | hooks:
16 | - id: trailing-whitespace
17 | - id: end-of-file-fixer
18 | - id: check-added-large-files
19 |
20 | - repo: https://github.com/psf/black
21 | rev: 24.8.0
22 | hooks:
23 | - id: black
24 | name: "Black: The uncompromising Python code formatter"
25 |
26 | - repo: https://github.com/PyCQA/isort
27 | rev: 5.13.2
28 | hooks:
29 | - id: isort
30 | name: "Sort Imports"
31 | args: [ "--profile black" ]
32 |
33 | - repo: https://github.com/codespell-project/codespell
34 | rev: v2.3.0
35 | hooks:
36 | - id: codespell
37 | args:
38 | - --ignore-words-list
39 | - "ans,hist"
40 | - --skip
41 | - "*.bib,*.ipynb"
42 |
43 | - repo: https://github.com/asottile/pyupgrade
44 | rev: v3.17.0
45 | hooks:
46 | - id: pyupgrade
47 | args: [ --py36-plus ]
48 |
49 | - repo: https://github.com/PyCQA/bandit
50 | rev: 1.7.10
51 | hooks:
52 | - id: bandit
53 | language_version: python3
54 | exclude: tests/
55 | args:
56 | - -s
57 | - "B404,B602,B603,B607,B101"
58 |
59 | - repo: local
60 | hooks:
61 | - id: clean
62 | name: clean
63 | entry: make
64 | args: [ "clean" ]
65 | language: system
66 | pass_filenames: false
67 |
68 |
69 | - repo: https://github.com/kynan/nbstripout
70 | rev: 0.7.1
71 | hooks:
72 | - id: nbstripout
73 | args:
74 | - --max-size=500k
75 | - --drop-empty-cells
76 |
--------------------------------------------------------------------------------
/.readthedocs.yaml:
--------------------------------------------------------------------------------
1 | # .readthedocs.yaml
2 | # Read the Docs configuration file
3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
4 |
5 | # Required
6 | version: 2
7 |
8 | mkdocs:
9 | configuration: mkdocs.yml
10 |
11 | # Optionally set the version of Python and requirements required to build your docs
12 | python:
13 | version: 3.8
14 | install:
15 | - requirements: docs/requirements.txt
16 |
--------------------------------------------------------------------------------
/CITATION.cff:
--------------------------------------------------------------------------------
1 | cff-version: 1.2.0
2 | message: "If you use GradsFlow, please cite it as below."
3 | authors:
4 | - family-names: "Maurya"
5 | given-names: "Aniket"
6 | orcid: "https://orcid.org/0000-0002-0202-4810"
7 | title: "gradsflow"
8 | doi: 10.5281/zenodo.5245150
9 | date-released: 2021-08-24
10 | url: "https://github.com/gradsflow/gradsflow"
11 |
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Contributor Covenant Code of Conduct
2 |
3 | ## Our Pledge
4 |
5 | We as members, contributors, and leaders pledge to make participation in our
6 | community a harassment-free experience for everyone, regardless of age, body
7 | size, visible or invisible disability, ethnicity, sex characteristics, gender
8 | identity and expression, level of experience, education, socio-economic status,
9 | nationality, personal appearance, race, religion, or sexual identity
10 | and orientation.
11 |
12 | We pledge to act and interact in ways that contribute to an open, welcoming,
13 | diverse, inclusive, and healthy community.
14 |
15 | ## Our Standards
16 |
17 | Examples of behavior that contributes to a positive environment for our
18 | community include:
19 |
20 | * Demonstrating empathy and kindness toward other people
21 | * Being respectful of differing opinions, viewpoints, and experiences
22 | * Giving and gracefully accepting constructive feedback
23 | * Accepting responsibility and apologizing to those affected by our mistakes,
24 | and learning from the experience
25 | * Focusing on what is best not just for us as individuals, but for the
26 | overall community
27 |
28 | Examples of unacceptable behavior include:
29 |
30 | * The use of sexualized language or imagery, and sexual attention or
31 | advances of any kind
32 | * Trolling, insulting or derogatory comments, and personal or political attacks
33 | * Public or private harassment
34 | * Publishing others' private information, such as a physical or email
35 | address, without their explicit permission
36 | * Other conduct which could reasonably be considered inappropriate in a
37 | professional setting
38 |
39 | ## Enforcement Responsibilities
40 |
41 | Community leaders are responsible for clarifying and enforcing our standards of
42 | acceptable behavior and will take appropriate and fair corrective action in
43 | response to any behavior that they deem inappropriate, threatening, offensive,
44 | or harmful.
45 |
46 | Community leaders have the right and responsibility to remove, edit, or reject
47 | comments, commits, code, wiki edits, issues, and other contributions that are
48 | not aligned to this Code of Conduct, and will communicate reasons for moderation
49 | decisions when appropriate.
50 |
51 | ## Scope
52 |
53 | This Code of Conduct applies within all community spaces, and also applies when
54 | an individual is officially representing the community in public spaces.
55 | Examples of representing our community include using an official e-mail address,
56 | posting via an official social media account, or acting as an appointed
57 | representative at an online or offline event.
58 |
59 | ## Enforcement
60 |
61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be
62 | reported to the community leaders responsible for enforcement at
63 | hello@gradsflow.com.
64 | All complaints will be reviewed and investigated promptly and fairly.
65 |
66 | All community leaders are obligated to respect the privacy and security of the
67 | reporter of any incident.
68 |
69 | ## Enforcement Guidelines
70 |
71 | Community leaders will follow these Community Impact Guidelines in determining
72 | the consequences for any action they deem in violation of this Code of Conduct:
73 |
74 | ### 1. Correction
75 |
76 | **Community Impact**: Use of inappropriate language or other behavior deemed
77 | unprofessional or unwelcome in the community.
78 |
79 | **Consequence**: A private, written warning from community leaders, providing
80 | clarity around the nature of the violation and an explanation of why the
81 | behavior was inappropriate. A public apology may be requested.
82 |
83 | ### 2. Warning
84 |
85 | **Community Impact**: A violation through a single incident or series
86 | of actions.
87 |
88 | **Consequence**: A warning with consequences for continued behavior. No
89 | interaction with the people involved, including unsolicited interaction with
90 | those enforcing the Code of Conduct, for a specified period of time. This
91 | includes avoiding interactions in community spaces as well as external channels
92 | like social media. Violating these terms may lead to a temporary or
93 | permanent ban.
94 |
95 | ### 3. Temporary Ban
96 |
97 | **Community Impact**: A serious violation of community standards, including
98 | sustained inappropriate behavior.
99 |
100 | **Consequence**: A temporary ban from any sort of interaction or public
101 | communication with the community for a specified period of time. No public or
102 | private interaction with the people involved, including unsolicited interaction
103 | with those enforcing the Code of Conduct, is allowed during this period.
104 | Violating these terms may lead to a permanent ban.
105 |
106 | ### 4. Permanent Ban
107 |
108 | **Community Impact**: Demonstrating a pattern of violation of community
109 | standards, including sustained inappropriate behavior, harassment of an
110 | individual, or aggression toward or disparagement of classes of individuals.
111 |
112 | **Consequence**: A permanent ban from any sort of public interaction within
113 | the community.
114 |
115 | ## Attribution
116 |
117 | This Code of Conduct is adapted from the [Contributor Covenant][homepage],
118 | version 2.0, available at
119 | https://www.contributor-covenant.org/version/2/0/code_of_conduct.html.
120 |
121 | Community Impact Guidelines were inspired by [Mozilla's code of conduct
122 | enforcement ladder](https://github.com/mozilla/diversity).
123 |
124 | [homepage]: https://www.contributor-covenant.org
125 |
126 | For answers to common questions about this code of conduct, see the FAQ at
127 | https://www.contributor-covenant.org/faq. Translations are available at
128 | https://www.contributor-covenant.org/translations.
129 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contributing guidelines
2 |
3 | 👍🎉 First off, thanks for taking the time to contribute! 🎉👍
4 |
5 | The following is a set of guidelines for contributing to Gradsflow and its packages,
6 | which are hosted in the GradsFlow Organization on GitHub.
7 | These are mostly guidelines, not rules.
8 | Use your best judgment, and feel free to propose changes to this document in a pull request.
9 |
10 | We welcome any kind of contribution to our software, from simple comment or question to a
11 | full fledged [pull request](https://help.github.com/articles/about-pull-requests/).
12 | Please read and follow our [Code of Conduct](CODE_OF_CONDUCT.md).
13 |
14 | A contribution can be one of the following cases:
15 |
16 | 1. you have a question;
17 | 1. you think you may have found a bug (including unexpected behavior);
18 | 1. you want to make some kind of change to the code base
19 | (e.g. to fix a bug, to add a new feature, to update documentation);
20 | 1. you want to make a new release of the code base.
21 |
22 | The sections below outline the steps in each case.
23 |
24 | ## You have a question
25 |
26 | 1. Use the search functionality [here](https://github.com/gradsflow/gradsflow/issues) to see if someone already
27 | filed the same issue or check out [Docs](https://docs.gradsflow.com).
28 | 2. If your issue search did not yield any relevant results, make a new issue.
29 | 3. Apply the "Question" label; apply other labels when relevant.
30 | 4. You can join our Slack group as well.
31 |
32 | ## You think you may have found a bug
33 |
34 | 1. use the search functionality [here](https://github.com/gradsflow/gradsflow/issues) to see if someone already filed the same issue;
35 | 1. if your issue search did not yield any relevant results, make a new issue, making sure to provide enough information to the rest of the community to understand the cause and context of the problem. Depending on the issue, you may want to include:
36 | - the [SHA hashcode](https://help.github.com/articles/autolinked-references-and-urls/#commit-shas) of the commit that is causing your problem;
37 | - some identifying information (name and version number) for dependencies you're using;
38 | - information about the operating system;
39 | 1. apply relevant labels to the newly created issue.
40 |
41 | ## You want to make some kind of change to the code base
42 |
43 | 1. (**important**) announce your plan to the rest of the community *before you start working*. This announcement should be in the form of a (new) issue;
44 | 1. (**important**) wait until some kind of consensus is reached about your idea being a good idea;
45 | 1. if needed, fork the repository to your own Github profile and create your own feature branch off of the latest master commit. While working on your feature branch, make sure to stay up to date with the master branch by pulling in changes, possibly from the 'upstream' repository (follow the instructions [here](https://help.github.com/articles/configuring-a-remote-for-a-fork/) and [here](https://help.github.com/articles/syncing-a-fork/));
46 | 1. make sure the existing tests still work by running ``pytest``;
47 | 1. add your own tests (if necessary);
48 | 1. update or expand the documentation;
49 | 1. update the `docs/CHANGELOG.md` file with change;
50 | 1. push your feature branch to (your fork of) the https://github.com/gradsflow/gradsflow repository on GitHub;
51 | 1. create the pull request, e.g. following the instructions [here](https://help.github.com/articles/creating-a-pull-request/).
52 |
53 | In case you feel like you've made a valuable contribution, but you don't know how to write or run tests for it, or how to generate the documentation: don't let this discourage you from making the pull request; we can help you! Just go ahead and submit the pull request, but keep in mind that you might be asked to append additional commits to your pull request.
54 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include pyproject.toml
2 | include setup.cfg
3 | include LICENSE
4 | include CONTRIBUTING.md
5 | include README.md
6 | recursive-exclude ** __pycache__/
7 | prune tests/
8 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | build-docs:
2 | cp README.md docs/index.md
3 |
4 | docsserve:
5 | mkdocs serve --dirtyreload --livereload
6 |
7 | test:
8 | python tests/__init__.py
9 | pytest
10 |
11 | coverage: ## Run tests with coverage
12 | coverage erase
13 | coverage run -m pytest
14 | coverage report -m
15 | coverage xml
16 |
17 | clean:
18 | rm -rf dist
19 | find . -type f -name "*.DS_Store" -ls -delete
20 | find . | grep -E "(__pycache__|\.pyc|\.pyo)" | xargs rm -rf
21 | find . | grep -E ".pytest_cache" | xargs rm -rf
22 | find . | grep -E ".ipynb_checkpoints" | xargs rm -rf
23 | rm -f .coverage
24 |
25 | style:
26 | black .
27 | isort --profile black .
28 |
29 | build: clean
30 | python -m build
31 |
32 | test_pypi: build
33 | twine upload -r testpypi dist/*
34 |
35 | pypi: build
36 | twine upload dist/*
37 |
38 | push:
39 | git push && git push --tags
40 |
41 | install: style clean
42 | flit install --deps none
43 |
--------------------------------------------------------------------------------
/docs/404.md:
--------------------------------------------------------------------------------
1 | # Oops! The page you are looking for does not exist.
2 |
--------------------------------------------------------------------------------
/docs/CNAME:
--------------------------------------------------------------------------------
1 | docs.gradsflow.com
2 |
--------------------------------------------------------------------------------
/docs/examples:
--------------------------------------------------------------------------------
1 | ../examples
--------------------------------------------------------------------------------
/docs/gradsflow/autotasks/autotasks.md:
--------------------------------------------------------------------------------
1 | ::: gradsflow.autotasks.autotasks
2 |
3 | ---
4 |
5 | ::: gradsflow.autotasks.autoclassification
6 |
7 | ---
8 |
9 | ::: gradsflow.autotasks.autosummarization
10 |
--------------------------------------------------------------------------------
/docs/gradsflow/autotasks/engine.md:
--------------------------------------------------------------------------------
1 | ::: gradsflow.autotasks.engine.backend
2 |
3 | ---
4 |
5 | ::: gradsflow.autotasks.engine.automodel
6 |
7 | ---
8 |
9 | ::: gradsflow.autotasks.engine.autoclassifier
10 |
--------------------------------------------------------------------------------
/docs/gradsflow/callbacks.md:
--------------------------------------------------------------------------------
1 | ::: gradsflow.callbacks.wandb.WandbCallback
2 |
3 | ---
4 |
5 | ::: gradsflow.callbacks.gpu.EmissionTrackerCallback
6 |
7 | ---
8 |
9 | ::: gradsflow.callbacks.comet.CometCallback
10 |
11 | ---
12 |
--------------------------------------------------------------------------------
/docs/gradsflow/core.md:
--------------------------------------------------------------------------------
1 | Core Building blocks
2 |
3 | ::: gradsflow.core.base.BaseAutoModel
4 |
--------------------------------------------------------------------------------
/docs/gradsflow/data.md:
--------------------------------------------------------------------------------
1 | ::: gradsflow.data.ray_dataset
2 |
3 | ---
4 |
5 | ::: gradsflow.data.image
6 |
7 | ---
8 |
9 | ::: gradsflow.data.autodata.AutoDataset
10 |
11 | ---
12 |
13 | ::: gradsflow.data.common
14 |
--------------------------------------------------------------------------------
/docs/gradsflow/models/base.md:
--------------------------------------------------------------------------------
1 | ::: gradsflow.models.base.BaseModel
2 |
--------------------------------------------------------------------------------
/docs/gradsflow/models/model.md:
--------------------------------------------------------------------------------
1 | ::: gradsflow.models.model.Model
2 |
--------------------------------------------------------------------------------
/docs/gradsflow/models/tracker.md:
--------------------------------------------------------------------------------
1 | ::: gradsflow.models.tracker.Tracker
2 |
--------------------------------------------------------------------------------
/docs/gradsflow/models/utils.md:
--------------------------------------------------------------------------------
1 | ::: gradsflow.models.utils.available_losses
2 |
3 | ---
4 |
5 | ::: gradsflow.models.utils.available_metrics
6 |
--------------------------------------------------------------------------------
/docs/gradsflow/tuner.md:
--------------------------------------------------------------------------------
1 | ::: gradsflow.tuner
2 |
--------------------------------------------------------------------------------
/docs/gradsflow/utility.md:
--------------------------------------------------------------------------------
1 | ::: gradsflow.utility
2 |
--------------------------------------------------------------------------------
/docs/overrides/main.html:
--------------------------------------------------------------------------------
1 | {% extends "base.html" %}
2 |
3 | {% block extrahead %}
4 | {% set title = config.site_name %}
5 | {% if page and page.meta and page.meta.title %}
6 | {% set title = title ~ " - " ~ page.meta.title %}
7 | {% elif page and page.title and not page.is_homepage %}
8 | {% set title = title ~ " - " ~ page.title | striptags %}
9 | {% endif %}
10 |
11 |
12 |
An open-source AutoML Library based on PyTorch
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 | {% endblock %}
30 |
31 | {% block outdated %}
32 | You're not viewing the latest version.
33 |
34 | Click here to go to latest.
35 |
36 | {% endblock %}
37 |
38 |
39 | {% set extracopyright %}
40 | | Apache License 2.0
41 | {% endset %}
42 |
--------------------------------------------------------------------------------
/docs/requirements.txt:
--------------------------------------------------------------------------------
1 | # git+https://github.com/gradsflow/gradsflow@main
2 | mkdocs==1.4.2
3 | mkdocstrings==0.20.0
4 | mkdocs-material==8.5.11
5 | mkdocstrings-python==0.8.3
6 | mkdocs-material-extensions==1.1.1
7 | mkdocs-git-revision-date-localized-plugin==1.1.0
8 | mkdocs-macros-plugin==0.7.0
9 | mkdocs-autorefs==0.4.1
10 | mkdocs-jupyter==0.22.0
11 | tags-macros-plugin @ git+https://github.com/jldiaz/mkdocs-plugin-tags.git@d26e2f124e4f3471639d426459e281080988fe7a
12 | mkdocs-meta-descriptions-plugin==2.2.0
13 | jupyter_contrib_nbextensions
14 | comet_ml
15 | lightning-flash[image,text]>=0.5.1
16 | wandb
17 | tensorboard
18 |
--------------------------------------------------------------------------------
/examples/nbs/03-TextSummarization.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {
6 | "id": "xYgiH2TkkX7x",
7 | "pycharm": {
8 | "name": "#%% md\n"
9 | }
10 | },
11 | "source": [
12 | "
\n",
13 | "\n",
14 | "`!pip install lightning-flash`\n",
15 | "\n",
16 | "`!pip install -U gradsflow`"
17 | ]
18 | },
19 | {
20 | "cell_type": "code",
21 | "execution_count": null,
22 | "metadata": {
23 | "id": "2_posF7Rj8sH",
24 | "pycharm": {
25 | "name": "#%%\n"
26 | }
27 | },
28 | "outputs": [],
29 | "source": [
30 | "from flash.core.data.utils import download_data\n",
31 | "from flash.text import SummarizationData, SummarizationTask\n",
32 | "\n",
33 | "# 1. Download the data\n",
34 | "download_data(\"https://pl-flash-data.s3.amazonaws.com/xsum.zip\", \"data/\")\n",
35 | "\n",
36 | "# 2. Load the data\n",
37 | "datamodule = SummarizationData.from_csv(\n",
38 | " \"input\",\n",
39 | " \"target\",\n",
40 | " train_file=\"data/xsum/train.csv\",\n",
41 | " val_file=\"data/xsum/valid.csv\",\n",
42 | " test_file=\"data/xsum/test.csv\",\n",
43 | " batch_size=4,\n",
44 | ")"
45 | ]
46 | },
47 | {
48 | "cell_type": "code",
49 | "execution_count": null,
50 | "metadata": {
51 | "id": "YCYFfKhDkVVK",
52 | "pycharm": {
53 | "name": "#%%\n"
54 | }
55 | },
56 | "outputs": [],
57 | "source": [
58 | "from gradsflow import AutoSummarization\n",
59 | "\n",
60 | "suggested_conf = dict(\n",
61 | " optimizers=[\"adam\"],\n",
62 | " lr=(5e-4, 1e-3),\n",
63 | ")\n",
64 | "\n",
65 | "model = AutoSummarization(\n",
66 | " datamodule,\n",
67 | " suggested_backbones=\"sshleifer/distilbart-cnn-12-6\",\n",
68 | " suggested_conf=suggested_conf,\n",
69 | " max_epochs=1,\n",
70 | " optimization_metric=\"train_loss\",\n",
71 | " timeout=600,\n",
72 | ")\n",
73 | "\n",
74 | "print(\"AutoSummarization initialised!\")\n",
75 | "model.hp_tune()"
76 | ]
77 | }
78 | ],
79 | "metadata": {
80 | "colab": {
81 | "name": "03-TextSummarization.ipynb",
82 | "provenance": []
83 | },
84 | "kernelspec": {
85 | "display_name": "Python 3",
86 | "language": "python",
87 | "name": "python3"
88 | },
89 | "language_info": {
90 | "codemirror_mode": {
91 | "name": "ipython",
92 | "version": 3
93 | },
94 | "file_extension": ".py",
95 | "mimetype": "text/x-python",
96 | "name": "python",
97 | "nbconvert_exporter": "python",
98 | "pygments_lexer": "ipython3",
99 | "version": "3.7.10"
100 | }
101 | },
102 | "nbformat": 4,
103 | "nbformat_minor": 4
104 | }
105 |
--------------------------------------------------------------------------------
/examples/nbs/05-model_fit.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "0",
6 | "metadata": {},
7 | "source": [
8 | "
"
9 | ]
10 | },
11 | {
12 | "cell_type": "code",
13 | "execution_count": null,
14 | "id": "1",
15 | "metadata": {},
16 | "outputs": [],
17 | "source": [
18 | "import torchvision\n",
19 | "from timm import create_model\n",
20 | "from torch.utils.data import DataLoader\n",
21 | "from torchvision import transforms as T\n",
22 | "\n",
23 | "from gradsflow import AutoDataset, Model\n",
24 | "from gradsflow.callbacks import (\n",
25 | " CSVLogger,\n",
26 | " EmissionTrackerCallback,\n",
27 | " ModelCheckpoint,\n",
28 | ")\n",
29 | "from gradsflow.data.common import random_split_dataset"
30 | ]
31 | },
32 | {
33 | "cell_type": "code",
34 | "execution_count": null,
35 | "id": "2",
36 | "metadata": {},
37 | "outputs": [
38 | {
39 | "name": "stdout",
40 | "output_type": "stream",
41 | "text": [
42 | "Files already downloaded and verified\n"
43 | ]
44 | }
45 | ],
46 | "source": [
47 | "# Replace dataloaders with your custom dataset and you are all set to train your model\n",
48 | "image_size = (64, 64)\n",
49 | "batch_size = 4\n",
50 | "\n",
51 | "to_rgb = lambda x: x.convert(\"RGB\")\n",
52 | "\n",
53 | "augs = T.Compose([to_rgb, T.AutoAugment(), T.Resize(image_size), T.ToTensor()])\n",
54 | "data = torchvision.datasets.Caltech101(\"~/\", download=True, transform=augs)\n",
55 | "train_data, val_data = random_split_dataset(data, 0.99)\n",
56 | "train_dl = DataLoader(train_data, batch_size=batch_size)\n",
57 | "val_dl = DataLoader(val_data, batch_size=batch_size)\n",
58 | "num_classes = len(data.categories)"
59 | ]
60 | },
61 | {
62 | "cell_type": "code",
63 | "execution_count": null,
64 | "id": "3",
65 | "metadata": {},
66 | "outputs": [
67 | {
68 | "name": "stderr",
69 | "output_type": "stream",
70 | "text": [
71 | "CODECARBON : No CPU tracking mode found. Falling back on CPU constant mode.\n",
72 | "/Users/aniket/miniconda3/envs/am/lib/python3.9/site-packages/apscheduler/util.py:95: PytzUsageWarning: The zone attribute is specific to pytz's interface; please migrate to a new time zone provider. For more details on how to do so, see https://pytz-deprecation-shim.readthedocs.io/en/latest/migration.html\n",
73 | " if obj.zone == 'local':\n",
74 | "/Users/aniket/miniconda3/envs/am/lib/python3.9/site-packages/apscheduler/triggers/interval.py:66: PytzUsageWarning: The normalize method is no longer necessary, as this time zone supports the fold attribute (PEP 495). For more details on migrating to a PEP 495-compliant implementation, see https://pytz-deprecation-shim.readthedocs.io/en/latest/migration.html\n",
75 | " return self.timezone.normalize(next_fire_time)\n"
76 | ]
77 | }
78 | ],
79 | "source": [
80 | "cbs = [\n",
81 | " CSVLogger(\n",
82 | " verbose=True,\n",
83 | " ),\n",
84 | " ModelCheckpoint(),\n",
85 | " EmissionTrackerCallback(),\n",
86 | " # CometCallback(offline=True),\n",
87 | "]"
88 | ]
89 | },
90 | {
91 | "cell_type": "code",
92 | "execution_count": null,
93 | "id": "4",
94 | "metadata": {},
95 | "outputs": [
96 | {
97 | "data": {
98 | "application/vnd.jupyter.widget-view+json": {
99 | "model_id": "e4eb63f7dc584cfcb22d92f15c96bbf3",
100 | "version_major": 2,
101 | "version_minor": 0
102 | },
103 | "text/plain": [
104 | "Output()"
105 | ]
106 | },
107 | "metadata": {},
108 | "output_type": "display_data"
109 | }
110 | ],
111 | "source": [
112 | "autodataset = AutoDataset(train_dl, val_dl, num_classes=num_classes)\n",
113 | "cnn = create_model(\"resnet18\", pretrained=False, num_classes=num_classes)\n",
114 | "\n",
115 | "model = Model(cnn)\n",
116 | "\n",
117 | "model.compile(\"crossentropyloss\", \"adam\", metrics=[\"accuracy\"])\n",
118 | "model.fit(autodataset, max_epochs=10, steps_per_epoch=10, callbacks=cbs)"
119 | ]
120 | }
121 | ],
122 | "metadata": {
123 | "interpreter": {
124 | "hash": "e2d961b663a5ae03743cd178a74853be9b21def56a249d21ac1502fcfb05a9ce"
125 | },
126 | "kernelspec": {
127 | "display_name": "Python 3 (ipykernel)",
128 | "language": "python",
129 | "name": "python3"
130 | },
131 | "language_info": {
132 | "codemirror_mode": {
133 | "name": "ipython",
134 | "version": 3
135 | },
136 | "file_extension": ".py",
137 | "mimetype": "text/x-python",
138 | "name": "python",
139 | "nbconvert_exporter": "python",
140 | "pygments_lexer": "ipython3",
141 | "version": "3.9.9"
142 | },
143 | "toc": {
144 | "base_numbering": 1,
145 | "nav_menu": {},
146 | "number_sections": true,
147 | "sideBar": true,
148 | "skip_h1_title": false,
149 | "title_cell": "Table of Contents",
150 | "title_sidebar": "Contents",
151 | "toc_cell": false,
152 | "toc_position": {},
153 | "toc_section_display": true,
154 | "toc_window_display": false
155 | }
156 | },
157 | "nbformat": 4,
158 | "nbformat_minor": 5
159 | }
160 |
--------------------------------------------------------------------------------
/examples/src/models/hello_world.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 GradsFlow. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # Source code inspired from https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html
15 | import matplotlib.pyplot as plt
16 | import numpy as np
17 | import torch
18 | import torch.nn.functional as F
19 | import torch.optim as optim
20 | import torchvision
21 | import torchvision.transforms as transforms
22 | from torch import nn
23 | from torch.utils.data import DataLoader
24 | from torchmetrics.classification import MulticlassAccuracy
25 |
26 | from gradsflow import AutoDataset, Model
27 | from gradsflow.callbacks import CSVLogger, ModelCheckpoint
28 |
29 | # Replace dataloaders with your custom dataset, and you are all set to train your model
30 | image_size = (64, 64)
31 | batch_size = 4
32 |
33 | transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
34 |
35 | trainset = torchvision.datasets.CIFAR10(root="~/data", train=True, download=True, transform=transform)
36 | train_dl = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)
37 |
38 | testset = torchvision.datasets.CIFAR10(root="~/data", train=False, download=True, transform=transform)
39 | val_dl = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)
40 | num_classes = len(trainset.classes)
41 | cbs = [
42 | CSVLogger(
43 | verbose=True,
44 | ),
45 | ModelCheckpoint(),
46 | # EmissionTrackerCallback(),
47 | # CometCallback(offline=True),
48 | # WandbCallback(),
49 | ]
50 |
51 |
52 | def imshow(img):
53 | img = img / 2 + 0.5 # unnormalize
54 | npimg = img.numpy()
55 | plt.imshow(np.transpose(npimg, (1, 2, 0)))
56 | plt.show()
57 |
58 |
59 | class Net(nn.Module):
60 | def __init__(self):
61 | super().__init__()
62 | self.conv1 = nn.Conv2d(3, 6, 5)
63 | self.pool = nn.MaxPool2d(2, 2)
64 | self.conv2 = nn.Conv2d(6, 16, 5)
65 | self.fc1 = nn.Linear(16 * 5 * 5, 120)
66 | self.fc2 = nn.Linear(120, 84)
67 | self.fc3 = nn.Linear(84, 10)
68 |
69 | def forward(self, x):
70 | x = self.pool(F.relu(self.conv1(x)))
71 | x = self.pool(F.relu(self.conv2(x)))
72 | x = torch.flatten(x, 1) # flatten all dimensions except batch
73 | x = F.relu(self.fc1(x))
74 | x = F.relu(self.fc2(x))
75 | x = self.fc3(x)
76 | return x
77 |
78 |
79 | if __name__ == "__main__":
80 | autodataset = AutoDataset(train_dl, val_dl, num_classes=num_classes)
81 | net = Net()
82 | model = Model(net)
83 | criterion = nn.CrossEntropyLoss()
84 |
85 | model.compile(
86 | criterion,
87 | optim.SGD,
88 | optimizer_config={"momentum": 0.9},
89 | learning_rate=0.001,
90 | metrics=[MulticlassAccuracy(autodataset.num_classes)],
91 | )
92 | model.fit(autodataset, max_epochs=2, callbacks=cbs)
93 |
94 | dataiter = iter(val_dl)
95 | images, labels = next(dataiter)
96 |
97 | # print images
98 | # imshow(torchvision.utils.make_grid(images))
99 | print("GroundTruth: ", " ".join(f"{trainset.classes[labels[j]]:5s}" for j in range(4)))
100 |
101 | outputs = net(images)
102 | _, predicted = torch.max(outputs, 1)
103 |
104 | print("Predicted: ", " ".join(f"{trainset.classes[predicted[j]]:5s}" for j in range(4)))
105 |
--------------------------------------------------------------------------------
/examples/src/tasks/image_classifier.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 GradsFlow. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import torchvision
16 | from torch.utils.data import DataLoader
17 | from torchvision import transforms as T
18 |
19 | from gradsflow import AutoImageClassifier
20 | from gradsflow.data.common import random_split_dataset
21 |
22 | # Replace dataloaders with your custom dataset and you are all set to train your model
23 | image_size = (64, 64)
24 | batch_size = 4
25 |
26 | to_rgb = lambda x: x.convert("RGB")
27 |
28 | # TODO: Add argument parser
29 | if __name__ == "__main__":
30 | augs = T.Compose([to_rgb, T.AutoAugment(), T.Resize(image_size), T.ToTensor()])
31 | data = torchvision.datasets.CIFAR10("~/data", download=True, transform=augs)
32 | train_data, val_data = random_split_dataset(data, 0.01)
33 | train_dl = DataLoader(train_data, batch_size=batch_size)
34 | val_dl = DataLoader(val_data, batch_size=batch_size)
35 |
36 | num_classes = len(data.classes)
37 |
38 | model = AutoImageClassifier(
39 | train_dataloader=train_dl,
40 | val_dataloader=val_dl,
41 | num_classes=num_classes,
42 | max_epochs=5,
43 | optimization_metric="train_loss",
44 | max_steps=1,
45 | n_trials=1,
46 | )
47 | print("AutoImageClassifier initialised!")
48 |
49 | model.hp_tune()
50 |
--------------------------------------------------------------------------------
/examples/src/tasks/text_classification.py:
--------------------------------------------------------------------------------
1 | from flash.core.data.utils import download_data
2 | from flash.text import TextClassificationData
3 |
4 | from gradsflow import AutoTextClassifier
5 |
6 | download_data("https://pl-flash-data.s3.amazonaws.com/imdb.zip", "./data/")
7 |
8 | print("Creating datamodule...")
9 | datamodule = TextClassificationData.from_csv(
10 | "review", "sentiment", train_file="data/imdb/train.csv", val_file="data/imdb/valid.csv", batch_size=4
11 | )
12 |
13 | suggested_conf = dict(
14 | optimizer=["adam", "adamw", "sgd"],
15 | lr=(5e-4, 1e-3),
16 | )
17 |
18 | model = AutoTextClassifier(
19 | datamodule,
20 | suggested_backbones=["prajjwal1/bert-tiny"],
21 | suggested_conf=suggested_conf,
22 | max_epochs=1,
23 | optimization_metric="val_accuracy",
24 | n_trials=1,
25 | )
26 |
27 | print("AutoTextClassifier initialised!")
28 | model.hp_tune(finetune=True)
29 |
--------------------------------------------------------------------------------
/gradsflow/__init__.py:
--------------------------------------------------------------------------------
1 | """An open-source AutoML Library based on PyTorch"""
2 |
3 | # Copyright (c) 2021 GradsFlow. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | import logging
18 | from os import environ as _environ
19 |
20 | from gradsflow.autotasks.autoclassification.image import AutoImageClassifier
21 | from gradsflow.autotasks.autoclassification.text import AutoTextClassifier
22 | from gradsflow.autotasks.autosummarization import AutoSummarization
23 | from gradsflow.autotasks.autotasks import autotask, available_tasks
24 | from gradsflow.autotasks.engine.automodel import AutoModel
25 | from gradsflow.data import AutoDataset
26 | from gradsflow.models.model import Model
27 | from gradsflow.tuner.automodel import AutoModelV2
28 | from gradsflow.tuner.tuner import Tuner
29 |
30 | __version__ = "0.0.8.post1"
31 | logging.basicConfig(level=_environ.get("LOG_LEVEL", "WARNING"))
32 |
--------------------------------------------------------------------------------
/gradsflow/autotasks/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 GradsFlow. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .autoclassification.image import AutoImageClassifier
16 | from .autoclassification.text import AutoTextClassifier
17 | from .autosummarization import AutoSummarization
18 | from .autotasks import autotask, available_tasks
19 |
--------------------------------------------------------------------------------
/gradsflow/autotasks/autoclassification/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 GradsFlow. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/gradsflow/autotasks/autoclassification/image.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 GradsFlow. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import timm
16 |
17 | from gradsflow.autotasks.engine.autoclassifier import AutoClassifier
18 | from gradsflow.autotasks.engine.backend import BackendType
19 | from gradsflow.models.model import Model
20 |
21 |
22 | # noinspection PyTypeChecker
23 | class AutoImageClassifier(AutoClassifier):
24 | """
25 | Automatically find Image Classification Model
26 |
27 | Args:
28 | datamodule Optional[DataModule]: PL Lightning DataModule with `num_classes` property.
29 | train_dataloader Optional[DataLoader]: torch dataloader
30 | val_dataloader Optional[DataLoader]: torch dataloader
31 | num_classes Optional[int]: number of classes
32 | max_epochs [int]: default=10.
33 | n_trials [int]: default=100.
34 | optimization_metric [Optional[str]]: defaults None
35 | suggested_backbones Union[List, str, None]: defaults None
36 | suggested_conf [Optional[dict] = None]: This sets Trial suggestions for optimizer,
37 | learning rate, and all the hyperparameters.
38 | timeout [int]: Hyperparameter search will stop after timeout.
39 | backend_type Optional[str]: Training loop code. Defaults to None.
40 |
41 | Examples:
42 | ```python
43 | from flash.core.data.utils import download_data
44 | from flash.image import ImageClassificationData
45 |
46 | from gradsflow import AutoImageClassifier
47 |
48 | # 1. Create the DataModule
49 | download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "./data")
50 |
51 | datamodule = ImageClassificationData.from_folders(
52 | train_folder="data/hymenoptera_data/train/",
53 | val_folder="data/hymenoptera_data/val/",
54 | )
55 |
56 | model = AutoImageClassifier(datamodule,
57 | max_epochs=10,
58 | optimization_metric="val_accuracy",
59 | timeout=300)
60 | model.hp_tune()
61 | ```
62 | """
63 |
64 | _DEFAULT_BACKBONES = ["ssl_resnet18", "ssl_resnet50"]
65 |
66 | def __init__(self, *args, **kwargs):
67 | super().__init__(*args, **kwargs, backend=BackendType.gf.value)
68 |
69 | def build_model(self, config: dict) -> Model:
70 | """Build ImageClassifier model from `ray.tune` hyperparameter configs
71 | or via _search_space dictionary arguments.
72 |
73 | Arguments:
74 | backbone [str]: Image classification backbone name - resnet18, resnet50,...
75 | (Check Lightning-Flash for full model list)
76 |
77 | optimizer [str]: PyTorch Optimizers. Check `AutoImageClassification._OPTIMIZER_INDEX`
78 | learning_rate [float]: Learning rate for the model.
79 | """
80 | backbone = config["backbone"]
81 |
82 | cnn = timm.create_model(backbone, pretrained=True, num_classes=self.num_classes)
83 | model = Model(cnn)
84 | model.compile(
85 | loss="crossentropyloss", optimizer=config["optimizer"], learning_rate=config["lr"], metrics="accuracy"
86 | )
87 | return model
88 |
--------------------------------------------------------------------------------
/gradsflow/autotasks/autoclassification/text/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 GradsFlow. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | from .text import AutoTextClassifier
15 |
--------------------------------------------------------------------------------
/gradsflow/autotasks/autoclassification/text/text.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 GradsFlow. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | import logging
17 |
18 | import torch
19 |
20 | from gradsflow.autotasks.engine.autoclassifier import AutoClassifier
21 |
22 | logger = logging.getLogger(__name__)
23 |
24 |
25 | # noinspection PyTypeChecker
26 | class AutoTextClassifier(AutoClassifier):
27 | """
28 | Automatically find Text Classification Model
29 |
30 | Examples:
31 | ```python
32 | from gradsflow import AutoTextClassifier
33 |
34 | from flash.core.data.utils import download_data
35 | from flash.text import TextClassificationData
36 |
37 | download_data("https://pl-flash-data.s3.amazonaws.com/imdb.zip", "./data/")
38 | datamodule = TextClassificationData.from_csv(
39 | "review",
40 | "sentiment",
41 | train_file="data/imdb/train.csv",
42 | val_file="data/imdb/valid.csv",
43 | batch_size=4,
44 | )
45 |
46 | model = AutoTextClassifier(datamodule,
47 | suggested_backbones=['sgugger/tiny-distilbert-classification'],
48 | max_epochs=10,
49 | optimization_metric="val_accuracy",
50 | timeout=300)
51 | model.hp_tune()
52 | ```
53 |
54 | Arguments:
55 | datamodule Optional[DataModule]: PL Lightning DataModule with `num_classes` property.
56 | train_dataloader Optional[DataLoader]: torch dataloader
57 | val_dataloader Optional[DataLoader]: torch dataloader
58 | num_classes Optional[int]: number of classes
59 | max_epochs [int]: default=10.
60 | n_trials [int]: default=100.
61 | optimization_metric [Optional[str]]: defaults None
62 | suggested_backbones Union[List, str, None]: defaults None
63 | suggested_conf [Optional[dict] = None]: This sets Trial suggestions for optimizer,
64 | learning rate, and all the hyperparameters.
65 | timeout [int]: Hyperparameter search will stop after timeout.
66 | """
67 |
68 | _DEFAULT_BACKBONES = [
69 | "distilbert-base-uncased-finetuned-sst-2-english",
70 | "sgugger/tiny-distilbert-classification",
71 | ]
72 |
73 | def __init__(self, *args, max_steps=-1, **kwargs):
74 | super().__init__(*args, max_steps=max_steps, **kwargs)
75 | meta = self.auto_dataset.meta
76 | self.num_classes = meta.get("num_labels") or meta.get("num_classes")
77 | logger.debug(f"num_classes = {self.num_classes}")
78 |
79 | def build_model(self, config: dict) -> torch.nn.Module:
80 | """Build TextClassifier model from `ray.tune` hyperparameter configs
81 | or via _search_space dictionary arguments
82 |
83 | Arguments:
84 | backbone [str]: Image classification backbone name - resnet18, resnet50,...
85 | (Check Lightning-Flash for full model list)
86 |
87 | optimizer [str]: PyTorch Optimizers. Check `AutoImageClassification._OPTIMIZER_INDEX`
88 | learning_rate [float]: Learning rate for the model.
89 | """
90 | from flash.text.classification import TextClassifier
91 | from torchmetrics import Accuracy
92 |
93 | backbone = config["backbone"]
94 | optimizer = config["optimizer"]
95 | learning_rate = config["lr"]
96 |
97 | return TextClassifier(
98 | self.num_classes,
99 | backbone=backbone,
100 | optimizer=self._OPTIMIZER_INDEX[optimizer],
101 | learning_rate=learning_rate,
102 | metrics=Accuracy(num_classes=self.num_classes),
103 | )
104 |
--------------------------------------------------------------------------------
/gradsflow/autotasks/autosummarization.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 GradsFlow. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import torch
16 |
17 | from gradsflow.autotasks.engine.autoclassifier import AutoClassifier
18 |
19 |
20 | # noinspection PyTypeChecker
21 | class AutoSummarization(AutoClassifier):
22 | """
23 | Automatically finds Text Summarization Model
24 |
25 | Args:
26 | datamodule Optional[DataModule]: PL Lightning DataModule with `num_classes` property.
27 | train_dataloader Optional[DataLoader]: torch dataloader
28 | val_dataloader Optional[DataLoader]: torch dataloader
29 | num_classes Optional[int]: number of classes
30 | max_epochs [int]: default=10.
31 | n_trials [int]: default=100.
32 | optimization_metric [Optional[str]]: defaults None
33 | suggested_backbones Union[List, str, None]: defaults None
34 | suggested_conf [Optional[dict] = None]: This sets Trial suggestions for optimizer,
35 | learning rate, and all the hyperparameters.
36 | timeout [int]: Hyperparameter search will stop after timeout.
37 |
38 | Examples:
39 | ```python
40 | from gradsflow import AutoSummarization
41 |
42 | from flash.core.data.utils import download_data
43 | from flash.text import SummarizationData, SummarizationTask
44 |
45 | # 1. Download the data
46 | download_data("https://pl-flash-data.s3.amazonaws.com/xsum.zip", "data/")
47 | # 2. Load the data
48 | datamodule = SummarizationData.from_csv(
49 | "input",
50 | "target",
51 | train_file="data/xsum/train.csv",
52 | val_file="data/xsum/valid.csv",
53 | test_file="data/xsum/test.csv",
54 | )
55 |
56 | model = AutoSummarization(datamodule,
57 | max_epochs=10,
58 | optimization_metric="val_accuracy",
59 | timeout=300)
60 | model.hp_tune()
61 | ```
62 | """
63 |
64 | _DEFAULT_BACKBONES = [
65 | "sshleifer/distilbart-cnn-12-6",
66 | "sshleifer/distilbart-xsum-12-3",
67 | ]
68 |
69 | def build_model(self, config: dict) -> torch.nn.Module:
70 | """Build SummarizationModel from `ray.tune` hyperparameter configs
71 | or via _search_space dictionary arguments
72 |
73 | Arguments:
74 | backbone [str]: Image classification backbone name -
75 | sshleifer/distilbart-cnn-12-6, sshleifer/distilbart-xsum-12-3,...
76 | (Check Lightning-Flash for full model list)
77 |
78 | optimizer [str]: PyTorch Optimizers. Check `AutoImageClassification._OPTIMIZER_INDEX`
79 | learning_rate [float]: Learning rate for the model.
80 | """
81 | from flash.text.seq2seq import SummarizationTask
82 |
83 | backbone = config["backbone"]
84 | optimizer = config["optimizer"]
85 | learning_rate = config["lr"]
86 |
87 | return SummarizationTask(
88 | backbone=backbone,
89 | optimizer=self._OPTIMIZER_INDEX[optimizer],
90 | learning_rate=learning_rate,
91 | )
92 |
--------------------------------------------------------------------------------
/gradsflow/autotasks/autotasks.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 GradsFlow. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from typing import List, Optional, Union
16 |
17 | from torch.utils.data import DataLoader
18 |
19 | from gradsflow.utility.imports import is_installed
20 |
21 | from .autoclassification.image import AutoImageClassifier
22 | from .autoclassification.text import AutoTextClassifier
23 | from .autosummarization import AutoSummarization
24 |
25 | if is_installed("pytorch_lightning"):
26 | import pytorch_lightning as pl
27 |
28 | SUPPORTED_TASKS = {
29 | "image-classification": AutoImageClassifier,
30 | "text-classification": AutoTextClassifier,
31 | "summarization": AutoSummarization,
32 | }
33 |
34 |
35 | def available_tasks() -> List[str]:
36 | """Get a list of all available autotasks."""
37 | return list(SUPPORTED_TASKS.keys())
38 |
39 |
40 | def autotask(
41 | datamodule: Optional["pl.LightningDataModule"] = None,
42 | train_dataloader: Optional[DataLoader] = None,
43 | val_dataloader: Optional[DataLoader] = None,
44 | num_classes: Optional[int] = None,
45 | task: Optional[str] = None,
46 | data_type: Optional[str] = None,
47 | max_epochs: int = 10,
48 | max_steps: int = 10,
49 | n_trials: int = 100,
50 | optimization_metric: Optional[str] = None,
51 | suggested_backbones: Union[List, str, None] = None,
52 | suggested_conf: Optional[dict] = None,
53 | timeout: int = 600,
54 | prune: bool = True,
55 | ):
56 | """
57 |
58 | Args:
59 | datamodule Optional[DataModule]: PL Lightning DataModule with `num_classes` property.
60 | train_dataloader Optional[DataLoader]: torch dataloader
61 | val_dataloader Optional[DataLoader]: torch dataloader
62 | num_classes Optional[int]: number of classes
63 | task Optional[str]: type of task. Check available autotasks `availalbe_tasks()
64 | data_type Optional[str]: default=None. type of data - image, text or infer.
65 | max_epochs [int]: default=10.
66 | n_trials [int]: default=100.
67 | optimization_metric [Optional[str]]: defaults None
68 | suggested_backbones Union[List, str, None]: defaults None
69 | suggested_conf [Optional[dict] = None]: This sets Trial suggestions for optimizer,
70 | learning rate, and all the hyperparameters.
71 | timeout [int]: Hyperparameter search will stop after timeout.
72 |
73 | Returns:
74 | Implementation of `AutoModel` for the task type.
75 | """
76 | if not (task or data_type):
77 | raise UserWarning("either task or data_type must be set!")
78 |
79 | if task not in SUPPORTED_TASKS:
80 | raise UserWarning(f"Unknown task {task}, available autotasks are {list(SUPPORTED_TASKS.keys())}")
81 |
82 | targeted_task = SUPPORTED_TASKS[task]
83 |
84 | return targeted_task(
85 | datamodule=datamodule,
86 | train_dataloader=train_dataloader,
87 | val_dataloader=val_dataloader,
88 | num_classes=num_classes,
89 | max_epochs=max_epochs,
90 | max_steps=max_steps,
91 | n_trials=n_trials,
92 | optimization_metric=optimization_metric,
93 | suggested_backbones=suggested_backbones,
94 | suggested_conf=suggested_conf,
95 | timeout=timeout,
96 | prune=prune,
97 | )
98 |
--------------------------------------------------------------------------------
/gradsflow/autotasks/engine/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 GradsFlow. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/gradsflow/autotasks/engine/autoclassifier.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 GradsFlow. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from abc import abstractmethod
16 | from typing import Dict, List, Optional, Union
17 |
18 | import torch
19 | from ray import tune
20 | from torch.utils.data import DataLoader
21 |
22 | from gradsflow.autotasks.engine.automodel import AutoModel
23 | from gradsflow.utility.common import listify
24 | from gradsflow.utility.imports import is_installed
25 |
26 | pl = None
27 | if is_installed("pytorch_lightning"):
28 | import pytorch_lightning as pl
29 |
30 |
31 | class AutoClassifier(AutoModel):
32 | """Implements `AutoModel` for classification autotasks."""
33 |
34 | _DEFAULT_BACKBONES = []
35 |
36 | def __init__(
37 | self,
38 | datamodule: Optional["pl.LightningDataModule"] = None,
39 | train_dataloader: Optional[DataLoader] = None,
40 | val_dataloader: Optional[DataLoader] = None,
41 | num_classes: Optional[int] = None,
42 | max_epochs: int = 10,
43 | max_steps: int = 10,
44 | n_trials: int = 100,
45 | optimization_metric: Optional[str] = None,
46 | suggested_backbones: Union[List, str, None] = None,
47 | suggested_conf: Optional[dict] = None,
48 | timeout: int = 600,
49 | prune: bool = True,
50 | backend: Optional[str] = None,
51 | ):
52 | super().__init__(
53 | datamodule=datamodule,
54 | train_dataloader=train_dataloader,
55 | val_dataloader=val_dataloader,
56 | num_classes=num_classes,
57 | max_epochs=max_epochs,
58 | max_steps=max_steps,
59 | optimization_metric=optimization_metric,
60 | n_trials=n_trials,
61 | suggested_conf=suggested_conf,
62 | timeout=timeout,
63 | prune=prune,
64 | backend=backend,
65 | )
66 |
67 | if isinstance(suggested_backbones, (str, list, tuple)):
68 | self.suggested_backbones = listify(suggested_backbones)
69 | elif suggested_backbones is None:
70 | self.suggested_backbones = self._DEFAULT_BACKBONES
71 | else:
72 | raise UserWarning("Invalid suggested_backbone type!")
73 |
74 | self.num_classes = num_classes
75 |
76 | def forward(self, x):
77 | if not self.model:
78 | raise UserWarning("model not initialized yet, run `hp_tune()` first.")
79 | return self.model(x)
80 |
81 | # noinspection PyTypeChecker
82 | def _create_search_space(self) -> Dict[str, str]:
83 | """Create hyperparameter config from `ray.tune`
84 |
85 | Returns:
86 | key-value pair of `ray.tune` _search_space
87 | """
88 | trial_backbone = tune.choice(self.suggested_backbones)
89 | trial_lr = tune.loguniform(*self.suggested_lr)
90 | trial_optimizer = tune.choice(self.suggested_optimizers)
91 | hparams = {
92 | "backbone": trial_backbone,
93 | "lr": trial_lr,
94 | "optimizer": trial_optimizer,
95 | }
96 | return hparams
97 |
98 | @abstractmethod
99 | def build_model(self, config: dict) -> torch.nn.Module:
100 | """Every Task implementing AutoClassifier has to implement a
101 | build model method that can build `torch.nn.Module` from dictionary config
102 | and return the model.
103 | """
104 |
--------------------------------------------------------------------------------
/gradsflow/autotasks/engine/backend.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 GradsFlow. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import logging
16 | import typing
17 | from enum import Enum
18 | from typing import Callable, Dict, Optional
19 |
20 | import torch
21 |
22 | from gradsflow.callbacks import report_checkpoint_callback
23 | from gradsflow.data import AutoDataset
24 | from gradsflow.utility.common import module_to_cls_index
25 | from gradsflow.utility.imports import is_installed
26 |
27 | if typing.TYPE_CHECKING:
28 | import lightning as L
29 |
30 | if is_installed("lightning-flash"):
31 | from flash import Task
32 | from flash import Trainer as FlashTrainer
33 | from lightning import Trainer as PLTrainer
34 | else:
35 | FlashTrainer = None
36 | PLTrainer = None
37 |
38 | logger = logging.getLogger("core.backend")
39 |
40 |
41 | class BackendType(Enum):
42 | # Remove torch
43 | lightning = "lightning"
44 | gf = "gf"
45 | torch = "gf"
46 | default = "lightning"
47 |
48 |
49 | class Backend:
50 | _OPTIMIZER_INDEX = module_to_cls_index(torch.optim, True)
51 |
52 | def __init__(
53 | self,
54 | autodataset: AutoDataset,
55 | model_builder: Callable,
56 | optimization_metric: Optional[str],
57 | max_epochs: int = 10,
58 | max_steps: Optional[int] = None,
59 | backend: Optional[str] = None,
60 | ):
61 | self.model_builder = model_builder
62 | self.backend_type = (backend or BackendType.default.value).lower()
63 | self.autodataset = autodataset
64 | self.optimization_metric = optimization_metric
65 | self.max_epochs = max_epochs
66 | self.max_steps = max_steps
67 |
68 | def _gf_objective(self, search_space: Dict, trainer_config: Dict, **_):
69 | autodataset = self.autodataset
70 | model = self.model_builder(search_space)
71 | tracker = model.fit(
72 | autodataset=autodataset,
73 | steps_per_epoch=self.max_steps,
74 | callbacks=trainer_config.get("callback_runner", ("tune_checkpoint", "tune_report")),
75 | show_progress=False,
76 | **trainer_config,
77 | )
78 | return tracker
79 |
80 | # noinspection PyTypeChecker
81 | def _lightning_objective(
82 | self, config: Dict, trainer_config: Dict, gpu: Optional[float] = 0, finetune: bool = False
83 | ):
84 | val_check_interval = 1.0
85 | if self.max_steps:
86 | val_check_interval = max(self.max_steps - 1, 1.0)
87 |
88 | datamodule = self.autodataset.datamodule
89 | model = self.model_builder(config)
90 |
91 | trainer_cls = FlashTrainer if isinstance(model, Task) else PLTrainer
92 |
93 | trainer: "L.Trainer" = trainer_cls(
94 | logger=True,
95 | accelerator="auto",
96 | devices="auto",
97 | max_epochs=self.max_epochs,
98 | max_steps=self.max_steps,
99 | callbacks=[report_checkpoint_callback()],
100 | val_check_interval=val_check_interval,
101 | **trainer_config,
102 | )
103 |
104 | hparams = dict(model=model.hparams)
105 | trainer.logger.log_hyperparams(hparams)
106 | if finetune:
107 | trainer.finetune(model, datamodule=datamodule)
108 | else:
109 | trainer.fit(model, datamodule=datamodule)
110 |
111 | logger.debug(trainer.callback_metrics)
112 | return trainer.callback_metrics[self.optimization_metric].item()
113 |
114 | def optimization_objective(
115 | self, config: dict, trainer_config: dict, finetune: bool = False, gpu: Optional[float] = 0.0
116 | ):
117 | """
118 | Defines lightning_objective function which is used by tuner to minimize/maximize the metric.
119 |
120 | Args:
121 | config dict: key value pair of hyperparameters.
122 | trainer_config dict: configurations passed directly to Lightning Trainer.
123 | gpu Optional[float]: GPU per trial
124 | """
125 | if self.backend_type == BackendType.lightning.value:
126 | return self._lightning_objective(config, trainer_config=trainer_config, gpu=gpu, finetune=finetune)
127 |
128 | if self.backend_type in (BackendType.gf.value,):
129 | return self._gf_objective(config, trainer_config=trainer_config, gpu=gpu, finetune=finetune)
130 |
131 | raise NotImplementedError(f"Trainer not implemented for backend_type: {self.backend_type}")
132 |
--------------------------------------------------------------------------------
/gradsflow/callbacks/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 GradsFlow. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | from .comet import CometCallback
15 | from .gpu import EmissionTrackerCallback
16 | from .logger import CSVLogger
17 | from .progress import ProgressCallback
18 | from .raytune import report_checkpoint_callback
19 | from .runner import CallbackRunner
20 | from .tensorboard import TensorboardCallback
21 | from .training import ModelCheckpoint, TrainEvalCallback
22 | from .wandb import WandbCallback
23 |
--------------------------------------------------------------------------------
/gradsflow/callbacks/base.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 GradsFlow. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import typing
16 | from abc import ABC
17 | from typing import Callable, Optional
18 |
19 | if typing.TYPE_CHECKING:
20 | from gradsflow.models.model import Model
21 |
22 |
23 | def dummy(x=None, **__):
24 | return x
25 |
26 |
27 | # TODO: set self.MODE in each callback stage train | val
28 | class Callback(ABC):
29 | """Callback objects define events on which it will run during the model training cycle."""
30 |
31 | _events = ("forward", "step", "train_epoch", "val_epoch", "epoch", "fit")
32 | _name: str = "Callback"
33 |
34 | def __init__(self, model: Optional["Model"] = None):
35 | self.model = model
36 |
37 | @property
38 | def name(self) -> str:
39 | return self._name
40 |
41 | def with_event(self, event_type: str, func: Callable, exception, final_fn: Callable = dummy):
42 | """Calls a function with event wrapped around. Inspired from FastAI.
43 | Ref: https://github.com/fastai/fastai/blob/6e44b354f4d12bdfa2c9530f38f851c54a05764d/fastai/learner.py#L162
44 | """
45 | assert event_type in self._events, f"event_type is {event_type} but should be {self._events}"
46 | start_event = f"on_{event_type}_start"
47 | end_event = f"on_{event_type}_end"
48 | cancel_event = f"on_{event_type}_cancel"
49 | try:
50 | getattr(self, start_event)()
51 | func()
52 | except exception:
53 | getattr(self, cancel_event)()
54 | getattr(self, end_event)()
55 | final_fn()
56 |
57 | def on_fit_start(self):
58 | """Called on each `model.fit(...)`"""
59 |
60 | def on_fit_end(
61 | self,
62 | ):
63 | """Called after `model.fit(...)`"""
64 |
65 | def on_fit_cancel(
66 | self,
67 | ):
68 | """Called after `model.fit(...)`is cancelled"""
69 |
70 | def on_train_epoch_start(
71 | self,
72 | ):
73 | """Called on start of training epoch"""
74 |
75 | def on_train_epoch_end(self, *args, **kwargs):
76 | """Called after end of training epoch"""
77 |
78 | def on_train_epoch_cancel(self):
79 | """Called after training epoch is cancelled"""
80 |
81 | def on_val_epoch_start(
82 | self,
83 | ):
84 | """Called on start of validation epoch"""
85 |
86 | def on_val_epoch_end(self, *args, **kwargs):
87 | """called after validation epoch ends"""
88 |
89 | def on_val_epoch_cancel(self):
90 | """called after validation epoch cancelled"""
91 |
92 | def on_train_step_start(self):
93 | """called before `train_step`"""
94 |
95 | def on_train_step_end(self, *args, **kwargs):
96 | """Called after training step"""
97 |
98 | def on_train_step_cancel(self):
99 | """Called after training step is cancelled"""
100 |
101 | def on_val_step_start(self):
102 | """Called on validation step"""
103 |
104 | def on_val_step_end(self, *args, **kwargs):
105 | """Called after validation step"""
106 |
107 | def on_val_step_cancel(self):
108 | """Called after validation step is cancelled"""
109 |
110 | def on_epoch_start(self):
111 | """Called Before each Epoch"""
112 |
113 | def on_epoch_end(self):
114 | """Called after each epoch"""
115 |
116 | def on_epoch_cancel(self):
117 | """Called after epoch is cancelled"""
118 |
119 | def on_forward_start(self):
120 | """Called before model.forward(...)"""
121 |
122 | def on_forward_end(self):
123 | """Called after model.forward(...)"""
124 |
125 | def on_forward_cancel(self):
126 | """Called after model.forward(...) is cancelled"""
127 |
128 | def clean(self):
129 | """Clean up"""
130 |
--------------------------------------------------------------------------------
/gradsflow/callbacks/comet.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 GradsFlow. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import os
16 |
17 | os.environ["COMET_DISABLE_AUTO_LOGGING"] = "1"
18 | from typing import TYPE_CHECKING, Optional
19 |
20 | BaseExperiment = None
21 | if TYPE_CHECKING:
22 | from comet_ml import BaseExperiment
23 |
24 | from gradsflow.callbacks.base import Callback
25 | from gradsflow.utility.imports import requires
26 |
27 | CURRENT_FILE = os.path.dirname(os.path.realpath(__file__))
28 |
29 |
30 | class CometCallback(Callback):
31 | """
32 | [Comet](https://www.comet.ml/) Logging callback.
33 | This callback requires `comet-ml` to be pre-installed (`pip install comet-ml`).
34 | Automatically log your Experiment to Comet logging platform. You need to provide API key either by setting
35 | environment variable `COMET_API_KEY` or directly pass as an argument to the callback.
36 | Checkout the documentation for more examples.
37 |
38 | Args:
39 | project_name: Name of the Project
40 | api_key: project API key
41 | offline: log experiment offline
42 | """
43 |
44 | def __init__(
45 | self,
46 | project_name: str = "awesome-project",
47 | workspace: Optional[str] = None,
48 | experiment_id: Optional[str] = None,
49 | api_key: Optional[str] = os.environ.get("COMET_API_KEY"),
50 | code_file: str = CURRENT_FILE,
51 | offline: bool = False,
52 | **kwargs,
53 | ):
54 | super().__init__(
55 | model=None,
56 | )
57 | self._code_file = code_file
58 | self._experiment_id = experiment_id
59 | self.experiment = self._create_experiment(
60 | project_name=project_name,
61 | workspace=workspace,
62 | api_key=api_key,
63 | offline=offline,
64 | experiment_id=experiment_id,
65 | **kwargs,
66 | )
67 | self._train_prefix = "train"
68 | self._val_prefix = "val"
69 |
70 | @staticmethod
71 | @requires("comet_ml", "CometCallback requires comet_ml to be installed!")
72 | def _create_experiment(
73 | project_name: str,
74 | workspace: str,
75 | offline: bool = False,
76 | api_key: Optional[str] = None,
77 | experiment_id: Optional[str] = None,
78 | **kwargs,
79 | ) -> "BaseExperiment":
80 | from comet_ml import (
81 | ExistingExperiment,
82 | ExistingOfflineExperiment,
83 | Experiment,
84 | OfflineExperiment,
85 | )
86 |
87 | if offline:
88 | if experiment_id:
89 | experiment = ExistingOfflineExperiment(
90 | project_name=project_name, workspace=workspace, previous_experiment=experiment_id, **kwargs
91 | )
92 | else:
93 | experiment = OfflineExperiment(project_name=project_name, workspace=workspace, **kwargs)
94 | else:
95 | if experiment_id:
96 | experiment = ExistingExperiment(
97 | project_name=project_name,
98 | workspace=workspace,
99 | api_key=api_key,
100 | previous_experiment=experiment_id,
101 | **kwargs,
102 | )
103 | else:
104 | experiment = Experiment(project_name=project_name, workspace=workspace, api_key=api_key, **kwargs)
105 | return experiment
106 |
107 | def on_fit_start(self):
108 | self.experiment.set_model_graph(self.model.learner)
109 | self.experiment.log_code(self._code_file)
110 |
111 | def on_train_epoch_start(
112 | self,
113 | ):
114 | self.experiment.train()
115 |
116 | def on_val_epoch_start(
117 | self,
118 | ):
119 | self.experiment.validate()
120 |
121 | def _step(self, prefix: str, outputs: dict):
122 | step = self.model.tracker.mode(prefix).steps
123 | loss = outputs["loss"].item()
124 | self.experiment.log_metrics(outputs.get("metrics", {}), step=step, prefix=prefix)
125 | self.experiment.log_metric(f"{prefix}_step_loss", loss, step=step)
126 |
127 | def on_train_step_end(self, outputs: dict = None, **_):
128 | self._step(prefix=self._train_prefix, outputs=outputs)
129 |
130 | def on_val_step_end(self, outputs: dict = None, **_):
131 | self._step(prefix=self._val_prefix, outputs=outputs)
132 |
133 | def on_epoch_end(self):
134 | epoch = self.model.tracker.current_epoch
135 | train_loss = self.model.tracker.train_loss
136 | train_metrics = self.model.tracker.train_metrics
137 | val_loss = self.model.tracker.val_loss
138 | val_metrics = self.model.tracker.val_metrics
139 |
140 | self.experiment.train()
141 | self.experiment.log_metric("train_epoch_loss", train_loss, epoch=epoch)
142 | self.experiment.log_metrics(train_metrics, epoch=epoch, prefix=self._train_prefix)
143 |
144 | self.experiment.validate()
145 | self.experiment.log_metric("val_epoch_loss", val_loss, epoch=epoch)
146 | self.experiment.log_metrics(val_metrics, epoch=epoch, prefix=self._val_prefix)
147 | self.experiment.log_epoch_end(epoch)
148 |
--------------------------------------------------------------------------------
/gradsflow/callbacks/gpu.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 GradsFlow. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | import logging
17 |
18 | from gradsflow.callbacks.base import Callback
19 | from gradsflow.utility.imports import requires
20 |
21 | logger = logging.getLogger(__name__)
22 |
23 |
24 | class EmissionTrackerCallback(Callback):
25 | """
26 | Tracks the carbon emissions produced by deep neural networks using
27 | [CodeCarbon](https://github.com/mlco2/codecarbon). To use this callback first install codecarbon using
28 | `pip install codecarbon`.
29 | For offline use, you must have to specify the [country code](https://github.com/mlco2/codecarbon#offline-mode).
30 |
31 | Args:
32 | offline: whether to use internet connection or not. You will have to provide the country code `country_iso_code` for offline use.
33 | **kwargs: passed directly to codecarbon class.
34 | """
35 |
36 | _name = "EmissionTrackerCallback"
37 |
38 | @requires("codecarbon", "install codecarbon to use EmissionTrackerCallback")
39 | def __init__(self, offline: bool = False, **kwargs):
40 | from codecarbon import EmissionsTracker, OfflineEmissionsTracker
41 |
42 | if offline:
43 | self._emission_tracker = OfflineEmissionsTracker(**kwargs)
44 | else:
45 | self._emission_tracker = EmissionsTracker(**kwargs)
46 | self._emission_tracker.start()
47 |
48 | super().__init__(model=None)
49 |
50 | def on_fit_end(self):
51 | emissions: float = self._emission_tracker.stop()
52 | logger.info(f"Emissions: {emissions} kg")
53 |
--------------------------------------------------------------------------------
/gradsflow/callbacks/logger.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 GradsFlow. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import logging
16 | import os
17 | from pathlib import Path
18 |
19 | import pandas as pd
20 |
21 | from gradsflow.callbacks.base import Callback
22 | from gradsflow.utility.common import to_item
23 |
24 | logger = logging.getLogger(__name__)
25 |
26 |
27 | class CSVLogger(Callback):
28 | """
29 | Saves Model training metrics as CSV
30 | Args:
31 | filename: filename of the csv
32 | path: folder path location of the csv
33 | verbose: Whether to show output
34 | """
35 |
36 | _name = "CSVLogger"
37 |
38 | def __init__(self, filename: str = "./experiment.csv", path: str = os.getcwd(), verbose: bool = False):
39 | super().__init__(model=None)
40 | self.filename = filename
41 | self.path = path
42 | self._dst = Path(path) / Path(filename)
43 | self._logs = []
44 | self.verbose = verbose
45 |
46 | def on_epoch_end(self):
47 | epoch = self.model.tracker.current_epoch
48 | train_loss = self.model.tracker.train_loss
49 | val_loss = self.model.tracker.val_loss
50 | train_metrics = self.model.tracker.train_metrics
51 | val_metrics = self.model.tracker.val_metrics
52 |
53 | train_metrics = {"train/" + k: v.avg for k, v in train_metrics.items()}
54 | val_metrics = {"val/" + k: v.avg for k, v in val_metrics.items()}
55 | train_metrics = to_item(train_metrics)
56 | val_metrics = to_item(val_metrics)
57 |
58 | data = {"epoch": epoch, "train_loss": train_loss, "val_loss": val_loss, **train_metrics, **val_metrics}
59 | if self.verbose:
60 | logger.info(f"verbose csv_logger on_epoch_end: {data}")
61 | self._logs.append(data)
62 | df = pd.DataFrame(self._logs)
63 | df.to_csv(self._dst, index=False)
64 |
--------------------------------------------------------------------------------
/gradsflow/callbacks/progress.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 GradsFlow. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | from typing import Dict, Optional
15 |
16 | from rich.progress import BarColumn, Progress, RenderableColumn, TimeRemainingColumn
17 |
18 | from gradsflow.callbacks.base import Callback
19 |
20 |
21 | class ProgressCallback(Callback):
22 | _name = "ProgressCallback"
23 |
24 | def __init__(self, model, progress_kwargs: Optional[Dict] = None):
25 | super().__init__(model)
26 | progress_kwargs = progress_kwargs or {}
27 | tracker = self.model.tracker
28 | self.bar_column = BarColumn()
29 | self.table_column = RenderableColumn(tracker.create_table())
30 |
31 | self.progress = Progress(
32 | "[progress.description]{task.description}",
33 | self.bar_column,
34 | "[progress.percentage]{task.percentage:>3.0f}%",
35 | TimeRemainingColumn(),
36 | self.table_column,
37 | **progress_kwargs,
38 | )
39 | tracker.progress = self.progress
40 | self.fit_prog = None
41 | self.train_prog_bar = None
42 | self.val_prog_bar = None
43 |
44 | def on_fit_start(self):
45 | self.progress.start()
46 | epochs = self.model.tracker.max_epochs
47 | completed = self.model.tracker.current_epoch
48 | self.fit_prog = self.progress.add_task("[red]Progress", total=epochs, completed=completed)
49 |
50 | def on_fit_end(self):
51 | self.progress.stop()
52 |
53 | def on_epoch_end(self):
54 | self.progress.update(self.fit_prog, advance=1)
55 |
56 | def on_train_epoch_start(self):
57 | n = self.model.autodataset.dataloader_length["train"]
58 | self.train_prog_bar = self.progress.add_task("[green]Learning", total=n)
59 |
60 | def on_train_epoch_end(self, *args, **kwargs):
61 | self.progress.remove_task(self.train_prog_bar)
62 | self.table_column.renderable = self.model.tracker.create_table()
63 |
64 | def on_train_step_end(self, *args, **kwargs):
65 | self.progress.update(self.train_prog_bar, advance=1)
66 | self.table_column.renderable = self.model.tracker.create_table()
67 |
68 | def on_val_epoch_start(self):
69 | val_len = self.model.autodataset.dataloader_length["val"]
70 | if val_len is None:
71 | return
72 | self.val_prog_bar = self.progress.add_task("[blue]Validating...", total=val_len)
73 |
74 | def on_val_epoch_end(self, *args, **kwargs):
75 | val_dl = self.model.autodataset.val_dataloader
76 | if not val_dl:
77 | return
78 | self.table_column.renderable = self.model.tracker.create_table()
79 | self.progress.remove_task(self.val_prog_bar)
80 |
81 | def on_val_step_end(self, *args, **kwargs):
82 | self.progress.update(self.val_prog_bar, advance=1)
83 | self.table_column.renderable = self.model.tracker.create_table()
84 |
85 | def clean(self):
86 | self.progress.stop()
87 |
--------------------------------------------------------------------------------
/gradsflow/callbacks/raytune.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 GradsFlow. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import os
15 | from typing import Optional
16 |
17 | from ray import tune
18 |
19 | from gradsflow.callbacks.base import Callback
20 |
21 | _METRICS = {
22 | "val_accuracy": "val_accuracy",
23 | "train_accuracy": "train_accuracy",
24 | }
25 |
26 |
27 | def report_checkpoint_callback(metrics: Optional[dict] = None, filename: Optional[str] = None):
28 | from ray.tune.integration.pytorch_lightning import TuneReportCheckpointCallback
29 |
30 | metrics = metrics or _METRICS
31 | filename = filename or "filename"
32 | callback = TuneReportCheckpointCallback(metrics=metrics, filename=filename, on="validation_end")
33 |
34 | return callback
35 |
36 |
37 | class TorchTuneCheckpointCallback(Callback):
38 | _name = "TorchTuneCheckpointCallback"
39 |
40 | def on_epoch_end(self):
41 | epoch = self.model.tracker.current_epoch
42 |
43 | with tune.checkpoint_dir(epoch) as checkpoint_dir:
44 | path = os.path.join(checkpoint_dir, "filename")
45 | self.model.save(path)
46 |
47 |
48 | class TorchTuneReport(Callback):
49 | _name = "TorchTuneReport"
50 |
51 | def on_epoch_end(self):
52 | val_loss = self.model.tracker.val_loss
53 | train_loss = self.model.tracker.train_loss
54 | val_tracker = self.model.tracker.train.metrics
55 | train_tracker = self.model.tracker.val.metrics
56 |
57 | train_metrics = {"train_" + k.lower(): v.avg for k, v in train_tracker.items()}
58 | val_metrics = {"val_" + k.lower(): v.avg for k, v in val_tracker.items()}
59 |
60 | tune.report(val_loss=val_loss, train_loss=train_loss, **train_metrics, **val_metrics)
61 |
--------------------------------------------------------------------------------
/gradsflow/callbacks/runner.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 GradsFlow. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import typing
15 | from collections import OrderedDict
16 | from typing import Any, Dict, List, Optional, Union
17 |
18 | from gradsflow.callbacks.base import Callback
19 | from gradsflow.callbacks.progress import ProgressCallback
20 | from gradsflow.callbacks.raytune import TorchTuneCheckpointCallback, TorchTuneReport
21 | from gradsflow.callbacks.training import TrainEvalCallback
22 | from gradsflow.utility import listify
23 |
24 | if typing.TYPE_CHECKING:
25 | from gradsflow.models.model import Model
26 |
27 |
28 | class CallbackRunner(Callback):
29 | _name: str = "CallbackRunner"
30 | _AVAILABLE_CALLBACKS: Dict[str, Any] = {
31 | "training": TrainEvalCallback,
32 | "tune_checkpoint": TorchTuneCheckpointCallback,
33 | "tune_report": TorchTuneReport,
34 | "progress": ProgressCallback,
35 | }
36 |
37 | def __init__(self, model: "Model", *callbacks: Union[str, Callback]):
38 | super().__init__(model)
39 | self.callbacks = []
40 | self.callbacks = OrderedDict()
41 | for callback in callbacks:
42 | self.append(callback)
43 |
44 | # skipcq: W0212
45 | def append(self, callback: Union[str, Callback]):
46 | try:
47 | if isinstance(callback, str):
48 | callback_fn: Callback = self._AVAILABLE_CALLBACKS[callback](model=self.model)
49 | self.callbacks[callback_fn._name] = callback_fn
50 | elif isinstance(callback, Callback):
51 | callback.model = self.model
52 | self.callbacks[callback._name] = callback
53 | except KeyError:
54 | raise NotImplementedError(f"callback is not implemented {callback}")
55 |
56 | def available_callbacks(self):
57 | return list(self._AVAILABLE_CALLBACKS.keys())
58 |
59 | def on_train_epoch_end(self, *args, **kwargs):
60 | for _, callback in self.callbacks.items():
61 | callback.on_train_epoch_end(*args, **kwargs)
62 |
63 | def on_train_epoch_start(self):
64 | for _, callback in self.callbacks.items():
65 | callback.on_train_epoch_start()
66 |
67 | def on_fit_start(self):
68 | for _, callback in self.callbacks.items():
69 | callback.on_fit_start()
70 |
71 | def on_fit_end(
72 | self,
73 | ):
74 | for _, callback in self.callbacks.items():
75 | callback.on_fit_end()
76 |
77 | def on_val_epoch_start(
78 | self,
79 | ):
80 | for _, callback in self.callbacks.items():
81 | callback.on_val_epoch_start()
82 |
83 | def on_val_epoch_end(self, *args, **kwargs):
84 | for _, callback in self.callbacks.items():
85 | callback.on_val_epoch_end(*args, **kwargs)
86 |
87 | def on_train_step_start(self):
88 | for _, callback in self.callbacks.items():
89 | callback.on_train_step_start()
90 |
91 | def on_train_step_end(self, *args, **kwargs):
92 | for _, callback in self.callbacks.items():
93 | callback.on_train_step_end(*args, **kwargs)
94 |
95 | def on_val_step_start(self):
96 | for _, callback in self.callbacks.items():
97 | callback.on_val_step_start()
98 |
99 | def on_val_step_end(self, *args, **kwargs):
100 | for _, callback in self.callbacks.items():
101 | callback.on_val_step_end(*args, **kwargs)
102 |
103 | def on_epoch_start(self):
104 | for _, callback in self.callbacks.items():
105 | callback.on_epoch_start()
106 |
107 | def on_epoch_end(self):
108 | for _, callback in self.callbacks.items():
109 | callback.on_epoch_end()
110 |
111 | def on_forward_start(self):
112 | for _, callback in self.callbacks.items():
113 | callback.on_forward_start()
114 |
115 | def on_forward_end(self):
116 | for _, callback in self.callbacks.items():
117 | callback.on_forward_end()
118 |
119 | def clean(self, keep: Optional[Union[List[str], str]] = None):
120 | """Remove all the callbacks except callback names provided in keep"""
121 | for _, callback in self.callbacks.items():
122 | callback.clean()
123 | not_keep = set(self.callbacks.keys()) - set(listify(keep))
124 | for key in not_keep:
125 | self.callbacks.pop(key)
126 | # self.callbacks = OrderedDict(list(self.callbacks.items())[0:1])
127 |
--------------------------------------------------------------------------------
/gradsflow/callbacks/training.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 GradsFlow. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import os
15 | from pathlib import Path
16 | from typing import Optional
17 |
18 | from gradsflow.callbacks.base import Callback
19 |
20 |
21 | class TrainEvalCallback(Callback):
22 | _name = "TrainEvalCallback"
23 |
24 | def on_train_step_start(self):
25 | self.model.optimizer.zero_grad()
26 |
27 | def on_train_step_end(self, *args, outputs: dict = None, **kwargs):
28 | MODE = "train"
29 | # ----- AUTO OPTIMIZATION -----
30 | if not self.model.disable_auto_optimization:
31 | self.model.backward(outputs["loss"])
32 | self.model.optimizer.step()
33 |
34 | # ----- METRIC UPDATES -----
35 | tracker = self.model.tracker
36 | loss = outputs["loss"]
37 | tracker.track_loss(loss, mode=MODE)
38 | tracker.track_metrics(outputs.get("metrics", {}), mode=MODE)
39 |
40 | def on_val_step_end(self, *args, outputs: dict = None, **kwargs):
41 | MODE = "val"
42 | # ----- METRIC UPDATES -----
43 | tracker = self.model.tracker
44 | loss = outputs["loss"]
45 | tracker.track_loss(loss, mode=MODE)
46 | tracker.track_metrics(outputs.get("metrics", {}), mode=MODE)
47 |
48 | def on_train_epoch_start(self):
49 | self.model.train()
50 | self.model.metrics.reset()
51 | self.model.tracker.train.reset()
52 |
53 | def on_val_epoch_start(self):
54 | self.model.eval()
55 | self.model.metrics.reset()
56 | self.model.tracker.val.reset()
57 |
58 |
59 | class ModelCheckpoint(Callback):
60 | """
61 | Saves Model checkpoint
62 | Args:
63 | filename: name of checkpoint
64 | path: folder path location of the model checkpoint. Will create a folder if does not exist.
65 | save_extra: whether to save extra details like tracker
66 | """
67 |
68 | _name = "ModelCheckpoint"
69 |
70 | def __init__(self, filename: Optional[str] = None, path: str = os.getcwd(), save_extra: bool = False):
71 | super().__init__(model=None)
72 | filename = Path(filename or "model")
73 | path = Path(path)
74 | self.path = path
75 | self.path.mkdir(exist_ok=True)
76 | self._dst = path / filename
77 | self.save_extra = save_extra
78 |
79 | def on_epoch_end(self):
80 | epoch = self.model.tracker.current_epoch
81 | path = f"{self._dst}_epoch={epoch}_.pt"
82 | self.model.save(path, save_extra=self.save_extra)
83 |
--------------------------------------------------------------------------------
/gradsflow/callbacks/wandb.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 GradsFlow. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import logging
16 | import os
17 | from typing import Dict, List, Optional
18 |
19 | from gradsflow.callbacks.base import Callback
20 | from gradsflow.utility.imports import is_installed, requires
21 |
22 | logger = logging.getLogger(__name__)
23 |
24 | if is_installed("wandb"):
25 | import wandb
26 |
27 | CURRENT_FILE = os.path.dirname(os.path.realpath(__file__))
28 |
29 |
30 | def define_metrics(default_step_metric: str = "global_step"):
31 | min_max_def: Dict[str, List[str]] = {
32 | "min": ["train/step_loss", "train/epoch_loss", "val/epoch_loss"],
33 | "max": ["train/acc*", "val/acc*"],
34 | }
35 | for summary, metric_list in min_max_def.items():
36 | for metric in metric_list:
37 | if "epoch" in metric or "val" in metric:
38 | wandb.define_metric(metric, summary=summary, step_metric="epoch")
39 | wandb.define_metric("*", step_metric=default_step_metric)
40 |
41 |
42 | class WandbCallback(Callback):
43 | """
44 | [Weights & Biases](https://www.wandb.com/) Logging callback. To use this callback `pip install wandb`.
45 | Any metric that contains `epoch` will be plotted with `epoch` and all the other metrics will be plotted against
46 | `global_step` which is total training steps. You can change the default axis by providing `default_step_metric`.
47 |
48 | Args:
49 | log_model: Whether to upload model artifact to Wandb
50 | code_file: path of the code you want to upload as artifact to Wandb
51 | default_step_metric: Metrics will be plotted against the `default_step_metric`. Default value is `global_step`.
52 |
53 | ```python
54 | from gradsflow.callbacks import WandbCallback
55 | from timm import create_model
56 |
57 | cnn = create_model("resnet18", pretrained=False, num_classes=1)
58 | model = Model(cnn)
59 | model.compile()
60 | cb = WandbCallback()
61 | autodataset = None # create your dataset
62 | model.fit(autodataset, callbacks=cb)
63 | ```
64 | """
65 |
66 | @requires("wandb", "WandbCallback requires wandb to be installed!")
67 | def __init__(self, log_model: bool = False, code_file: Optional[str] = None, default_step_metric="global_step"):
68 | super().__init__()
69 | if wandb.run is None:
70 | logger.warning("wandb.init() was not called before initializing WandbCallback()" "Calling wandb.init()")
71 | wandb.init()
72 | self._code_file = code_file
73 | self._train_prefix = "train"
74 | self._val_prefix = "val"
75 | self._log_model = log_model
76 | self._setup(default_step_metric)
77 |
78 | def _setup(self, default_step_metric):
79 | define_metrics(default_step_metric)
80 |
81 | def on_fit_start(self):
82 | if self._log_model:
83 | wandb.log_artifact(self.model.learner)
84 | if self._code_file:
85 | wandb.log_artifact(self._code_file)
86 |
87 | def _apply_prefix(self, data: dict, prefix: str):
88 | data = {f"{prefix}/{k}": v for k, v in data.items()}
89 | return data
90 |
91 | def on_train_step_end(self, outputs: dict = None, **_):
92 | prefix = "train"
93 | global_step = self.model.tracker.global_step
94 | loss = outputs["loss"].item()
95 | # log train step loss
96 | wandb.log({f"{prefix}/step_loss": loss, "train_step": global_step}, commit=False)
97 |
98 | # log train step metrics
99 | metrics = outputs.get("metrics", {})
100 | metrics = self._apply_prefix(metrics, prefix)
101 | wandb.log(metrics, commit=False)
102 |
103 | # https://docs.wandb.ai/guides/track/log#how-do-i-use-custom-x-axes
104 | wandb.log({"global_step": global_step})
105 |
106 | def on_epoch_end(self):
107 | epoch = self.model.tracker.current_epoch
108 | train_loss = self.model.tracker.train_loss
109 | train_metrics = self.model.tracker.train_metrics.to_dict()
110 | val_loss = self.model.tracker.val_loss
111 | val_metrics = self.model.tracker.val_metrics.to_dict()
112 |
113 | train_metrics = self._apply_prefix(train_metrics, prefix=self._train_prefix)
114 | val_metrics = self._apply_prefix(val_metrics, prefix=self._val_prefix)
115 | train_metrics.update({"epoch": epoch})
116 | val_metrics.update({"epoch": epoch})
117 |
118 | wandb.log({"train/epoch_loss": train_loss, "epoch": epoch}, commit=False)
119 | wandb.log({"val/epoch_loss": val_loss, "epoch": epoch}, commit=False)
120 | wandb.log(train_metrics, commit=False)
121 | wandb.log(val_metrics, commit=False)
122 | wandb.log({})
123 |
--------------------------------------------------------------------------------
/gradsflow/core/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 GradsFlow. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Core Building blocks for Auto Tasks"""
16 |
--------------------------------------------------------------------------------
/gradsflow/core/base.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 GradsFlow. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | from abc import ABC, abstractmethod
16 | from dataclasses import asdict, dataclass
17 | from typing import Dict, Optional
18 |
19 | import numpy as np
20 | import torch
21 |
22 | from gradsflow.utility.common import AverageMeter, GDict, module_to_cls_index
23 |
24 |
25 | class BaseAutoModel(ABC):
26 | """
27 | The main class for AutoML which consists everything required for tranining a model -
28 | data, model and trainer.
29 | """
30 |
31 | _OPTIMIZER_INDEX = module_to_cls_index(torch.optim, True)
32 |
33 | @abstractmethod
34 | def _create_search_space(self) -> Dict[str, str]:
35 | """creates search space"""
36 | raise NotImplementedError
37 |
38 | @abstractmethod
39 | def build_model(self, search_space: dict):
40 | """Build model from dictionary _search_space"""
41 | raise NotImplementedError
42 |
43 |
44 | @dataclass(init=False)
45 | class TrackingValues:
46 | loss: Optional[AverageMeter] = None # Average loss in a single Epoch
47 | steps: Optional[int] = None # Step per epoch
48 | step_loss: Optional[float] = None
49 | metrics: Optional[Dict[str, AverageMeter]] = None # Average value in a single Epoch
50 |
51 | def __init__(self):
52 | self.metrics = GDict()
53 | self.loss = AverageMeter(name="loss")
54 |
55 | def update_loss(self, loss: float):
56 | assert isinstance(loss, (int, float, np.ndarray)), f"loss must be int | float | np.ndarray but got {type(loss)}"
57 | self.step_loss = loss
58 | self.loss.update(loss)
59 |
60 | def update_metrics(self, metrics: Dict[str, float]):
61 | """Update `TrackingValues` metrics. mode can be train or val"""
62 | # Track values that averages with epoch
63 | for key, value in metrics.items():
64 | try:
65 | self.metrics[key].update(value)
66 | except KeyError:
67 | self.metrics[key] = AverageMeter(name=key)
68 | self.metrics[key].update(value)
69 |
70 | def to_dict(self) -> dict:
71 | return asdict(self)
72 |
73 | def reset(self):
74 | """Values are Reset on start of each `on_*_epoch_start`"""
75 | self.loss.reset()
76 | for _, metric in self.metrics.items():
77 | metric.reset()
78 |
--------------------------------------------------------------------------------
/gradsflow/core/metrics.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 GradsFlow. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | from typing import Dict, Union
15 |
16 | import torch
17 | import torchmetrics
18 | from torchmetrics import Metric, MetricCollection
19 |
20 | from gradsflow.utility.common import module_to_cls_index
21 |
22 | _tm_classes = module_to_cls_index(torchmetrics, lower_key=False)
23 |
24 | metrics_classes: Dict[str, Metric] = {k: v for k, v in _tm_classes.items() if 65 <= ord(k[0]) <= 90}
25 | metrics_classes = {k.lower(): v for k, v in metrics_classes.items()}
26 |
27 |
28 | class MetricsContainer:
29 | def __init__(self, device):
30 | self._device = device
31 | self._metrics: MetricCollection = MetricCollection([])
32 |
33 | @property
34 | def metrics(self):
35 | return self._metrics
36 |
37 | def compile_metrics(self, *metrics: Union[str, Metric]) -> None:
38 | """Initialize metrics collection and add provided `*metrics` to the container."""
39 | if len(self._metrics) > 0:
40 | self._metrics = MetricCollection([])
41 | self.add_metrics(*metrics)
42 |
43 | def add_metrics(self, *metrics: Union[str, Metric]) -> None:
44 | for m in metrics:
45 | if isinstance(m, str):
46 | m_cls = metrics_classes.get(m)
47 | assert (
48 | m_cls is not None
49 | ), f"metrics {m} is not available! Available metrics are {tuple(metrics_classes.keys())}"
50 | m_obj = m_cls()
51 | elif isinstance(m, Metric):
52 | m_obj = m
53 | else:
54 | raise NotImplementedError(f"metrics not implemented for {m}! Please see `torchmetrics`.")
55 | self._metrics.add_metrics(m_obj)
56 | self._metrics.to(self._device)
57 |
58 | def _update(self, preds, target):
59 | """Iteratively update all the `torchmetrics` value"""
60 | self._metrics.update(preds, target)
61 |
62 | def compute(self):
63 | return self._metrics.compute()
64 |
65 | def calculate_metrics(self, preds, target) -> Dict[str, torch.Tensor]:
66 | """Iteratively update the compiled metrics and return the new computed values"""
67 | self._update(preds, target)
68 | return self.compute()
69 |
70 | def reset(self):
71 | """Reset the values of each of the compiled metrics"""
72 | self._metrics.reset()
73 |
--------------------------------------------------------------------------------
/gradsflow/data/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 GradsFlow. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | from .autodata import AutoDataset
15 | from .common import random_split_dataset
16 | from .image import get_augmentations, get_fake_data, image_dataset_from_directory
17 |
--------------------------------------------------------------------------------
/gradsflow/data/base.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 GradsFlow. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import dataclasses
15 | from typing import Union
16 |
17 | from torch.utils.data import DataLoader, Dataset
18 |
19 | from gradsflow.data.ray_dataset import RayDataset
20 |
21 |
22 | @dataclasses.dataclass(init=False)
23 | class Data:
24 | dataloader: DataLoader
25 | dataset: Union[RayDataset, Dataset]
26 |
27 |
28 | class BaseAutoDataset:
29 | def __init__(self):
30 | self.meta = {}
31 | self.datamodule = None
32 | self._train_dataloader = None
33 | self._val_dataloader = None
34 | self.train_dataset = None
35 | self.val_dataset = None
36 | self.num_classes = None
37 | self._val_dataloader_length = None
38 | self._train_dataloader_length = None
39 |
--------------------------------------------------------------------------------
/gradsflow/data/common.py:
--------------------------------------------------------------------------------
1 | """Provide some common functionalities/utilities for Datasets"""
2 |
3 | # Copyright (c) 2021 GradsFlow. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | from typing import List
17 |
18 | from torch.utils.data import Dataset, random_split
19 |
20 |
21 | def random_split_dataset(data: Dataset, pct=0.9) -> List[Dataset]:
22 | """
23 | Randomly splits dataset into two sets. Length of first split is len(data) * pct.
24 | Args:
25 | data: pytorch Dataset object with `__len__` implementation.
26 | pct: percentage of split.
27 | """
28 | n = len(data)
29 | split_1 = int(n * pct)
30 | split_2 = n - split_1
31 | return random_split(data, (split_1, split_2))
32 |
--------------------------------------------------------------------------------
/gradsflow/data/image.py:
--------------------------------------------------------------------------------
1 | """Data loader for image dataset"""
2 |
3 | # Copyright (c) 2021 GradsFlow. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | import logging
17 | import os
18 | from pathlib import Path
19 | from typing import List, Optional, Tuple, Union
20 |
21 | from torch.utils.data import DataLoader
22 | from torchvision import transforms as T
23 | from torchvision.datasets import ImageFolder
24 |
25 | from gradsflow.data.base import Data
26 | from gradsflow.data.ray_dataset import RayImageFolder
27 |
28 | logger = logging.getLogger("data.image")
29 |
30 |
31 | def get_augmentations(image_size: tuple = (224, 224), auto_augment_policy: bool = True):
32 | if auto_augment_policy:
33 | augmentations = [T.Resize(image_size), T.AutoAugment(), T.ToTensor()]
34 | else:
35 | augmentations = [T.Resize(image_size), T.ToTensor()]
36 | return T.Compose(augmentations)
37 |
38 |
39 | def image_dataset_from_directory(
40 | directory: Union[List[str], Path, str],
41 | transform=None,
42 | image_size=(224, 224),
43 | batch_size: int = 1,
44 | shuffle: bool = False,
45 | pin_memory: bool = True,
46 | num_workers: Optional[int] = None,
47 | ray_data: bool = False,
48 | ) -> Data:
49 | """
50 | Create Dataset and Dataloader for image folder dataset.
51 | Args:
52 | directory:
53 | transform:
54 | image_size:
55 | batch_size:
56 | shuffle:
57 | pin_memory:
58 | num_workers:
59 |
60 | Returns:
61 | A dictionary containing dataset and dataloader.
62 | """
63 | data = Data()
64 | num_workers = num_workers or os.cpu_count()
65 | if transform is True:
66 | transform = get_augmentations(image_size)
67 | if ray_data:
68 | data.dataset = RayImageFolder(directory, transform=transform)
69 | else:
70 | data.dataset = ImageFolder(directory, transform=transform)
71 | logger.info("ds created")
72 | data.dataloader = DataLoader(
73 | data.dataset,
74 | batch_size=batch_size,
75 | pin_memory=pin_memory,
76 | shuffle=shuffle,
77 | num_workers=num_workers,
78 | )
79 | return data
80 |
81 |
82 | def get_fake_data(
83 | image_size: Tuple[int, int], num_classes=10, batch_size=1, pin_memory=False, shuffle=True, num_workers=0
84 | ):
85 | from torchvision.datasets import FakeData
86 |
87 | data = Data()
88 |
89 | transform = get_augmentations(
90 | image_size=image_size,
91 | )
92 | data.dataset = FakeData(size=100, image_size=[3, *image_size], num_classes=num_classes, transform=transform)
93 | data.dataloader = DataLoader(
94 | data.dataset,
95 | batch_size=batch_size,
96 | pin_memory=pin_memory,
97 | shuffle=shuffle,
98 | num_workers=num_workers,
99 | )
100 |
101 | return data
102 |
--------------------------------------------------------------------------------
/gradsflow/data/mixins.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 GradsFlow. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | from typing import Dict, List, Tuple, Union
15 |
16 | import torch
17 |
18 | from gradsflow.utility import default_device
19 |
20 |
21 | class DataMixin:
22 | INPUT_KEY = 0 # other common value - inputs, images, text
23 | OUTPUT_KEY = 1 # other common values - target, ground
24 | device = default_device()
25 |
26 | def fetch_inputs(self, data: Union[List, Dict]):
27 | return data[self.INPUT_KEY]
28 |
29 | def fetch_target(self, data: Union[List, Dict]):
30 | return data[self.OUTPUT_KEY]
31 |
32 | @classmethod
33 | def send_to_device(cls, data: Union[List, Dict, Tuple, torch.Tensor, int, float]):
34 | """Send data to be device"""
35 | if isinstance(data, (int, float, str)):
36 | return data
37 |
38 | if isinstance(data, torch.Tensor):
39 | return data.to(cls.device)
40 |
41 | if isinstance(data, (list, tuple)):
42 | return list(map(cls.send_to_device, data))
43 | if isinstance(data, dict):
44 | return {k: cls.send_to_device(v) for k, v in data.items()}
45 | raise NotImplementedError(
46 | f"send_to_device is not implemented for data of type {type(data)}! Please raise an issue/pr"
47 | )
48 |
--------------------------------------------------------------------------------
/gradsflow/data/ray_dataset.py:
--------------------------------------------------------------------------------
1 | """Mimics torch.data.Dataset for ray.data integration"""
2 |
3 | # Copyright (c) 2021 GradsFlow. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | import io
18 | from typing import Callable, List, Union
19 |
20 | import ray
21 | from PIL import Image
22 | from torch.utils.data import IterableDataset
23 |
24 |
25 | class RayDataset(IterableDataset):
26 | def __init__(self, path: Union[List[str], str]):
27 | self.path = path
28 | self.ds = ray.data.read_binary_files(path, include_paths=True)
29 |
30 | def __iter__(self):
31 | return self.ds.iter_rows()
32 |
33 | def __len__(self):
34 | return len(self.input_files)
35 |
36 | def map_(self, func, *args, **kwargs) -> None:
37 | """Inplace Map for ray.data
38 | Time complexity: O(dataset size / parallelism)
39 |
40 | See https://docs.ray.io/en/latest/data/dataset.html#transforming-datasets"""
41 | self.ds = self.ds.map(func, *args, **kwargs)
42 |
43 | def map_batch_(self, func, batch_size: int = 2, **kwargs) -> None:
44 | """Inplace Map for ray.data
45 | Time complexity: O(dataset size / parallelism)
46 | See https://docs.ray.io/en/latest/data/dataset.html#transforming-datasets"""
47 | self.ds = self.ds.map_batches(func, batch_size=batch_size, **kwargs)
48 |
49 | @property
50 | def input_files(self):
51 | return self.ds.input_files()
52 |
53 |
54 | class RayImageFolder(RayDataset):
55 | """Read image datasets
56 | ```
57 | root/dog/xxx.png
58 | root/dog/xxy.png
59 | root/dog/[...]/xxz.png
60 |
61 | root/cat/123.png
62 | root/cat/nsdf3.png
63 | root/cat/[...]/asd932_.png
64 | ```
65 | """
66 |
67 | def __init__(self, path, transform: Union[Callable, None] = None):
68 | super().__init__(path)
69 | self.transform = transform
70 |
71 | @staticmethod
72 | def file_to_class(files: Union[str, List]):
73 | file_list = []
74 | if isinstance(files, (tuple, list)):
75 | for file in files:
76 | file_list.append(file.split("/")[-2])
77 | return file_list
78 | return files.split("/")[-2]
79 |
80 | def find_classes(self) -> List[str]:
81 | files = self.input_files
82 | return sorted(list(set(map(self.file_to_class, files))))
83 |
84 | def __iter__(self):
85 | for data in self.ds.iter_rows():
86 | x = Image.open(io.BytesIO(data[1]))
87 | target = self.file_to_class(data[0])
88 | if self.transform:
89 | x = self.transform(x)
90 | yield x, target
91 |
--------------------------------------------------------------------------------
/gradsflow/models/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 GradsFlow. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | from .model import Model
15 | from .utils import available_losses, available_metrics
16 |
--------------------------------------------------------------------------------
/gradsflow/models/constants.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 GradsFlow. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | LEARNER = "learner"
15 |
--------------------------------------------------------------------------------
/gradsflow/models/exceptions.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 GradsFlow. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import logging
15 |
16 | logger = logging.getLogger(__name__)
17 |
18 |
19 | class EpochCancel(Exception):
20 | def __init__(self):
21 | super().__init__()
22 | logger.info("epoch cancelled")
23 |
24 |
25 | class FitCancel(Exception):
26 | def __init__(self):
27 | super().__init__()
28 | logger.info("model.fit cancelled")
29 |
--------------------------------------------------------------------------------
/gradsflow/models/tracker.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 GradsFlow. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import logging
15 | from dataclasses import dataclass
16 | from typing import Dict, List, Optional
17 |
18 | from rich import box
19 | from rich.table import Table
20 |
21 | from gradsflow.core.base import TrackingValues
22 | from gradsflow.utility.common import GDict, to_item
23 |
24 | logger = logging.getLogger(__name__)
25 |
26 |
27 | @dataclass(init=False)
28 | class BaseTracker:
29 | global_step: int = 0 # Global training steps
30 | max_epochs: int = 0
31 | current_epoch: int = 0 # current train current_epoch
32 | steps_per_epoch: Optional[int] = None
33 | train: TrackingValues = TrackingValues()
34 | val: TrackingValues = TrackingValues()
35 |
36 |
37 | class Tracker(BaseTracker):
38 | """
39 | Tracks loss, accuracy and model weights during model.fit()
40 | """
41 |
42 | def __init__(self):
43 | self.train.metrics = GDict()
44 | self.val.metrics = GDict()
45 | self.logs: List[Dict] = []
46 |
47 | def __getitem__(self, key: str): # skipcq: PYL-R1705
48 | """
49 | 1. key= `train | val` then return respective `TrackingValues` object
50 | 2. key=`metrics` then return a dictionary of metrics
51 | 3. key=`loss` then return a dictionary of losses
52 | Args:
53 | key: train, val, metrics or loss
54 |
55 | Returns:
56 | `TrackingValues` or a Dictionary
57 | """
58 | if key in ("train", "val"):
59 | return self.mode(key)
60 | elif key == "metrics":
61 | return {"train": self.train_metrics, "val": self.val_metrics}
62 | elif key == "loss":
63 | return {"train": self.train_loss, "val": self.val_loss}
64 |
65 | raise KeyError(f"key {key} is not implemented!")
66 |
67 | @property
68 | def train_loss(self):
69 | return self.train.loss.avg
70 |
71 | @property
72 | def val_loss(self):
73 | return self.val.loss.avg
74 |
75 | @property
76 | def train_metrics(self) -> GDict:
77 | return self.train.metrics
78 |
79 | @property
80 | def val_metrics(self) -> GDict:
81 | return self.val.metrics
82 |
83 | def mode(self, mode) -> TrackingValues:
84 | if mode == "train":
85 | return self.train
86 | if mode == "val":
87 | return self.val
88 |
89 | raise KeyError(f"mode {mode} is not implemented!")
90 |
91 | def _append_logs(self, key, value):
92 | """Append Key Value pairs to `Tracker.logs`"""
93 | # TODO: accept a list of keys and values as well.
94 | epoch = self.current_epoch
95 | data = {"current_epoch": epoch, key: to_item(value)}
96 | self.logs.append(data)
97 |
98 | def track_loss(self, loss: float, mode: str):
99 | """Tracks loss by adding to `Tracker.logs` and maintaining average loss in a single Epoch with `TrackingValues`.
100 | Update loss with `TrackingValues.update_loss(loss)` which is called with `TrainEvalCallback` at `*_step_end`.
101 | Args:
102 | loss: Step Loss
103 | mode: can be train | val
104 | """
105 | loss = to_item(loss)
106 | value_tracker = self.mode(mode)
107 | value_tracker.update_loss(loss)
108 | key = mode + "/" + "loss"
109 | self._append_logs(key, loss)
110 |
111 | def track_metrics(self, metric: Dict[str, float], mode: str):
112 | """Tracks metrics by adding to `Tracker.logs` and maintaining average metric in a single Epoch with `TrackingValues`.
113 | Update metrics with `TrackingValues.update_metrics(metrics)` which is called with `TrainEvalCallback` at `*_step_end`.
114 | Args:
115 | metric: Step metric
116 | mode: can be train | val
117 | """
118 | value_tracker = self.mode(mode)
119 |
120 | # Track values that averages with epoch
121 | value_tracker.update_metrics(metric)
122 |
123 | # _append_logs value for each step in a dict
124 | for k, v in metric.items():
125 | k = mode + "/" + k
126 | self._append_logs(k, v)
127 |
128 | def create_table(self) -> Table:
129 | headings = ["i", "train/loss"]
130 | row = [self.current_epoch, self.train_loss]
131 |
132 | if self.val.loss.computed:
133 | headings.append("val/loss")
134 | row.append(self.val_loss)
135 |
136 | for metric_name, value in self.train_metrics.items():
137 | headings.append("train/" + metric_name)
138 | row.append(value.avg)
139 |
140 | for metric_name, value in self.val_metrics.items():
141 | headings.append("val/" + metric_name)
142 | row.append(value.avg)
143 |
144 | row = list(map(lambda x: f"{x: .3f}" if isinstance(x, float) else str(x), row))
145 | table = Table(*headings, expand=True, box=box.SIMPLE)
146 | table.add_row(*row)
147 | return table
148 |
149 | def reset(self):
150 | """Resets epochs, logs and train & val `TrackingValues`."""
151 | logger.debug("Reset Tracker")
152 | self.max_epochs = 0
153 | self.current_epoch = 0
154 | self.steps_per_epoch = None
155 | self.train = TrackingValues()
156 | self.val = TrackingValues()
157 | self.logs = []
158 |
--------------------------------------------------------------------------------
/gradsflow/models/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 GradsFlow. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | from typing import Callable, Dict, List, Optional, Union
15 |
16 | import numpy as np
17 | import torch
18 | import torchmetrics
19 | from torch import nn
20 | from torchmetrics import Metric
21 |
22 | from gradsflow.utility.common import filter_list, module_to_cls_index
23 |
24 | SCALAR = Union[torch.Tensor, np.float64, float, int]
25 | _nn_classes = module_to_cls_index(nn)
26 | _tm_classes = module_to_cls_index(torchmetrics, lower_key=False)
27 |
28 | losses: Dict[str, Callable] = {k: v for k, v in _nn_classes.items() if "loss" in k}
29 |
30 | metrics: Dict[str, Metric] = {k: v for k, v in _tm_classes.items() if 65 <= ord(k[0]) <= 90}
31 | metrics = {k.lower(): v for k, v in metrics.items()}
32 |
33 |
34 | def available_losses(pattern: Optional[str] = None) -> List[str]:
35 | """Get available loss functions
36 | ```python
37 | >> available_losses()
38 | >> # crossentropy, binarycrossentropy, mae, ...
39 |
40 | # Filter available losses with regex pattern
41 | >> available_losses("m.e)
42 | >> # ["mae", "mse"]
43 | ```
44 | """
45 | loss_keys = list(losses.keys())
46 | return filter_list(loss_keys, pattern)
47 |
48 |
49 | def available_metrics(pattern: Optional[str] = None) -> List[str]:
50 | """Get available Metrics
51 | ```python
52 | >> available_metrics()
53 | >> # accuracy, F1, RMSE, ...
54 |
55 | # Filter available metrics with regex pattern
56 | >> available_metrics("acc.*")
57 | >> # ["accuracy"]
58 | ```
59 | """
60 | metric_keys = list(metrics.keys())
61 | return filter_list(metric_keys, pattern)
62 |
--------------------------------------------------------------------------------
/gradsflow/tuner/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 GradsFlow. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .automodel import AutoModelV2
16 | from .tuner import Tuner
17 |
--------------------------------------------------------------------------------
/gradsflow/tuner/tuner.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 GradsFlow. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | from typing import Any, Dict, List, Sequence, Union
15 |
16 | import ray
17 | from ray import tune
18 | from ray.tune.sample import Domain
19 |
20 |
21 | class ComplexObject:
22 | """Class to store and retrieve large size objects and convert it to `ray.tune.Domain`.
23 | Objects will be stored with `ray.put` and retrieved with `ray.get`.
24 | """
25 |
26 | def __init__(self):
27 | self.values: List[Any] = []
28 |
29 | def __len__(self):
30 | return len(self.values)
31 |
32 | def append(self, value: Any):
33 | self.values.append(ray.put(value))
34 |
35 | def get_complex_object(self, idx):
36 | return ray.get(self.values[idx])
37 |
38 | def to_choice(self) -> Domain:
39 | """converts to ray.tune Domain"""
40 | indices = tuple(range(len(self.values)))
41 | return tune.choice(indices)
42 |
43 |
44 | class Tuner:
45 | """
46 | Supports `ray.tune` methods and provide an easy way for tuning large size complex objects like Models.
47 | """
48 |
49 | def __init__(self):
50 | self._search_space: Dict[str, Domain] = {}
51 | self._complex_objects: Dict[str, ComplexObject] = {}
52 |
53 | def update_search_space(self, k: str, v: Union[Domain, ComplexObject]):
54 | """
55 | Update search space with value `ray.tune(...)` or `gradsflow.tuner.ComplexObject`
56 | Args:
57 | k: hyperparameter name
58 | v: hyperparameter value - `ray.tune(...)` or `gradsflow.tuner.ComplexObject` object.
59 | """
60 | if isinstance(v, Domain):
61 | self._search_space[k] = v
62 | elif isinstance(v, ComplexObject):
63 | assert isinstance(v, ComplexObject), f"Selected is_complex but object is of type {type(v)}"
64 | self._search_space[k] = v.to_choice()
65 | self._complex_objects[k] = v
66 | else:
67 | raise UserWarning(f"Tuner Search space doesn't support {type(v)}")
68 | assert isinstance(
69 | self._search_space[k], Domain
70 | ), "search space should only contain object of type `tune.Domain`"
71 |
72 | def suggest_complex(self, key: str, *values: Sequence) -> ComplexObject:
73 | """
74 | Use this method when you want to search models or any large object.
75 | It will also update search space with the provided key.
76 | Args:
77 | key: hyperparameter name
78 | *values: values for the hyperparameter
79 |
80 | Returns:
81 | `ComplexObject`
82 | """
83 | complex_object = ComplexObject()
84 | for _, v in enumerate(values):
85 | complex_object.append(v)
86 |
87 | object_choice = complex_object.to_choice()
88 | self._search_space[key] = object_choice
89 | self._complex_objects[key] = complex_object
90 | return complex_object
91 |
92 | def scalar(self, key: str, value):
93 | """This sets a scalar value and will not be used for tuning"""
94 | self._search_space[key] = value
95 |
96 | def choice(self, key: str, *values) -> Domain:
97 | """Tune for categorical values"""
98 | x = tune.choice(values)
99 | self._search_space[key] = x
100 | return x
101 |
102 | def loguniform(self, key: str, lower: float, upper: float, base: float = 10) -> Domain:
103 | x = tune.loguniform(lower, upper, base)
104 | self._search_space[key] = x
105 | return x
106 |
107 | def union(self, tuner: "Tuner") -> "Tuner":
108 | """Inplace Merge of two Tuners"""
109 | self._search_space.update(tuner._search_space)
110 | self._complex_objects.update(tuner._complex_objects)
111 | return self
112 |
113 | @staticmethod
114 | def merge(*tuners: "Tuner") -> "Tuner":
115 | final_tuner = Tuner()
116 | for t in tuners:
117 | final_tuner.union(t)
118 | return final_tuner
119 |
120 | @property
121 | def value(self) -> Dict[str, Domain]:
122 | return self._search_space
123 |
124 | def get_complex_object(self, key: str, idx: int):
125 | """Get registered complex object value from key at given index"""
126 | return self._complex_objects[key].get_complex_object(idx)
127 |
128 | def get(self, key: str) -> Union[Domain, ComplexObject]:
129 | if key in self._complex_objects:
130 | return self._complex_objects[key]
131 | if key in self._search_space:
132 | return self._search_space[key]
133 | raise KeyError(f"key={key} is not available in tuner! Available={tuple(self._search_space.keys())}")
134 |
--------------------------------------------------------------------------------
/gradsflow/utility/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 GradsFlow. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | from .common import AverageMeter, default_device, listify
15 | from .data import download
16 |
--------------------------------------------------------------------------------
/gradsflow/utility/common.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 GradsFlow. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import dataclasses
15 | import inspect
16 | import logging
17 | import os
18 | import re
19 | import sys
20 | import warnings
21 | from glob import glob
22 | from pathlib import Path
23 | from typing import Any, Dict, List, Optional, Union
24 |
25 | import numpy as np
26 | import torch
27 |
28 |
29 | def get_file_extension(path: str) -> str:
30 | """Returns extension of the file"""
31 | return os.path.basename(path).split(".")[-1]
32 |
33 |
34 | def get_files(folder: str):
35 | """Fetch every file from given folder recursively."""
36 | folder = str(Path(folder) / "**" / "*")
37 | return glob(folder, recursive=True)
38 |
39 |
40 | def default_device():
41 | return "cuda" if torch.cuda.is_available() else "cpu"
42 |
43 |
44 | def module_to_cls_index(module, lower_key: bool = True) -> dict:
45 | """Fetch classes from module and create a Dictionary with key as class name and value as Class"""
46 | class_members = inspect.getmembers(sys.modules[module.__name__], inspect.isclass)
47 | mapping = {}
48 | for k, v in class_members:
49 | if lower_key:
50 | k = k.lower()
51 | mapping[k] = v
52 |
53 | return mapping
54 |
55 |
56 | def listify(item: Any) -> List:
57 | """Convert any scalar value into list."""
58 | if not item:
59 | return []
60 | if isinstance(item, list):
61 | return item
62 | if isinstance(item, (tuple, set)):
63 | return list(item)
64 | if isinstance(item, (int, float, str)):
65 | return [item]
66 | try:
67 | return list(item)
68 | except TypeError:
69 | return [item]
70 |
71 |
72 | # ref: https://github.com/rwightman/pytorch-image-models/blob/b544ad4d3fcd02057ab9f43b118290f2a089566f/timm/utils/metrics.py#L7
73 | @dataclasses.dataclass(init=False)
74 | class AverageMeter:
75 | """Computes and stores the average and current value.
76 | `val` is the running value, `avg` is the average value over an epoch.
77 | """
78 |
79 | avg: Optional[float] = 0
80 |
81 | def __init__(self, name=None):
82 | self.name = name
83 | self.computed = False
84 | self.val: Optional[float] = None
85 | self.sum: Optional[float] = None
86 | self.count: Optional[int] = None
87 | self.reset()
88 |
89 | def reset(self):
90 | self.val = 0
91 | self.avg = 0
92 | self.sum = 0
93 | self.count = 0
94 |
95 | def update(self, val, n=1):
96 | """Updates the average meter value with new data. It also converts `torch.Tensor` to primitive datatype."""
97 | self.computed = True
98 | val = to_item(val)
99 | self.val = val
100 | self.sum += val * n
101 | self.count += n
102 | self.avg = self.sum / self.count
103 |
104 |
105 | def to_item(data: Any) -> Union[int, float, str, np.ndarray, Dict]:
106 | """
107 | Converts torch.Tensor into cpu numpy format.
108 | Args:
109 | data: torch.Tensor contained in any Iterable or Dictionary.
110 | """
111 |
112 | if isinstance(data, (int, float, str)):
113 | return data
114 | if isinstance(data, (list, tuple)):
115 | return type(data)(map(to_item, data))
116 | if isinstance(data, dict):
117 | return {k: to_item(v) for k, v in data.items()}
118 |
119 | if torch.is_tensor(data):
120 | if data.requires_grad:
121 | data = data.detach()
122 | data = data.cpu().numpy()
123 |
124 | logging.info("to_item didn't convert any value.")
125 | return data
126 |
127 |
128 | def filter_list(arr: List[str], pattern: Optional[str] = None) -> List[str]:
129 | """Filter a list of strings with given pattern
130 | ```python
131 | >> arr = ['crossentropy', 'binarycrossentropy', 'softmax', 'mae',]
132 | >> filter_list(arr, ".*entropy*")
133 | >> # ["crossentropy", "binarycrossentropy"]
134 | ```
135 | """
136 | if pattern is None:
137 | return arr
138 |
139 | p = re.compile(pattern)
140 | return [s for s in arr if p.match(s)]
141 |
142 |
143 | class GDict(dict):
144 | def to_dict(self):
145 | clone = self.copy()
146 | for k in clone.keys():
147 | value = clone[k]
148 | try:
149 | clone[k] = dataclasses.asdict(value)
150 | except TypeError:
151 | continue
152 | return clone
153 |
--------------------------------------------------------------------------------
/gradsflow/utility/data.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 GradsFlow. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from smart_open import smart_open
16 |
17 |
18 | def download(path):
19 | """Read any filesystem or cloud file"""
20 | with smart_open(path) as fr:
21 | return fr.read()
22 |
--------------------------------------------------------------------------------
/gradsflow/utility/imports.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 GradsFlow. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import functools
15 | import importlib
16 | from typing import Optional
17 |
18 |
19 | def is_installed(module_name: str) -> bool:
20 | try:
21 | return importlib.util.find_spec(module_name) is not None
22 | except AttributeError:
23 | return False
24 |
25 |
26 | def requires(package_name: str, err_msg: Optional[str] = None):
27 | def inner_fn(func):
28 | @functools.wraps(func)
29 | def wrapper(*args, **kwargs):
30 | msg = err_msg or f"{package_name} Module must be installed to use!"
31 | if not is_installed(package_name):
32 | raise ModuleNotFoundError(msg)
33 | return func(*args, **kwargs)
34 |
35 | return wrapper
36 |
37 | return inner_fn
38 |
--------------------------------------------------------------------------------
/mkdocs.yml:
--------------------------------------------------------------------------------
1 | site_name: GradsFlow
2 | site_description: "An open-source AutoML Library based on PyTorch"
3 | site_author: Aniket Maurya
4 | copyright: 'Copyright © 2023 GradsFlow'
5 |
6 | #banner_url: https://ik.imagekit.io/gradsflow/logo/v2/gf-logo-geadsflow-orange-white-bg_y5v0LNPvdeM.png
7 | repo_url: https://github.com/gradsflow/gradsflow/
8 | repo_name: gradsflow/gradsflow
9 |
10 | theme:
11 | name: material
12 | custom_dir: docs/overrides
13 | palette:
14 | - scheme: default
15 | primary: deep orange
16 | accent: indigo
17 | toggle:
18 | icon: material/weather-sunny
19 | name: Switch to dark mode
20 |
21 | - scheme: slate
22 | primary: deep orange
23 | accent: indigo
24 | toggle:
25 | icon: material/weather-night
26 | name: Switch to light mode
27 |
28 | logo: https://ik.imagekit.io/gradsflow/logo/v2/gf-logo-gflow-black_9ZE8jOuKXTm.svg?updatedAt=1633488021249
29 | favicon: https://ik.imagekit.io/gradsflow/logo/v2/gf-logo-gflow-white_vCxfpINvg.svg
30 | features:
31 | - search.suggest
32 | - search.highlight
33 |
34 | # Necessary for search to work properly
35 | include_search_page: false
36 | search_index_only: true
37 |
38 | markdown_extensions:
39 | - meta
40 | - pymdownx.highlight
41 | - pymdownx.superfences
42 | - pymdownx.details
43 | - pymdownx.superfences
44 | - admonition
45 | - pymdownx.emoji:
46 | emoji_index: !!python/name:materialx.emoji.twemoji
47 | emoji_generator: !!python/name:materialx.emoji.to_svg
48 | - toc:
49 | permalink: true
50 |
51 | plugins:
52 | - git-revision-date-localized
53 | - search
54 | - autorefs
55 | - mkdocs-jupyter
56 | - mkdocstrings:
57 | default_handler: python
58 | handlers:
59 | python:
60 | rendering:
61 | show_source: false
62 |
63 | extra:
64 | homepage: https://docs.gradsflow.com
65 |
66 | nav:
67 | - Intro: 'index.md'
68 | - Examples:
69 | - Auto Image Classification: 'examples/nbs/01-ImageClassification.ipynb'
70 | - Auto Text Classification: 'examples/nbs/02-TextClassification.ipynb'
71 | - Auto Summarization: 'examples/nbs/03-TextSummarization.ipynb'
72 | - Remote Dataset Loading: 'examples/nbs/04-RayDataset.ipynb'
73 | - Model Training: 'examples/nbs/05-model_fit.ipynb'
74 | - AutoModel - HyperParameter Search: 'examples/nbs/06-AutoModel_fit.ipynb'
75 | - Pix2Pix GAN Code Explanation: 'examples/nbs/Pix2Pix_explained_with_code.ipynb'
76 | - 🤗 HuggingFace Training Example: 'examples/nbs/2021-10-3-huggingface-training.ipynb'
77 | - API References:
78 | - Model:
79 | - gradsflow/models/base.md
80 | - gradsflow/models/model.md
81 | - gradsflow/models/tracker.md
82 | - gradsflow/models/utils.md
83 | - Tuner: gradsflow/tuner
84 | - AutoTasks:
85 | - gradsflow/autotasks/autotasks.md
86 | - gradsflow/autotasks/engine.md
87 | - Data: gradsflow/data
88 | - Callbacks: gradsflow/callbacks
89 | - Core: gradsflow/core
90 | - utility: gradsflow/utils.md
91 | - Release Notes: 'CHANGELOG.md'
92 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = [
3 | "setuptools >= 40.9.0",
4 | "wheel",
5 | ]
6 | build-backend = "setuptools.build_meta"
7 |
8 |
9 | [tool.isort]
10 | profile = "black"
11 |
12 | [tool.black]
13 | line_length = 120
14 |
15 |
16 | [tool.pytest.ini_options]
17 | norecursedirs = ["tests/autotasks", "tests/tuner"]
18 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [metadata]
2 | name = gradsflow
3 | version = attr: gradsflow.__version__
4 | author = Aniket Maurya
5 | author_email = aniket@gradsflow.com
6 | description = An open-source AutoML Library based on PyTorch
7 | long_description = file: README.md, LICENSE.md
8 | long_description_content_type = text/markdown
9 | url = https://github.com/gradsflow/gradsflow
10 | project_urls =
11 | Bug Tracker = https://github.com/gradsflow/gradsflow/issues
12 | Documentation = https://docs.gradsflow.com
13 | Source Code = https://github.com/gradsflow/gradsflow
14 |
15 | classifiers =
16 | ; How mature is this project? Common values are
17 | ; 3 - Alpha, 4 - Beta, 5 - Production/Stable
18 | Development Status :: 4 - Beta
19 | Intended Audience :: Developers
20 | Programming Language :: Python :: 3
21 | License :: OSI Approved :: Apache Software License
22 | Operating System :: OS Independent
23 |
24 | keywords = AutoML, Pytorch, Deep Learning
25 |
26 | [options]
27 | packages = find:
28 | python_requires = >=3.8
29 | install_requires =
30 | torch >=1.13.1
31 | torchvision
32 | ray[default,tune] >=2.2.0
33 | timm>=0.6.12
34 | rich>=13.3.1
35 | smart_open >=5.1,<=5.2.1
36 | torchmetrics >=0.11.1
37 | lightning >=1.9.2
38 |
39 | [options.extras_require]
40 | dev = codecarbon >=1.2.0; wandb; tensorboard
41 | test = pytest; coverage; pytest-sugar; pytest-randomly
42 |
43 | [options.packages.find] #optional
44 | exclude=tests, docs, examples
45 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 GradsFlow. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | from setuptools import setup
15 |
16 | if __name__ == "__main__":
17 | setup()
18 |
--------------------------------------------------------------------------------
/sonar-project.properties:
--------------------------------------------------------------------------------
1 | sonar.projectKey=gradsflow_gradsflow
2 | sonar.organization=gflow
3 |
4 | # This is the name and version displayed in the SonarCloud UI.
5 | #sonar.projectName=gradsflow
6 | #sonar.projectVersion=1.0
7 |
8 | # Path is relative to the sonar-project.properties file. Replace "\" by "/" on Windows.
9 | sonar.sources=gradsflow
10 |
11 | # Encoding of the source code. Default is default system encoding
12 | sonar.sourceEncoding=UTF-8
13 |
14 | #---- Language properties ----
15 | sonar.language=py
16 | sonar.python.version=3
17 | sonar.python.coverage.reportPaths=coverage.xml
18 | sonar.tests=tests
19 | sonar.exclusions=tests/**
20 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 GradsFlow. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import os
16 | import sys
17 | import warnings
18 |
19 | try:
20 | import gradsflow
21 | except ImportError:
22 | sys.path.append("./")
23 |
24 | os.environ["GF_CI"] = "true"
25 |
26 | warnings.filterwarnings("ignore")
27 |
--------------------------------------------------------------------------------
/tests/__main__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 GradsFlow. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import urllib.request
16 | import zipfile
17 | from pathlib import Path
18 |
19 | cwd = Path.cwd()
20 | (Path.cwd() / "data").mkdir(exist_ok=True)
21 |
22 | urllib.request.urlretrieve(
23 | "https://github.com/gradsflow/test-data/archive/refs/tags/cat-dog-v0.zip",
24 | f"{cwd}/data/test-cat-dog-v0.zip",
25 | )
26 |
27 | with zipfile.ZipFile(f"{cwd}/data/test-cat-dog-v0.zip", "r") as zip_ref:
28 | zip_ref.extractall(f"{cwd}/data/")
29 |
--------------------------------------------------------------------------------
/tests/autotasks/test_autotasks.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | os.environ["GF_CI"] = "true"
4 |
5 | import warnings
6 | from pathlib import Path
7 |
8 | import ray
9 | from flash.image import ImageClassificationData
10 |
11 | from gradsflow.autotasks import autotask
12 | from gradsflow.data.image import image_dataset_from_directory
13 | from gradsflow.models.model import Model
14 |
15 | warnings.filterwarnings("ignore")
16 |
17 | ray.init(local_mode=True, ignore_reinit_error=True)
18 |
19 | data_dir = Path.cwd()
20 | datamodule = ImageClassificationData.from_folders(
21 | train_folder=f"{data_dir}/data/hymenoptera_data/train/",
22 | val_folder=f"{data_dir}/data/hymenoptera_data/val/",
23 | batch_size=1,
24 | )
25 | data_dir = Path.cwd() / "data"
26 |
27 | train_data = image_dataset_from_directory(f"{data_dir}/hymenoptera_data/train/", transform=True)
28 | train_dl = train_data.dataloader
29 |
30 | val_data = image_dataset_from_directory(f"{data_dir}/hymenoptera_data/val/", transform=True)
31 | val_dl = val_data.dataloader
32 |
33 |
34 | def test_build_model():
35 | model = autotask(
36 | task="image-classification",
37 | train_dataloader=train_dl,
38 | val_dataloader=val_dl,
39 | num_classes=2,
40 | timeout=5,
41 | suggested_backbones="ssl_resnet18",
42 | n_trials=1,
43 | )
44 | kwargs = {"backbone": "ssl_resnet18", "optimizer": "adam", "lr": 1e-1}
45 | model.model = model.build_model(kwargs)
46 | assert isinstance(model.model, Model)
47 |
48 |
49 | def test_hp_tune():
50 | model = autotask(
51 | task="image-classification",
52 | train_dataloader=train_dl,
53 | val_dataloader=val_dl,
54 | num_classes=2,
55 | max_epochs=1,
56 | max_steps=2,
57 | timeout=10,
58 | suggested_backbones="ssl_resnet18",
59 | optimization_metric="val_loss",
60 | n_trials=1,
61 | )
62 | model.hp_tune(name="pytest-experiment", mode="max", gpu=0)
63 |
--------------------------------------------------------------------------------
/tests/autotasks/test_autotrainer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 GradsFlow. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | from unittest.mock import MagicMock, Mock, patch
15 |
16 | import pytest
17 |
18 | from gradsflow.autotasks.engine.backend import Backend
19 |
20 | trainer_config = {"show_progress": False}
21 |
22 |
23 | @patch("gradsflow.autotasks.engine.backend.FlashTrainer")
24 | @patch("gradsflow.autotasks.engine.backend.PLTrainer")
25 | def test_optimization_objective(mock_pl_trainer: Mock, mock_fl_trainer: Mock):
26 | dm = MagicMock()
27 | model_builder = MagicMock()
28 |
29 | # backend_type is pl
30 | autotrainer = Backend(dm, model_builder, optimization_metric="val_accuracy", backend="pl")
31 | autotrainer.optimization_objective({}, trainer_config)
32 | assert mock_pl_trainer.called or mock_fl_trainer.called
33 |
34 | # wrong backend_type is passed
35 | with pytest.raises(NotImplementedError):
36 | autotrainer = Backend(
37 | dm,
38 | model_builder,
39 | optimization_metric="val_accuracy",
40 | backend="error",
41 | )
42 | autotrainer.optimization_objective({}, trainer_config)
43 |
--------------------------------------------------------------------------------
/tests/autotasks/test_core_automodel.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 GradsFlow. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from pathlib import Path
16 | from unittest.mock import MagicMock, patch
17 |
18 | import pytest
19 | import torch
20 | from flash.image import ImageClassificationData
21 |
22 | from gradsflow.autotasks.engine.automodel import AutoModel
23 | from gradsflow.autotasks.engine.backend import BackendType
24 |
25 | cwd = Path.cwd()
26 | datamodule = ImageClassificationData.from_folders(
27 | train_folder=f"{cwd}/data/hymenoptera_data/train/", val_folder=f"{cwd}/data/hymenoptera_data/val/", batch_size=1
28 | )
29 |
30 |
31 | @patch.multiple(AutoModel, __abstractmethods__=set())
32 | def test_auto_model():
33 | assert AutoModel(datamodule)
34 |
35 |
36 | @patch.multiple(AutoModel, __abstractmethods__=set())
37 | def test_build_model():
38 | model = AutoModel(datamodule)
39 | with pytest.raises(NotImplementedError):
40 | model.build_model({"lr": 1})
41 |
42 |
43 | @patch.multiple(AutoModel, __abstractmethods__=set())
44 | def test_create_search_space():
45 | model = AutoModel(datamodule)
46 | with pytest.raises(NotImplementedError):
47 | model._create_search_space()
48 |
49 |
50 | @patch.multiple(AutoModel, __abstractmethods__=set())
51 | @patch("gradsflow.autotasks.engine.backend.FlashTrainer")
52 | @patch("gradsflow.autotasks.engine.backend.PLTrainer")
53 | def test_objective(mock_pl_trainer, mock_fl_trainer):
54 | optimization_metric = "val_accuracy"
55 | model = AutoModel(
56 | datamodule,
57 | optimization_metric=optimization_metric,
58 | backend=BackendType.lightning.value,
59 | )
60 |
61 | model.backend.model_builder = MagicMock()
62 |
63 | mock_pl_trainer.callback_metrics = mock_fl_trainer.callback_metrics = {optimization_metric: torch.as_tensor([1])}
64 |
65 | model.backend.optimization_objective({}, {})
66 |
67 |
68 | @patch.multiple(AutoModel, __abstractmethods__=set())
69 | @patch("gradsflow.autotasks.engine.automodel.tune.run")
70 | def test_hp_tune(
71 | mock_tune,
72 | ):
73 | automodel = AutoModel(datamodule)
74 | automodel._create_search_space = MagicMock()
75 | automodel.build_model = MagicMock()
76 | automodel.hp_tune(gpu=0, cpu=2)
77 |
78 | mock_tune.assert_called()
79 |
--------------------------------------------------------------------------------
/tests/autotasks/test_image.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 GradsFlow. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import os
16 |
17 | os.environ["GF_CI"] = "true"
18 |
19 | import warnings
20 | from pathlib import Path
21 |
22 | import pytest
23 | import torch
24 |
25 | from gradsflow.autotasks import AutoImageClassifier
26 | from gradsflow.data.image import image_dataset_from_directory
27 | from gradsflow.models.model import Model
28 |
29 | warnings.filterwarnings("ignore")
30 |
31 | data_dir = Path.cwd() / "data"
32 |
33 | train_data = image_dataset_from_directory(f"{data_dir}/hymenoptera_data/train/", transform=True)
34 | train_dl = train_data.dataloader
35 |
36 | val_data = image_dataset_from_directory(f"{data_dir}/hymenoptera_data/val/", transform=True)
37 | val_dl = val_data.dataloader
38 |
39 |
40 | def test_forward():
41 | model = AutoImageClassifier(
42 | train_dataloader=train_dl,
43 | val_dataloader=val_dl,
44 | num_classes=2,
45 | )
46 |
47 | with pytest.raises(UserWarning):
48 | model.forward(torch.rand(1, 3, 8, 8))
49 |
50 |
51 | def test_build_model():
52 | automodel = AutoImageClassifier(
53 | train_dataloader=train_dl,
54 | val_dataloader=val_dl,
55 | num_classes=2,
56 | max_epochs=1,
57 | max_steps=5,
58 | timeout=10,
59 | suggested_backbones="ssl_resnet18",
60 | n_trials=1,
61 | )
62 | kwargs = {"backbone": "ssl_resnet18", "optimizer": "adam", "lr": 1e-1}
63 | automodel.model = automodel.build_model(kwargs)
64 | assert isinstance(automodel.model, Model)
65 |
66 |
67 | def test_hp_tune():
68 | model = AutoImageClassifier(
69 | train_dataloader=train_dl,
70 | val_dataloader=val_dl,
71 | num_classes=2,
72 | max_epochs=1,
73 | max_steps=5,
74 | timeout=10,
75 | suggested_backbones="ssl_resnet18",
76 | optimization_metric="train_loss",
77 | n_trials=1,
78 | )
79 | model.hp_tune(name="pytest-experiment", mode="min", gpu=0)
80 |
--------------------------------------------------------------------------------
/tests/autotasks/test_summarization.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 GradsFlow. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import os
16 |
17 | os.environ["GF_CI"] = "true"
18 |
19 | from unittest.mock import MagicMock
20 |
21 | from gradsflow.autotasks import AutoSummarization
22 |
23 |
24 | def test_build_model():
25 | datamodule = MagicMock()
26 | model = AutoSummarization(
27 | datamodule,
28 | max_epochs=1,
29 | timeout=5,
30 | suggested_backbones="sshleifer/distilbart-cnn-12-6",
31 | n_trials=1,
32 | )
33 | model_confs = {
34 | "backbone": model._DEFAULT_BACKBONES[-1],
35 | "optimizer": "adam",
36 | "lr": 1e-3,
37 | }
38 | model.build_model(model_confs)
39 |
--------------------------------------------------------------------------------
/tests/autotasks/test_text.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 GradsFlow. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import os
16 |
17 | os.environ["GF_CI"] = "true"
18 |
19 | from unittest.mock import MagicMock
20 |
21 | from gradsflow.autotasks import AutoTextClassifier
22 |
23 |
24 | def test_build_model():
25 | suggested_conf = dict(
26 | optimizers=["adam"],
27 | lr=(5e-4, 1e-3),
28 | )
29 | datamodule = MagicMock()
30 | datamodule.num_classes = 2
31 | model = AutoTextClassifier(
32 | datamodule,
33 | num_classes=datamodule.num_classes,
34 | suggested_backbones=["sgugger/tiny-distilbert-classification"],
35 | suggested_conf=suggested_conf,
36 | max_epochs=1,
37 | optimization_metric="val_loss",
38 | timeout=5,
39 | n_trials=1,
40 | )
41 |
42 | model_confs = {
43 | "backbone": model._DEFAULT_BACKBONES[-1],
44 | "optimizer": "adam",
45 | "lr": 1e-3,
46 | }
47 | model.build_model(model_confs)
48 |
--------------------------------------------------------------------------------
/tests/callbacks/test_logger.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 GradsFlow. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import os
15 | from pathlib import Path
16 |
17 | import pytest
18 |
19 | from gradsflow import AutoDataset
20 | from gradsflow.callbacks import (
21 | CometCallback,
22 | CSVLogger,
23 | EmissionTrackerCallback,
24 | ModelCheckpoint,
25 | )
26 | from gradsflow.data.image import image_dataset_from_directory
27 | from gradsflow.utility.imports import is_installed
28 | from tests.dummies import DummyModel
29 |
30 | data_dir = Path.cwd()
31 | folder = f"{data_dir}/data/test-data-cat-dog-v0/cat-dog/"
32 | data = image_dataset_from_directory(folder, transform=True, ray_data=False)
33 |
34 |
35 | @pytest.fixture
36 | def dummy_model():
37 | return DummyModel()
38 |
39 |
40 | @pytest.fixture
41 | def auto_dataset():
42 | return AutoDataset(train_dataloader=data.dataloader, val_dataloader=data.dataloader)
43 |
44 |
45 | def test_csv_logger(dummy_model, auto_dataset):
46 | csv_logger = CSVLogger(filename="test_csv_logger.csv")
47 | dummy_model.compile()
48 | dummy_model.fit(auto_dataset, callbacks=csv_logger)
49 | assert os.path.isfile("test_csv_logger.csv")
50 |
51 |
52 | def test_model_checkpoint(dummy_model, auto_dataset):
53 | ckpt_cb = ModelCheckpoint(filename="model_ckpt", path="test_model_checkpoint_folder")
54 | dummy_model.compile()
55 | dummy_model.fit(auto_dataset, callbacks=ckpt_cb)
56 | assert os.path.exists("test_model_checkpoint_folder")
57 |
58 |
59 | @pytest.mark.skipif(not is_installed("comet_ml"), reason="requires `comet_ml` installed")
60 | def test_comet(dummy_model, auto_dataset):
61 | with pytest.raises(ValueError):
62 | CometCallback()
63 |
64 | comet = CometCallback(offline=True)
65 | dummy_model.compile()
66 | dummy_model.fit(auto_dataset, callbacks=[comet])
67 |
68 |
69 | @pytest.mark.skipif(not is_installed("codecarbon"), reason="requires `codecarbon` installed")
70 | def test_emission_tracker(dummy_model, auto_dataset):
71 | emission_tracker = EmissionTrackerCallback()
72 | dummy_model.compile()
73 | dummy_model.fit(auto_dataset, callbacks=[emission_tracker])
74 |
--------------------------------------------------------------------------------
/tests/callbacks/test_runner.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 GradsFlow. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import pytest
15 |
16 | from gradsflow.callbacks import CallbackRunner, TrainEvalCallback
17 | from gradsflow.callbacks.base import Callback
18 |
19 |
20 | def test_init(dummy_model):
21 | assert isinstance(CallbackRunner(dummy_model, "training").callbacks["TrainEvalCallback"], TrainEvalCallback)
22 | with pytest.raises(NotImplementedError):
23 | CallbackRunner(dummy_model, "random")
24 |
25 |
26 | def test_append(dummy_model):
27 | cb = CallbackRunner(dummy_model)
28 | with pytest.raises(NotImplementedError):
29 | cb.append("random")
30 | cb.append("tune_checkpoint")
31 | cb.append(TrainEvalCallback(cb.model))
32 | assert len(cb.callbacks) == 2
33 |
34 | for cb_name, cb in cb.callbacks.items():
35 | assert isinstance(cb_name, str)
36 | assert isinstance(cb, Callback)
37 |
38 |
39 | def test_clean(dummy_model):
40 | cb = CallbackRunner(dummy_model, TrainEvalCallback())
41 | cb.clean(keep="TrainEvalCallback")
42 | assert cb.callbacks.get("TrainEvalCallback") is not None
43 |
--------------------------------------------------------------------------------
/tests/callbacks/test_wandb.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 GradsFlow. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | from unittest.mock import Mock, patch
15 |
16 | import pytest
17 |
18 | from gradsflow.callbacks.wandb import WandbCallback
19 | from gradsflow.utility.imports import is_installed
20 |
21 |
22 | @pytest.mark.skipif(not is_installed("wandb"), reason="requires `wandb` installed")
23 | @patch("gradsflow.callbacks.wandb.wandb")
24 | def test_wandbcallback(mock_wandb: Mock, cnn_model, auto_dataset):
25 | model = cnn_model
26 | cb = WandbCallback()
27 | model.compile()
28 | model.fit(auto_dataset, callbacks=cb)
29 | mock_wandb.log.assert_called()
30 |
--------------------------------------------------------------------------------
/tests/conftest.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 GradsFlow. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # Arrange
15 | from pathlib import Path
16 |
17 | import pytest
18 | import timm
19 | from torch import nn
20 |
21 | from gradsflow import AutoDataset, Model
22 | from gradsflow.data import image_dataset_from_directory
23 | from gradsflow.models.tracker import Tracker
24 |
25 | data_dir = Path.cwd()
26 | folder = f"{data_dir}/data/test-data-cat-dog-v0/cat-dog/"
27 | data = image_dataset_from_directory(folder, transform=True, ray_data=False)
28 |
29 |
30 | @pytest.fixture
31 | def auto_dataset():
32 | return AutoDataset(train_dataloader=data.dataloader, val_dataloader=data.dataloader)
33 |
34 |
35 | @pytest.fixture
36 | def resnet18():
37 | cnn = timm.create_model("ssl_resnet18", pretrained=False, num_classes=10).eval()
38 |
39 | return cnn
40 |
41 |
42 | @pytest.fixture
43 | def cnn_model(resnet18):
44 | model = Model(resnet18)
45 | model.TEST = True
46 |
47 | return model
48 |
49 |
50 | @pytest.fixture
51 | def tracker():
52 | return Tracker()
53 |
54 |
55 | @pytest.fixture
56 | def dummy_model():
57 | """A dummy torch.nn model that adds 1 to the forward input value."""
58 |
59 | class DummyModel(nn.Module):
60 | def __init__(self):
61 | super().__init__()
62 |
63 | def forward(self, x):
64 | return x + 1
65 |
66 | return DummyModel()
67 |
--------------------------------------------------------------------------------
/tests/core/test_base.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 GradsFlow. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/tests/data/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 GradsFlow. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/tests/data/test_autodata.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 GradsFlow. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | from pathlib import Path
15 |
16 | import pytest
17 | import torch
18 | from lightning import fabric
19 | from lightning.fabric import Fabric
20 | from torch.utils.data import DataLoader, TensorDataset
21 |
22 | from gradsflow.data import AutoDataset
23 |
24 | dataset = TensorDataset(torch.randn(8, 1, 32, 32))
25 | dataloader = DataLoader(dataset)
26 |
27 | from gradsflow.data.image import image_dataset_from_directory
28 |
29 | data_dir = Path.cwd()
30 | folder = f"{data_dir}/data/test-data-cat-dog-v0/cat-dog/"
31 | data = image_dataset_from_directory(folder, transform=True, ray_data=False)
32 |
33 |
34 | def test_auto_dataset():
35 | with pytest.raises(UserWarning):
36 | AutoDataset()
37 |
38 |
39 | def test_sent_to_device():
40 | accelerator = Fabric()
41 | autodata = AutoDataset(dataloader)
42 | assert autodata.device_setup_status is None
43 | autodata.setup_data(accelerator)
44 | assert autodata.device_setup_status
45 |
46 |
47 | def test_dataset():
48 | accelerator = Fabric()
49 | autodata = AutoDataset(train_dataset=data.dataset, val_dataset=data.dataset)
50 | autodata.setup_data(accelerator)
51 | assert isinstance(autodata.train_dataloader, fabric.fabric._FabricDataLoader)
52 |
--------------------------------------------------------------------------------
/tests/data/test_common.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 GradsFlow. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | from gradsflow.data.common import random_split_dataset
15 | from gradsflow.data.image import get_fake_data
16 |
17 | fake_data = get_fake_data((32, 32))
18 |
19 |
20 | def test_random_split_dataset():
21 | d1, d2 = random_split_dataset(fake_data.dataset, 0.9)
22 | assert len(d1) > len(d2)
23 | assert len(d1) == int(len(fake_data.dataset) * 0.9)
24 |
--------------------------------------------------------------------------------
/tests/data/test_image_data.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 GradsFlow. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | from pathlib import Path
15 |
16 | from gradsflow.data.base import Data
17 | from gradsflow.data.image import image_dataset_from_directory
18 |
19 | data_dir = Path.cwd()
20 |
21 |
22 | def test_image_dataset_from_directory():
23 | folder = f"{data_dir}/data/test-data-cat-dog-v0/cat-dog/"
24 | res = image_dataset_from_directory(folder, transform=True, ray_data=False)
25 | assert isinstance(res, Data)
26 |
--------------------------------------------------------------------------------
/tests/data/test_mixins.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 GradsFlow. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import pytest
15 | import torch
16 |
17 | from gradsflow.data.mixins import DataMixin
18 | from gradsflow.utility import default_device
19 |
20 |
21 | class DataTest(DataMixin):
22 | device = default_device()
23 |
24 |
25 | datamixin = DataTest()
26 |
27 |
28 | def test_send_to_device():
29 | # data as primitive
30 | assert datamixin.send_to_device(1) == 1
31 | assert datamixin.send_to_device(1.5) == 1.5
32 |
33 | # data as Tensor
34 | x = torch.randn(4, 1)
35 | assert isinstance(datamixin.send_to_device(x), torch.Tensor)
36 |
37 | # data as list
38 | batch = torch.randn(4, 16), [1] * 4
39 | assert datamixin.send_to_device(batch)
40 |
41 | # data as dict
42 | batch = {"inputs": torch.randn(4, 16), "targets": [1] * 4}
43 | assert datamixin.send_to_device(batch)
44 |
45 | # catch error
46 | with pytest.raises(NotImplementedError):
47 | datamixin.send_to_device(set(batch))
48 |
--------------------------------------------------------------------------------
/tests/data/test_ray_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 GradsFlow. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | from pathlib import Path
15 |
16 | import pytest
17 | from PIL import Image
18 |
19 | from gradsflow.data.ray_dataset import RayDataset, RayImageFolder
20 |
21 | data_dir = Path.cwd()
22 |
23 |
24 | # TODO: remote dataset test
25 | @pytest.mark.skip
26 | def test_ray_dataset():
27 | folder = f"{data_dir}/data/test-data-cat-dog-v0/cat-dog/"
28 |
29 | dataset = RayDataset(folder)
30 |
31 | assert len(dataset) == 8
32 | assert next(iter(dataset))
33 |
34 | assert dataset
35 |
36 |
37 | @pytest.mark.skip
38 | def test_ray_image_folder():
39 | folder = f"{data_dir}/data/test-data-cat-dog-v0/cat-dog/"
40 |
41 | dataset = RayImageFolder(folder)
42 |
43 | # test_find_classes
44 | assert dataset.find_classes() == ["cat", "dog"]
45 |
46 | # test_iter
47 | item = next(iter(dataset))
48 | assert isinstance(item[0], Image.Image)
49 | assert isinstance(item[1], str)
50 |
--------------------------------------------------------------------------------
/tests/dummies.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 GradsFlow. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import torch
16 |
17 | from gradsflow.models import Model
18 |
19 |
20 | class DummyModel(Model):
21 | def __init__(self):
22 | learner = torch.nn.Linear(1, 4)
23 | super().__init__(learner)
24 |
25 | def backward(self, loss: torch.Tensor):
26 | return None
27 |
28 | def train_step(self, batch):
29 | return {"loss": torch.as_tensor(1), "metrics": {"accuracy": 1}}
30 |
31 | def val_step(self, batch):
32 | return {"loss": torch.as_tensor(1), "metrics": {"accuracy": 1}}
33 |
--------------------------------------------------------------------------------
/tests/models/test_exceptions.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 GradsFlow. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import pytest
15 |
16 | from gradsflow.models.exceptions import EpochCancel, FitCancel
17 |
18 |
19 | def test_exceptions():
20 | with pytest.raises(EpochCancel):
21 | raise EpochCancel()
22 |
23 | with pytest.raises(FitCancel):
24 | raise FitCancel()
25 |
--------------------------------------------------------------------------------
/tests/models/test_model.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 GradsFlow. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import pytest
15 | import timm
16 | import torch
17 | from torchmetrics.classification import BinaryAccuracy
18 |
19 | from gradsflow.callbacks import ModelCheckpoint
20 | from gradsflow.data import AutoDataset
21 | from gradsflow.data.image import get_fake_data
22 | from gradsflow.models.model import Model
23 | from gradsflow.models.tracker import Tracker
24 |
25 | image_size = (64, 64)
26 | train_data = get_fake_data(image_size)
27 | val_data = get_fake_data(image_size)
28 |
29 | num_classes = train_data.dataset.num_classes
30 | autodataset = AutoDataset(train_data.dataloader, val_data.dataloader, num_classes=num_classes)
31 |
32 | cnn = timm.create_model("ssl_resnet18", pretrained=False, num_classes=num_classes).eval()
33 | model = Model(cnn)
34 | model.compile("crossentropyloss", "adam")
35 | model.TEST = True
36 |
37 |
38 | def test_predict(cnn_model):
39 | x = torch.randn(1, 3, 64, 64)
40 | r1 = cnn_model.forward(x)
41 | r2 = cnn_model(x)
42 | r3 = cnn_model.predict(x)
43 | assert torch.all(torch.isclose(r1, r2))
44 | assert torch.all(torch.isclose(r2, r3))
45 | assert isinstance(model.predict(torch.randn(1, 3, 64, 64)), torch.Tensor)
46 |
47 |
48 | def test_fit(cnn_model):
49 | cnn_model.compile()
50 | assert autodataset
51 | tracker = cnn_model.fit(autodataset, max_epochs=1, steps_per_epoch=1, show_progress=True)
52 | assert isinstance(tracker, Tracker)
53 |
54 | autodataset2 = AutoDataset(train_data.dataloader, num_classes=num_classes)
55 | cnn_model.TEST = False
56 | ckpt_cb = ModelCheckpoint(save_extra=False)
57 | tracker2 = cnn_model.fit(
58 | autodataset2,
59 | max_epochs=1,
60 | steps_per_epoch=1,
61 | show_progress=False,
62 | resume=False,
63 | callbacks=[ckpt_cb],
64 | )
65 | assert isinstance(tracker2, Tracker)
66 |
67 |
68 | def test_compile():
69 | def compute_accuracy(*_, **__):
70 | return 1
71 |
72 | with pytest.raises(NotImplementedError):
73 | model1 = Model(cnn)
74 | model1.compile("crossentropyloss", "adam", metrics=compute_accuracy)
75 |
76 | with pytest.raises(AssertionError):
77 | model2 = Model(cnn)
78 | model2.compile("crossentropyloss", "adam", metrics="random_val")
79 |
80 | model3 = Model(cnn)
81 | model3.compile("crossentropyloss", "adam", metrics=[BinaryAccuracy()])
82 |
83 | model4 = Model(cnn)
84 | model4.compile("crossentropyloss", torch.optim.Adam)
85 |
86 | model5 = Model(cnn)
87 | model5.compile(torch.nn.CrossEntropyLoss, torch.optim.Adam, learning_rate=0.01)
88 | assert model5.optimizer.param_groups[0]["lr"] == 0.01
89 |
90 |
91 | def test_set_accelerator(resnet18):
92 | model = Model(resnet18, precision=16)
93 | model.compile()
94 | assert model._accelerator
95 |
96 |
97 | def test_save_model(tmp_path, resnet18, cnn_model):
98 | path = f"{tmp_path}/dummy_model.pth"
99 | cnn_model.save(path, save_extra=True)
100 | assert isinstance(torch.load(path), dict)
101 |
102 | cnn_model.save(path, save_extra=False)
103 | assert isinstance(torch.load(path), type(resnet18))
104 |
105 |
106 | def test_load_from_checkpoint(tmp_path, cnn_model):
107 | path = f"{tmp_path}/dummy_model.pth"
108 | cnn_model.save(path, save_extra=True)
109 | assert isinstance(torch.load(path), dict)
110 |
111 | offline_model = cnn_model.load_from_checkpoint(path)
112 | assert isinstance(offline_model, Model)
113 |
--------------------------------------------------------------------------------
/tests/models/test_tracker.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 GradsFlow. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import pytest
15 | from rich.table import Table
16 |
17 |
18 | def test_reset(tracker):
19 | tracker.max_epochs = 5
20 | tracker.reset()
21 | assert tracker.max_epochs == 0
22 |
23 |
24 | def test_mode(tracker):
25 | tracker.mode("train")
26 | tracker.mode("val")
27 | with pytest.raises(KeyError):
28 | tracker.mode("test")
29 |
30 |
31 | def test_track(tracker):
32 | tracker._append_logs("val", 0.9)
33 | tracker._append_logs("score", 0.5)
34 |
35 |
36 | def test_create_table(tracker):
37 | tracker.track_loss(0.1, "train")
38 | tracker.track_loss(0.2, "val")
39 | tracker.track_metrics({"accuracy": 0.9}, mode="train")
40 | table = tracker.create_table()
41 | assert isinstance(table, Table)
42 |
43 |
44 | def test_get_item(tracker):
45 | assert tracker["train"] == tracker.mode("train")
46 | assert isinstance(tracker["metrics"], dict)
47 | assert "train" in tracker["loss"]
48 | assert "val" in tracker["loss"]
49 |
--------------------------------------------------------------------------------
/tests/models/test_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 GradsFlow. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # Licensed under the Apache License, Version 2.0 (the "License");
16 | # you may not use this file except in compliance with the License.
17 | # You may obtain a copy of the License at
18 | #
19 | # http://www.apache.org/licenses/LICENSE-2.0
20 | #
21 | # Unless required by applicable law or agreed to in writing, software
22 | # distributed under the License is distributed on an "AS IS" BASIS,
23 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
24 | # See the License for the specific language governing permissions and
25 | # limitations under the License.
26 |
27 | from gradsflow.models.utils import available_losses, available_metrics
28 |
29 |
30 | def test_available_losses():
31 | assert isinstance(available_losses()[0], str)
32 | assert isinstance(available_losses(), list)
33 |
34 |
35 | def test_available_metrics():
36 | assert isinstance(available_metrics()[0], str)
37 | assert isinstance(available_metrics(), list)
38 |
--------------------------------------------------------------------------------
/tests/tuner/test_automodel.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 GradsFlow. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import os
15 |
16 | os.environ["GF_CI"] = "true"
17 |
18 | import torch.nn
19 | from ray import tune
20 | from timm import create_model
21 |
22 | from gradsflow.data import AutoDataset
23 | from gradsflow.data.image import get_fake_data
24 | from gradsflow.models.constants import LEARNER
25 | from gradsflow.tuner import AutoModelV2 as AutoModel
26 | from gradsflow.tuner import Tuner
27 |
28 | image_size = (64, 64)
29 | train_data = get_fake_data(image_size, num_classes=2)
30 | val_data = get_fake_data(image_size, num_classes=2)
31 |
32 | num_classes = train_data.dataset.num_classes
33 | autodataset = AutoDataset(train_data.dataloader, val_data.dataloader, num_classes=num_classes)
34 |
35 |
36 | def test_hp_tune():
37 | tuner = Tuner()
38 | cnn = create_model("resnet18", pretrained=False, num_classes=num_classes)
39 |
40 | model = AutoModel(cnn, optimization_metric="val_loss")
41 | model.compile(
42 | loss="crossentropyloss", optimizer=tune.choice(("adam", "sgd")), learning_rate=tune.loguniform(1e-5, 1e-3)
43 | )
44 |
45 | model.hp_tune(
46 | tuner,
47 | autodataset,
48 | n_trials=1,
49 | epochs=1,
50 | cpu=0.05,
51 | gpu=0,
52 | trainer_config={"steps_per_epoch": 2},
53 | )
54 |
55 |
56 | def test_get_learner():
57 | tuner = Tuner()
58 | cnn = create_model("resnet18", pretrained=False, num_classes=num_classes)
59 | complex_cnn = tuner.suggest_complex("learner", cnn)
60 | automodel = AutoModel(complex_cnn, optimization_metric="val_loss")
61 | hparams = {LEARNER: 0}
62 | model = automodel._get_learner(hparams, tuner)
63 | assert isinstance(model, torch.nn.Module)
64 |
65 |
66 | def test_compile():
67 | tuner = Tuner()
68 | cnn = create_model("resnet18", pretrained=False, num_classes=num_classes)
69 | complex_cnn = tuner.suggest_complex("learner", cnn)
70 |
71 | model = AutoModel(complex_cnn, optimization_metric="val_loss")
72 | model.compile(
73 | loss="crossentropyloss", optimizer=tune.choice(("adam", "sgd")), learning_rate=tune.loguniform(1e-5, 1e-3)
74 | )
75 |
--------------------------------------------------------------------------------
/tests/tuner/test_tuner.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 GradsFlow. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import pytest
15 | import ray
16 | from ray.tune.sample import Domain
17 |
18 | from gradsflow.tuner.tuner import ComplexObject, Tuner
19 |
20 | complex_object = ComplexObject()
21 |
22 |
23 | def test_append():
24 | complex_object.append("test_append")
25 | assert "test_append" in ray.get(complex_object.values)
26 |
27 |
28 | def test_get_complex_object():
29 | assert complex_object.get_complex_object(0) == "test_append"
30 |
31 |
32 | def test_to_choice():
33 | assert isinstance(complex_object.to_choice(), Domain)
34 |
35 |
36 | def test_update_search_space():
37 | tuner = Tuner()
38 | tuner.update_search_space("test_update_search_space", complex_object)
39 | assert isinstance(tuner.get_complex_object("test_update_search_space", 0), str)
40 |
41 | with pytest.raises(UserWarning):
42 | tuner.update_search_space("hello", "world")
43 |
44 |
45 | def test_union():
46 | tuner = Tuner()
47 | tuner1 = Tuner()
48 | tuner1.choice("dropout", 0.1, 0.2, 0.3)
49 | tuner2 = tuner.union(tuner1)
50 | assert tuner2.value.get("dropout") is not None
51 |
52 |
53 | def test_merge():
54 | tuner1 = Tuner()
55 | tuner1.choice("dropout", 0.1, 0.2, 0.3)
56 | tuner2 = Tuner()
57 | tuner2.choice("layers", 1, 2, 3)
58 | tuner3 = Tuner.merge(tuner1, tuner2)
59 | assert "layers" in tuner3.value
60 |
61 |
62 | def test_suggest_complex():
63 | tuner = Tuner()
64 | tuner.suggest_complex("test_complex", "val1", "val2")
65 | assert "test_complex" in tuner.value
66 |
67 |
68 | def test_scalar():
69 | tuner = Tuner()
70 | tuner.scalar("a", "b")
71 | assert tuner.get("a") == "b"
72 |
73 |
74 | def test_get():
75 | tuner = Tuner()
76 | tuner.choice("optimizer", "val1", "val2")
77 | assert isinstance(tuner.get("optimizer"), Domain)
78 |
79 | tuner.suggest_complex("complex_opt", "val1")
80 | assert tuner.get("complex_opt").get_complex_object(0) == "val1"
81 |
82 | with pytest.raises(KeyError):
83 | tuner.get("random_key")
84 |
--------------------------------------------------------------------------------
/tests/utility/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 GradsFlow. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/tests/utility/test_common.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 GradsFlow. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import numpy as np
15 | import torch
16 |
17 | from gradsflow.utility.common import (
18 | GDict,
19 | default_device,
20 | filter_list,
21 | get_file_extension,
22 | get_files,
23 | listify,
24 | module_to_cls_index,
25 | to_item,
26 | )
27 |
28 |
29 | def test_create_module_index():
30 | assert isinstance(module_to_cls_index(torch.optim), dict)
31 |
32 |
33 | def test_get_files():
34 | assert len(get_files("./")) != 0
35 |
36 |
37 | def test_get_file_extension():
38 | assert get_file_extension("image.1.png") == "png"
39 |
40 |
41 | def test_listify():
42 | assert listify(None) == []
43 | assert listify(1) == [1]
44 | assert listify((1, 2)) == [1, 2]
45 | assert listify([1]) == [1]
46 | assert listify({"a": 1}) == ["a"]
47 |
48 |
49 | def test_get_device():
50 | assert default_device() in ("cpu", "cuda")
51 |
52 |
53 | def test_to_item():
54 | x = torch.rand(1, 1, requires_grad=True)
55 | assert isinstance(to_item(x), np.ndarray)
56 |
57 | x = [torch.rand(10)]
58 | assert isinstance(to_item(x), list)
59 |
60 | x = (torch.rand(10),)
61 | assert isinstance(to_item(x), tuple)
62 |
63 | x = {"input": torch.rand(10)}
64 | assert isinstance(to_item(x), dict)
65 | assert isinstance(to_item(x)["input"], np.ndarray)
66 |
67 |
68 | def test_filter_list():
69 | arr = [
70 | "crossentropy",
71 | "binarycrossentropy",
72 | "softmax",
73 | "mae",
74 | ]
75 | assert filter_list(arr, ".*entropy") == arr[:2]
76 | assert filter_list(arr) == arr
77 |
78 |
79 | def test_gdict():
80 | gdict: GDict[str, str] = GDict()
81 | gdict["hi"] = "hello"
82 | assert gdict["hi"] == "hello"
83 | assert list(gdict.items())[0][0] == "hi"
84 | assert list(gdict.items())[0][1] == "hello"
85 | assert gdict.to_dict() == {"hi": "hello"}
86 |
--------------------------------------------------------------------------------
/tests/utility/test_data.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 GradsFlow. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | from gradsflow.utility import download
15 |
16 |
17 | def test_download():
18 | with open("test_download.txt", "w") as fw:
19 | fw.write("Hello")
20 | assert b"hello" == (download("test_download.txt")).lower()
21 |
--------------------------------------------------------------------------------