├── .codecov.yml ├── .github ├── CODEOWNERS ├── ISSUE_TEMPLATE │ ├── bug_report.md │ └── feature_request.md ├── PULL_REQUEST_TEMPLATE.md ├── dependabot.yml ├── stale.yml └── workflows │ ├── ci_install-pkg.yml │ └── ci_testing.yml ├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── MANIFEST.in ├── README.md ├── _config.yml ├── assets ├── cassava_images.jpg ├── cassava_metrics.png ├── herbarium_sample-imgs.jpg ├── herbarium_training-metrics.png ├── imet_sample-imgs.png ├── imet_training-cls-spl-100.png ├── imet_training-cls-spl-10k.png ├── plants_sample-images.jpg └── plants_training-metrics.png ├── kaggle_imgclassif ├── __init__.py ├── birdclef │ ├── __init__.py │ └── data.py ├── cassava │ ├── __init__.py │ ├── data.py │ └── models.py ├── imet_collect │ ├── __init__.py │ ├── data.py │ └── models.py └── plant_pathology │ ├── __init__.py │ ├── augment.py │ ├── data.py │ └── models.py ├── notebooks ├── Cassava-Leaf-with-Flash.ipynb ├── Cassava-Leaf-with-Lightning.ipynb ├── Herbarium-with-Flash-EfficientNet.ipynb ├── Plant-Pathology-with-Flash.ipynb ├── Plant-Pathology-with-Lightning.ipynb ├── Plant-Pathology-with-Lightning_standalone.ipynb ├── iMet-with-Lightning-and-ViT.ipynb └── iMet-with-Lightning.ipynb ├── pyproject.toml ├── scripts ├── birdclef_convert-spectrograms.py ├── herbarium_train-model.py ├── imet_create-dataset-subset.py └── plant-pathology_train-model.py ├── setup.cfg ├── setup.py ├── streamlit-app.py └── tests ├── __init__.py ├── _data ├── cassava │ ├── train.csv │ └── train_images │ │ ├── 218377.jpg │ │ ├── 6477704.jpg │ │ └── 7635457.jpg ├── imet-collect │ ├── label_map.csv │ ├── test │ │ └── test │ │ │ ├── 023c01465d76f827ca9620667f7de487.jpg │ │ │ ├── 02ca3baa47d2737b7796ae6bca32aa1d.jpg │ │ │ └── 050266ba8ff68b14fd17d4b05707ff19.jpg │ ├── train-1 │ │ └── train-1 │ │ │ ├── 09fe6ff247881b37779bcb386c26d7bb.png │ │ │ ├── 0d5b8274de10cd73836c858c101266ea.png │ │ │ ├── 11a87738861970a67249592db12f2da1.png │ │ │ ├── 12c80004e34f9102cad72c7312133529.png │ │ │ ├── 14f3fa3b620d46be00696eacda9df583.png │ │ │ ├── 1cc66a822733a3c3a1ce66fe4be60a6f.png │ │ │ └── 258e4a904729119efd85faaba80c965a.png │ └── train-from-kaggle.csv └── plant-pathology │ ├── test_images │ └── 8a0d7cad7053f18d.jpg │ ├── train.csv │ └── train_images │ ├── 800113bb65efe69e.jpg │ ├── 8002cb321f8bfcdf.jpg │ ├── 800f85dc5f407aef.jpg │ ├── 8a0be55d81f4bf0c.jpg │ ├── 8a1a97abda0b4a7a.jpg │ ├── 8a2d598f2ec436e6.jpg │ └── 8a954b82bf81f2bc.jpg ├── birdclef └── __init__.py ├── cassava ├── __init__.py ├── test_data.py └── test_models.py ├── imet_collect ├── __init__.py ├── test_data.py └── test_models.py └── plant_pathology ├── __init__.py ├── test_augment.py ├── test_data.py └── test_models.py /.codecov.yml: -------------------------------------------------------------------------------- 1 | # see https://docs.codecov.io/docs/codecov-yaml 2 | # Validation check: 3 | # $ curl --data-binary @.codecov.yml https://codecov.io/validate 4 | 5 | # https://docs.codecov.io/docs/codecovyml-reference 6 | codecov: 7 | bot: "codecov-io" 8 | strict_yaml_branch: "yaml-config" 9 | require_ci_to_pass: yes 10 | notify: 11 | # after_n_builds: 2 12 | wait_for_ci: yes 13 | 14 | coverage: 15 | precision: 0 # 2 = xx.xx%, 0 = xx% 16 | round: nearest # how coverage is rounded: down/up/nearest 17 | range: 40...100 # custom range of coverage colors from red -> yellow -> green 18 | status: 19 | project: 20 | default: 21 | # basic 22 | target: auto 23 | threshold: 0% 24 | base: auto 25 | # advanced settings 26 | if_ci_failed: error #success, failure, error, ignore 27 | informational: true 28 | only_pulls: true 29 | patch: off 30 | changes: false 31 | 32 | parsers: 33 | gcov: 34 | branch_detection: 35 | conditional: true 36 | loop: true 37 | macro: false 38 | method: false 39 | javascript: 40 | enable_partials: false 41 | 42 | comment: 43 | layout: header, diff 44 | require_changes: true 45 | behavior: default # update if exists else create new 46 | # branches: * 47 | -------------------------------------------------------------------------------- /.github/CODEOWNERS: -------------------------------------------------------------------------------- 1 | # Each line is a file pattern followed by one or more owners. 2 | 3 | # These owners will be the default owners for everything in the repo. 4 | # Unless a later match takes precedence, @global-owner1 and @global-owner2 5 | # will be requested for review when someone opens a pull request. 6 | 7 | * @borda 8 | 9 | # CI/CD and configs 10 | /.github/ @borda 11 | *.yml @borda 12 | require*.txt @borda 13 | 14 | # Docs 15 | /docs/ @borda 16 | /.github/*.md @borda 17 | /.github/ISSUE_TEMPLATE/ @borda 18 | 19 | /.github/CODEOWNERS @borda 20 | /README.md @borda 21 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: bug / fix, help wanted 6 | assignees: '' 7 | --- 8 | 9 | ## 🐛 Bug 10 | 11 | 12 | 13 | ### To Reproduce 14 | 15 | Steps to reproduce the behavior: 16 | 17 | 1. Go to '...' 18 | 2. Run '....' 19 | 3. Scroll down to '....' 20 | 4. See error 21 | 22 | 23 | 24 | ### Expected behavior 25 | 26 | 27 | 28 | ### Additional context 29 | 30 | 31 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: enhancement, help wanted 6 | assignees: '' 7 | --- 8 | 9 | ## 🚀 Feature 10 | 11 | 12 | 13 | ### Motivation 14 | 15 | 16 | 17 | ### Pitch 18 | 19 | 20 | 21 | ### Alternatives 22 | 23 | 24 | 25 | ### Additional context 26 | 27 | 28 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | # Before submitting 2 | 3 | - [ ] Was this discussed/approved via a Github issue? (no need for typos and docs improvements) 4 | - [ ] Did you make sure to update the docs? 5 | - [ ] Did you write any new necessary tests? 6 | 7 | ## What does this PR do? 8 | 9 | Fixes # (issue). 10 | 11 | ## PR review 12 | 13 | Anyone in the community is free to review the PR once the tests have passed. 14 | If we didn't discuss your PR in Github issues there's a high chance it will not be merged. 15 | 16 | ## Did you have fun? 17 | 18 | Make sure you had fun coding 🙃 19 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | # Basic dependabot.yml file with minimum configuration for two package managers 2 | 3 | version: 2 4 | updates: 5 | # # Enable version updates for python 6 | # - package-ecosystem: "pip" 7 | # # Look for a `requirements` in the `root` directory 8 | # directory: "/" 9 | # # Check for updates once a week 10 | # schedule: 11 | # interval: "monthly" 12 | # # Labels on pull requests for version updates only 13 | # labels: ["CI / tests"] 14 | # pull-request-branch-name: 15 | # # Separate sections of the branch name with a hyphen for example, `dependabot-npm_and_yarn-next_js-acorn-6.4.1` 16 | # separator: "-" 17 | # # Allow up to 5 open pull requests for pip dependencies 18 | # open-pull-requests-limit: 5 19 | # reviewers: 20 | # - "borda" 21 | 22 | # Enable version updates for GitHub Actions 23 | - package-ecosystem: "github-actions" 24 | directory: "/" 25 | # Check for updates once a week 26 | schedule: 27 | interval: "monthly" 28 | # Labels on pull requests for version updates only 29 | labels: ["CI / tests"] 30 | pull-request-branch-name: 31 | # Separate sections of the branch name with a hyphen for example, `dependabot-npm_and_yarn-next_js-acorn-6.4.1` 32 | separator: "-" 33 | # Allow up to 5 open pull requests for GitHub Actions 34 | open-pull-requests-limit: 5 35 | reviewers: 36 | - "borda" 37 | -------------------------------------------------------------------------------- /.github/stale.yml: -------------------------------------------------------------------------------- 1 | # https://github.com/marketplace/stale 2 | 3 | # Number of days of inactivity before an issue becomes stale 4 | daysUntilStale: 45 5 | # Number of days of inactivity before a stale issue is closed 6 | daysUntilClose: 7 7 | # Issues with these labels will never be considered stale 8 | exemptLabels: 9 | - pinned 10 | - security 11 | # Label to use when marking an issue as stale 12 | staleLabel: won't fix 13 | # Comment to post when marking an issue as stale. Set to `false` to disable 14 | markComment: > 15 | This issue has been automatically marked as stale because it has not had 16 | recent activity. It will be closed if no further activity occurs. Thank you 17 | for your contributions. 18 | # Comment to post when closing a stale issue. Set to `false` to disable 19 | closeComment: false 20 | -------------------------------------------------------------------------------- /.github/workflows/ci_install-pkg.yml: -------------------------------------------------------------------------------- 1 | name: Install package 2 | 3 | # see: https://help.github.com/en/actions/reference/events-that-trigger-workflows 4 | on: # Trigger the workflow on push or pull request, but only for the main branch 5 | push: 6 | branches: [main] 7 | pull_request: {} 8 | 9 | jobs: 10 | pkg-check: 11 | runs-on: ubuntu-22.04 12 | steps: 13 | - uses: actions/checkout@master 14 | - uses: actions/setup-python@v5 15 | with: 16 | python-version: "3.9" 17 | 18 | - name: Check package 19 | run: | 20 | pip install -U check-manifest setuptools 21 | check-manifest 22 | python setup.py check --metadata --strict 23 | 24 | - name: Create package 25 | run: | 26 | pip install --upgrade setuptools wheel 27 | python setup.py sdist bdist_wheel 28 | 29 | - name: Verify package 30 | run: | 31 | pip install -q "twine==6.1.*" 32 | twine check dist/* 33 | 34 | pkg-install: 35 | runs-on: ${{ matrix.os }} 36 | strategy: 37 | fail-fast: false 38 | matrix: 39 | os: ["ubuntu-22.04", "macOS-13", "windows-2022"] 40 | python-version: ["3.9"] # because of Kaggle 41 | steps: 42 | - uses: actions/checkout@master 43 | - uses: actions/setup-python@v5 44 | with: 45 | python-version: ${{ matrix.python-version }} 46 | 47 | - name: Create package 48 | run: | 49 | pip install -U setuptools wheel 50 | python setup.py sdist bdist_wheel 51 | 52 | - name: Try installing 53 | working-directory: dist 54 | run: | 55 | ls 56 | pip install $(python -c "import glob ; print(' '.join(glob.glob('*.whl')))") 57 | pip show kaggle-image-classification 58 | python -c "from kaggle_imgclassif import plant_pathology ; print(plant_pathology.__version__)" 59 | python -c "from kaggle_imgclassif import imet_collect ; print(imet_collect.__version__)" 60 | python -c "from kaggle_imgclassif import cassava ; print(cassava.__version__)" 61 | python -c "from kaggle_imgclassif import birdclef ; print(birdclef.__version__)" 62 | 63 | install-guardian: 64 | runs-on: ubuntu-latest 65 | needs: [pkg-install, pkg-check] 66 | if: always() 67 | steps: 68 | - run: echo "${{ needs.pkg-install.result }}" 69 | - name: failing... 70 | if: needs.pkg-install.result == 'failure' 71 | run: exit 1 72 | - name: cancelled or skipped... 73 | if: contains(fromJSON('["cancelled", "skipped"]'), needs.pkg-install.result) 74 | timeout-minutes: 1 75 | run: sleep 90 76 | -------------------------------------------------------------------------------- /.github/workflows/ci_testing.yml: -------------------------------------------------------------------------------- 1 | name: CI complete testing 2 | 3 | # see: https://help.github.com/en/actions/reference/events-that-trigger-workflows 4 | on: # Trigger the workflow on push or pull request, but only for the main branch 5 | push: {} 6 | pull_request: 7 | branches: [main] 8 | 9 | jobs: 10 | pytester: 11 | runs-on: ${{ matrix.os }} 12 | strategy: 13 | fail-fast: false 14 | matrix: 15 | os: ["ubuntu-22.04"] # macOS-11, windows-2019 16 | python-version: ["3.8"] 17 | package: ["plant_pathology", "imet_collect", "cassava", "birdclef"] 18 | # Timeout: https://stackoverflow.com/a/59076067/4521646 19 | timeout-minutes: 25 20 | steps: 21 | - uses: actions/checkout@v4 22 | - name: Set up Python ${{ matrix.python-version }} 23 | uses: actions/setup-python@v5 24 | with: 25 | python-version: ${{ matrix.python-version }} 26 | #cache: "pip" 27 | 28 | - name: Install dependencies 29 | run: | 30 | sudo apt install -y libsndfile1 31 | pip install ".[test,${{ matrix.package }}]" --find-links https://download.pytorch.org/whl/cpu/torch_stable.html 32 | pip list 33 | shell: bash 34 | 35 | - name: Tests 36 | run: | 37 | python -m pytest kaggle_imgclassif tests/${{ matrix.package }} -v --cov=kaggle_imgclassif 38 | 39 | - name: Statistics 40 | run: | 41 | coverage report 42 | coverage xml 43 | 44 | - name: Upload coverage to Codecov 45 | uses: codecov/codecov-action@v5 46 | if: always() 47 | # see: https://github.com/actions/toolkit/issues/399 48 | continue-on-error: true 49 | with: 50 | token: ${{ secrets.CODECOV_TOKEN }} 51 | file: coverage.xml 52 | flags: cpu,pytest,python${{ matrix.python-version }},${{ matrix.package }} 53 | fail_ci_if_error: false 54 | 55 | 56 | testing-guardian: 57 | runs-on: ubuntu-latest 58 | needs: pytester 59 | if: always() 60 | steps: 61 | - run: echo "${{ needs.pytester.result }}" 62 | - name: failing... 63 | if: needs.pytester.result == 'failure' 64 | run: exit 1 65 | - name: cancelled or skipped... 66 | if: contains(fromJSON('["cancelled", "skipped"]'), needs.pytester.result) 67 | timeout-minutes: 1 68 | run: sleep 90 69 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | default_language_version: 2 | python: python3 3 | 4 | ci: 5 | autofix_prs: true 6 | autoupdate_commit_msg: '[pre-commit.ci] pre-commit suggestions' 7 | autoupdate_schedule: quarterly 8 | # submodules: true 9 | 10 | repos: 11 | - repo: https://github.com/pre-commit/pre-commit-hooks 12 | rev: v5.0.0 13 | hooks: 14 | - id: end-of-file-fixer 15 | - id: trailing-whitespace 16 | - id: check-case-conflict 17 | - id: check-json 18 | - id: check-yaml 19 | - id: check-toml 20 | - id: check-added-large-files 21 | exclude: .*\.ipynb 22 | args: ['--maxkb=250', '--enforce-all'] 23 | - id: check-docstring-first 24 | - id: detect-private-key 25 | 26 | - repo: https://github.com/executablebooks/mdformat 27 | rev: 0.7.22 28 | hooks: 29 | - id: mdformat 30 | args: ['--number'] 31 | additional_dependencies: 32 | - mdformat-gfm 33 | - mdformat-black 34 | - mdformat_frontmatter 35 | 36 | - repo: https://github.com/astral-sh/ruff-pre-commit 37 | rev: v0.11.4 38 | hooks: 39 | # try to fix what is possible 40 | - id: ruff 41 | args: ["--fix"] 42 | # perform formatting updates 43 | - id: ruff-format 44 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Jirka Borovec 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | # Manifest syntax https://docs.python.org/2/distutils/sourcedist.html 2 | graft wheelhouse 3 | 4 | recursive-exclude __pycache__ *.py[cod] *.orig 5 | 6 | # Include the README and CHANGELOG 7 | include *.md 8 | recursive-include assets *.png *.jpg 9 | 10 | # Include the license file 11 | include LICENSE 12 | 13 | exclude *.sh 14 | exclude *.toml 15 | exclude *.svg 16 | exclude *-app.py 17 | 18 | # Exclude build configs 19 | exclude *.yml 20 | exclude *.yaml 21 | 22 | prune .git 23 | prune .github 24 | prune notebook* 25 | prune scripts* 26 | prune temp* 27 | prune test* 28 | prune docs* 29 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Kaggle: Image classification challenges 2 | 3 | ![CI complete testing](https://github.com/Borda/kaggle_image-classify/workflows/CI%20complete%20testing/badge.svg?branch=main&event=push) 4 | [![codecov](https://codecov.io/gh/Borda/kaggle_image-classify/branch/main/graph/badge.svg?token=5t1Aj5BIyS)](https://codecov.io/gh/Borda/kaggle_image-classify) 5 | [![pre-commit.ci status](https://results.pre-commit.ci/badge/github/Borda/kaggle_image-classify/main.svg)](https://results.pre-commit.ci/latest/github/Borda/kaggle_image-classify/main) 6 | 7 | ## Experimentation 8 | 9 | ### install this tooling 10 | 11 | A simple way how to use this basic functions: 12 | 13 | ```bash 14 | ! pip install https://github.com/Borda/kaggle_image-classify/archive/main.zip 15 | ``` 16 | 17 | ## Kaggle: [Herbarium 2022](https://www.kaggle.com/competitions/herbarium-2022-fgvc9) 18 | 19 | The Herbarium 2022: Flora of North America dataset comprises 1.05 M images of 15,501 vascular plants, which constitute more than 90% of the taxa documented in North America. The provided dataset is constrained to include only vascular land plants (lycophytes, ferns, gymnosperms, and flowering plants) and it has a long-tail distribution. The number of images per taxon is as few as seven and as many as 100 images. Although more images are available. 20 | 21 | ![Sample images](./assets/herbarium_sample-imgs.jpg) 22 | 23 | ### run notebooks in Kaggle 24 | 25 | - [🌿Herbarium: EDA 🔎 & baseline Flash⚡EfficientNet](https://www.kaggle.com/code/jirkaborovec/herbarium-eda-baseline-flash-efficientnet) 26 | - [🌿Herbarium: Lightning⚡Flash (inference)](https://www.kaggle.com/code/jirkaborovec/herbarium-lightning-flash-inference) 27 | 28 | ### run notebooks in Colab 29 | 30 | - [🌿Herbarium with Lit⚡Flash & EfficientNet](https://colab.research.google.com/github/Borda/kaggle_image-classify/blob/main/notebooks/Herbarium-with-Flash-EfficientNet.ipynb) 31 | 32 | ### some results 33 | 34 | Training progress with EffNet-b3 with training for 10 epochs: 35 | 36 | ![Training process](./assets/herbarium_training-metrics.png) 37 | 38 | ## Kaggle: [Plant Pathology 2021 - FGVC8](https://www.kaggle.com/c/plant-pathology-2021-fgvc8) 39 | 40 | Foliar (leaf) diseases pose a major threat to the overall productivity and quality of apple orchards. 41 | The current process for disease diagnosis in apple orchards is based on manual scouting by humans, which is time-consuming and expensive. 42 | 43 | The main objective of the competition is to develop machine learning-based models to accurately classify a given leaf image from the test dataset to a particular disease category, and to identify an individual disease from multiple disease symptoms on a single leaf image. 44 | 45 | ![Sample images](./assets/plants_sample-images.jpg) 46 | 47 | ### run notebooks in Kaggle 48 | 49 | - [Plant Pathology with Flash](https://www.kaggle.com/jirkaborovec/plant-pathology-with-pytorch-lightning-flash) 50 | - [Plant Pathology with Lightning ⚡](https://www.kaggle.com/jirkaborovec/plant-pathology-with-lightning) 51 | - [Plant Pathology with Lightning [predictions]](https://www.kaggle.com/jirkaborovec/plant-pathology-with-lightning-predictions) 52 | 53 | ### run notebooks in Colab 54 | 55 | - [Plant pathology with Lightning](https://colab.research.google.com/github/Borda/kaggle_image-classify/blob/main/notebooks/Plant-Pathology-with-Lightning.ipynb) 56 | - [Plant pathology with Lightning - StandAlone](https://colab.research.google.com/github/Borda/kaggle_image-classify/blob/main/notebooks/Plant-Pathology-with-Lightning_standalone.ipynb) (without this package) 57 | - [Plant pathology with Flash](https://colab.research.google.com/github/Borda/kaggle_image-classify/blob/main/notebooks/Plant-Pathology-with-Flash.ipynb) 58 | 59 | I would recommend uploading the dataset to you personal gDrive and then in notebooks connect the gDrive which saves you lost of time with re-uploading dataset when ever your Colab is reset... :\] 60 | 61 | ### some results 62 | 63 | Training progress with ResNet50 with training for 10 epochs > over 96% validation accuracy: 64 | 65 | ![Training process](./assets/plants_training-metrics.png) 66 | 67 | ### More reading 68 | 69 | - [Practical Lighting Tips to Rank on Kaggle Image Challenges](https://devblog.pytorchlightning.ai/practical-tips-to-rank-on-kaggle-image-challenges-with-lightning-242e2e533429) 70 | 71 | ## Kaggle: [iMet Collection 2021 x AIC - FGVC8](https://www.kaggle.com/c/imet-2021-fgvc8) 72 | 73 | The online cataloguing information is generated by subject matter experts and includes a wide range of data. These include, but are not limited to: multiple object classifications, artist, title, period, date, medium, culture, size, provenance, geographic location, and other related museum objects within The Met’s collection. 74 | Adding fine-grained attributes to aid in the visual understanding of the museum objects will enable the ability to search for visually related objects. 75 | 76 | ![Sample images](./assets/imet_sample-imgs.png) 77 | 78 | ### run notebooks in Kaggle 79 | 80 | - [iMet Collection with Lightning ⚡](https://www.kaggle.com/jirkaborovec/imet-with-lightning) 81 | 82 | ### run notebooks in Colab 83 | 84 | - [iMet Collection with Lightning with ResNet50](https://colab.research.google.com/github/Borda/kaggle_image-classify/blob/main/notebooks/iMet-with-Lightning.ipynb) 85 | - [iMet Collection with Lightning and VisionTransformers from TIMM](https://colab.research.google.com/github/Borda/kaggle_image-classify/blob/main/notebooks/iMet-with-Lightning-and-ViT.ipynb) 86 | 87 | I would recommend uploading the dataset to you personal gDrive and then in notebooks connect the gDrive which saves you lost of time with re-uploading dataset when ever your Colab is reset... :\] 88 | 89 | ### some results 90 | 91 | Training progress with ResNet50 with training for 35 epochs and subset labels with ore then 100 samples: 92 | 93 | ![training on 100 samples per class](./assets/imet_training-cls-spl-100.png) 94 | 95 | ## Kaggle: [Cassava Leaf Disease Classification](https://www.kaggle.com/c/cassava-leaf-disease-classification/overview) 96 | 97 | The task is to classify each cassava image into five categories indicating - plant with a certain kind of disease or healthy leaf. 98 | 99 | Organizers introduced a dataset of 21,367 labeled images collected during a regular survey in Uganda. Most images were crowd-sourced from farmers taking photos of their gardens, and annotated by experts at the National Crops Resources Research Institute (NaCRRI) in collaboration with the AI lab at Makerere University, Kampala. 100 | 101 | ![Sample images](./assets/cassava_images.jpg) 102 | 103 | ### run notebooks in Colab 104 | 105 | - [Cassava with Lightning](https://colab.research.google.com/github/Borda/kaggle_image-classify/blob/main/notebooks/Cassava_with_Lightning.ipynb) 106 | - [Cassava with Flash](https://colab.research.google.com/github/Borda/kaggle_image-classify/blob/main/notebooks/Cassava_with_Flash.ipynb) 107 | 108 | I would recommend uploading the dataset to you personal gDrive and then in notebooks connect the gDrive which saves you lost of time with re-uploading dataset when ever your Colab is reset... :\] 109 | 110 | ### some results 111 | 112 | Training progress with ResNet50 with training for 10 epochs: 113 | 114 | ![Training process](./assets/cassava_metrics.png) 115 | -------------------------------------------------------------------------------- /_config.yml: -------------------------------------------------------------------------------- 1 | theme: jekyll-theme-cayman 2 | -------------------------------------------------------------------------------- /assets/cassava_images.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Borda/kaggle_image-classify/a2ce04098b04aa7a7f4c8e0e9a8005ee60e951b6/assets/cassava_images.jpg -------------------------------------------------------------------------------- /assets/cassava_metrics.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Borda/kaggle_image-classify/a2ce04098b04aa7a7f4c8e0e9a8005ee60e951b6/assets/cassava_metrics.png -------------------------------------------------------------------------------- /assets/herbarium_sample-imgs.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Borda/kaggle_image-classify/a2ce04098b04aa7a7f4c8e0e9a8005ee60e951b6/assets/herbarium_sample-imgs.jpg -------------------------------------------------------------------------------- /assets/herbarium_training-metrics.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Borda/kaggle_image-classify/a2ce04098b04aa7a7f4c8e0e9a8005ee60e951b6/assets/herbarium_training-metrics.png -------------------------------------------------------------------------------- /assets/imet_sample-imgs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Borda/kaggle_image-classify/a2ce04098b04aa7a7f4c8e0e9a8005ee60e951b6/assets/imet_sample-imgs.png -------------------------------------------------------------------------------- /assets/imet_training-cls-spl-100.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Borda/kaggle_image-classify/a2ce04098b04aa7a7f4c8e0e9a8005ee60e951b6/assets/imet_training-cls-spl-100.png -------------------------------------------------------------------------------- /assets/imet_training-cls-spl-10k.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Borda/kaggle_image-classify/a2ce04098b04aa7a7f4c8e0e9a8005ee60e951b6/assets/imet_training-cls-spl-10k.png -------------------------------------------------------------------------------- /assets/plants_sample-images.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Borda/kaggle_image-classify/a2ce04098b04aa7a7f4c8e0e9a8005ee60e951b6/assets/plants_sample-images.jpg -------------------------------------------------------------------------------- /assets/plants_training-metrics.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Borda/kaggle_image-classify/a2ce04098b04aa7a7f4c8e0e9a8005ee60e951b6/assets/plants_training-metrics.png -------------------------------------------------------------------------------- /kaggle_imgclassif/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Borda/kaggle_image-classify/a2ce04098b04aa7a7f4c8e0e9a8005ee60e951b6/kaggle_imgclassif/__init__.py -------------------------------------------------------------------------------- /kaggle_imgclassif/birdclef/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | __version__ = "0.1.0dev" 4 | __docs__ = "Tooling for Kaggle BirdCLEF" 5 | __author__ = "Jiri Borovec" 6 | __author_email__ = "jirka@pytorchlightning.ai" 7 | 8 | _PATH_PACKAGE = os.path.realpath(os.path.dirname(__file__)) 9 | _PATH_PROJECT = os.path.dirname(_PATH_PACKAGE) 10 | -------------------------------------------------------------------------------- /kaggle_imgclassif/birdclef/data.py: -------------------------------------------------------------------------------- 1 | import os 2 | from math import ceil 3 | 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | from PIL import Image 7 | 8 | try: 9 | import librosa 10 | import noisereduce 11 | except ImportError: 12 | noisereduce, librosa = None, None 13 | 14 | SPECTROGRAM_PARAMS = dict( 15 | sample_rate=32_000, hop_length=640, n_fft=800, n_mels=128, fmin=20, fmax=16_000, win_length=512 16 | ) 17 | PCEN_PARAS = dict( 18 | time_constant=0.06, 19 | eps=1e-6, 20 | gain=0.8, 21 | power=0.25, 22 | bias=10, 23 | ) 24 | SPECTROGRAM_RANGE = (-80, 0) 25 | 26 | 27 | def create_spectrogram( 28 | fname: str, 29 | reduce_noise: bool = False, 30 | frame_size: int = 5, 31 | frame_step: int = 2, 32 | spec_params: dict = SPECTROGRAM_PARAMS, 33 | ) -> list: 34 | waveform, sample_rate = librosa.core.load(fname, sr=spec_params["sample_rate"], mono=True) 35 | if reduce_noise: 36 | waveform = noisereduce.reduce_noise( 37 | y=waveform, 38 | sr=sample_rate, 39 | time_constant_s=float(frame_size), 40 | time_mask_smooth_ms=250, 41 | n_fft=spec_params["n_fft"], 42 | use_tqdm=False, 43 | n_jobs=2, 44 | ) 45 | 46 | frames = cut_frames(waveform, sample_rate, frame_size, frame_step) 47 | spectrograms = [] 48 | for frm in frames: 49 | sg = librosa.feature.melspectrogram( 50 | y=frm, 51 | sr=sample_rate, 52 | n_fft=spec_params["n_fft"], 53 | win_length=spec_params["win_length"], 54 | hop_length=spec_params["hop_length"], 55 | n_mels=spec_params["n_mels"], 56 | fmin=spec_params["fmin"], 57 | fmax=spec_params["fmax"], 58 | power=1, 59 | ) 60 | # sg = librosa.pcen(sg, sr=sample_rate, hop_length=spec_params["hop_length"], **PCEN_PARAS) 61 | sg = librosa.amplitude_to_db(sg, ref=np.max) 62 | spectrograms.append(np.nan_to_num(sg)) 63 | return spectrograms 64 | 65 | 66 | def cut_frames( 67 | waveform, 68 | sample_rate: int, 69 | frame_size: int = 5, 70 | frame_step: int = 2, 71 | min_frame_fraction: float = 0.2, 72 | ): 73 | step = int(frame_step * sample_rate) 74 | size = int(frame_size * sample_rate) 75 | count = ceil((len(waveform) - size) / float(step)) 76 | frames = [] 77 | for i in range(max(1, count)): 78 | begin = i * step 79 | end = begin + size 80 | frame = waveform[begin:end] 81 | if len(frame) < size: 82 | if i == 0: 83 | rep = round(float(size) / len(frame)) 84 | frame = frame.repeat(int(rep)) 85 | elif len(frame) < (size * min_frame_fraction): 86 | continue 87 | else: 88 | frame = waveform[-size:] 89 | frames.append(frame) 90 | return frames 91 | 92 | 93 | def convert_and_export( 94 | fn: str, 95 | path_in: str, 96 | path_out: str, 97 | reduce_noise: bool = False, 98 | frame_size: int = 5, 99 | frame_step: int = 2, 100 | img_extension: str = ".png", 101 | img_size: int = 512, 102 | ) -> None: 103 | path_audio = os.path.join(path_in, fn) 104 | try: 105 | sgs = create_spectrogram( 106 | path_audio, 107 | reduce_noise=reduce_noise, 108 | frame_size=frame_size, 109 | frame_step=frame_step, 110 | ) 111 | except Exception as ex: 112 | print(f"Failed conversion for audio: {path_audio}\n with {ex}") 113 | return 114 | if not sgs: 115 | print(f"Too short audio for: {path_audio} with ") 116 | return 117 | path_npz = os.path.join(path_out, fn + ".npz") 118 | os.makedirs(os.path.dirname(path_npz), exist_ok=True) 119 | np.savez_compressed(path_npz, np.array(sgs, dtype=np.float16)) 120 | for i, sg in enumerate(sgs): 121 | path_img = os.path.join(path_out, fn + f".{i:03}" + img_extension) 122 | try: 123 | if img_extension == ".png": 124 | sg = (sg - SPECTROGRAM_RANGE[0]) / float(SPECTROGRAM_RANGE[1] - SPECTROGRAM_RANGE[0]) 125 | sg = np.clip(sg, a_min=0, a_max=1) * 255 126 | img = Image.fromarray(sg.astype(np.uint8)) 127 | if img_size: 128 | img = img.resize((img_size, img_size)) 129 | img.save(path_img) 130 | else: 131 | plt.imsave(path_img, sg, vmin=SPECTROGRAM_RANGE[0], vmax=SPECTROGRAM_RANGE[1]) 132 | except Exception as ex: 133 | print(f"Failed exporting for image: {path_img}\n with {ex}") 134 | continue 135 | -------------------------------------------------------------------------------- /kaggle_imgclassif/cassava/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | __version__ = "0.1.1" 4 | __docs__ = "Tooling for Kaggle Cassava Leaf Disease Classification" 5 | __author__ = "Jiri Borovec" 6 | __author_email__ = "jirka@pytorchlightning.ai" 7 | 8 | _PATH_PACKAGE = os.path.realpath(os.path.dirname(__file__)) 9 | _PATH_PROJECT = os.path.dirname(_PATH_PACKAGE) 10 | -------------------------------------------------------------------------------- /kaggle_imgclassif/cassava/data.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import multiprocessing as mproc 3 | import os 4 | from math import ceil 5 | 6 | import matplotlib.pylab as plt 7 | import pandas as pd 8 | from PIL import Image 9 | from pytorch_lightning import LightningDataModule 10 | from torch.utils.data import DataLoader, Dataset 11 | from torchvision import transforms as T 12 | 13 | TRAIN_TRANSFORM = T.Compose([ 14 | T.Resize(512), 15 | T.RandomPerspective(), 16 | T.RandomResizedCrop(224), 17 | T.RandomHorizontalFlip(), 18 | T.RandomVerticalFlip(), 19 | T.ToTensor(), 20 | # T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 21 | T.Normalize([0.431, 0.498, 0.313], [0.237, 0.239, 0.227]), 22 | ]) 23 | 24 | VALID_TRANSFORM = T.Compose([ 25 | T.Resize(256), 26 | T.CenterCrop(224), 27 | T.ToTensor(), 28 | # T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 29 | T.Normalize([0.431, 0.498, 0.313], [0.237, 0.239, 0.227]), 30 | ]) 31 | 32 | 33 | class CassavaDataset(Dataset): 34 | def __init__( 35 | self, 36 | path_csv: str = "/content/train.csv", 37 | path_img_dir: str = "/content/train_images/", 38 | transforms=None, 39 | mode: str = "train", 40 | split: float = 0.8, 41 | ): 42 | self.path_img_dir = path_img_dir 43 | self.transforms = transforms 44 | self.mode = mode 45 | 46 | self.data = pd.read_csv(path_csv) 47 | # shuffle data 48 | self.data = self.data.sample(frac=1, random_state=42).reset_index(drop=True) 49 | 50 | # split dataset 51 | assert 0.0 <= split <= 1.0 52 | frac = int(ceil(split * len(self.data))) 53 | self.data = self.data[:frac] if mode == "train" else self.data[frac:] 54 | self.img_names = list(self.data["image_id"]) 55 | self.labels = list(self.data["label"]) 56 | 57 | def __getitem__(self, idx: int) -> tuple: 58 | img_path = os.path.join(self.path_img_dir, self.img_names[idx]) 59 | assert os.path.isfile(img_path) 60 | label = self.labels[idx] 61 | img = plt.imread(img_path) 62 | 63 | # augmentation 64 | if self.transforms: 65 | img = self.transforms(Image.fromarray(img)) 66 | return img, label 67 | 68 | def __len__(self) -> int: 69 | return len(self.data) 70 | 71 | 72 | class CassavaDataModule(LightningDataModule): 73 | def __init__( 74 | self, 75 | path_csv: str = "/content/train.csv", 76 | path_img_dir: str = "/content/train_images/", 77 | train_augment=TRAIN_TRANSFORM, 78 | valid_augment=VALID_TRANSFORM, 79 | batch_size: int = 128, 80 | split: float = 0.8, 81 | ): 82 | super().__init__() 83 | self.path_csv = path_csv 84 | self.path_img_dir = path_img_dir 85 | self.train_augment = train_augment 86 | self.valid_augment = valid_augment 87 | self.batch_size = batch_size 88 | self.split = split 89 | 90 | def prepare_data(self): 91 | pass 92 | 93 | def setup(self, stage=None): 94 | self.train_dataset = CassavaDataset( 95 | self.path_csv, 96 | self.path_img_dir, 97 | split=self.split, 98 | mode="train", 99 | transforms=self.train_augment, 100 | ) 101 | logging.info(f"training dataset: {len(self.train_dataset)}") 102 | self.valid_dataset = CassavaDataset( 103 | self.path_csv, 104 | self.path_img_dir, 105 | split=self.split, 106 | mode="valid", 107 | transforms=self.valid_augment, 108 | ) 109 | logging.info(f"validation dataset: {len(self.valid_dataset)}") 110 | 111 | def train_dataloader(self): 112 | return DataLoader( 113 | self.train_dataset, 114 | batch_size=self.batch_size, 115 | num_workers=mproc.cpu_count(), 116 | shuffle=True, 117 | ) 118 | 119 | def val_dataloader(self): 120 | return DataLoader( 121 | self.valid_dataset, 122 | batch_size=self.batch_size, 123 | num_workers=mproc.cpu_count(), 124 | shuffle=False, 125 | ) 126 | 127 | def test_dataloader(self): 128 | pass 129 | -------------------------------------------------------------------------------- /kaggle_imgclassif/cassava/models.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | import timm 4 | import torch 5 | from pytorch_lightning import LightningModule 6 | from torch import nn 7 | from torch.nn import functional as F 8 | from torchmetrics import Accuracy, F1Score 9 | 10 | 11 | class LitCassava(LightningModule): 12 | """Basic Cassava model. 13 | 14 | >>> model = LitCassava("resnet18") 15 | """ 16 | 17 | def __init__(self, model: Union[str, nn.Module], num_classes: int = 5, lr: float = 1e-4): 18 | super().__init__() 19 | if isinstance(model, str): 20 | self.model = timm.create_model(model, pretrained=True, num_classes=num_classes) 21 | else: 22 | self.model = model 23 | self.accuracy = Accuracy() 24 | self.f1_score = F1Score(num_classes) 25 | self.learn_rate = lr 26 | self.loss_fn = F.cross_entropy 27 | 28 | def forward(self, x): 29 | return F.softmax(self.model(x)) 30 | 31 | def training_step(self, batch, batch_idx): 32 | x, y = batch 33 | y_hat = self(x) 34 | loss = self.loss_fn(y_hat, y) 35 | self.log("train_loss", loss, prog_bar=True) 36 | return loss 37 | 38 | def validation_step(self, batch, batch_idx): 39 | x, y = batch 40 | y_hat = self(x) 41 | loss = self.loss_fn(y_hat, y) 42 | self.log("valid_loss", loss, prog_bar=False) 43 | self.log("valid_acc", self.accuracy(y_hat, y), prog_bar=True) 44 | self.log("valid_f1", self.f1_score(y_hat, y), prog_bar=True) 45 | 46 | def configure_optimizers(self): 47 | optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.learn_rate) 48 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, self.trainer.max_epochs, 0) 49 | return [optimizer], [scheduler] 50 | -------------------------------------------------------------------------------- /kaggle_imgclassif/imet_collect/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | __version__ = "0.2.1" 4 | __docs__ = "Tooling for Kaggle ..." 5 | __author__ = "Jiri Borovec" 6 | __author_email__ = "jirka@pytorchlightning.ai" 7 | 8 | _PATH_PACKAGE = os.path.realpath(os.path.dirname(__file__)) 9 | _PATH_PROJECT = os.path.dirname(_PATH_PACKAGE) 10 | -------------------------------------------------------------------------------- /kaggle_imgclassif/imet_collect/data.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import itertools 3 | import logging 4 | import multiprocessing as mproc 5 | import os 6 | from typing import Dict, List, Optional, Sequence, Tuple, Union 7 | 8 | import numpy as np 9 | import pandas as pd 10 | import torch 11 | import tqdm 12 | from joblib import Parallel, delayed 13 | from PIL import Image 14 | from pytorch_lightning import LightningDataModule 15 | from torch import Tensor 16 | from torch.utils.data import DataLoader, Dataset 17 | from torchvision import transforms as T 18 | 19 | try: 20 | import cv2 21 | except ImportError: 22 | cv2 = None 23 | 24 | # ImageFile.LOAD_TRUNCATED_IMAGES = True 25 | 26 | #: default training augmentation 27 | TORCHVISION_TRAIN_TRANSFORM = T.Compose([ 28 | T.Resize(size=256, interpolation=Image.BILINEAR), 29 | T.RandomRotation(degrees=25), 30 | T.RandomPerspective(distortion_scale=0.2), 31 | T.RandomResizedCrop(size=224), 32 | # T.RandomHorizontalFlip(p=0.5), 33 | T.RandomVerticalFlip(p=0.5), 34 | # T.ColorJitter(brightness=0.05, contrast=0.05, saturation=0.05, hue=0.05), 35 | T.ToTensor(), 36 | T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 37 | ]) 38 | #: default validation augmentation 39 | TORCHVISION_VALID_TRANSFORM = T.Compose([ 40 | T.Resize(size=256, interpolation=Image.BILINEAR), 41 | T.CenterCrop(size=224), 42 | T.ToTensor(), 43 | T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 44 | ]) 45 | 46 | 47 | def load_image(path_img: str) -> Image.Image: 48 | try: 49 | return Image.open(path_img) 50 | except AttributeError: 51 | img = cv2.imread(path_img) 52 | if img.ndim == 3 and img.shape[-1] == 3: 53 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 54 | return Image.fromarray(img) 55 | 56 | 57 | def get_nb_pixels(img_path: str): 58 | try: 59 | img = load_image(img_path) 60 | return np.prod(img.size) 61 | except Exception: 62 | return 0 63 | 64 | 65 | class IMetDataset(Dataset): 66 | """The ful dataset with one-hot encoding for multi-label case.""" 67 | 68 | IMAGE_SIZE_LIMIT = 1000 69 | COL_LABELS = "attribute_ids" 70 | COL_IMAGES = "id" 71 | 72 | def __init__( 73 | self, 74 | df_data: Union[str, pd.DataFrame] = "train-from-kaggle.csv", 75 | path_img_dir: str = "train-1/train-1", 76 | transforms=None, 77 | mode: str = "train", 78 | split: float = 0.8, 79 | uq_labels: Tuple[str] = None, 80 | random_state: Optional[int] = None, 81 | check_imgs: bool = True, 82 | ): 83 | self.path_img_dir = path_img_dir 84 | self.transforms = transforms 85 | self.mode = mode 86 | self._img_names = None 87 | self._raw_labels = None 88 | 89 | # set or load the config table 90 | if isinstance(df_data, pd.DataFrame): 91 | self.data = df_data 92 | elif isinstance(df_data, str): 93 | assert os.path.isfile(df_data), f"missing file: {df_data}" 94 | self.data = pd.read_csv(df_data) 95 | else: 96 | raise ValueError(f"unrecognised input for DataFrame/CSV: {df_data}") 97 | 98 | # take over existing table or load from file 99 | if uq_labels: 100 | self.labels_unique = tuple(uq_labels) 101 | else: 102 | labels_all = list(itertools.chain(*[lbs.split(" ") for lbs in self.raw_labels])) 103 | # labels_all = [int(lb) for lb in labels_all] 104 | self.labels_unique = tuple(sorted(set(labels_all))) 105 | self.labels_lut = {lb: i for i, lb in enumerate(self.labels_unique)} 106 | self.num_classes = len(self.labels_unique) 107 | 108 | # filter/drop too small images 109 | if check_imgs: 110 | with Parallel(n_jobs=mproc.cpu_count()) as parallel: 111 | self.data["pixels"] = parallel( 112 | delayed(get_nb_pixels)(os.path.join(self.path_img_dir, im)) for im in self.img_names 113 | ) 114 | nb_small_imgs = sum(self.data["pixels"] < self.IMAGE_SIZE_LIMIT) 115 | if nb_small_imgs: 116 | logging.warning(f"found and dropped {nb_small_imgs} too small or invalid images :/") 117 | self.data = self.data[self.data["pixels"] >= self.IMAGE_SIZE_LIMIT] 118 | # shuffle data 119 | if random_state is not None: 120 | self.data = self.data.sample(frac=1, random_state=random_state).reset_index(drop=True) 121 | 122 | # split dataset 123 | assert 0.0 <= split <= 1.0, f"split {split} is out of range" 124 | frac = int(split * len(self.data)) 125 | self.data = self.data[:frac] if mode == "train" else self.data[frac:] 126 | # need to reset after another split since it cached 127 | self._img_names = None 128 | self._raw_labels = None 129 | self.labels = self._prepare_labels() 130 | 131 | @property 132 | def img_names(self): 133 | if not self._img_names: 134 | self._img_names = [f"{n}.png" if "." not in n else n for n in self.data[self.COL_IMAGES]] 135 | return self._img_names 136 | 137 | @property 138 | def raw_labels(self): 139 | if not self._raw_labels: 140 | self._raw_labels = list(self.data[self.COL_LABELS]) 141 | return self._raw_labels 142 | 143 | def _prepare_labels(self) -> list: 144 | return [torch.tensor(self.to_binary_encoding(lb)) if lb else None for lb in self.raw_labels] 145 | 146 | def to_binary_encoding(self, labels: str) -> tuple: 147 | # processed with encoding 148 | encode = [0] * len(self.labels_unique) 149 | for lb in labels.split(" "): 150 | encode[self.labels_lut[lb]] = 1 151 | return tuple(encode) 152 | 153 | def __getitem__(self, idx: int) -> tuple: 154 | img_name = self.img_names[idx] 155 | img_path = os.path.join(self.path_img_dir, img_name) 156 | assert os.path.isfile(img_path) 157 | label = self.labels[idx] 158 | # todo: find some faster way, do conversion only if needed; im.mode not in ("L", "RGB") 159 | img = load_image(img_path).convert("RGB") 160 | 161 | # augmentation 162 | if self.transforms: 163 | img = self.transforms(img) 164 | 165 | # in case of predictions, return image name as label 166 | label = label if label is not None else img_name 167 | return img, label 168 | 169 | def __len__(self) -> int: 170 | return len(self.data) 171 | 172 | 173 | class IMetDM(LightningDataModule): 174 | IMAGE_EXTENSIONS = (".png", ".jpg", ".jpeg") 175 | 176 | def __init__( 177 | self, 178 | base_path: str, 179 | path_csv: str = "train-from-kaggle.csv", 180 | batch_size: int = 128, 181 | num_workers: int = None, 182 | train_transforms=TORCHVISION_TRAIN_TRANSFORM, 183 | valid_transforms=TORCHVISION_VALID_TRANSFORM, 184 | split: float = 0.8, 185 | random_state: Optional[int] = None, 186 | ): 187 | super().__init__() 188 | # path configurations 189 | assert os.path.isdir(base_path), f"missing folder: {base_path}" 190 | self.train_dir = os.path.join(base_path, "train-1/train-1") 191 | self.test_dir = os.path.join(base_path, "test/test") 192 | 193 | if not os.path.isfile(path_csv): 194 | path_csv = os.path.join(base_path, path_csv) 195 | assert os.path.isfile(path_csv), f"missing table: {path_csv}" 196 | self.path_csv = path_csv 197 | 198 | self.train_transforms = train_transforms 199 | self.valid_transforms = valid_transforms 200 | 201 | # other configs 202 | self.batch_size = batch_size 203 | self.split = split 204 | self.random_state = random_state 205 | self.num_workers = num_workers if num_workers is not None else mproc.cpu_count() 206 | self.labels_unique: Sequence = ... 207 | self.lut_label: Dict = ... 208 | self.label_histogram: Tensor = ... 209 | 210 | # need to be filled in setup() 211 | self.train_dataset = None 212 | self.valid_dataset = None 213 | self.test_table = [] 214 | self.test_dataset = None 215 | 216 | def prepare_data(self): 217 | pass 218 | 219 | @property 220 | def num_classes(self) -> int: 221 | return len(self.labels_unique) 222 | 223 | @staticmethod 224 | def binary_mapping( 225 | encoding: Tensor, 226 | lut_label: Dict[int, str], 227 | thr: float = 0.5, 228 | label_required: bool = True, 229 | ) -> Union[str, List[str]]: 230 | """Convert Model outputs to string labels. 231 | 232 | Args: 233 | encoding: one-hot encoding 234 | lut_label: look-up-table with labels 235 | thr: threshold for label binarization 236 | label_required: if it is required to return any label and no label is above `thr`, use argmax 237 | """ 238 | assert lut_label 239 | # on case it is not one hot encoding but single label 240 | if encoding.nelement() == 1: 241 | return lut_label[encoding[0]] 242 | labels = [lut_label[i] for i, s in enumerate(encoding) if s >= thr] 243 | # in case no reached threshold then take max 244 | if not labels and label_required: 245 | idx = torch.argmax(encoding).item() 246 | labels = [lut_label[idx]] 247 | return sorted(labels) 248 | 249 | def binary_encoding_to_labels( 250 | self, 251 | encoding: Tensor, 252 | thr: float = 0.5, 253 | with_sigm: bool = True, 254 | label_required: bool = True, 255 | ) -> Union[str, List[str]]: 256 | """Convert Model outputs to string labels. 257 | 258 | Args: 259 | encoding: one-hot encoding 260 | thr: threshold for label binarization 261 | with_sigm: apply sigmoid to convert to probabilities 262 | label_required: if it is required to return any label and no label is above `thr`, use argmax 263 | """ 264 | if with_sigm: 265 | encoding = torch.sigmoid(encoding) 266 | return self.binary_mapping(encoding, self.lut_label, thr=thr, label_required=label_required) 267 | 268 | def setup(self, *_, **__) -> None: 269 | """Prepare datasets.""" 270 | pbar = tqdm.tqdm(total=4) 271 | assert os.path.isdir(self.train_dir), f"missing folder: {self.train_dir}" 272 | ds = IMetDataset(self.path_csv, self.train_dir, mode="train", split=1.0) 273 | self.labels_unique = ds.labels_unique 274 | self.lut_label = dict(enumerate(self.labels_unique)) 275 | pbar.update() 276 | 277 | ds_defaults = dict( 278 | df_data=ds.data, 279 | path_img_dir=self.train_dir, 280 | split=self.split, 281 | uq_labels=self.labels_unique, 282 | check_imgs=False, 283 | random_state=self.random_state, 284 | ) 285 | self.train_dataset = IMetDataset(**ds_defaults, mode="train", transforms=self.train_transforms) 286 | logging.info(f"training dataset: {len(self.train_dataset)}") 287 | pbar.update() 288 | self.valid_dataset = IMetDataset(**ds_defaults, mode="valid", transforms=self.valid_transforms) 289 | logging.info(f"validation dataset: {len(self.valid_dataset)}") 290 | pbar.update() 291 | 292 | if not os.path.isdir(self.test_dir): 293 | return 294 | ls_images = glob.glob(os.path.join(self.test_dir, "*.*")) 295 | ls_images = [os.path.basename(p) for p in ls_images if os.path.splitext(p)[-1] in self.IMAGE_EXTENSIONS] 296 | self.test_table = [{"id": n, "attribute_ids": ""} for n in ls_images] 297 | self.test_dataset = IMetDataset( 298 | df_data=pd.DataFrame(self.test_table), 299 | path_img_dir=self.test_dir, 300 | split=0, 301 | uq_labels=self.labels_unique, 302 | mode="test", 303 | transforms=self.valid_transforms, 304 | ) 305 | logging.info(f"test dataset: {len(self.test_dataset)}") 306 | pbar.update() 307 | 308 | def train_dataloader(self) -> DataLoader: 309 | return DataLoader( 310 | self.train_dataset, 311 | batch_size=self.batch_size, 312 | num_workers=self.num_workers, 313 | shuffle=True, 314 | ) 315 | 316 | def val_dataloader(self) -> DataLoader: 317 | return DataLoader( 318 | self.valid_dataset, 319 | batch_size=self.batch_size, 320 | num_workers=self.num_workers, 321 | shuffle=False, 322 | ) 323 | 324 | def test_dataloader(self) -> Optional[DataLoader]: 325 | if self.test_dataset: 326 | return DataLoader( 327 | self.test_dataset, 328 | batch_size=self.batch_size, 329 | num_workers=self.num_workers, 330 | shuffle=False, 331 | ) 332 | logging.warning("no testing images found") 333 | return None 334 | -------------------------------------------------------------------------------- /kaggle_imgclassif/imet_collect/models.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union 2 | 3 | import timm 4 | import torch 5 | from pytorch_lightning import LightningModule 6 | from torch import Tensor, nn 7 | from torch.nn import functional as F 8 | from torchmetrics import Accuracy, F1Score, Precision 9 | 10 | 11 | class LitMet(LightningModule): 12 | """This model is meant and tested to be used together with ... 13 | 14 | >>> model = LitMet("resnet18", num_classes=125) 15 | """ 16 | 17 | def __init__( 18 | self, 19 | model: Union[str, nn.Module], 20 | num_classes: int, 21 | lr: float = 1e-4, 22 | augmentations: Optional[nn.Module] = None, 23 | ): 24 | super().__init__() 25 | if isinstance(model, str): 26 | self.name = model 27 | self.model = timm.create_model(model, pretrained=True, num_classes=num_classes) 28 | else: 29 | self.model = model 30 | self.name = model.__class__.__name__ 31 | self.num_classes = num_classes 32 | self.train_accuracy = Accuracy() 33 | _metrics_extra_args = dict(num_classes=self.num_classes, average="weighted") 34 | self.train_precision = Precision(**_metrics_extra_args) 35 | self.train_f1_score = F1Score(**_metrics_extra_args) 36 | self.val_accuracy = Accuracy() 37 | self.val_precision = Precision(**_metrics_extra_args) 38 | self.val_f1_score = F1Score(**_metrics_extra_args) 39 | self.learning_rate = lr 40 | self.aug = augmentations 41 | 42 | def forward(self, x: Tensor) -> Tensor: 43 | return self.model(x) 44 | 45 | def compute_loss(self, y_hat: Tensor, y: Tensor): 46 | return F.binary_cross_entropy_with_logits(y_hat, y.to(y_hat.dtype)) 47 | 48 | def training_step(self, batch, batch_idx): 49 | x, y = batch 50 | if self.aug: 51 | x = self.aug(x) # => batched augmentations 52 | y_hat = self(x) 53 | loss = self.compute_loss(y_hat, y) 54 | y_prob = torch.sigmoid(y_hat) 55 | self.log("train_loss", loss, prog_bar=False) 56 | self.log("train_acc", self.train_accuracy(y_prob, y), prog_bar=False) 57 | self.log("train_prec", self.train_precision(y_prob, y), prog_bar=False) 58 | self.log("train_f1", self.train_f1_score(y_prob, y), prog_bar=True) 59 | return loss 60 | 61 | def validation_step(self, batch, batch_idx): 62 | x, y = batch 63 | y_hat = self(x) 64 | loss = self.compute_loss(y_hat, y) 65 | y_prob = torch.sigmoid(y_hat) 66 | self.log("valid_loss", loss, prog_bar=False) 67 | self.log("valid_acc", self.val_accuracy(y_prob, y), prog_bar=True) 68 | self.log("valid_prec", self.val_precision(y_prob, y), prog_bar=True) 69 | self.log("valid_f1", self.val_f1_score(y_prob, y), prog_bar=True) 70 | 71 | def configure_optimizers(self): 72 | optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.learning_rate) 73 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, self.trainer.max_epochs, 0) 74 | return [optimizer], [scheduler] 75 | -------------------------------------------------------------------------------- /kaggle_imgclassif/plant_pathology/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | __version__ = "0.5.0" 4 | __docs__ = "Tooling for Kaggle Plant Pathology" 5 | __author__ = "Jiri Borovec" 6 | __author_email__ = "jirka@pytorchlightning.ai" 7 | 8 | _PATH_PACKAGE = os.path.realpath(os.path.dirname(__file__)) 9 | _PATH_PROJECT = os.path.dirname(_PATH_PACKAGE) 10 | 11 | #: computed color mean from given dataset 12 | DATASET_IMAGE_MEAN = (0.48690377, 0.62658835, 0.4078062) 13 | #: computed color STD from given dataset 14 | DATASET_IMAGE_STD = (0.18142496, 0.15883319, 0.19026241) 15 | -------------------------------------------------------------------------------- /kaggle_imgclassif/plant_pathology/augment.py: -------------------------------------------------------------------------------- 1 | """Module to perform efficient preprocess and data augmentation.""" 2 | 3 | from typing import Tuple 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | 9 | # Define the augmentations pipeline 10 | from PIL import Image 11 | from torch import Tensor 12 | from torchvision import transforms as T 13 | 14 | try: 15 | from kornia import augmentation, geometry, image_to_tensor 16 | except ImportError: 17 | augmentation, geometry, image_to_tensor = None, None, None 18 | 19 | from kaggle_imgclassif.plant_pathology import DATASET_IMAGE_MEAN, DATASET_IMAGE_STD 20 | 21 | #: default training augmentation 22 | TORCHVISION_TRAIN_TRANSFORM = T.Compose([ 23 | T.Resize(size=512, interpolation=Image.BILINEAR), 24 | T.RandomRotation(degrees=30), 25 | T.RandomPerspective(distortion_scale=0.4), 26 | T.RandomResizedCrop(size=224), 27 | T.RandomHorizontalFlip(p=0.5), 28 | T.RandomVerticalFlip(p=0.5), 29 | # T.ColorJitter(brightness=0.05, contrast=0.05, saturation=0.05, hue=0.05), 30 | T.ToTensor(), 31 | # T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 32 | T.Normalize(DATASET_IMAGE_MEAN, DATASET_IMAGE_STD), # custom 33 | ]) 34 | #: default validation augmentation 35 | TORCHVISION_VALID_TRANSFORM = T.Compose([ 36 | T.Resize(size=256, interpolation=Image.BILINEAR), 37 | T.CenterCrop(size=224), 38 | T.ToTensor(), 39 | # T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 40 | T.Normalize(DATASET_IMAGE_MEAN, DATASET_IMAGE_STD), # custom 41 | ]) 42 | 43 | 44 | class Resize(nn.Module): 45 | def __init__(self, size: int): 46 | super().__init__() 47 | self.size = size 48 | 49 | def forward(self, x): 50 | return geometry.resize(x[None], self.size)[0] 51 | 52 | 53 | class LitPreprocess(nn.Module): 54 | """Applies the processing to the image in the worker before collate.""" 55 | 56 | def __init__(self, img_size: Tuple[int, int]): 57 | super().__init__() 58 | if isinstance(img_size, int): 59 | img_size = (img_size, img_size) 60 | self.preprocess = nn.Sequential( 61 | # K.augmentation.RandomResizedCrop((224, 224)), 62 | Resize(img_size), # use this better to see whole image 63 | augmentation.Normalize(Tensor(DATASET_IMAGE_MEAN), Tensor(DATASET_IMAGE_STD)), 64 | ) 65 | 66 | @torch.no_grad() 67 | def forward(self, x: Tensor) -> Tensor: 68 | x = image_to_tensor(np.array(x)).float() / 255.0 69 | assert len(x.shape) == 3, x.shape 70 | out = self.preprocess(x) 71 | return out[0] 72 | 73 | 74 | class LitAugmenter(nn.Module): 75 | """Applies random augmentation to a batch of images.""" 76 | 77 | def __init__(self, viz: bool = False): 78 | super().__init__() 79 | self.viz = viz 80 | self.augmentations = nn.Sequential( 81 | augmentation.RandomRotation(degrees=30.0), 82 | augmentation.RandomPerspective(distortion_scale=0.4), 83 | augmentation.RandomResizedCrop((224, 224)), 84 | augmentation.RandomHorizontalFlip(p=0.5), 85 | augmentation.RandomVerticalFlip(p=0.5), 86 | # K.augmentation.GaussianBlur((3, 3), (0.1, 2.0), p=1.0), 87 | # K.augmentation.ColorJitter(0.01, 0.01, 0.01, 0.01, p=0.25), 88 | ) 89 | self.denorm = augmentation.Denormalize(Tensor(DATASET_IMAGE_MEAN), Tensor(DATASET_IMAGE_STD)) 90 | 91 | @torch.no_grad() 92 | def forward(self, x: Tensor) -> Tensor: 93 | assert len(x.shape) == 4, x.shape 94 | out = x 95 | # idx = torch.randperm(len(self.geometric))[0] # OneOf 96 | # out = self.geometric[idx](x) 97 | out = self.augmentations(out) 98 | if self.viz: 99 | out = self.denorm(out) 100 | return out 101 | 102 | 103 | #: Kornia default augmentations 104 | if augmentation: 105 | KORNIA_TRAIN_TRANSFORM = LitPreprocess(512) 106 | KORNIA_VALID_TRANSFORM = LitPreprocess(224) 107 | -------------------------------------------------------------------------------- /kaggle_imgclassif/plant_pathology/data.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import itertools 3 | import logging 4 | import os 5 | from typing import Dict, List, Optional, Sequence, Tuple, Type, Union 6 | from warnings import warn 7 | 8 | import matplotlib.pyplot as plt 9 | import pandas as pd 10 | import torch 11 | from PIL import Image 12 | from pytorch_lightning import LightningDataModule 13 | from torch import Tensor 14 | from torch.utils.data import DataLoader, Dataset 15 | 16 | try: 17 | from torchsampler import ImbalancedDatasetSampler 18 | except ImportError: 19 | ImbalancedDatasetSampler = None 20 | 21 | try: 22 | from kaggle_imgclassif.plant_pathology.augment import KORNIA_TRAIN_TRANSFORM, KORNIA_VALID_TRANSFORM 23 | except ImportError: 24 | KORNIA_TRAIN_TRANSFORM, KORNIA_VALID_TRANSFORM = None, None 25 | 26 | IMAGE_EXTENSIONS = (".png", ".jpg", ".jpeg") 27 | 28 | 29 | class PlantPathologyDataset(Dataset): 30 | """The ful dataset with one-hot encoding for multi-label case.""" 31 | 32 | def __init__( 33 | self, 34 | df_data: Union[str, pd.DataFrame] = "train.csv", 35 | path_img_dir: str = "train_images", 36 | transforms=None, 37 | mode: str = "train", 38 | split: float = 0.8, 39 | uq_labels: Tuple[str] = None, 40 | random_state=42, 41 | ): 42 | self.path_img_dir = path_img_dir 43 | self.transforms = transforms 44 | self.mode = mode 45 | 46 | # set or load the config table 47 | if isinstance(df_data, pd.DataFrame): 48 | self.data = df_data 49 | elif isinstance(df_data, str): 50 | assert os.path.isfile(df_data), f"missing file: {df_data}" 51 | self.data = pd.read_csv(df_data) 52 | else: 53 | raise ValueError(f"unrecognised input for DataFrame/CSV: {df_data}") 54 | 55 | # take over existing table or load from file 56 | if uq_labels: 57 | self.labels_unique = tuple(uq_labels) 58 | else: 59 | labels_all = list(itertools.chain(*[lbs.split(" ") for lbs in self.raw_labels])) 60 | self.labels_unique = tuple(sorted(set(labels_all))) 61 | self.labels_lut = {lb: i for i, lb in enumerate(self.labels_unique)} 62 | self.num_classes = len(self.labels_unique) 63 | 64 | # shuffle data 65 | self.data = self.data.sample(frac=1, random_state=random_state).reset_index(drop=True) 66 | 67 | # split dataset 68 | assert 0.0 <= split <= 1.0, f"split {split} is out of range" 69 | frac = int(split * len(self.data)) 70 | self.data = self.data[:frac] if mode == "train" else self.data[frac:] 71 | self.img_names = list(self.data["image"]) 72 | self.labels = self._prepare_labels() 73 | # compute importance order 74 | self.label_importance_index = [] 75 | 76 | @property 77 | def raw_labels(self): 78 | return list(self.data["labels"]) 79 | 80 | def _prepare_labels(self) -> list: 81 | return [torch.tensor(self.to_binary_encoding(lb)) if lb else None for lb in self.raw_labels] 82 | 83 | @property 84 | def label_histogram(self) -> Tensor: 85 | lb_stack = torch.tensor(list(map(tuple, self.labels))) 86 | return torch.sum(lb_stack, dim=0) 87 | 88 | def to_binary_encoding(self, labels: str) -> tuple: 89 | # processed with encoding 90 | one_hot = [0] * len(self.labels_unique) 91 | for lb in labels.split(" "): 92 | one_hot[self.labels_lut[lb]] = 1 93 | return tuple(one_hot) 94 | 95 | def __getitem__(self, idx: int) -> tuple: 96 | img_name = self.img_names[idx] 97 | img_path = os.path.join(self.path_img_dir, img_name) 98 | assert os.path.isfile(img_path) 99 | label = self.labels[idx] 100 | img = plt.imread(img_path) 101 | 102 | # augmentation 103 | if self.transforms: 104 | img = self.transforms(Image.fromarray(img)) 105 | # in case of predictions, return image name as label 106 | label = label if label is not None else img_name 107 | return img, label 108 | 109 | def __len__(self) -> int: 110 | return len(self.data) 111 | 112 | def get_sample_pseudo_label(self, idx: int): 113 | if not self.label_importance_index: 114 | idx_nb = list(enumerate(self.label_histogram)) 115 | idx_nb = sorted(idx_nb, key=lambda x: x[1]) 116 | self.label_importance_index = [i[0] for i in idx_nb] 117 | binary = self.labels[idx] 118 | # take the less occurred label, not the tuple combination as combination does not matter too much 119 | for i in self.label_importance_index: 120 | if binary[i]: 121 | return i 122 | # this is a failer... 123 | return tuple(binary.numpy()) 124 | 125 | def get_sample_pseudo_labels(self, *_): 126 | return [self.get_sample_pseudo_label(i) for i in range(len(self))] 127 | 128 | 129 | class PlantPathologySimpleDataset(PlantPathologyDataset): 130 | """Simplified version; we keep only complex label for multi-label cases and the true label for all others.""" 131 | 132 | def _translate_labels(self, lb): 133 | if lb is None: 134 | return None 135 | lb = self.labels_lut["complex"] if torch.sum(lb) > 1 else torch.argmax(lb) 136 | return int(lb) 137 | 138 | def _prepare_labels(self) -> list: 139 | labels = super()._prepare_labels() 140 | return list(map(self._translate_labels, labels)) 141 | 142 | @property 143 | def label_histogram(self) -> Tensor: 144 | if not isinstance(self.labels, Tensor): 145 | self.labels = torch.tensor(self.labels) 146 | return torch.bincount(self.labels) 147 | 148 | def get_sample_pseudo_label(self, idx: int): 149 | return self.labels[idx] 150 | 151 | 152 | class PlantPathologyDM(LightningDataModule): 153 | labels_unique: Sequence 154 | lut_label: Dict 155 | label_histogram: Tensor 156 | 157 | def __init__( 158 | self, 159 | path_csv: str = "train.csv", 160 | base_path: str = ".", 161 | batch_size: int = 32, 162 | num_workers: int = None, 163 | simple: bool = False, 164 | train_transforms=None, 165 | valid_transforms=None, 166 | split: float = 0.8, 167 | balancing: bool = False, 168 | ): 169 | super().__init__() 170 | # path configurations 171 | assert os.path.isdir(base_path), f"missing folder: {base_path}" 172 | self.train_dir = os.path.join(base_path, "train_images") 173 | self.test_dir = os.path.join(base_path, "test_images") 174 | 175 | if not os.path.isfile(path_csv): 176 | path_csv = os.path.join(base_path, path_csv) 177 | assert os.path.isfile(path_csv), f"missing table: {path_csv}" 178 | self.path_csv = path_csv 179 | 180 | # other configs 181 | self.batch_size = batch_size 182 | self.split = split 183 | self.num_workers = num_workers if num_workers is not None else os.cpu_count() 184 | self.balancing = balancing 185 | 186 | # need to be filled in setup() 187 | self.train_dataset = None 188 | self.valid_dataset = None 189 | self.test_table = [] 190 | self.test_dataset = None 191 | self.train_transforms = train_transforms or KORNIA_TRAIN_TRANSFORM 192 | self.valid_transforms = valid_transforms or KORNIA_VALID_TRANSFORM 193 | self.dataset_cls: Type = PlantPathologySimpleDataset if simple else PlantPathologyDataset 194 | 195 | def prepare_data(self): 196 | pass 197 | 198 | @property 199 | def num_classes(self) -> int: 200 | assert self.train_dataset 201 | assert self.valid_dataset 202 | return max(self.train_dataset.num_classes, self.valid_dataset.num_classes) 203 | 204 | @staticmethod 205 | def binary_mapping( 206 | encoding: Tensor, 207 | lut_label: Dict[int, str], 208 | thr: float = 0.5, 209 | label_required: bool = True, 210 | ) -> Union[str, List[str]]: 211 | """Convert Model outputs to string labels. 212 | 213 | Args: 214 | encoding: one-hot encoding 215 | lut_label: look-up-table with labels 216 | thr: threshold for label binarization 217 | label_required: if it is required to return any label and no label is above `thr`, use argmax 218 | """ 219 | assert lut_label 220 | # on case it is not one hot encoding but single label 221 | if encoding.nelement() == 1: 222 | return lut_label[encoding[0]] 223 | labels = [lut_label[i] for i, s in enumerate(encoding) if s >= thr] 224 | # in case no reached threshold then take max 225 | if not labels and label_required: 226 | idx = torch.argmax(encoding).item() 227 | labels = [lut_label[idx]] 228 | return sorted(labels) 229 | 230 | def binary_encoding_to_labels( 231 | self, 232 | encoding: Tensor, 233 | thr: float = 0.5, 234 | label_required: bool = True, 235 | ) -> Union[str, List[str]]: 236 | """Convert Model outputs to string labels. 237 | 238 | Args: 239 | encoding: one-hot encoding 240 | thr: threshold for label binarization 241 | label_required: if it is required to return any label and no label is above `thr`, use argmax 242 | """ 243 | return self.binary_mapping(encoding, self.lut_label, thr=thr, label_required=label_required) 244 | 245 | def setup(self, *_, **__) -> None: 246 | """Prepare datasets.""" 247 | assert os.path.isdir(self.train_dir), f"missing folder: {self.train_dir}" 248 | ds = self.dataset_cls(self.path_csv, self.train_dir, mode="train", split=1.0) 249 | self.labels_unique = ds.labels_unique 250 | self.label_histogram = ds.label_histogram 251 | self.lut_label = dict(enumerate(self.labels_unique)) 252 | 253 | ds_defaults = dict( 254 | df_data=self.path_csv, 255 | path_img_dir=self.train_dir, 256 | split=self.split, 257 | uq_labels=self.labels_unique, 258 | ) 259 | self.train_dataset = self.dataset_cls(**ds_defaults, mode="train", transforms=self.train_transforms) 260 | logging.info(f"training dataset: {len(self.train_dataset)}") 261 | self.valid_dataset = self.dataset_cls(**ds_defaults, mode="valid", transforms=self.valid_transforms) 262 | logging.info(f"validation dataset: {len(self.valid_dataset)}") 263 | 264 | if not os.path.isdir(self.test_dir): 265 | return 266 | ls_images = glob.glob(os.path.join(self.test_dir, "*.*")) 267 | ls_images = [os.path.basename(p) for p in ls_images if os.path.splitext(p)[-1] in IMAGE_EXTENSIONS] 268 | self.test_table = [dict(image=n, labels="") for n in ls_images] 269 | self.test_dataset = self.dataset_cls( 270 | df_data=pd.DataFrame(self.test_table), 271 | path_img_dir=self.test_dir, 272 | split=0, 273 | uq_labels=self.labels_unique, 274 | mode="test", 275 | transforms=self.valid_transforms, 276 | ) 277 | logging.info(f"test dataset: {len(self.test_dataset)}") 278 | 279 | def _dataloader_extra_args(self, dataset: PlantPathologyDataset) -> dict: 280 | dl_kwargs = dict(shuffle=True) 281 | # if you ask and you have it 282 | if self.balancing and ImbalancedDatasetSampler: 283 | dl_kwargs = dict( 284 | shuffle=False, 285 | sampler=ImbalancedDatasetSampler( 286 | dataset=dataset, 287 | callback_get_label=self.dataset_cls.get_sample_pseudo_labels, 288 | ), 289 | ) 290 | elif self.balancing: 291 | warn("You have asked for `ImbalancedDatasetSampler` but you do not have it installed.") 292 | return dl_kwargs 293 | 294 | def train_dataloader(self) -> DataLoader: 295 | dl_kwargs = self._dataloader_extra_args(self.train_dataset) 296 | return DataLoader( 297 | self.train_dataset, 298 | batch_size=self.batch_size, 299 | num_workers=self.num_workers, 300 | **dl_kwargs, 301 | ) 302 | 303 | def val_dataloader(self) -> DataLoader: 304 | return DataLoader( 305 | self.valid_dataset, 306 | batch_size=self.batch_size, 307 | num_workers=self.num_workers, 308 | shuffle=False, 309 | ) 310 | 311 | def test_dataloader(self) -> Optional[DataLoader]: 312 | if self.test_dataset: 313 | return DataLoader( 314 | self.test_dataset, 315 | batch_size=self.batch_size, 316 | num_workers=0, 317 | shuffle=False, 318 | ) 319 | logging.warning("no testing images found") 320 | return None 321 | -------------------------------------------------------------------------------- /kaggle_imgclassif/plant_pathology/models.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union 2 | 3 | import timm 4 | import torch 5 | from pytorch_lightning import LightningModule 6 | from torch import Tensor, nn 7 | from torch.nn import functional as F 8 | from torchmetrics import Accuracy, F1Score, Precision 9 | 10 | 11 | class LitPlantPathology(LightningModule): 12 | """This model is meant and tested to be used together with `PlantPathologySimpleDataset` 13 | 14 | >>> model = LitPlantPathology() 15 | """ 16 | 17 | def __init__( 18 | self, 19 | model: Union[nn.Module, str] = "resnet50", 20 | num_classes: int = 6, 21 | lr: float = 1e-4, 22 | augmentations: Optional[nn.Module] = None, 23 | ): 24 | super().__init__() 25 | if isinstance(model, str): 26 | self.arch = model 27 | self.model = timm.create_model(model, pretrained=True, num_classes=num_classes) 28 | else: 29 | self.model = model 30 | self.arch = model.__class__.__name__ 31 | self.num_classes = num_classes 32 | self.train_accuracy = Accuracy() 33 | self.train_precision = Precision(**self._metrics_extra_args) 34 | self.train_f1_score = F1Score(**self._metrics_extra_args) 35 | self.val_accuracy = Accuracy() 36 | self.val_precision = Precision(**self._metrics_extra_args) 37 | self.val_f1_score = F1Score(**self._metrics_extra_args) 38 | self.learning_rate = lr 39 | self.aug = augmentations 40 | 41 | @property 42 | def _metrics_extra_args(self): 43 | return dict() 44 | 45 | def forward(self, x: Tensor) -> Tensor: 46 | return F.softmax(self.model(x)) 47 | 48 | def compute_loss(self, y_hat: Tensor, y: Tensor): 49 | return F.cross_entropy(y_hat, y) 50 | 51 | def training_step(self, batch, batch_idx): 52 | x, y = batch 53 | if self.aug: 54 | x = self.aug(x) # => batched augmentations 55 | y_hat = self(x) 56 | loss = self.compute_loss(y_hat, y) 57 | self.log("train_loss", loss, prog_bar=False) 58 | self.log("train_acc", self.train_accuracy(y_hat, y), prog_bar=False) 59 | self.log("train_prec", self.train_precision(y_hat, y), prog_bar=False) 60 | self.log("train_f1", self.train_f1_score(y_hat, y), prog_bar=True) 61 | return loss 62 | 63 | def validation_step(self, batch, batch_idx): 64 | x, y = batch 65 | y_hat = self(x) 66 | loss = self.compute_loss(y_hat, y) 67 | self.log("valid_loss", loss, prog_bar=False) 68 | self.log("valid_acc", self.val_accuracy(y_hat, y), prog_bar=True) 69 | self.log("valid_prec", self.val_precision(y_hat, y), prog_bar=True) 70 | self.log("valid_f1", self.val_f1_score(y_hat, y), prog_bar=True) 71 | 72 | def configure_optimizers(self): 73 | optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.learning_rate) 74 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, self.trainer.max_epochs, 0) 75 | return [optimizer], [scheduler] 76 | 77 | 78 | class MultiPlantPathology(LitPlantPathology): 79 | """This model is meant and tested to be used together with `PlantPathologyDataset` 80 | 81 | >>> model = MultiPlantPathology() 82 | """ 83 | 84 | @property 85 | def _metrics_extra_args(self): 86 | return dict(num_classes=self.num_classes, average="weighted") 87 | 88 | def forward(self, x: Tensor) -> Tensor: 89 | return torch.sigmoid(self.model(x)) 90 | 91 | def compute_loss(self, y_hat: Tensor, y: Tensor): 92 | return F.binary_cross_entropy_with_logits(y_hat, y.to(y_hat.dtype)) 93 | -------------------------------------------------------------------------------- /notebooks/Cassava-Leaf-with-Flash.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "q3WKUi9scrf-" 7 | }, 8 | "source": [ 9 | "# Kaggle: [Cassava Leaf Disease Classification](https://www.kaggle.com/c/cassava-leaf-disease-classification/overview)\n" 10 | ] 11 | }, 12 | { 13 | "cell_type": "markdown", 14 | "metadata": { 15 | "id": "7GMm-s-hf2S1" 16 | }, 17 | "source": [ 18 | "## Setup environment \n", 19 | "\n", 20 | "- connect the gDrive with dataset\n", 21 | "- extract data to local\n", 22 | "- install pytorch lightning" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": null, 28 | "metadata": { 29 | "colab": { 30 | "base_uri": "https://localhost:8080/" 31 | }, 32 | "id": "LvB1eeLVcxWx", 33 | "outputId": "d6c16fd3-3bd5-43fa-ea49-61e075e8c7cf" 34 | }, 35 | "outputs": [], 36 | "source": [ 37 | "from google.colab import drive\n", 38 | "\n", 39 | "# connect to my gDrive\n", 40 | "drive.mount(\"/content/gdrive\")" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": null, 46 | "metadata": { 47 | "id": "yHoSvpVfdKw8" 48 | }, 49 | "outputs": [], 50 | "source": [ 51 | "# copy the dataset to local drive\n", 52 | "! cp /content/gdrive/MyDrive/Data/cassava-leaf-disease-classification.zip ." 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": null, 58 | "metadata": { 59 | "colab": { 60 | "base_uri": "https://localhost:8080/" 61 | }, 62 | "id": "o3FgFFDjcLv3", 63 | "outputId": "c237031a-1b30-4f01-9c14-008b43f99f86" 64 | }, 65 | "outputs": [], 66 | "source": [ 67 | "# extract dataset to the drive\n", 68 | "! unzip -q cassava-leaf-disease-classification.zip\n", 69 | "! ls -l" 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": null, 75 | "metadata": { 76 | "colab": { 77 | "base_uri": "https://localhost:8080/" 78 | }, 79 | "id": "PtHzviVIeJso", 80 | "outputId": "eb35b978-370a-4223-b382-b704150b04d6" 81 | }, 82 | "outputs": [], 83 | "source": [ 84 | "! pip install \"pytorch-lightning==1.2.0rc0\" \"lightning-bolts==0.3.2rc1\" \"lightning-flash==0.2.2rc2\" \"torchtext==0.5\" -q\n", 85 | "\n", 86 | "# import os\n", 87 | "# os.kill(os.getpid(), 9)\n", 88 | "! pip list | grep torch" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": null, 94 | "metadata": { 95 | "colab": { 96 | "base_uri": "https://localhost:8080/" 97 | }, 98 | "id": "1RyRQneVMu2p", 99 | "outputId": "2d2429cf-f7b5-44ec-ffa9-9384da174e6c" 100 | }, 101 | "outputs": [], 102 | "source": [ 103 | "! nvidia-smi" 104 | ] 105 | }, 106 | { 107 | "cell_type": "markdown", 108 | "metadata": { 109 | "id": "7aYRczoogb0a" 110 | }, 111 | "source": [ 112 | "## Data exploration" 113 | ] 114 | }, 115 | { 116 | "cell_type": "code", 117 | "execution_count": null, 118 | "metadata": { 119 | "colab": { 120 | "base_uri": "https://localhost:8080/" 121 | }, 122 | "id": "P9UFmundgfh0", 123 | "outputId": "cdfd48a9-f323-448e-c011-5d6b9276768a" 124 | }, 125 | "outputs": [], 126 | "source": [ 127 | "%matplotlib inline\n", 128 | "\n", 129 | "import json\n", 130 | "from pprint import pprint\n", 131 | "\n", 132 | "import pandas as pd\n", 133 | "\n", 134 | "path_csv = \"/content/train.csv\"\n", 135 | "train_data = pd.read_csv(path_csv)\n", 136 | "print(train_data.head())\n", 137 | "\n", 138 | "label_mapping = json.load(open(\"/content/label_num_to_disease_map.json\"))\n", 139 | "label_mapping = {int(k): v for k, v in label_mapping.items()}\n", 140 | "pprint(label_mapping)" 141 | ] 142 | }, 143 | { 144 | "cell_type": "code", 145 | "execution_count": null, 146 | "metadata": { 147 | "colab": { 148 | "base_uri": "https://localhost:8080/", 149 | "height": 296 150 | }, 151 | "id": "ieqZuAoQgkHc", 152 | "outputId": "e5812c7c-0f67-49e5-a33f-4ea03bb49d4a" 153 | }, 154 | "outputs": [], 155 | "source": [ 156 | "import numpy as np\n", 157 | "import seaborn as sns\n", 158 | "\n", 159 | "lb_hist = dict(zip(range(10), np.bincount(train_data[\"label\"])))\n", 160 | "pprint(lb_hist)\n", 161 | "\n", 162 | "ax = sns.countplot(y=train_data[\"label\"].map(label_mapping), orient=\"v\")\n", 163 | "ax.grid()" 164 | ] 165 | }, 166 | { 167 | "cell_type": "code", 168 | "execution_count": null, 169 | "metadata": { 170 | "colab": { 171 | "base_uri": "https://localhost:8080/", 172 | "height": 706 173 | }, 174 | "id": "rUliH6KijHa9", 175 | "outputId": "4f639d77-7a97-4aa3-f531-44807bf3c7a7" 176 | }, 177 | "outputs": [], 178 | "source": [ 179 | "import matplotlib.pyplot as plt\n", 180 | "\n", 181 | "fig, axarr = plt.subplots(nrows=4, ncols=5, figsize=(16, 10))\n", 182 | "for lb, df_ in train_data.groupby(\"label\"):\n", 183 | " img_names = list(df_[\"image_id\"])\n", 184 | " for i in range(4):\n", 185 | " img_name = img_names[i]\n", 186 | " img = plt.imread(f\"/content/train_images/{img_name}\")\n", 187 | " axarr[i, lb].imshow(img)\n", 188 | " axarr[i, lb].set_title(f\"label: {lb} & image: {img_name}\")\n", 189 | " axarr[i, lb].set_xticks([])\n", 190 | " axarr[i, lb].set_yticks([])\n", 191 | "fig.tight_layout()" 192 | ] 193 | }, 194 | { 195 | "cell_type": "markdown", 196 | "metadata": { 197 | "id": "R9D4KE1Z2q9i" 198 | }, 199 | "source": [ 200 | "## Dataset adjustment" 201 | ] 202 | }, 203 | { 204 | "cell_type": "code", 205 | "execution_count": null, 206 | "metadata": { 207 | "colab": { 208 | "base_uri": "https://localhost:8080/" 209 | }, 210 | "id": "Odif1NWVmkeX", 211 | "outputId": "662c190a-7ea8-4f84-d617-2b9bd68dfcfe" 212 | }, 213 | "outputs": [], 214 | "source": [ 215 | "import os\n", 216 | "import shutil\n", 217 | "\n", 218 | "import pandas as pd\n", 219 | "import tqdm\n", 220 | "\n", 221 | "path_csv = \"/content/train.csv\"\n", 222 | "data = pd.read_csv(path_csv)\n", 223 | "# shuffle data\n", 224 | "data = data.sample(frac=1, random_state=42).reset_index(drop=True)\n", 225 | "\n", 226 | "frac = int(0.8 * len(data))\n", 227 | "train = data[:frac]\n", 228 | "valid = data[frac:]\n", 229 | "\n", 230 | "# crating train and valid folder\n", 231 | "for folder, df in [(\"train\", train), (\"valid\", valid)]:\n", 232 | " folder = os.path.join(\"/content/dataset\", folder)\n", 233 | " os.makedirs(folder, exist_ok=True)\n", 234 | " # triage images per class / label\n", 235 | " for _, row in tqdm.tqdm(df.iterrows()):\n", 236 | " img_name, lb = row[\"image_id\"], row[\"label\"]\n", 237 | " folder_lb = os.path.join(folder, str(lb))\n", 238 | " # create folder for label if it is missing\n", 239 | " if not os.path.isdir(folder_lb):\n", 240 | " os.mkdir(folder_lb)\n", 241 | " shutil.copy(os.path.join(\"/content/train_images\", img_name), os.path.join(folder_lb, img_name))\n", 242 | "\n", 243 | "! ls -l /content/dataset/train\n", 244 | "! ls -l /content/dataset/valid" 245 | ] 246 | }, 247 | { 248 | "cell_type": "markdown", 249 | "metadata": { 250 | "id": "jklJWxh1wiFn" 251 | }, 252 | "source": [ 253 | "## Flash finetuning" 254 | ] 255 | }, 256 | { 257 | "cell_type": "code", 258 | "execution_count": null, 259 | "metadata": { 260 | "colab": { 261 | "base_uri": "https://localhost:8080/" 262 | }, 263 | "id": "s5HFEi7A7wh4", 264 | "outputId": "d6800563-c29d-4fb8-e08b-a7dea7afa930" 265 | }, 266 | "outputs": [], 267 | "source": [ 268 | "import multiprocessing as mproc\n", 269 | "\n", 270 | "import flash\n", 271 | "import torch\n", 272 | "from flash.core.data import download_data\n", 273 | "from flash.core.finetuning import FreezeUnfreeze\n", 274 | "from flash.vision import ImageClassificationData, ImageClassifier\n", 275 | "\n", 276 | "# 2. Load the data\n", 277 | "datamodule = ImageClassificationData.from_folders(\n", 278 | " train_folder=\"/content/dataset/train/\",\n", 279 | " valid_folder=\"/content/dataset/valid/\",\n", 280 | " batch_size=128,\n", 281 | " num_workers=mproc.cpu_count(),\n", 282 | ")\n", 283 | "\n", 284 | "# 3. Build the model\n", 285 | "model = ImageClassifier(\n", 286 | " backbone=\"resnet34\",\n", 287 | " optimizer=torch.optim.Adam,\n", 288 | " num_classes=datamodule.num_classes,\n", 289 | ")" 290 | ] 291 | }, 292 | { 293 | "cell_type": "code", 294 | "execution_count": null, 295 | "metadata": {}, 296 | "outputs": [], 297 | "source": [ 298 | "# 4. Create the trainer. Run twice on data\n", 299 | "trainer = flash.Trainer(\n", 300 | " gpus=1,\n", 301 | " max_epochs=3,\n", 302 | " precision=16,\n", 303 | " val_check_interval=0.5,\n", 304 | " progress_bar_refresh_rate=1,\n", 305 | ")\n", 306 | "\n", 307 | "# 5. Train the model\n", 308 | "trainer.finetune(model, datamodule=datamodule, strategy=FreezeUnfreeze(unfreeze_epoch=1))\n", 309 | "\n", 310 | "# 7. Save it!\n", 311 | "trainer.save_checkpoint(\"image_classification_model.pt\")" 312 | ] 313 | }, 314 | { 315 | "cell_type": "code", 316 | "execution_count": null, 317 | "metadata": {}, 318 | "outputs": [], 319 | "source": [ 320 | "# Start tensorboard.\n", 321 | "%load_ext tensorboard\n", 322 | "%tensorboard --logdir lightning_logs/" 323 | ] 324 | } 325 | ], 326 | "metadata": { 327 | "accelerator": "GPU", 328 | "colab": { 329 | "collapsed_sections": [], 330 | "name": "Kaggle: Cassava with Flash.ipynb", 331 | "provenance": [] 332 | }, 333 | "kernelspec": { 334 | "display_name": "Python 3", 335 | "language": "python", 336 | "name": "python3" 337 | }, 338 | "language_info": { 339 | "codemirror_mode": { 340 | "name": "ipython", 341 | "version": 3 342 | }, 343 | "file_extension": ".py", 344 | "mimetype": "text/x-python", 345 | "name": "python", 346 | "nbconvert_exporter": "python", 347 | "pygments_lexer": "ipython3", 348 | "version": "3.7.3" 349 | }, 350 | "widgets": { 351 | "application/vnd.jupyter.widget-state+json": { 352 | "05daa3a8a4494329ac5a21a31dc556c0": { 353 | "model_module": "@jupyter-widgets/controls", 354 | "model_name": "FloatProgressModel", 355 | "state": { 356 | "_dom_classes": [], 357 | "_model_module": "@jupyter-widgets/controls", 358 | "_model_module_version": "1.5.0", 359 | "_model_name": "FloatProgressModel", 360 | "_view_count": null, 361 | "_view_module": "@jupyter-widgets/controls", 362 | "_view_module_version": "1.5.0", 363 | "_view_name": "ProgressView", 364 | "bar_style": "info", 365 | "description": "Validating: 100%", 366 | "description_tooltip": null, 367 | "layout": "IPY_MODEL_ce299120c5b042dc8d7989dc6e83998c", 368 | "max": 1, 369 | "min": 0, 370 | "orientation": "horizontal", 371 | "style": "IPY_MODEL_b6f95b26648a4d6d88e6afefb6ca9abd", 372 | "value": 1 373 | } 374 | }, 375 | "13c5862559f54acea759078b34f48c4a": { 376 | "model_module": "@jupyter-widgets/controls", 377 | "model_name": "FloatProgressModel", 378 | "state": { 379 | "_dom_classes": [], 380 | "_model_module": "@jupyter-widgets/controls", 381 | "_model_module_version": "1.5.0", 382 | "_model_name": "FloatProgressModel", 383 | "_view_count": null, 384 | "_view_module": "@jupyter-widgets/controls", 385 | "_view_module_version": "1.5.0", 386 | "_view_name": "ProgressView", 387 | "bar_style": "info", 388 | "description": "Validating: 100%", 389 | "description_tooltip": null, 390 | "layout": "IPY_MODEL_c60b79d3b9a9437a84ffbd7761802972", 391 | "max": 1, 392 | "min": 0, 393 | "orientation": "horizontal", 394 | "style": "IPY_MODEL_6ec7cec91bae4ed9bb7cfad9d724f9eb", 395 | "value": 1 396 | } 397 | }, 398 | "1b5b6a2845a943ec95fff5c7916a3b00": { 399 | "model_module": "@jupyter-widgets/controls", 400 | "model_name": "DescriptionStyleModel", 401 | "state": { 402 | "_model_module": "@jupyter-widgets/controls", 403 | "_model_module_version": "1.5.0", 404 | "_model_name": "DescriptionStyleModel", 405 | "_view_count": null, 406 | "_view_module": "@jupyter-widgets/base", 407 | "_view_module_version": "1.2.0", 408 | "_view_name": "StyleView", 409 | "description_width": "" 410 | } 411 | }, 412 | "1d162408abec4ec0abba2491c141cb77": { 413 | "model_module": "@jupyter-widgets/controls", 414 | "model_name": "HTMLModel", 415 | "state": { 416 | "_dom_classes": [], 417 | "_model_module": "@jupyter-widgets/controls", 418 | "_model_module_version": "1.5.0", 419 | "_model_name": "HTMLModel", 420 | "_view_count": null, 421 | "_view_module": "@jupyter-widgets/controls", 422 | "_view_module_version": "1.5.0", 423 | "_view_name": "HTMLView", 424 | "description": "", 425 | "description_tooltip": null, 426 | "layout": "IPY_MODEL_a083d62410f34d3995b9d0ab36f09c47", 427 | "placeholder": "​", 428 | "style": "IPY_MODEL_6d14aa01bae045758efeef9adeaeda75", 429 | "value": " 34/34 [00:43<00:00, 1.19s/it]" 430 | } 431 | }, 432 | "221689944bf445beaed9aabab4fbefbd": { 433 | "model_module": "@jupyter-widgets/controls", 434 | "model_name": "DescriptionStyleModel", 435 | "state": { 436 | "_model_module": "@jupyter-widgets/controls", 437 | "_model_module_version": "1.5.0", 438 | "_model_name": "DescriptionStyleModel", 439 | "_view_count": null, 440 | "_view_module": "@jupyter-widgets/base", 441 | "_view_module_version": "1.2.0", 442 | "_view_name": "StyleView", 443 | "description_width": "" 444 | } 445 | }, 446 | "2355d488135b4618bd399b1a5a5684bc": { 447 | "model_module": "@jupyter-widgets/base", 448 | "model_name": "LayoutModel", 449 | "state": { 450 | "_model_module": "@jupyter-widgets/base", 451 | "_model_module_version": "1.2.0", 452 | "_model_name": "LayoutModel", 453 | "_view_count": null, 454 | "_view_module": "@jupyter-widgets/base", 455 | "_view_module_version": "1.2.0", 456 | "_view_name": "LayoutView", 457 | "align_content": null, 458 | "align_items": null, 459 | "align_self": null, 460 | "border": null, 461 | "bottom": null, 462 | "display": null, 463 | "flex": "2", 464 | "flex_flow": null, 465 | "grid_area": null, 466 | "grid_auto_columns": null, 467 | "grid_auto_flow": null, 468 | "grid_auto_rows": null, 469 | "grid_column": null, 470 | "grid_gap": null, 471 | "grid_row": null, 472 | "grid_template_areas": null, 473 | "grid_template_columns": null, 474 | "grid_template_rows": null, 475 | "height": null, 476 | "justify_content": null, 477 | "justify_items": null, 478 | "left": null, 479 | "margin": null, 480 | "max_height": null, 481 | "max_width": null, 482 | "min_height": null, 483 | "min_width": null, 484 | "object_fit": null, 485 | "object_position": null, 486 | "order": null, 487 | "overflow": null, 488 | "overflow_x": null, 489 | "overflow_y": null, 490 | "padding": null, 491 | "right": null, 492 | "top": null, 493 | "visibility": null, 494 | "width": null 495 | } 496 | }, 497 | "271df06f7b0541f79de174e3ae0b770a": { 498 | "model_module": "@jupyter-widgets/base", 499 | "model_name": "LayoutModel", 500 | "state": { 501 | "_model_module": "@jupyter-widgets/base", 502 | "_model_module_version": "1.2.0", 503 | "_model_name": "LayoutModel", 504 | "_view_count": null, 505 | "_view_module": "@jupyter-widgets/base", 506 | "_view_module_version": "1.2.0", 507 | "_view_name": "LayoutView", 508 | "align_content": null, 509 | "align_items": null, 510 | "align_self": null, 511 | "border": null, 512 | "bottom": null, 513 | "display": null, 514 | "flex": null, 515 | "flex_flow": null, 516 | "grid_area": null, 517 | "grid_auto_columns": null, 518 | "grid_auto_flow": null, 519 | "grid_auto_rows": null, 520 | "grid_column": null, 521 | "grid_gap": null, 522 | "grid_row": null, 523 | "grid_template_areas": null, 524 | "grid_template_columns": null, 525 | "grid_template_rows": null, 526 | "height": null, 527 | "justify_content": null, 528 | "justify_items": null, 529 | "left": null, 530 | "margin": null, 531 | "max_height": null, 532 | "max_width": null, 533 | "min_height": null, 534 | "min_width": null, 535 | "object_fit": null, 536 | "object_position": null, 537 | "order": null, 538 | "overflow": null, 539 | "overflow_x": null, 540 | "overflow_y": null, 541 | "padding": null, 542 | "right": null, 543 | "top": null, 544 | "visibility": null, 545 | "width": null 546 | } 547 | }, 548 | "308a38ce67584a78a01b7064f2afc5f8": { 549 | "model_module": "@jupyter-widgets/base", 550 | "model_name": "LayoutModel", 551 | "state": { 552 | "_model_module": "@jupyter-widgets/base", 553 | "_model_module_version": "1.2.0", 554 | "_model_name": "LayoutModel", 555 | "_view_count": null, 556 | "_view_module": "@jupyter-widgets/base", 557 | "_view_module_version": "1.2.0", 558 | "_view_name": "LayoutView", 559 | "align_content": null, 560 | "align_items": null, 561 | "align_self": null, 562 | "border": null, 563 | "bottom": null, 564 | "display": null, 565 | "flex": null, 566 | "flex_flow": null, 567 | "grid_area": null, 568 | "grid_auto_columns": null, 569 | "grid_auto_flow": null, 570 | "grid_auto_rows": null, 571 | "grid_column": null, 572 | "grid_gap": null, 573 | "grid_row": null, 574 | "grid_template_areas": null, 575 | "grid_template_columns": null, 576 | "grid_template_rows": null, 577 | "height": null, 578 | "justify_content": null, 579 | "justify_items": null, 580 | "left": null, 581 | "margin": null, 582 | "max_height": null, 583 | "max_width": null, 584 | "min_height": null, 585 | "min_width": null, 586 | "object_fit": null, 587 | "object_position": null, 588 | "order": null, 589 | "overflow": null, 590 | "overflow_x": null, 591 | "overflow_y": null, 592 | "padding": null, 593 | "right": null, 594 | "top": null, 595 | "visibility": null, 596 | "width": null 597 | } 598 | }, 599 | "3612a8fc4dbf4e30852b1121180958a6": { 600 | "model_module": "@jupyter-widgets/controls", 601 | "model_name": "FloatProgressModel", 602 | "state": { 603 | "_dom_classes": [], 604 | "_model_module": "@jupyter-widgets/controls", 605 | "_model_module_version": "1.5.0", 606 | "_model_name": "FloatProgressModel", 607 | "_view_count": null, 608 | "_view_module": "@jupyter-widgets/controls", 609 | "_view_module_version": "1.5.0", 610 | "_view_name": "ProgressView", 611 | "bar_style": "success", 612 | "description": "Epoch 2: 100%", 613 | "description_tooltip": null, 614 | "layout": "IPY_MODEL_b8df3931eeb34b19b079cb5d10e9076e", 615 | "max": 201, 616 | "min": 0, 617 | "orientation": "horizontal", 618 | "style": "IPY_MODEL_5953bf14438a46d5ae9d306a1d4f678b", 619 | "value": 201 620 | } 621 | }, 622 | "38b6f596d2d243a09bfd96341f08670a": { 623 | "model_module": "@jupyter-widgets/base", 624 | "model_name": "LayoutModel", 625 | "state": { 626 | "_model_module": "@jupyter-widgets/base", 627 | "_model_module_version": "1.2.0", 628 | "_model_name": "LayoutModel", 629 | "_view_count": null, 630 | "_view_module": "@jupyter-widgets/base", 631 | "_view_module_version": "1.2.0", 632 | "_view_name": "LayoutView", 633 | "align_content": null, 634 | "align_items": null, 635 | "align_self": null, 636 | "border": null, 637 | "bottom": null, 638 | "display": null, 639 | "flex": null, 640 | "flex_flow": null, 641 | "grid_area": null, 642 | "grid_auto_columns": null, 643 | "grid_auto_flow": null, 644 | "grid_auto_rows": null, 645 | "grid_column": null, 646 | "grid_gap": null, 647 | "grid_row": null, 648 | "grid_template_areas": null, 649 | "grid_template_columns": null, 650 | "grid_template_rows": null, 651 | "height": null, 652 | "justify_content": null, 653 | "justify_items": null, 654 | "left": null, 655 | "margin": null, 656 | "max_height": null, 657 | "max_width": null, 658 | "min_height": null, 659 | "min_width": null, 660 | "object_fit": null, 661 | "object_position": null, 662 | "order": null, 663 | "overflow": null, 664 | "overflow_x": null, 665 | "overflow_y": null, 666 | "padding": null, 667 | "right": null, 668 | "top": null, 669 | "visibility": null, 670 | "width": null 671 | } 672 | }, 673 | "3c9430d77fe54280b2c5389157fd0527": { 674 | "model_module": "@jupyter-widgets/controls", 675 | "model_name": "HTMLModel", 676 | "state": { 677 | "_dom_classes": [], 678 | "_model_module": "@jupyter-widgets/controls", 679 | "_model_module_version": "1.5.0", 680 | "_model_name": "HTMLModel", 681 | "_view_count": null, 682 | "_view_module": "@jupyter-widgets/controls", 683 | "_view_module_version": "1.5.0", 684 | "_view_name": "HTMLView", 685 | "description": "", 686 | "description_tooltip": null, 687 | "layout": "IPY_MODEL_e088f1f357b448059cbfb2fd081c124d", 688 | "placeholder": "​", 689 | "style": "IPY_MODEL_ea515854fbd444ad9c10faeaad0db444", 690 | "value": " 201/201 [04:15<00:00, 1.27s/it, loss=1.11, v_num=1, val_accuracy=0.785, val_cross_entropy=1.12, train_accuracy_step=0.844, train_cross_entropy_step=1.07, train_accuracy_epoch=0.723, train_cross_entropy_epoch=1.18]" 691 | } 692 | }, 693 | "3d3896fd7fa848e8afd940c1adfab572": { 694 | "model_module": "@jupyter-widgets/base", 695 | "model_name": "LayoutModel", 696 | "state": { 697 | "_model_module": "@jupyter-widgets/base", 698 | "_model_module_version": "1.2.0", 699 | "_model_name": "LayoutModel", 700 | "_view_count": null, 701 | "_view_module": "@jupyter-widgets/base", 702 | "_view_module_version": "1.2.0", 703 | "_view_name": "LayoutView", 704 | "align_content": null, 705 | "align_items": null, 706 | "align_self": null, 707 | "border": null, 708 | "bottom": null, 709 | "display": "inline-flex", 710 | "flex": null, 711 | "flex_flow": "row wrap", 712 | "grid_area": null, 713 | "grid_auto_columns": null, 714 | "grid_auto_flow": null, 715 | "grid_auto_rows": null, 716 | "grid_column": null, 717 | "grid_gap": null, 718 | "grid_row": null, 719 | "grid_template_areas": null, 720 | "grid_template_columns": null, 721 | "grid_template_rows": null, 722 | "height": null, 723 | "justify_content": null, 724 | "justify_items": null, 725 | "left": null, 726 | "margin": null, 727 | "max_height": null, 728 | "max_width": null, 729 | "min_height": null, 730 | "min_width": null, 731 | "object_fit": null, 732 | "object_position": null, 733 | "order": null, 734 | "overflow": null, 735 | "overflow_x": null, 736 | "overflow_y": null, 737 | "padding": null, 738 | "right": null, 739 | "top": null, 740 | "visibility": null, 741 | "width": "100%" 742 | } 743 | }, 744 | "430a1bf095bc4ba9aae67e08e007fec8": { 745 | "model_module": "@jupyter-widgets/controls", 746 | "model_name": "DescriptionStyleModel", 747 | "state": { 748 | "_model_module": "@jupyter-widgets/controls", 749 | "_model_module_version": "1.5.0", 750 | "_model_name": "DescriptionStyleModel", 751 | "_view_count": null, 752 | "_view_module": "@jupyter-widgets/base", 753 | "_view_module_version": "1.2.0", 754 | "_view_name": "StyleView", 755 | "description_width": "" 756 | } 757 | }, 758 | "453d1a6d694e4e16aecf667138617c8d": { 759 | "model_module": "@jupyter-widgets/controls", 760 | "model_name": "DescriptionStyleModel", 761 | "state": { 762 | "_model_module": "@jupyter-widgets/controls", 763 | "_model_module_version": "1.5.0", 764 | "_model_name": "DescriptionStyleModel", 765 | "_view_count": null, 766 | "_view_module": "@jupyter-widgets/base", 767 | "_view_module_version": "1.2.0", 768 | "_view_name": "StyleView", 769 | "description_width": "" 770 | } 771 | }, 772 | "500225961bb14c25995ccdc159a9da31": { 773 | "model_module": "@jupyter-widgets/controls", 774 | "model_name": "HBoxModel", 775 | "state": { 776 | "_dom_classes": [], 777 | "_model_module": "@jupyter-widgets/controls", 778 | "_model_module_version": "1.5.0", 779 | "_model_name": "HBoxModel", 780 | "_view_count": null, 781 | "_view_module": "@jupyter-widgets/controls", 782 | "_view_module_version": "1.5.0", 783 | "_view_name": "HBoxView", 784 | "box_style": "", 785 | "children": [ 786 | "IPY_MODEL_84363d70ad5b46a2b5057d303ae959a6", 787 | "IPY_MODEL_ca047f0d00424b9c92a7575b73760f63" 788 | ], 789 | "layout": "IPY_MODEL_f374fd5878c247338ada58c76ea6a4da" 790 | } 791 | }, 792 | "5953bf14438a46d5ae9d306a1d4f678b": { 793 | "model_module": "@jupyter-widgets/controls", 794 | "model_name": "ProgressStyleModel", 795 | "state": { 796 | "_model_module": "@jupyter-widgets/controls", 797 | "_model_module_version": "1.5.0", 798 | "_model_name": "ProgressStyleModel", 799 | "_view_count": null, 800 | "_view_module": "@jupyter-widgets/base", 801 | "_view_module_version": "1.2.0", 802 | "_view_name": "StyleView", 803 | "bar_color": null, 804 | "description_width": "initial" 805 | } 806 | }, 807 | "5f0a0965e6004d03a28116f991c9de99": { 808 | "model_module": "@jupyter-widgets/controls", 809 | "model_name": "FloatProgressModel", 810 | "state": { 811 | "_dom_classes": [], 812 | "_model_module": "@jupyter-widgets/controls", 813 | "_model_module_version": "1.5.0", 814 | "_model_name": "FloatProgressModel", 815 | "_view_count": null, 816 | "_view_module": "@jupyter-widgets/controls", 817 | "_view_module_version": "1.5.0", 818 | "_view_name": "ProgressView", 819 | "bar_style": "info", 820 | "description": "Validation sanity check: 100%", 821 | "description_tooltip": null, 822 | "layout": "IPY_MODEL_2355d488135b4618bd399b1a5a5684bc", 823 | "max": 1, 824 | "min": 0, 825 | "orientation": "horizontal", 826 | "style": "IPY_MODEL_7b9da95737e24056b473bf655572922c", 827 | "value": 1 828 | } 829 | }, 830 | "63757390ba8f499484f441dafadd11f9": { 831 | "model_module": "@jupyter-widgets/controls", 832 | "model_name": "ProgressStyleModel", 833 | "state": { 834 | "_model_module": "@jupyter-widgets/controls", 835 | "_model_module_version": "1.5.0", 836 | "_model_name": "ProgressStyleModel", 837 | "_view_count": null, 838 | "_view_module": "@jupyter-widgets/base", 839 | "_view_module_version": "1.2.0", 840 | "_view_name": "StyleView", 841 | "bar_color": null, 842 | "description_width": "initial" 843 | } 844 | }, 845 | "6d14aa01bae045758efeef9adeaeda75": { 846 | "model_module": "@jupyter-widgets/controls", 847 | "model_name": "DescriptionStyleModel", 848 | "state": { 849 | "_model_module": "@jupyter-widgets/controls", 850 | "_model_module_version": "1.5.0", 851 | "_model_name": "DescriptionStyleModel", 852 | "_view_count": null, 853 | "_view_module": "@jupyter-widgets/base", 854 | "_view_module_version": "1.2.0", 855 | "_view_name": "StyleView", 856 | "description_width": "" 857 | } 858 | }, 859 | "6eb3cc718cd44ee4bbc1a9447bcbeb27": { 860 | "model_module": "@jupyter-widgets/controls", 861 | "model_name": "FloatProgressModel", 862 | "state": { 863 | "_dom_classes": [], 864 | "_model_module": "@jupyter-widgets/controls", 865 | "_model_module_version": "1.5.0", 866 | "_model_name": "FloatProgressModel", 867 | "_view_count": null, 868 | "_view_module": "@jupyter-widgets/controls", 869 | "_view_module_version": "1.5.0", 870 | "_view_name": "ProgressView", 871 | "bar_style": "info", 872 | "description": "Validating: 100%", 873 | "description_tooltip": null, 874 | "layout": "IPY_MODEL_e3d1641fd1bb45c2b91a598f9875ddf5", 875 | "max": 1, 876 | "min": 0, 877 | "orientation": "horizontal", 878 | "style": "IPY_MODEL_fcdc7f0094274ddf93776dc81726318d", 879 | "value": 1 880 | } 881 | }, 882 | "6ec7cec91bae4ed9bb7cfad9d724f9eb": { 883 | "model_module": "@jupyter-widgets/controls", 884 | "model_name": "ProgressStyleModel", 885 | "state": { 886 | "_model_module": "@jupyter-widgets/controls", 887 | "_model_module_version": "1.5.0", 888 | "_model_name": "ProgressStyleModel", 889 | "_view_count": null, 890 | "_view_module": "@jupyter-widgets/base", 891 | "_view_module_version": "1.2.0", 892 | "_view_name": "StyleView", 893 | "bar_color": null, 894 | "description_width": "initial" 895 | } 896 | }, 897 | "77c71962f28c4f1f940a7c21c940fdaa": { 898 | "model_module": "@jupyter-widgets/base", 899 | "model_name": "LayoutModel", 900 | "state": { 901 | "_model_module": "@jupyter-widgets/base", 902 | "_model_module_version": "1.2.0", 903 | "_model_name": "LayoutModel", 904 | "_view_count": null, 905 | "_view_module": "@jupyter-widgets/base", 906 | "_view_module_version": "1.2.0", 907 | "_view_name": "LayoutView", 908 | "align_content": null, 909 | "align_items": null, 910 | "align_self": null, 911 | "border": null, 912 | "bottom": null, 913 | "display": null, 914 | "flex": "2", 915 | "flex_flow": null, 916 | "grid_area": null, 917 | "grid_auto_columns": null, 918 | "grid_auto_flow": null, 919 | "grid_auto_rows": null, 920 | "grid_column": null, 921 | "grid_gap": null, 922 | "grid_row": null, 923 | "grid_template_areas": null, 924 | "grid_template_columns": null, 925 | "grid_template_rows": null, 926 | "height": null, 927 | "justify_content": null, 928 | "justify_items": null, 929 | "left": null, 930 | "margin": null, 931 | "max_height": null, 932 | "max_width": null, 933 | "min_height": null, 934 | "min_width": null, 935 | "object_fit": null, 936 | "object_position": null, 937 | "order": null, 938 | "overflow": null, 939 | "overflow_x": null, 940 | "overflow_y": null, 941 | "padding": null, 942 | "right": null, 943 | "top": null, 944 | "visibility": null, 945 | "width": null 946 | } 947 | }, 948 | "782a51d3fbb0454ba98096c2ad965e69": { 949 | "model_module": "@jupyter-widgets/controls", 950 | "model_name": "HBoxModel", 951 | "state": { 952 | "_dom_classes": [], 953 | "_model_module": "@jupyter-widgets/controls", 954 | "_model_module_version": "1.5.0", 955 | "_model_name": "HBoxModel", 956 | "_view_count": null, 957 | "_view_module": "@jupyter-widgets/controls", 958 | "_view_module_version": "1.5.0", 959 | "_view_name": "HBoxView", 960 | "box_style": "", 961 | "children": [ 962 | "IPY_MODEL_3612a8fc4dbf4e30852b1121180958a6", 963 | "IPY_MODEL_3c9430d77fe54280b2c5389157fd0527" 964 | ], 965 | "layout": "IPY_MODEL_eabe50a8db2745cda833c6bbc3f6428b" 966 | } 967 | }, 968 | "7a23774d08984ee98daa4850d445fb70": { 969 | "model_module": "@jupyter-widgets/controls", 970 | "model_name": "HTMLModel", 971 | "state": { 972 | "_dom_classes": [], 973 | "_model_module": "@jupyter-widgets/controls", 974 | "_model_module_version": "1.5.0", 975 | "_model_name": "HTMLModel", 976 | "_view_count": null, 977 | "_view_module": "@jupyter-widgets/controls", 978 | "_view_module_version": "1.5.0", 979 | "_view_name": "HTMLView", 980 | "description": "", 981 | "description_tooltip": null, 982 | "layout": "IPY_MODEL_308a38ce67584a78a01b7064f2afc5f8", 983 | "placeholder": "​", 984 | "style": "IPY_MODEL_d6402be0c4a54ac7b92d1af1b72010cf", 985 | "value": " 34/34 [00:48<00:00, 1.16s/it]" 986 | } 987 | }, 988 | "7b9da95737e24056b473bf655572922c": { 989 | "model_module": "@jupyter-widgets/controls", 990 | "model_name": "ProgressStyleModel", 991 | "state": { 992 | "_model_module": "@jupyter-widgets/controls", 993 | "_model_module_version": "1.5.0", 994 | "_model_name": "ProgressStyleModel", 995 | "_view_count": null, 996 | "_view_module": "@jupyter-widgets/base", 997 | "_view_module_version": "1.2.0", 998 | "_view_name": "StyleView", 999 | "bar_color": null, 1000 | "description_width": "initial" 1001 | } 1002 | }, 1003 | "837af53c23504230be8046d0e2b72010": { 1004 | "model_module": "@jupyter-widgets/controls", 1005 | "model_name": "FloatProgressModel", 1006 | "state": { 1007 | "_dom_classes": [], 1008 | "_model_module": "@jupyter-widgets/controls", 1009 | "_model_module_version": "1.5.0", 1010 | "_model_name": "FloatProgressModel", 1011 | "_view_count": null, 1012 | "_view_module": "@jupyter-widgets/controls", 1013 | "_view_module_version": "1.5.0", 1014 | "_view_name": "ProgressView", 1015 | "bar_style": "info", 1016 | "description": "Validating: 100%", 1017 | "description_tooltip": null, 1018 | "layout": "IPY_MODEL_df81ecfbacdb4cc194cdfeb913c58afd", 1019 | "max": 1, 1020 | "min": 0, 1021 | "orientation": "horizontal", 1022 | "style": "IPY_MODEL_a11abe00213d4c68805f56c47e79a640", 1023 | "value": 1 1024 | } 1025 | }, 1026 | "84363d70ad5b46a2b5057d303ae959a6": { 1027 | "model_module": "@jupyter-widgets/controls", 1028 | "model_name": "FloatProgressModel", 1029 | "state": { 1030 | "_dom_classes": [], 1031 | "_model_module": "@jupyter-widgets/controls", 1032 | "_model_module_version": "1.5.0", 1033 | "_model_name": "FloatProgressModel", 1034 | "_view_count": null, 1035 | "_view_module": "@jupyter-widgets/controls", 1036 | "_view_module_version": "1.5.0", 1037 | "_view_name": "ProgressView", 1038 | "bar_style": "info", 1039 | "description": "Validating: 100%", 1040 | "description_tooltip": null, 1041 | "layout": "IPY_MODEL_77c71962f28c4f1f940a7c21c940fdaa", 1042 | "max": 1, 1043 | "min": 0, 1044 | "orientation": "horizontal", 1045 | "style": "IPY_MODEL_c98945a2d2e947ccb3659f61875d7a2e", 1046 | "value": 1 1047 | } 1048 | }, 1049 | "9082e599bc044e6190b2bf236ce680a3": { 1050 | "model_module": "@jupyter-widgets/controls", 1051 | "model_name": "FloatProgressModel", 1052 | "state": { 1053 | "_dom_classes": [], 1054 | "_model_module": "@jupyter-widgets/controls", 1055 | "_model_module_version": "1.5.0", 1056 | "_model_name": "FloatProgressModel", 1057 | "_view_count": null, 1058 | "_view_module": "@jupyter-widgets/controls", 1059 | "_view_module_version": "1.5.0", 1060 | "_view_name": "ProgressView", 1061 | "bar_style": "info", 1062 | "description": "Validating: 100%", 1063 | "description_tooltip": null, 1064 | "layout": "IPY_MODEL_93fdf41486ab4eef9d3da5c99a695ecb", 1065 | "max": 1, 1066 | "min": 0, 1067 | "orientation": "horizontal", 1068 | "style": "IPY_MODEL_63757390ba8f499484f441dafadd11f9", 1069 | "value": 1 1070 | } 1071 | }, 1072 | "93fdf41486ab4eef9d3da5c99a695ecb": { 1073 | "model_module": "@jupyter-widgets/base", 1074 | "model_name": "LayoutModel", 1075 | "state": { 1076 | "_model_module": "@jupyter-widgets/base", 1077 | "_model_module_version": "1.2.0", 1078 | "_model_name": "LayoutModel", 1079 | "_view_count": null, 1080 | "_view_module": "@jupyter-widgets/base", 1081 | "_view_module_version": "1.2.0", 1082 | "_view_name": "LayoutView", 1083 | "align_content": null, 1084 | "align_items": null, 1085 | "align_self": null, 1086 | "border": null, 1087 | "bottom": null, 1088 | "display": null, 1089 | "flex": "2", 1090 | "flex_flow": null, 1091 | "grid_area": null, 1092 | "grid_auto_columns": null, 1093 | "grid_auto_flow": null, 1094 | "grid_auto_rows": null, 1095 | "grid_column": null, 1096 | "grid_gap": null, 1097 | "grid_row": null, 1098 | "grid_template_areas": null, 1099 | "grid_template_columns": null, 1100 | "grid_template_rows": null, 1101 | "height": null, 1102 | "justify_content": null, 1103 | "justify_items": null, 1104 | "left": null, 1105 | "margin": null, 1106 | "max_height": null, 1107 | "max_width": null, 1108 | "min_height": null, 1109 | "min_width": null, 1110 | "object_fit": null, 1111 | "object_position": null, 1112 | "order": null, 1113 | "overflow": null, 1114 | "overflow_x": null, 1115 | "overflow_y": null, 1116 | "padding": null, 1117 | "right": null, 1118 | "top": null, 1119 | "visibility": null, 1120 | "width": null 1121 | } 1122 | }, 1123 | "96cb795d7c1049c5a8ccf8da49a0a78f": { 1124 | "model_module": "@jupyter-widgets/controls", 1125 | "model_name": "HBoxModel", 1126 | "state": { 1127 | "_dom_classes": [], 1128 | "_model_module": "@jupyter-widgets/controls", 1129 | "_model_module_version": "1.5.0", 1130 | "_model_name": "HBoxModel", 1131 | "_view_count": null, 1132 | "_view_module": "@jupyter-widgets/controls", 1133 | "_view_module_version": "1.5.0", 1134 | "_view_name": "HBoxView", 1135 | "box_style": "", 1136 | "children": [ 1137 | "IPY_MODEL_837af53c23504230be8046d0e2b72010", 1138 | "IPY_MODEL_7a23774d08984ee98daa4850d445fb70" 1139 | ], 1140 | "layout": "IPY_MODEL_d0a02bde6c1248a684be20dfcb6427b9" 1141 | } 1142 | }, 1143 | "9e035b983cf245878e3dd701f8b79c3b": { 1144 | "model_module": "@jupyter-widgets/base", 1145 | "model_name": "LayoutModel", 1146 | "state": { 1147 | "_model_module": "@jupyter-widgets/base", 1148 | "_model_module_version": "1.2.0", 1149 | "_model_name": "LayoutModel", 1150 | "_view_count": null, 1151 | "_view_module": "@jupyter-widgets/base", 1152 | "_view_module_version": "1.2.0", 1153 | "_view_name": "LayoutView", 1154 | "align_content": null, 1155 | "align_items": null, 1156 | "align_self": null, 1157 | "border": null, 1158 | "bottom": null, 1159 | "display": null, 1160 | "flex": null, 1161 | "flex_flow": null, 1162 | "grid_area": null, 1163 | "grid_auto_columns": null, 1164 | "grid_auto_flow": null, 1165 | "grid_auto_rows": null, 1166 | "grid_column": null, 1167 | "grid_gap": null, 1168 | "grid_row": null, 1169 | "grid_template_areas": null, 1170 | "grid_template_columns": null, 1171 | "grid_template_rows": null, 1172 | "height": null, 1173 | "justify_content": null, 1174 | "justify_items": null, 1175 | "left": null, 1176 | "margin": null, 1177 | "max_height": null, 1178 | "max_width": null, 1179 | "min_height": null, 1180 | "min_width": null, 1181 | "object_fit": null, 1182 | "object_position": null, 1183 | "order": null, 1184 | "overflow": null, 1185 | "overflow_x": null, 1186 | "overflow_y": null, 1187 | "padding": null, 1188 | "right": null, 1189 | "top": null, 1190 | "visibility": null, 1191 | "width": null 1192 | } 1193 | }, 1194 | "a083d62410f34d3995b9d0ab36f09c47": { 1195 | "model_module": "@jupyter-widgets/base", 1196 | "model_name": "LayoutModel", 1197 | "state": { 1198 | "_model_module": "@jupyter-widgets/base", 1199 | "_model_module_version": "1.2.0", 1200 | "_model_name": "LayoutModel", 1201 | "_view_count": null, 1202 | "_view_module": "@jupyter-widgets/base", 1203 | "_view_module_version": "1.2.0", 1204 | "_view_name": "LayoutView", 1205 | "align_content": null, 1206 | "align_items": null, 1207 | "align_self": null, 1208 | "border": null, 1209 | "bottom": null, 1210 | "display": null, 1211 | "flex": null, 1212 | "flex_flow": null, 1213 | "grid_area": null, 1214 | "grid_auto_columns": null, 1215 | "grid_auto_flow": null, 1216 | "grid_auto_rows": null, 1217 | "grid_column": null, 1218 | "grid_gap": null, 1219 | "grid_row": null, 1220 | "grid_template_areas": null, 1221 | "grid_template_columns": null, 1222 | "grid_template_rows": null, 1223 | "height": null, 1224 | "justify_content": null, 1225 | "justify_items": null, 1226 | "left": null, 1227 | "margin": null, 1228 | "max_height": null, 1229 | "max_width": null, 1230 | "min_height": null, 1231 | "min_width": null, 1232 | "object_fit": null, 1233 | "object_position": null, 1234 | "order": null, 1235 | "overflow": null, 1236 | "overflow_x": null, 1237 | "overflow_y": null, 1238 | "padding": null, 1239 | "right": null, 1240 | "top": null, 1241 | "visibility": null, 1242 | "width": null 1243 | } 1244 | }, 1245 | "a11abe00213d4c68805f56c47e79a640": { 1246 | "model_module": "@jupyter-widgets/controls", 1247 | "model_name": "ProgressStyleModel", 1248 | "state": { 1249 | "_model_module": "@jupyter-widgets/controls", 1250 | "_model_module_version": "1.5.0", 1251 | "_model_name": "ProgressStyleModel", 1252 | "_view_count": null, 1253 | "_view_module": "@jupyter-widgets/base", 1254 | "_view_module_version": "1.2.0", 1255 | "_view_name": "StyleView", 1256 | "bar_color": null, 1257 | "description_width": "initial" 1258 | } 1259 | }, 1260 | "a26348e6a0b44970b55bd5131ace7561": { 1261 | "model_module": "@jupyter-widgets/controls", 1262 | "model_name": "HTMLModel", 1263 | "state": { 1264 | "_dom_classes": [], 1265 | "_model_module": "@jupyter-widgets/controls", 1266 | "_model_module_version": "1.5.0", 1267 | "_model_name": "HTMLModel", 1268 | "_view_count": null, 1269 | "_view_module": "@jupyter-widgets/controls", 1270 | "_view_module_version": "1.5.0", 1271 | "_view_name": "HTMLView", 1272 | "description": "", 1273 | "description_tooltip": null, 1274 | "layout": "IPY_MODEL_38b6f596d2d243a09bfd96341f08670a", 1275 | "placeholder": "​", 1276 | "style": "IPY_MODEL_453d1a6d694e4e16aecf667138617c8d", 1277 | "value": " 34/34 [00:47<00:00, 1.26s/it]" 1278 | } 1279 | }, 1280 | "a70c774f906a4875abc81841fcfc77c8": { 1281 | "model_module": "@jupyter-widgets/controls", 1282 | "model_name": "HTMLModel", 1283 | "state": { 1284 | "_dom_classes": [], 1285 | "_model_module": "@jupyter-widgets/controls", 1286 | "_model_module_version": "1.5.0", 1287 | "_model_name": "HTMLModel", 1288 | "_view_count": null, 1289 | "_view_module": "@jupyter-widgets/controls", 1290 | "_view_module_version": "1.5.0", 1291 | "_view_name": "HTMLView", 1292 | "description": "", 1293 | "description_tooltip": null, 1294 | "layout": "IPY_MODEL_b46fa21a16754d72bc86d829c6782dfc", 1295 | "placeholder": "​", 1296 | "style": "IPY_MODEL_430a1bf095bc4ba9aae67e08e007fec8", 1297 | "value": " 34/34 [00:44<00:00, 1.18s/it]" 1298 | } 1299 | }, 1300 | "a86c90b9d4a045a9aeba91215ddaa505": { 1301 | "model_module": "@jupyter-widgets/controls", 1302 | "model_name": "HBoxModel", 1303 | "state": { 1304 | "_dom_classes": [], 1305 | "_model_module": "@jupyter-widgets/controls", 1306 | "_model_module_version": "1.5.0", 1307 | "_model_name": "HBoxModel", 1308 | "_view_count": null, 1309 | "_view_module": "@jupyter-widgets/controls", 1310 | "_view_module_version": "1.5.0", 1311 | "_view_name": "HBoxView", 1312 | "box_style": "", 1313 | "children": [ 1314 | "IPY_MODEL_5f0a0965e6004d03a28116f991c9de99", 1315 | "IPY_MODEL_aec28fd642964750ac31b7ea9dc5f242" 1316 | ], 1317 | "layout": "IPY_MODEL_3d3896fd7fa848e8afd940c1adfab572" 1318 | } 1319 | }, 1320 | "aa59a77d23a144a6a2f5bd1501aee54b": { 1321 | "model_module": "@jupyter-widgets/controls", 1322 | "model_name": "HBoxModel", 1323 | "state": { 1324 | "_dom_classes": [], 1325 | "_model_module": "@jupyter-widgets/controls", 1326 | "_model_module_version": "1.5.0", 1327 | "_model_name": "HBoxModel", 1328 | "_view_count": null, 1329 | "_view_module": "@jupyter-widgets/controls", 1330 | "_view_module_version": "1.5.0", 1331 | "_view_name": "HBoxView", 1332 | "box_style": "", 1333 | "children": [ 1334 | "IPY_MODEL_6eb3cc718cd44ee4bbc1a9447bcbeb27", 1335 | "IPY_MODEL_cc93f8cebae24650be23092c89666f08" 1336 | ], 1337 | "layout": "IPY_MODEL_f328795ccb1a48f38a5d261bdbd02623" 1338 | } 1339 | }, 1340 | "aec28fd642964750ac31b7ea9dc5f242": { 1341 | "model_module": "@jupyter-widgets/controls", 1342 | "model_name": "HTMLModel", 1343 | "state": { 1344 | "_dom_classes": [], 1345 | "_model_module": "@jupyter-widgets/controls", 1346 | "_model_module_version": "1.5.0", 1347 | "_model_name": "HTMLModel", 1348 | "_view_count": null, 1349 | "_view_module": "@jupyter-widgets/controls", 1350 | "_view_module_version": "1.5.0", 1351 | "_view_name": "HTMLView", 1352 | "description": "", 1353 | "description_tooltip": null, 1354 | "layout": "IPY_MODEL_271df06f7b0541f79de174e3ae0b770a", 1355 | "placeholder": "​", 1356 | "style": "IPY_MODEL_ef574dcb4bc04fcf94ceddb3f4dbd4a3", 1357 | "value": " 2/2 [00:07<00:00, 2.96s/it]" 1358 | } 1359 | }, 1360 | "af7af99cfee74f69b03f2b32276c94be": { 1361 | "model_module": "@jupyter-widgets/controls", 1362 | "model_name": "HBoxModel", 1363 | "state": { 1364 | "_dom_classes": [], 1365 | "_model_module": "@jupyter-widgets/controls", 1366 | "_model_module_version": "1.5.0", 1367 | "_model_name": "HBoxModel", 1368 | "_view_count": null, 1369 | "_view_module": "@jupyter-widgets/controls", 1370 | "_view_module_version": "1.5.0", 1371 | "_view_name": "HBoxView", 1372 | "box_style": "", 1373 | "children": [ 1374 | "IPY_MODEL_9082e599bc044e6190b2bf236ce680a3", 1375 | "IPY_MODEL_1d162408abec4ec0abba2491c141cb77" 1376 | ], 1377 | "layout": "IPY_MODEL_c1388acf869b45ce91a0c21a0c1ab24d" 1378 | } 1379 | }, 1380 | "b46fa21a16754d72bc86d829c6782dfc": { 1381 | "model_module": "@jupyter-widgets/base", 1382 | "model_name": "LayoutModel", 1383 | "state": { 1384 | "_model_module": "@jupyter-widgets/base", 1385 | "_model_module_version": "1.2.0", 1386 | "_model_name": "LayoutModel", 1387 | "_view_count": null, 1388 | "_view_module": "@jupyter-widgets/base", 1389 | "_view_module_version": "1.2.0", 1390 | "_view_name": "LayoutView", 1391 | "align_content": null, 1392 | "align_items": null, 1393 | "align_self": null, 1394 | "border": null, 1395 | "bottom": null, 1396 | "display": null, 1397 | "flex": null, 1398 | "flex_flow": null, 1399 | "grid_area": null, 1400 | "grid_auto_columns": null, 1401 | "grid_auto_flow": null, 1402 | "grid_auto_rows": null, 1403 | "grid_column": null, 1404 | "grid_gap": null, 1405 | "grid_row": null, 1406 | "grid_template_areas": null, 1407 | "grid_template_columns": null, 1408 | "grid_template_rows": null, 1409 | "height": null, 1410 | "justify_content": null, 1411 | "justify_items": null, 1412 | "left": null, 1413 | "margin": null, 1414 | "max_height": null, 1415 | "max_width": null, 1416 | "min_height": null, 1417 | "min_width": null, 1418 | "object_fit": null, 1419 | "object_position": null, 1420 | "order": null, 1421 | "overflow": null, 1422 | "overflow_x": null, 1423 | "overflow_y": null, 1424 | "padding": null, 1425 | "right": null, 1426 | "top": null, 1427 | "visibility": null, 1428 | "width": null 1429 | } 1430 | }, 1431 | "b6f95b26648a4d6d88e6afefb6ca9abd": { 1432 | "model_module": "@jupyter-widgets/controls", 1433 | "model_name": "ProgressStyleModel", 1434 | "state": { 1435 | "_model_module": "@jupyter-widgets/controls", 1436 | "_model_module_version": "1.5.0", 1437 | "_model_name": "ProgressStyleModel", 1438 | "_view_count": null, 1439 | "_view_module": "@jupyter-widgets/base", 1440 | "_view_module_version": "1.2.0", 1441 | "_view_name": "StyleView", 1442 | "bar_color": null, 1443 | "description_width": "initial" 1444 | } 1445 | }, 1446 | "b8df3931eeb34b19b079cb5d10e9076e": { 1447 | "model_module": "@jupyter-widgets/base", 1448 | "model_name": "LayoutModel", 1449 | "state": { 1450 | "_model_module": "@jupyter-widgets/base", 1451 | "_model_module_version": "1.2.0", 1452 | "_model_name": "LayoutModel", 1453 | "_view_count": null, 1454 | "_view_module": "@jupyter-widgets/base", 1455 | "_view_module_version": "1.2.0", 1456 | "_view_name": "LayoutView", 1457 | "align_content": null, 1458 | "align_items": null, 1459 | "align_self": null, 1460 | "border": null, 1461 | "bottom": null, 1462 | "display": null, 1463 | "flex": "2", 1464 | "flex_flow": null, 1465 | "grid_area": null, 1466 | "grid_auto_columns": null, 1467 | "grid_auto_flow": null, 1468 | "grid_auto_rows": null, 1469 | "grid_column": null, 1470 | "grid_gap": null, 1471 | "grid_row": null, 1472 | "grid_template_areas": null, 1473 | "grid_template_columns": null, 1474 | "grid_template_rows": null, 1475 | "height": null, 1476 | "justify_content": null, 1477 | "justify_items": null, 1478 | "left": null, 1479 | "margin": null, 1480 | "max_height": null, 1481 | "max_width": null, 1482 | "min_height": null, 1483 | "min_width": null, 1484 | "object_fit": null, 1485 | "object_position": null, 1486 | "order": null, 1487 | "overflow": null, 1488 | "overflow_x": null, 1489 | "overflow_y": null, 1490 | "padding": null, 1491 | "right": null, 1492 | "top": null, 1493 | "visibility": null, 1494 | "width": null 1495 | } 1496 | }, 1497 | "c1388acf869b45ce91a0c21a0c1ab24d": { 1498 | "model_module": "@jupyter-widgets/base", 1499 | "model_name": "LayoutModel", 1500 | "state": { 1501 | "_model_module": "@jupyter-widgets/base", 1502 | "_model_module_version": "1.2.0", 1503 | "_model_name": "LayoutModel", 1504 | "_view_count": null, 1505 | "_view_module": "@jupyter-widgets/base", 1506 | "_view_module_version": "1.2.0", 1507 | "_view_name": "LayoutView", 1508 | "align_content": null, 1509 | "align_items": null, 1510 | "align_self": null, 1511 | "border": null, 1512 | "bottom": null, 1513 | "display": "inline-flex", 1514 | "flex": null, 1515 | "flex_flow": "row wrap", 1516 | "grid_area": null, 1517 | "grid_auto_columns": null, 1518 | "grid_auto_flow": null, 1519 | "grid_auto_rows": null, 1520 | "grid_column": null, 1521 | "grid_gap": null, 1522 | "grid_row": null, 1523 | "grid_template_areas": null, 1524 | "grid_template_columns": null, 1525 | "grid_template_rows": null, 1526 | "height": null, 1527 | "justify_content": null, 1528 | "justify_items": null, 1529 | "left": null, 1530 | "margin": null, 1531 | "max_height": null, 1532 | "max_width": null, 1533 | "min_height": null, 1534 | "min_width": null, 1535 | "object_fit": null, 1536 | "object_position": null, 1537 | "order": null, 1538 | "overflow": null, 1539 | "overflow_x": null, 1540 | "overflow_y": null, 1541 | "padding": null, 1542 | "right": null, 1543 | "top": null, 1544 | "visibility": null, 1545 | "width": "100%" 1546 | } 1547 | }, 1548 | "c60b79d3b9a9437a84ffbd7761802972": { 1549 | "model_module": "@jupyter-widgets/base", 1550 | "model_name": "LayoutModel", 1551 | "state": { 1552 | "_model_module": "@jupyter-widgets/base", 1553 | "_model_module_version": "1.2.0", 1554 | "_model_name": "LayoutModel", 1555 | "_view_count": null, 1556 | "_view_module": "@jupyter-widgets/base", 1557 | "_view_module_version": "1.2.0", 1558 | "_view_name": "LayoutView", 1559 | "align_content": null, 1560 | "align_items": null, 1561 | "align_self": null, 1562 | "border": null, 1563 | "bottom": null, 1564 | "display": null, 1565 | "flex": "2", 1566 | "flex_flow": null, 1567 | "grid_area": null, 1568 | "grid_auto_columns": null, 1569 | "grid_auto_flow": null, 1570 | "grid_auto_rows": null, 1571 | "grid_column": null, 1572 | "grid_gap": null, 1573 | "grid_row": null, 1574 | "grid_template_areas": null, 1575 | "grid_template_columns": null, 1576 | "grid_template_rows": null, 1577 | "height": null, 1578 | "justify_content": null, 1579 | "justify_items": null, 1580 | "left": null, 1581 | "margin": null, 1582 | "max_height": null, 1583 | "max_width": null, 1584 | "min_height": null, 1585 | "min_width": null, 1586 | "object_fit": null, 1587 | "object_position": null, 1588 | "order": null, 1589 | "overflow": null, 1590 | "overflow_x": null, 1591 | "overflow_y": null, 1592 | "padding": null, 1593 | "right": null, 1594 | "top": null, 1595 | "visibility": null, 1596 | "width": null 1597 | } 1598 | }, 1599 | "c98945a2d2e947ccb3659f61875d7a2e": { 1600 | "model_module": "@jupyter-widgets/controls", 1601 | "model_name": "ProgressStyleModel", 1602 | "state": { 1603 | "_model_module": "@jupyter-widgets/controls", 1604 | "_model_module_version": "1.5.0", 1605 | "_model_name": "ProgressStyleModel", 1606 | "_view_count": null, 1607 | "_view_module": "@jupyter-widgets/base", 1608 | "_view_module_version": "1.2.0", 1609 | "_view_name": "StyleView", 1610 | "bar_color": null, 1611 | "description_width": "initial" 1612 | } 1613 | }, 1614 | "ca047f0d00424b9c92a7575b73760f63": { 1615 | "model_module": "@jupyter-widgets/controls", 1616 | "model_name": "HTMLModel", 1617 | "state": { 1618 | "_dom_classes": [], 1619 | "_model_module": "@jupyter-widgets/controls", 1620 | "_model_module_version": "1.5.0", 1621 | "_model_name": "HTMLModel", 1622 | "_view_count": null, 1623 | "_view_module": "@jupyter-widgets/controls", 1624 | "_view_module_version": "1.5.0", 1625 | "_view_name": "HTMLView", 1626 | "description": "", 1627 | "description_tooltip": null, 1628 | "layout": "IPY_MODEL_9e035b983cf245878e3dd701f8b79c3b", 1629 | "placeholder": "​", 1630 | "style": "IPY_MODEL_221689944bf445beaed9aabab4fbefbd", 1631 | "value": " 34/34 [00:44<00:00, 1.15s/it]" 1632 | } 1633 | }, 1634 | "cc93f8cebae24650be23092c89666f08": { 1635 | "model_module": "@jupyter-widgets/controls", 1636 | "model_name": "HTMLModel", 1637 | "state": { 1638 | "_dom_classes": [], 1639 | "_model_module": "@jupyter-widgets/controls", 1640 | "_model_module_version": "1.5.0", 1641 | "_model_name": "HTMLModel", 1642 | "_view_count": null, 1643 | "_view_module": "@jupyter-widgets/controls", 1644 | "_view_module_version": "1.5.0", 1645 | "_view_name": "HTMLView", 1646 | "description": "", 1647 | "description_tooltip": null, 1648 | "layout": "IPY_MODEL_d9f3e40901784bf49bd27d36a67d0aec", 1649 | "placeholder": "​", 1650 | "style": "IPY_MODEL_1b5b6a2845a943ec95fff5c7916a3b00", 1651 | "value": " 34/34 [00:47<00:00, 1.27s/it]" 1652 | } 1653 | }, 1654 | "ce299120c5b042dc8d7989dc6e83998c": { 1655 | "model_module": "@jupyter-widgets/base", 1656 | "model_name": "LayoutModel", 1657 | "state": { 1658 | "_model_module": "@jupyter-widgets/base", 1659 | "_model_module_version": "1.2.0", 1660 | "_model_name": "LayoutModel", 1661 | "_view_count": null, 1662 | "_view_module": "@jupyter-widgets/base", 1663 | "_view_module_version": "1.2.0", 1664 | "_view_name": "LayoutView", 1665 | "align_content": null, 1666 | "align_items": null, 1667 | "align_self": null, 1668 | "border": null, 1669 | "bottom": null, 1670 | "display": null, 1671 | "flex": "2", 1672 | "flex_flow": null, 1673 | "grid_area": null, 1674 | "grid_auto_columns": null, 1675 | "grid_auto_flow": null, 1676 | "grid_auto_rows": null, 1677 | "grid_column": null, 1678 | "grid_gap": null, 1679 | "grid_row": null, 1680 | "grid_template_areas": null, 1681 | "grid_template_columns": null, 1682 | "grid_template_rows": null, 1683 | "height": null, 1684 | "justify_content": null, 1685 | "justify_items": null, 1686 | "left": null, 1687 | "margin": null, 1688 | "max_height": null, 1689 | "max_width": null, 1690 | "min_height": null, 1691 | "min_width": null, 1692 | "object_fit": null, 1693 | "object_position": null, 1694 | "order": null, 1695 | "overflow": null, 1696 | "overflow_x": null, 1697 | "overflow_y": null, 1698 | "padding": null, 1699 | "right": null, 1700 | "top": null, 1701 | "visibility": null, 1702 | "width": null 1703 | } 1704 | }, 1705 | "d0a02bde6c1248a684be20dfcb6427b9": { 1706 | "model_module": "@jupyter-widgets/base", 1707 | "model_name": "LayoutModel", 1708 | "state": { 1709 | "_model_module": "@jupyter-widgets/base", 1710 | "_model_module_version": "1.2.0", 1711 | "_model_name": "LayoutModel", 1712 | "_view_count": null, 1713 | "_view_module": "@jupyter-widgets/base", 1714 | "_view_module_version": "1.2.0", 1715 | "_view_name": "LayoutView", 1716 | "align_content": null, 1717 | "align_items": null, 1718 | "align_self": null, 1719 | "border": null, 1720 | "bottom": null, 1721 | "display": "inline-flex", 1722 | "flex": null, 1723 | "flex_flow": "row wrap", 1724 | "grid_area": null, 1725 | "grid_auto_columns": null, 1726 | "grid_auto_flow": null, 1727 | "grid_auto_rows": null, 1728 | "grid_column": null, 1729 | "grid_gap": null, 1730 | "grid_row": null, 1731 | "grid_template_areas": null, 1732 | "grid_template_columns": null, 1733 | "grid_template_rows": null, 1734 | "height": null, 1735 | "justify_content": null, 1736 | "justify_items": null, 1737 | "left": null, 1738 | "margin": null, 1739 | "max_height": null, 1740 | "max_width": null, 1741 | "min_height": null, 1742 | "min_width": null, 1743 | "object_fit": null, 1744 | "object_position": null, 1745 | "order": null, 1746 | "overflow": null, 1747 | "overflow_x": null, 1748 | "overflow_y": null, 1749 | "padding": null, 1750 | "right": null, 1751 | "top": null, 1752 | "visibility": null, 1753 | "width": "100%" 1754 | } 1755 | }, 1756 | "d6402be0c4a54ac7b92d1af1b72010cf": { 1757 | "model_module": "@jupyter-widgets/controls", 1758 | "model_name": "DescriptionStyleModel", 1759 | "state": { 1760 | "_model_module": "@jupyter-widgets/controls", 1761 | "_model_module_version": "1.5.0", 1762 | "_model_name": "DescriptionStyleModel", 1763 | "_view_count": null, 1764 | "_view_module": "@jupyter-widgets/base", 1765 | "_view_module_version": "1.2.0", 1766 | "_view_name": "StyleView", 1767 | "description_width": "" 1768 | } 1769 | }, 1770 | "d9f3e40901784bf49bd27d36a67d0aec": { 1771 | "model_module": "@jupyter-widgets/base", 1772 | "model_name": "LayoutModel", 1773 | "state": { 1774 | "_model_module": "@jupyter-widgets/base", 1775 | "_model_module_version": "1.2.0", 1776 | "_model_name": "LayoutModel", 1777 | "_view_count": null, 1778 | "_view_module": "@jupyter-widgets/base", 1779 | "_view_module_version": "1.2.0", 1780 | "_view_name": "LayoutView", 1781 | "align_content": null, 1782 | "align_items": null, 1783 | "align_self": null, 1784 | "border": null, 1785 | "bottom": null, 1786 | "display": null, 1787 | "flex": null, 1788 | "flex_flow": null, 1789 | "grid_area": null, 1790 | "grid_auto_columns": null, 1791 | "grid_auto_flow": null, 1792 | "grid_auto_rows": null, 1793 | "grid_column": null, 1794 | "grid_gap": null, 1795 | "grid_row": null, 1796 | "grid_template_areas": null, 1797 | "grid_template_columns": null, 1798 | "grid_template_rows": null, 1799 | "height": null, 1800 | "justify_content": null, 1801 | "justify_items": null, 1802 | "left": null, 1803 | "margin": null, 1804 | "max_height": null, 1805 | "max_width": null, 1806 | "min_height": null, 1807 | "min_width": null, 1808 | "object_fit": null, 1809 | "object_position": null, 1810 | "order": null, 1811 | "overflow": null, 1812 | "overflow_x": null, 1813 | "overflow_y": null, 1814 | "padding": null, 1815 | "right": null, 1816 | "top": null, 1817 | "visibility": null, 1818 | "width": null 1819 | } 1820 | }, 1821 | "de09bcb035f94ad3a308cd5039bc0f03": { 1822 | "model_module": "@jupyter-widgets/controls", 1823 | "model_name": "HBoxModel", 1824 | "state": { 1825 | "_dom_classes": [], 1826 | "_model_module": "@jupyter-widgets/controls", 1827 | "_model_module_version": "1.5.0", 1828 | "_model_name": "HBoxModel", 1829 | "_view_count": null, 1830 | "_view_module": "@jupyter-widgets/controls", 1831 | "_view_module_version": "1.5.0", 1832 | "_view_name": "HBoxView", 1833 | "box_style": "", 1834 | "children": [ 1835 | "IPY_MODEL_13c5862559f54acea759078b34f48c4a", 1836 | "IPY_MODEL_a26348e6a0b44970b55bd5131ace7561" 1837 | ], 1838 | "layout": "IPY_MODEL_f1428d0ac3034701b0577b2b54814504" 1839 | } 1840 | }, 1841 | "df81ecfbacdb4cc194cdfeb913c58afd": { 1842 | "model_module": "@jupyter-widgets/base", 1843 | "model_name": "LayoutModel", 1844 | "state": { 1845 | "_model_module": "@jupyter-widgets/base", 1846 | "_model_module_version": "1.2.0", 1847 | "_model_name": "LayoutModel", 1848 | "_view_count": null, 1849 | "_view_module": "@jupyter-widgets/base", 1850 | "_view_module_version": "1.2.0", 1851 | "_view_name": "LayoutView", 1852 | "align_content": null, 1853 | "align_items": null, 1854 | "align_self": null, 1855 | "border": null, 1856 | "bottom": null, 1857 | "display": null, 1858 | "flex": "2", 1859 | "flex_flow": null, 1860 | "grid_area": null, 1861 | "grid_auto_columns": null, 1862 | "grid_auto_flow": null, 1863 | "grid_auto_rows": null, 1864 | "grid_column": null, 1865 | "grid_gap": null, 1866 | "grid_row": null, 1867 | "grid_template_areas": null, 1868 | "grid_template_columns": null, 1869 | "grid_template_rows": null, 1870 | "height": null, 1871 | "justify_content": null, 1872 | "justify_items": null, 1873 | "left": null, 1874 | "margin": null, 1875 | "max_height": null, 1876 | "max_width": null, 1877 | "min_height": null, 1878 | "min_width": null, 1879 | "object_fit": null, 1880 | "object_position": null, 1881 | "order": null, 1882 | "overflow": null, 1883 | "overflow_x": null, 1884 | "overflow_y": null, 1885 | "padding": null, 1886 | "right": null, 1887 | "top": null, 1888 | "visibility": null, 1889 | "width": null 1890 | } 1891 | }, 1892 | "e088f1f357b448059cbfb2fd081c124d": { 1893 | "model_module": "@jupyter-widgets/base", 1894 | "model_name": "LayoutModel", 1895 | "state": { 1896 | "_model_module": "@jupyter-widgets/base", 1897 | "_model_module_version": "1.2.0", 1898 | "_model_name": "LayoutModel", 1899 | "_view_count": null, 1900 | "_view_module": "@jupyter-widgets/base", 1901 | "_view_module_version": "1.2.0", 1902 | "_view_name": "LayoutView", 1903 | "align_content": null, 1904 | "align_items": null, 1905 | "align_self": null, 1906 | "border": null, 1907 | "bottom": null, 1908 | "display": null, 1909 | "flex": null, 1910 | "flex_flow": null, 1911 | "grid_area": null, 1912 | "grid_auto_columns": null, 1913 | "grid_auto_flow": null, 1914 | "grid_auto_rows": null, 1915 | "grid_column": null, 1916 | "grid_gap": null, 1917 | "grid_row": null, 1918 | "grid_template_areas": null, 1919 | "grid_template_columns": null, 1920 | "grid_template_rows": null, 1921 | "height": null, 1922 | "justify_content": null, 1923 | "justify_items": null, 1924 | "left": null, 1925 | "margin": null, 1926 | "max_height": null, 1927 | "max_width": null, 1928 | "min_height": null, 1929 | "min_width": null, 1930 | "object_fit": null, 1931 | "object_position": null, 1932 | "order": null, 1933 | "overflow": null, 1934 | "overflow_x": null, 1935 | "overflow_y": null, 1936 | "padding": null, 1937 | "right": null, 1938 | "top": null, 1939 | "visibility": null, 1940 | "width": null 1941 | } 1942 | }, 1943 | "e3d1641fd1bb45c2b91a598f9875ddf5": { 1944 | "model_module": "@jupyter-widgets/base", 1945 | "model_name": "LayoutModel", 1946 | "state": { 1947 | "_model_module": "@jupyter-widgets/base", 1948 | "_model_module_version": "1.2.0", 1949 | "_model_name": "LayoutModel", 1950 | "_view_count": null, 1951 | "_view_module": "@jupyter-widgets/base", 1952 | "_view_module_version": "1.2.0", 1953 | "_view_name": "LayoutView", 1954 | "align_content": null, 1955 | "align_items": null, 1956 | "align_self": null, 1957 | "border": null, 1958 | "bottom": null, 1959 | "display": null, 1960 | "flex": "2", 1961 | "flex_flow": null, 1962 | "grid_area": null, 1963 | "grid_auto_columns": null, 1964 | "grid_auto_flow": null, 1965 | "grid_auto_rows": null, 1966 | "grid_column": null, 1967 | "grid_gap": null, 1968 | "grid_row": null, 1969 | "grid_template_areas": null, 1970 | "grid_template_columns": null, 1971 | "grid_template_rows": null, 1972 | "height": null, 1973 | "justify_content": null, 1974 | "justify_items": null, 1975 | "left": null, 1976 | "margin": null, 1977 | "max_height": null, 1978 | "max_width": null, 1979 | "min_height": null, 1980 | "min_width": null, 1981 | "object_fit": null, 1982 | "object_position": null, 1983 | "order": null, 1984 | "overflow": null, 1985 | "overflow_x": null, 1986 | "overflow_y": null, 1987 | "padding": null, 1988 | "right": null, 1989 | "top": null, 1990 | "visibility": null, 1991 | "width": null 1992 | } 1993 | }, 1994 | "ea515854fbd444ad9c10faeaad0db444": { 1995 | "model_module": "@jupyter-widgets/controls", 1996 | "model_name": "DescriptionStyleModel", 1997 | "state": { 1998 | "_model_module": "@jupyter-widgets/controls", 1999 | "_model_module_version": "1.5.0", 2000 | "_model_name": "DescriptionStyleModel", 2001 | "_view_count": null, 2002 | "_view_module": "@jupyter-widgets/base", 2003 | "_view_module_version": "1.2.0", 2004 | "_view_name": "StyleView", 2005 | "description_width": "" 2006 | } 2007 | }, 2008 | "eabe50a8db2745cda833c6bbc3f6428b": { 2009 | "model_module": "@jupyter-widgets/base", 2010 | "model_name": "LayoutModel", 2011 | "state": { 2012 | "_model_module": "@jupyter-widgets/base", 2013 | "_model_module_version": "1.2.0", 2014 | "_model_name": "LayoutModel", 2015 | "_view_count": null, 2016 | "_view_module": "@jupyter-widgets/base", 2017 | "_view_module_version": "1.2.0", 2018 | "_view_name": "LayoutView", 2019 | "align_content": null, 2020 | "align_items": null, 2021 | "align_self": null, 2022 | "border": null, 2023 | "bottom": null, 2024 | "display": "inline-flex", 2025 | "flex": null, 2026 | "flex_flow": "row wrap", 2027 | "grid_area": null, 2028 | "grid_auto_columns": null, 2029 | "grid_auto_flow": null, 2030 | "grid_auto_rows": null, 2031 | "grid_column": null, 2032 | "grid_gap": null, 2033 | "grid_row": null, 2034 | "grid_template_areas": null, 2035 | "grid_template_columns": null, 2036 | "grid_template_rows": null, 2037 | "height": null, 2038 | "justify_content": null, 2039 | "justify_items": null, 2040 | "left": null, 2041 | "margin": null, 2042 | "max_height": null, 2043 | "max_width": null, 2044 | "min_height": null, 2045 | "min_width": null, 2046 | "object_fit": null, 2047 | "object_position": null, 2048 | "order": null, 2049 | "overflow": null, 2050 | "overflow_x": null, 2051 | "overflow_y": null, 2052 | "padding": null, 2053 | "right": null, 2054 | "top": null, 2055 | "visibility": null, 2056 | "width": "100%" 2057 | } 2058 | }, 2059 | "ebc1d5f7397c4f30951367350704eb1e": { 2060 | "model_module": "@jupyter-widgets/controls", 2061 | "model_name": "HBoxModel", 2062 | "state": { 2063 | "_dom_classes": [], 2064 | "_model_module": "@jupyter-widgets/controls", 2065 | "_model_module_version": "1.5.0", 2066 | "_model_name": "HBoxModel", 2067 | "_view_count": null, 2068 | "_view_module": "@jupyter-widgets/controls", 2069 | "_view_module_version": "1.5.0", 2070 | "_view_name": "HBoxView", 2071 | "box_style": "", 2072 | "children": [ 2073 | "IPY_MODEL_05daa3a8a4494329ac5a21a31dc556c0", 2074 | "IPY_MODEL_a70c774f906a4875abc81841fcfc77c8" 2075 | ], 2076 | "layout": "IPY_MODEL_fd1fc78d047e4a7eb6aa827f9ebc9a9d" 2077 | } 2078 | }, 2079 | "ef574dcb4bc04fcf94ceddb3f4dbd4a3": { 2080 | "model_module": "@jupyter-widgets/controls", 2081 | "model_name": "DescriptionStyleModel", 2082 | "state": { 2083 | "_model_module": "@jupyter-widgets/controls", 2084 | "_model_module_version": "1.5.0", 2085 | "_model_name": "DescriptionStyleModel", 2086 | "_view_count": null, 2087 | "_view_module": "@jupyter-widgets/base", 2088 | "_view_module_version": "1.2.0", 2089 | "_view_name": "StyleView", 2090 | "description_width": "" 2091 | } 2092 | }, 2093 | "f1428d0ac3034701b0577b2b54814504": { 2094 | "model_module": "@jupyter-widgets/base", 2095 | "model_name": "LayoutModel", 2096 | "state": { 2097 | "_model_module": "@jupyter-widgets/base", 2098 | "_model_module_version": "1.2.0", 2099 | "_model_name": "LayoutModel", 2100 | "_view_count": null, 2101 | "_view_module": "@jupyter-widgets/base", 2102 | "_view_module_version": "1.2.0", 2103 | "_view_name": "LayoutView", 2104 | "align_content": null, 2105 | "align_items": null, 2106 | "align_self": null, 2107 | "border": null, 2108 | "bottom": null, 2109 | "display": "inline-flex", 2110 | "flex": null, 2111 | "flex_flow": "row wrap", 2112 | "grid_area": null, 2113 | "grid_auto_columns": null, 2114 | "grid_auto_flow": null, 2115 | "grid_auto_rows": null, 2116 | "grid_column": null, 2117 | "grid_gap": null, 2118 | "grid_row": null, 2119 | "grid_template_areas": null, 2120 | "grid_template_columns": null, 2121 | "grid_template_rows": null, 2122 | "height": null, 2123 | "justify_content": null, 2124 | "justify_items": null, 2125 | "left": null, 2126 | "margin": null, 2127 | "max_height": null, 2128 | "max_width": null, 2129 | "min_height": null, 2130 | "min_width": null, 2131 | "object_fit": null, 2132 | "object_position": null, 2133 | "order": null, 2134 | "overflow": null, 2135 | "overflow_x": null, 2136 | "overflow_y": null, 2137 | "padding": null, 2138 | "right": null, 2139 | "top": null, 2140 | "visibility": null, 2141 | "width": "100%" 2142 | } 2143 | }, 2144 | "f328795ccb1a48f38a5d261bdbd02623": { 2145 | "model_module": "@jupyter-widgets/base", 2146 | "model_name": "LayoutModel", 2147 | "state": { 2148 | "_model_module": "@jupyter-widgets/base", 2149 | "_model_module_version": "1.2.0", 2150 | "_model_name": "LayoutModel", 2151 | "_view_count": null, 2152 | "_view_module": "@jupyter-widgets/base", 2153 | "_view_module_version": "1.2.0", 2154 | "_view_name": "LayoutView", 2155 | "align_content": null, 2156 | "align_items": null, 2157 | "align_self": null, 2158 | "border": null, 2159 | "bottom": null, 2160 | "display": "inline-flex", 2161 | "flex": null, 2162 | "flex_flow": "row wrap", 2163 | "grid_area": null, 2164 | "grid_auto_columns": null, 2165 | "grid_auto_flow": null, 2166 | "grid_auto_rows": null, 2167 | "grid_column": null, 2168 | "grid_gap": null, 2169 | "grid_row": null, 2170 | "grid_template_areas": null, 2171 | "grid_template_columns": null, 2172 | "grid_template_rows": null, 2173 | "height": null, 2174 | "justify_content": null, 2175 | "justify_items": null, 2176 | "left": null, 2177 | "margin": null, 2178 | "max_height": null, 2179 | "max_width": null, 2180 | "min_height": null, 2181 | "min_width": null, 2182 | "object_fit": null, 2183 | "object_position": null, 2184 | "order": null, 2185 | "overflow": null, 2186 | "overflow_x": null, 2187 | "overflow_y": null, 2188 | "padding": null, 2189 | "right": null, 2190 | "top": null, 2191 | "visibility": null, 2192 | "width": "100%" 2193 | } 2194 | }, 2195 | "f374fd5878c247338ada58c76ea6a4da": { 2196 | "model_module": "@jupyter-widgets/base", 2197 | "model_name": "LayoutModel", 2198 | "state": { 2199 | "_model_module": "@jupyter-widgets/base", 2200 | "_model_module_version": "1.2.0", 2201 | "_model_name": "LayoutModel", 2202 | "_view_count": null, 2203 | "_view_module": "@jupyter-widgets/base", 2204 | "_view_module_version": "1.2.0", 2205 | "_view_name": "LayoutView", 2206 | "align_content": null, 2207 | "align_items": null, 2208 | "align_self": null, 2209 | "border": null, 2210 | "bottom": null, 2211 | "display": "inline-flex", 2212 | "flex": null, 2213 | "flex_flow": "row wrap", 2214 | "grid_area": null, 2215 | "grid_auto_columns": null, 2216 | "grid_auto_flow": null, 2217 | "grid_auto_rows": null, 2218 | "grid_column": null, 2219 | "grid_gap": null, 2220 | "grid_row": null, 2221 | "grid_template_areas": null, 2222 | "grid_template_columns": null, 2223 | "grid_template_rows": null, 2224 | "height": null, 2225 | "justify_content": null, 2226 | "justify_items": null, 2227 | "left": null, 2228 | "margin": null, 2229 | "max_height": null, 2230 | "max_width": null, 2231 | "min_height": null, 2232 | "min_width": null, 2233 | "object_fit": null, 2234 | "object_position": null, 2235 | "order": null, 2236 | "overflow": null, 2237 | "overflow_x": null, 2238 | "overflow_y": null, 2239 | "padding": null, 2240 | "right": null, 2241 | "top": null, 2242 | "visibility": null, 2243 | "width": "100%" 2244 | } 2245 | }, 2246 | "fcdc7f0094274ddf93776dc81726318d": { 2247 | "model_module": "@jupyter-widgets/controls", 2248 | "model_name": "ProgressStyleModel", 2249 | "state": { 2250 | "_model_module": "@jupyter-widgets/controls", 2251 | "_model_module_version": "1.5.0", 2252 | "_model_name": "ProgressStyleModel", 2253 | "_view_count": null, 2254 | "_view_module": "@jupyter-widgets/base", 2255 | "_view_module_version": "1.2.0", 2256 | "_view_name": "StyleView", 2257 | "bar_color": null, 2258 | "description_width": "initial" 2259 | } 2260 | }, 2261 | "fd1fc78d047e4a7eb6aa827f9ebc9a9d": { 2262 | "model_module": "@jupyter-widgets/base", 2263 | "model_name": "LayoutModel", 2264 | "state": { 2265 | "_model_module": "@jupyter-widgets/base", 2266 | "_model_module_version": "1.2.0", 2267 | "_model_name": "LayoutModel", 2268 | "_view_count": null, 2269 | "_view_module": "@jupyter-widgets/base", 2270 | "_view_module_version": "1.2.0", 2271 | "_view_name": "LayoutView", 2272 | "align_content": null, 2273 | "align_items": null, 2274 | "align_self": null, 2275 | "border": null, 2276 | "bottom": null, 2277 | "display": "inline-flex", 2278 | "flex": null, 2279 | "flex_flow": "row wrap", 2280 | "grid_area": null, 2281 | "grid_auto_columns": null, 2282 | "grid_auto_flow": null, 2283 | "grid_auto_rows": null, 2284 | "grid_column": null, 2285 | "grid_gap": null, 2286 | "grid_row": null, 2287 | "grid_template_areas": null, 2288 | "grid_template_columns": null, 2289 | "grid_template_rows": null, 2290 | "height": null, 2291 | "justify_content": null, 2292 | "justify_items": null, 2293 | "left": null, 2294 | "margin": null, 2295 | "max_height": null, 2296 | "max_width": null, 2297 | "min_height": null, 2298 | "min_width": null, 2299 | "object_fit": null, 2300 | "object_position": null, 2301 | "order": null, 2302 | "overflow": null, 2303 | "overflow_x": null, 2304 | "overflow_y": null, 2305 | "padding": null, 2306 | "right": null, 2307 | "top": null, 2308 | "visibility": null, 2309 | "width": "100%" 2310 | } 2311 | } 2312 | } 2313 | } 2314 | }, 2315 | "nbformat": 4, 2316 | "nbformat_minor": 1 2317 | } 2318 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools", "wheel"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [tool.pytest.ini_options] 6 | norecursedirs = [ 7 | ".git", 8 | ".github", 9 | "dist", 10 | "build", 11 | "docs", 12 | ] 13 | addopts = [ 14 | "--strict-markers", 15 | "--doctest-modules", 16 | "--color=yes", 17 | "--disable-pytest-warnings", 18 | ] 19 | filterwarnings = [ 20 | "error::FutureWarning", 21 | ] 22 | xfail_strict = true 23 | junit_duration_report = "call" 24 | 25 | [tool.coverage.report] 26 | exclude_lines = [ 27 | "pragma: no cover", 28 | "pass", 29 | ] 30 | 31 | [tool.ruff] 32 | target-version = "py38" 33 | line-length = 120 34 | 35 | # Unlike Flake8, default to a complexity level of 10. 36 | lint.mccabe.max-complexity = 10 37 | # Use Google-style docstrings. 38 | lint.pydocstyle.convention = "google" 39 | format.preview = true 40 | lint.select = [ 41 | "E", 42 | "F", # see: https://pypi.org/project/pyflakes 43 | "I", #see: https://pypi.org/project/isort 44 | "S", # see: https://pypi.org/project/flake8-bandit 45 | "UP", # see: https://docs.astral.sh/ruff/rules/#pyupgrade-up 46 | # "D", # see: https://pypi.org/project/pydocstyle 47 | "W", # see: https://pypi.org/project/pycodestyle 48 | ] 49 | lint.extend-select = [ 50 | # "C4", # see: https://pypi.org/project/flake8-comprehensions 51 | "PLE", # see: https://pypi.org/project/pylint/ 52 | "PT", # see: https://pypi.org/project/flake8-pytest-style 53 | "RET", # see: https://pypi.org/project/flake8-return 54 | "RUF100", # Ralternative to yesqa 55 | "SIM", # see: https://pypi.org/project/flake8-simplify 56 | ] 57 | lint.ignore = [ 58 | "S101", # todo: Use of `assert` detected 59 | ] 60 | [tool.ruff.lint.per-file-ignores] 61 | "setup.py" = ["D100", "SIM115"] 62 | "notebooks/**" = [ 63 | "E501", "F401", "F821", 64 | "SIM115", # todo 65 | ] 66 | "scripts_*/**" = [ 67 | "S", "D" 68 | ] 69 | "tests/**" = [ 70 | "S", "D" 71 | ] 72 | -------------------------------------------------------------------------------- /scripts/birdclef_convert-spectrograms.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | from functools import partial 4 | 5 | import fire 6 | import matplotlib.pyplot as plt 7 | import numpy as np 8 | import pandas as pd 9 | from joblib import Parallel, delayed 10 | from tqdm.auto import tqdm 11 | 12 | from kaggle_imgclassif.birdclef.data import convert_and_export 13 | 14 | 15 | def _color_means(img_path): 16 | img = plt.imread(img_path) 17 | if np.max(img) > 1.5: 18 | img = img / 255.0 19 | clr_mean = np.mean(img) if img.ndim == 2 else {i: np.mean(img[..., i]) for i in range(3)} 20 | clr_std = np.std(img) if img.ndim == 2 else {i: np.std(img[..., i]) for i in range(3)} 21 | return clr_mean, clr_std 22 | 23 | 24 | def main( 25 | path_dataset: str, reduce_noise: bool = False, img_extension: str = ".png", img_size: int = 256, n_jobs: int = 12 26 | ): 27 | train_meta = pd.read_csv(os.path.join(path_dataset, "train_metadata.csv")).sample(frac=1) 28 | print(train_meta.head()) 29 | 30 | _convert_and_export = partial( 31 | convert_and_export, 32 | path_in=os.path.join(path_dataset, "train_audio"), 33 | path_out=os.path.join(path_dataset, "train_images"), 34 | reduce_noise=reduce_noise, 35 | img_extension=img_extension, 36 | img_size=img_size, 37 | ) 38 | 39 | _ = Parallel(n_jobs=n_jobs)(delayed(_convert_and_export)(fn) for fn in tqdm(train_meta["filename"])) 40 | # _= list(map(_convert_and_export, tqdm(train_meta["filename"]))) 41 | 42 | images = glob.glob(os.path.join(path_dataset, "train_images", "*", "*" + img_extension)) 43 | clr_mean_std = Parallel(n_jobs=n_jobs)(delayed(_color_means)(fn) for fn in tqdm(images)) 44 | img_color_mean = pd.DataFrame([c[0] for c in clr_mean_std]).describe() 45 | print(img_color_mean.T) 46 | img_color_std = pd.DataFrame([c[1] for c in clr_mean_std]).describe() 47 | print(img_color_std.T) 48 | img_color_mean = list(img_color_mean.T["mean"]) 49 | img_color_std = list(img_color_std.T["mean"]) 50 | print(f"MEAN: {img_color_mean}\n STD: {img_color_std}") 51 | 52 | 53 | if __name__ == "__main__": 54 | fire.Fire(main) 55 | -------------------------------------------------------------------------------- /scripts/herbarium_train-model.py: -------------------------------------------------------------------------------- 1 | """Sample execution on A100. 2 | 3 | >> python3 scripts/herbarium_train-model.py --gpus 6 --max_epochs 30 --val_split 0.05 \ 4 | --learning_rate 0.01 --model_backbone convnext_base_384_in22ft1k --image_size 384 --model_pretrained True \ 5 | --batch_size 72 --label_smoothing None --accumulate_grad_batches=12 6 | 7 | >> python3 scripts/herbarium_train-model.py --gpus 6 --max_epochs 30 --val_split 0.05 \ 8 | --learning_rate 0.001 --model_backbone dm_nfnet_f3 --image_size 416 --model_pretrained True \ 9 | --batch_size 18 --accumulate_grad_batches=48 10 | """ 11 | 12 | import json 13 | import os 14 | from dataclasses import dataclass 15 | from typing import Any, Callable, Dict, Optional, Tuple 16 | 17 | import fire 18 | import flash 19 | import pandas as pd 20 | import torch 21 | from flash.core.data.io.input_transform import InputTransform 22 | from flash.image import ImageClassificationData, ImageClassifier 23 | from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, StochasticWeightAveraging 24 | from pytorch_lightning.loggers import WandbLogger 25 | from sklearn.model_selection import train_test_split 26 | from timm.loss import AsymmetricLossSingleLabel, LabelSmoothingCrossEntropy 27 | from torchmetrics import F1Score 28 | from torchvision import transforms as T 29 | 30 | 31 | @dataclass 32 | class ImageClassificationInputTransform(InputTransform): 33 | image_size: Tuple[int, int] = (224, 224) 34 | color_mean: Tuple[float, float, float] = (0.781, 0.759, 0.710) 35 | color_std: Tuple[float, float, float] = (0.241, 0.245, 0.249) 36 | 37 | def input_per_sample_transform(self): 38 | return T.Compose([ 39 | T.ToTensor(), 40 | T.Resize(self.image_size), 41 | T.Normalize(self.color_mean, self.color_std), 42 | ]) 43 | 44 | def train_input_per_sample_transform(self): 45 | return T.Compose([ 46 | T.TrivialAugmentWide(), 47 | T.RandomPosterize(bits=6), 48 | T.RandomEqualize(), 49 | T.ToTensor(), 50 | T.Resize(self.image_size), 51 | T.RandomHorizontalFlip(), 52 | # T.ColorJitter(brightness=0.2, hue=0.1), 53 | T.RandomAutocontrast(), 54 | T.RandomAdjustSharpness(sharpness_factor=2), 55 | T.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 5)), 56 | T.RandomAffine(degrees=10, scale=(0.9, 1.1), translate=(0.1, 0.1)), 57 | # T.RandomPerspective(distortion_scale=0.1), 58 | T.Normalize(self.color_mean, self.color_std), 59 | ]) 60 | 61 | def target_per_sample_transform(self) -> Callable: 62 | return torch.as_tensor 63 | 64 | 65 | def load_df_train(dataset_dir: str) -> pd.DataFrame: 66 | with open(os.path.join(dataset_dir, "train_metadata.json")) as fp: 67 | train_data = json.load(fp) 68 | train_annotations = pd.DataFrame(train_data["annotations"]) 69 | train_images = pd.DataFrame(train_data["images"]).set_index("image_id") 70 | train_categories = pd.DataFrame(train_data["categories"]).set_index("category_id") 71 | train_institutions = pd.DataFrame(train_data["institutions"]).set_index("institution_id") 72 | df_train = pd.merge(train_annotations, train_images, how="left", right_index=True, left_on="image_id") 73 | df_train = pd.merge(df_train, train_categories, how="left", right_index=True, left_on="category_id") 74 | df_train = pd.merge(df_train, train_institutions, how="left", right_index=True, left_on="institution_id") 75 | df_train["file_name"] = df_train["file_name"].apply(lambda p: os.path.join("train_images", p)) 76 | return df_train 77 | 78 | 79 | def append_predictions(df_train: pd.DataFrame, path_csv: str = None) -> pd.DataFrame: 80 | if not path_csv: 81 | return df_train 82 | if os.path.isfile(path_csv): 83 | raise FileNotFoundError(f"Missing predictions: {path_csv}") 84 | df_preds = pd.read_csv(path_csv) 85 | df_preds["file_name"] = df_preds["file_name"].apply(lambda p: os.path.join("test_images", p)) 86 | df_train.append(df_preds) 87 | return df_train 88 | 89 | 90 | def inference( 91 | model, df_test: pd.DataFrame, dataset_dir: str, image_size: int, batch_size: int, gpus: int = 0 92 | ) -> pd.DataFrame: 93 | print(f"inference for {len(df_test)} images") 94 | print(df_test.head()) 95 | 96 | datamodule = ImageClassificationData.from_data_frame( 97 | input_field="file_name", 98 | # target_fields="category_id", 99 | predict_data_frame=df_test, 100 | # for simplicity take just fraction of the data 101 | # predict_data_frame=test_images[:len(test_images) // 100], 102 | predict_images_root=os.path.join(dataset_dir, "test_images"), 103 | predict_transform=ImageClassificationInputTransform, 104 | batch_size=batch_size, 105 | transform_kwargs={"image_size": (image_size, image_size)}, 106 | num_workers=batch_size, 107 | ) 108 | 109 | trainer = flash.Trainer(gpus=min(gpus, 1)) 110 | 111 | predictions = [] 112 | for lbs in trainer.predict(model, datamodule=datamodule, output="labels"): 113 | # lbs = [torch.argmax(p["preds"].float()).item() for p in preds] 114 | predictions += lbs 115 | 116 | print(f"Predictions: {len(predictions)} & Test images: {len(df_test)}") 117 | df_test["category_id"] = predictions 118 | return df_test 119 | 120 | 121 | def main( 122 | dataset_dir: str = "/home/jirka/Datasets/herbarium-2022-fgvc9", 123 | checkpoints_dir: str = "/home/jirka/Workspace/checkpoints_herbarium-flash", 124 | predict_csv: str = None, 125 | model_backbone: str = "efficientnet_b3", 126 | model_pretrained: bool = False, 127 | image_size: int = 320, 128 | optimizer: str = "AdamW", 129 | lr_scheduler: Optional[str] = None, 130 | learning_rate: float = 5e-3, 131 | label_smoothing: float = 0.01, 132 | batch_size: int = 24, 133 | max_epochs: int = 20, 134 | gpus: int = 1, 135 | val_split: float = 0.1, 136 | early_stopping: Optional[float] = None, 137 | swa: Optional[float] = None, 138 | num_workers: int = None, 139 | run_inference: bool = True, 140 | **trainer_kwargs: Dict[str, Any], 141 | ) -> None: 142 | print(f"Additional Trainer args: {trainer_kwargs}") 143 | df_train = load_df_train(dataset_dir) 144 | 145 | with open(os.path.join(dataset_dir, "test_metadata.json")) as fp: 146 | test_data = json.load(fp) 147 | df_test = pd.DataFrame(test_data).set_index("image_id") 148 | 149 | # ToDo 150 | # df_counts = df_train.groupby("category_id").size() 151 | # labels = list(df_counts.index) 152 | # sampler = WeightedRandomSampler(torch.from_numpy(1. / df_counts.values), len(df_counts)) 153 | 154 | df_train, df_val = train_test_split(df_train, test_size=val_split, stratify=df_train["category_id"].tolist()) 155 | # noisy predictions shall not be in validation 156 | df_train = append_predictions(df_train, path_csv=predict_csv) 157 | 158 | datamodule = ImageClassificationData.from_data_frame( 159 | input_field="file_name", 160 | target_fields="category_id", 161 | # for simplicity take just half of the data 162 | # train_data_frame=df_train[:len(df_train) // 2], 163 | train_data_frame=df_train, 164 | train_images_root=dataset_dir, 165 | val_data_frame=df_val, 166 | val_images_root=dataset_dir, 167 | transform=ImageClassificationInputTransform, 168 | transform_kwargs={"image_size": (image_size, image_size)}, 169 | batch_size=batch_size, 170 | num_workers=num_workers if num_workers else min(batch_size, int(os.cpu_count() / gpus)), 171 | # sampler=sampler, 172 | ) 173 | 174 | loss = LabelSmoothingCrossEntropy(label_smoothing) if label_smoothing else AsymmetricLossSingleLabel() 175 | 176 | model = ImageClassifier( 177 | backbone=model_backbone, 178 | metrics=F1Score(num_classes=datamodule.num_classes, average="macro"), 179 | pretrained=model_pretrained, 180 | loss_fn=loss, 181 | optimizer=optimizer, 182 | learning_rate=learning_rate, 183 | lr_scheduler=lr_scheduler, 184 | num_classes=datamodule.num_classes, 185 | ) 186 | 187 | # Trainer Args 188 | logger = WandbLogger(project="Flash_tract-image-segmentation") 189 | log_id = str(logger.experiment.id) 190 | monitor = "val_f1score" 191 | cbs = [ModelCheckpoint(dirpath=checkpoints_dir, filename=f"{log_id}", monitor=monitor, mode="max", verbose=True)] 192 | if early_stopping is not None: 193 | cbs.append(EarlyStopping(monitor=monitor, min_delta=early_stopping, mode="max", verbose=True)) 194 | if isinstance(swa, float): 195 | cbs.append(StochasticWeightAveraging(swa_epoch_start=swa)) 196 | 197 | trainer_flags = dict( 198 | callbacks=cbs, 199 | max_epochs=max_epochs, 200 | precision="bf16" if gpus else 32, 201 | gpus=gpus, 202 | accelerator="ddp" if gpus > 1 else None, 203 | logger=logger, 204 | gradient_clip_val=1e-2, 205 | ) 206 | trainer_flags.update(trainer_kwargs) 207 | trainer = flash.Trainer(**trainer_flags) 208 | 209 | # Train the model 210 | # trainer.finetune(model, datamodule=datamodule, strategy="no_freeze") 211 | trainer.finetune(model, datamodule=datamodule, strategy=("freeze_unfreeze", 2)) 212 | 213 | # Save the model! 214 | checkpoint_name = f"herbarium-classif-{log_id}_{model_backbone}-{image_size}px.pt" 215 | trainer.save_checkpoint(os.path.join(checkpoints_dir, checkpoint_name)) 216 | 217 | if run_inference and trainer.is_global_zero: 218 | df_preds = inference( 219 | model, df_test, dataset_dir=dataset_dir, image_size=image_size, batch_size=batch_size, gpus=gpus 220 | ) 221 | preds_name = f"predictions_herbarium-{log_id}_{model_backbone}-{image_size}.csv" 222 | df_preds.to_csv(os.path.join(checkpoints_dir, preds_name)) 223 | submission = pd.DataFrame({"Id": df_preds.index, "Predicted": df_preds["category_id"]}).set_index("Id") 224 | submission_name = f"submission_herbarium-{log_id}_{model_backbone}-{image_size}.csv" 225 | submission.to_csv(os.path.join(checkpoints_dir, submission_name)) 226 | 227 | 228 | if __name__ == "__main__": 229 | fire.Fire(main) 230 | -------------------------------------------------------------------------------- /scripts/imet_create-dataset-subset.py: -------------------------------------------------------------------------------- 1 | """Create a subset with more frequent labels. 2 | 3 | > python notebooks/imet_create-dataset-subset.py TEMP/train-from-kaggle.csv 1500 4 | """ 5 | 6 | import itertools 7 | import os 8 | 9 | import fire 10 | import numpy as np 11 | import pandas as pd 12 | 13 | 14 | def main(path_csv: str = "train-from-kaggle.csv", col_labels: str = "attribute_ids", count_thr: int = 1000): 15 | print(f"Loafing: {path_csv}") 16 | df_train = pd.read_csv(path_csv) 17 | print(f"Samples: {len(df_train)}") 18 | labels_all = list(itertools.chain(*[[int(lb) for lb in lbs.split(" ")] for lbs in df_train[col_labels]])) 19 | lb_hist = dict(zip(range(max(labels_all) + 1), np.bincount(labels_all))) 20 | print(f"Filter: {count_thr}") 21 | df_hist = pd.DataFrame([dict(lb=lb, count=count) for lb, count in lb_hist.items() if count > count_thr]).set_index( 22 | "lb" 23 | ) 24 | print(f"Reductions: {len(lb_hist)} >> {len(df_hist)}") 25 | 26 | allowed_lbs = set(list(df_hist.index)) 27 | df_train[col_labels] = [ 28 | " ".join([lb for lb in lbs.split() if int(lb) in allowed_lbs]) for lbs in df_train[col_labels] 29 | ] 30 | df_train[col_labels].replace("", np.nan, inplace=True) 31 | df_train.dropna(subset=[col_labels], inplace=True) 32 | print(f"Samples: {len(df_train)}") 33 | name_csv, _ = os.path.splitext(os.path.basename(path_csv)) 34 | path_csv = os.path.join(os.path.dirname(path_csv), f"{name_csv}_min-lb-sample-{count_thr}.csv") 35 | df_train.to_csv(path_csv) 36 | 37 | labels_all = list(itertools.chain(*[[int(lb) for lb in lbs.split(" ")] for lbs in df_train[col_labels]])) 38 | print(f"sanity check - nb labels: {len(set(labels_all))}") 39 | 40 | 41 | if __name__ == "__main__": 42 | fire.Fire(main) 43 | -------------------------------------------------------------------------------- /scripts/plant-pathology_train-model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # # Kaggle: Plant Pathology 2021 - FGVC8 4 | # 5 | # > python plant-pathology_train-model.py \ 6 | # --model.model resnet34 \ 7 | # --data.base_path /mnt/69B27B700DDA7D73/Datasets/plant-pathology-2021-640px/ 8 | # 9 | 10 | import torch 11 | from pytorch_lightning.utilities.cli import LightningCLI 12 | 13 | from kaggle_imgclassif.plant_pathology.data import PlantPathologyDM 14 | from kaggle_imgclassif.plant_pathology.models import MultiPlantPathology 15 | 16 | TRAINER_DEFAULTS = dict( 17 | gpus=1, 18 | max_epochs=25, 19 | precision=16, 20 | accumulate_grad_batches=10, 21 | val_check_interval=0.5, 22 | progress_bar_refresh_rate=1, 23 | weights_summary="top", 24 | auto_scale_batch_size="binsearch", 25 | ) 26 | 27 | 28 | class TuneFitCLI(LightningCLI): 29 | def before_fit(self) -> None: 30 | """Implement to run some code before fit is started.""" 31 | res = self.trainer.tune(**self.fit_kwargs, scale_batch_size_kwargs=dict(max_trials=5)) 32 | self.instantiate_classes() 33 | torch.cuda.empty_cache() 34 | self.datamodule.batch_size = int(res["scale_batch_size"] * 0.9) 35 | 36 | 37 | if __name__ == "__main__": 38 | cli = TuneFitCLI( 39 | model_class=MultiPlantPathology, 40 | datamodule_class=PlantPathologyDM, 41 | trainer_defaults=TRAINER_DEFAULTS, 42 | seed_everything_default=42, 43 | ) 44 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | name = kaggle-image-classification 3 | version = 2022.05 4 | author = Jiri Borovec 5 | author_email = jirka@pytorchlightning.ai 6 | url = https://github.com/Borda/kaggle_image-classify 7 | description = Tooling for Kaggle image classification challenges 8 | description-file = README.md 9 | long_description = file: README.md, LICENSE 10 | long_description_content_type = text/markdown 11 | keywords = image, classification, kaggle, challenge 12 | license = BSD 3-Clause License 13 | license_file = LICENSE 14 | classifiers = 15 | Environment :: Console 16 | Natural Language :: English 17 | # How mature is this project? Common values are 18 | # 3 - Alpha, 4 - Beta, 5 - Production/Stable 19 | Development Status :: 3 - Alpha 20 | # Indicate who your project is intended for 21 | Intended Audience :: Developers 22 | Topic :: Scientific/Engineering :: Artificial Intelligence 23 | Topic :: Scientific/Engineering :: Image Recognition 24 | Topic :: Scientific/Engineering :: Information Analysis 25 | # Pick your license as you wish 26 | # 'License :: OSI Approved :: BSD License', 27 | Operating System :: OS Independent 28 | # Specify the Python versions you support here. In particular, ensure 29 | # that you indicate whether you support Python 2, Python 3 or both. 30 | Programming Language :: Python :: 3 31 | 32 | [options] 33 | python_requires = >=3.8 34 | zip_safe = False 35 | include_package_data = True 36 | packages = find: 37 | install_requires = 38 | Pillow >=8.2 39 | torch >=1.8.1, <2.0 40 | torchmetrics >=0.7.0, <0.11.0 41 | pytorch-lightning >=1.5.0, <2.0 42 | torchvision 43 | timm >=0.5 44 | pandas 45 | matplotlib 46 | scikit-learn >=1.0 47 | seaborn 48 | joblib 49 | tqdm 50 | fire 51 | 52 | ;[options.package_data] 53 | ;* = *.txt, *.rst 54 | ;hello = *.msg 55 | 56 | ;[options.entry_points] 57 | ;console_scripts = 58 | ; executable-name = my_package.module:function 59 | 60 | [options.extras_require] 61 | app = 62 | streamlit 63 | gdown 64 | test = 65 | codecov >=2.1 66 | pytest >=6.0 67 | pytest-cov >2.10 68 | twine >=4.0 69 | plant_pathology = 70 | kornia >=0.5.2 71 | imet_collect = 72 | opencv-python 73 | birdclef = 74 | lightning-flash[audio] 75 | noisereduce 76 | librosa 77 | 78 | [options.packages.find] 79 | exclude = 80 | docs* 81 | notebooks* 82 | tests* 83 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | if __name__ == "__main__": 4 | setup() 5 | -------------------------------------------------------------------------------- /streamlit-app.py: -------------------------------------------------------------------------------- 1 | """Simple StreamLit app for plant classification. 2 | 3 | >> streamlit run streamlit-app.py 4 | """ 5 | 6 | import os 7 | 8 | import gdown 9 | import numpy as np 10 | import streamlit as st 11 | import torch 12 | from PIL import Image 13 | 14 | from kaggle_imgclassif.plant_pathology.augment import TORCHVISION_VALID_TRANSFORM 15 | from kaggle_imgclassif.plant_pathology.data import PlantPathologyDM 16 | from kaggle_imgclassif.plant_pathology.models import LitPlantPathology, MultiPlantPathology 17 | 18 | MODEL_PATH_GDRIVE = "https://drive.google.com/uc?id=1bynbFW0FpIt7fnqzImu2UIM1PHb9-yjw" 19 | MODEL_PATH_LOCAL = "fgvc8_resnet50.pt" 20 | UNIQUE_LABELS = ("scab", "rust", "complex", "frog_eye_leaf_spot", "powdery_mildew", "cider_apple_rust", "healthy") 21 | LUT_LABELS = dict(enumerate(sorted(UNIQUE_LABELS))) 22 | 23 | 24 | @st.cache(allow_output_mutation=True) 25 | def get_model(model_path: str = MODEL_PATH_LOCAL) -> LitPlantPathology: 26 | if not os.path.isfile(model_path): 27 | # download models if it missing locally 28 | gdown.download(MODEL_PATH_GDRIVE, model_path, quiet=False) 29 | 30 | net = torch.load(model_path) 31 | model = MultiPlantPathology(model=net) 32 | return model.eval() 33 | 34 | 35 | def process_image( 36 | model: LitPlantPathology, 37 | img_path: str = "tests/_data/plant-pathology/test_images/8a0d7cad7053f18d.jpg", 38 | streamlit_app: bool = False, 39 | ): 40 | if not img_path: 41 | return 42 | 43 | img = Image.open(img_path) 44 | if streamlit_app: 45 | st.image(img) 46 | 47 | img = TORCHVISION_VALID_TRANSFORM(img) 48 | 49 | with torch.no_grad(): 50 | encode = model(img.unsqueeze(0))[0] 51 | # process classification outputs 52 | binary = np.round(encode.detach().numpy(), decimals=2) 53 | labels = PlantPathologyDM.binary_mapping(encode, LUT_LABELS) 54 | 55 | if streamlit_app: 56 | st.write(", ".join(labels)) 57 | else: 58 | print(f"Binary: {binary} >> {labels}") 59 | 60 | 61 | if __name__ == "__main__": 62 | st.set_option("deprecation.showfileUploaderEncoding", False) 63 | 64 | # Upload an image and set some options for demo purposes 65 | st.header("Plant Pathology Demo") 66 | img_file = st.sidebar.file_uploader(label="Upload an image", type=["png", "jpg"]) 67 | 68 | # load model and ideally use cache version to speedup 69 | model = get_model() 70 | 71 | # run the app 72 | process_image(model, img_file, streamlit_app=True) 73 | # process_image(model) # dry rn with locals 74 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | _ROOT_TESTS = os.path.dirname(__file__) 4 | _ROOT_DATA = os.path.join(_ROOT_TESTS, "_data") 5 | -------------------------------------------------------------------------------- /tests/_data/cassava/train.csv: -------------------------------------------------------------------------------- 1 | image_id,label 2 | 218377.jpg,1 3 | 6477704.jpg,3 4 | 7635457.jpg,0 5 | -------------------------------------------------------------------------------- /tests/_data/cassava/train_images/218377.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Borda/kaggle_image-classify/a2ce04098b04aa7a7f4c8e0e9a8005ee60e951b6/tests/_data/cassava/train_images/218377.jpg -------------------------------------------------------------------------------- /tests/_data/cassava/train_images/6477704.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Borda/kaggle_image-classify/a2ce04098b04aa7a7f4c8e0e9a8005ee60e951b6/tests/_data/cassava/train_images/6477704.jpg -------------------------------------------------------------------------------- /tests/_data/cassava/train_images/7635457.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Borda/kaggle_image-classify/a2ce04098b04aa7a7f4c8e0e9a8005ee60e951b6/tests/_data/cassava/train_images/7635457.jpg -------------------------------------------------------------------------------- /tests/_data/imet-collect/test/test/023c01465d76f827ca9620667f7de487.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Borda/kaggle_image-classify/a2ce04098b04aa7a7f4c8e0e9a8005ee60e951b6/tests/_data/imet-collect/test/test/023c01465d76f827ca9620667f7de487.jpg -------------------------------------------------------------------------------- /tests/_data/imet-collect/test/test/02ca3baa47d2737b7796ae6bca32aa1d.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Borda/kaggle_image-classify/a2ce04098b04aa7a7f4c8e0e9a8005ee60e951b6/tests/_data/imet-collect/test/test/02ca3baa47d2737b7796ae6bca32aa1d.jpg -------------------------------------------------------------------------------- /tests/_data/imet-collect/test/test/050266ba8ff68b14fd17d4b05707ff19.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Borda/kaggle_image-classify/a2ce04098b04aa7a7f4c8e0e9a8005ee60e951b6/tests/_data/imet-collect/test/test/050266ba8ff68b14fd17d4b05707ff19.jpg -------------------------------------------------------------------------------- /tests/_data/imet-collect/train-1/train-1/09fe6ff247881b37779bcb386c26d7bb.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Borda/kaggle_image-classify/a2ce04098b04aa7a7f4c8e0e9a8005ee60e951b6/tests/_data/imet-collect/train-1/train-1/09fe6ff247881b37779bcb386c26d7bb.png -------------------------------------------------------------------------------- /tests/_data/imet-collect/train-1/train-1/0d5b8274de10cd73836c858c101266ea.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Borda/kaggle_image-classify/a2ce04098b04aa7a7f4c8e0e9a8005ee60e951b6/tests/_data/imet-collect/train-1/train-1/0d5b8274de10cd73836c858c101266ea.png -------------------------------------------------------------------------------- /tests/_data/imet-collect/train-1/train-1/11a87738861970a67249592db12f2da1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Borda/kaggle_image-classify/a2ce04098b04aa7a7f4c8e0e9a8005ee60e951b6/tests/_data/imet-collect/train-1/train-1/11a87738861970a67249592db12f2da1.png -------------------------------------------------------------------------------- /tests/_data/imet-collect/train-1/train-1/12c80004e34f9102cad72c7312133529.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Borda/kaggle_image-classify/a2ce04098b04aa7a7f4c8e0e9a8005ee60e951b6/tests/_data/imet-collect/train-1/train-1/12c80004e34f9102cad72c7312133529.png -------------------------------------------------------------------------------- /tests/_data/imet-collect/train-1/train-1/14f3fa3b620d46be00696eacda9df583.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Borda/kaggle_image-classify/a2ce04098b04aa7a7f4c8e0e9a8005ee60e951b6/tests/_data/imet-collect/train-1/train-1/14f3fa3b620d46be00696eacda9df583.png -------------------------------------------------------------------------------- /tests/_data/imet-collect/train-1/train-1/1cc66a822733a3c3a1ce66fe4be60a6f.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Borda/kaggle_image-classify/a2ce04098b04aa7a7f4c8e0e9a8005ee60e951b6/tests/_data/imet-collect/train-1/train-1/1cc66a822733a3c3a1ce66fe4be60a6f.png -------------------------------------------------------------------------------- /tests/_data/imet-collect/train-1/train-1/258e4a904729119efd85faaba80c965a.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Borda/kaggle_image-classify/a2ce04098b04aa7a7f4c8e0e9a8005ee60e951b6/tests/_data/imet-collect/train-1/train-1/258e4a904729119efd85faaba80c965a.png -------------------------------------------------------------------------------- /tests/_data/imet-collect/train-from-kaggle.csv: -------------------------------------------------------------------------------- 1 | id,attribute_ids 2 | 1cc66a822733a3c3a1ce66fe4be60a6f,124 2362 782 96 3 | 09fe6ff247881b37779bcb386c26d7bb,3192 3465 3193 233 783 4 | 11a87738861970a67249592db12f2da1,3334 507 262 2281 784 5 | 0d5b8274de10cd73836c858c101266ea,3465 370 946 783 6 | 12c80004e34f9102cad72c7312133529,2941 3235 233 1660 784 7 | 258e4a904729119efd85faaba80c965a,341 2362 784 8 | 14f3fa3b620d46be00696eacda9df583,507 792 9 | -------------------------------------------------------------------------------- /tests/_data/plant-pathology/test_images/8a0d7cad7053f18d.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Borda/kaggle_image-classify/a2ce04098b04aa7a7f4c8e0e9a8005ee60e951b6/tests/_data/plant-pathology/test_images/8a0d7cad7053f18d.jpg -------------------------------------------------------------------------------- /tests/_data/plant-pathology/train.csv: -------------------------------------------------------------------------------- 1 | image,labels 2 | 800113bb65efe69e.jpg,healthy 3 | 8002cb321f8bfcdf.jpg,scab frog_eye_leaf_spot complex 4 | 800f85dc5f407aef.jpg,cider_apple_rust 5 | 8a0be55d81f4bf0c.jpg,healthy 6 | 8a1a97abda0b4a7a.jpg,frog_eye_leaf_spot 7 | 8a2d598f2ec436e6.jpg,powdery_mildew 8 | 8a954b82bf81f2bc.jpg,rust complex 9 | -------------------------------------------------------------------------------- /tests/_data/plant-pathology/train_images/800113bb65efe69e.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Borda/kaggle_image-classify/a2ce04098b04aa7a7f4c8e0e9a8005ee60e951b6/tests/_data/plant-pathology/train_images/800113bb65efe69e.jpg -------------------------------------------------------------------------------- /tests/_data/plant-pathology/train_images/8002cb321f8bfcdf.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Borda/kaggle_image-classify/a2ce04098b04aa7a7f4c8e0e9a8005ee60e951b6/tests/_data/plant-pathology/train_images/8002cb321f8bfcdf.jpg -------------------------------------------------------------------------------- /tests/_data/plant-pathology/train_images/800f85dc5f407aef.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Borda/kaggle_image-classify/a2ce04098b04aa7a7f4c8e0e9a8005ee60e951b6/tests/_data/plant-pathology/train_images/800f85dc5f407aef.jpg -------------------------------------------------------------------------------- /tests/_data/plant-pathology/train_images/8a0be55d81f4bf0c.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Borda/kaggle_image-classify/a2ce04098b04aa7a7f4c8e0e9a8005ee60e951b6/tests/_data/plant-pathology/train_images/8a0be55d81f4bf0c.jpg -------------------------------------------------------------------------------- /tests/_data/plant-pathology/train_images/8a1a97abda0b4a7a.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Borda/kaggle_image-classify/a2ce04098b04aa7a7f4c8e0e9a8005ee60e951b6/tests/_data/plant-pathology/train_images/8a1a97abda0b4a7a.jpg -------------------------------------------------------------------------------- /tests/_data/plant-pathology/train_images/8a2d598f2ec436e6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Borda/kaggle_image-classify/a2ce04098b04aa7a7f4c8e0e9a8005ee60e951b6/tests/_data/plant-pathology/train_images/8a2d598f2ec436e6.jpg -------------------------------------------------------------------------------- /tests/_data/plant-pathology/train_images/8a954b82bf81f2bc.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Borda/kaggle_image-classify/a2ce04098b04aa7a7f4c8e0e9a8005ee60e951b6/tests/_data/plant-pathology/train_images/8a954b82bf81f2bc.jpg -------------------------------------------------------------------------------- /tests/birdclef/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Borda/kaggle_image-classify/a2ce04098b04aa7a7f4c8e0e9a8005ee60e951b6/tests/birdclef/__init__.py -------------------------------------------------------------------------------- /tests/cassava/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Borda/kaggle_image-classify/a2ce04098b04aa7a7f4c8e0e9a8005ee60e951b6/tests/cassava/__init__.py -------------------------------------------------------------------------------- /tests/cassava/test_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy 4 | 5 | from kaggle_imgclassif.cassava.data import CassavaDataModule, CassavaDataset 6 | from tests import _ROOT_DATA 7 | 8 | PATH_DATA = os.path.join(_ROOT_DATA, "cassava") 9 | 10 | 11 | def test_dataset(path_data=PATH_DATA): 12 | dataset = CassavaDataset( 13 | path_csv=os.path.join(path_data, "train.csv"), 14 | path_img_dir=os.path.join(path_data, "train_images"), 15 | ) 16 | img, lb = dataset[0] 17 | assert isinstance(img, numpy.ndarray) 18 | 19 | 20 | def test_datamodule(path_data=PATH_DATA): 21 | dm = CassavaDataModule( 22 | path_csv=os.path.join(path_data, "train.csv"), 23 | path_img_dir=os.path.join(path_data, "train_images"), 24 | ) 25 | dm.setup() 26 | 27 | for imgs, lbs in dm.train_dataloader(): 28 | assert len(imgs) 29 | assert len(lbs) 30 | break 31 | -------------------------------------------------------------------------------- /tests/cassava/test_models.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from pytorch_lightning import Trainer 4 | 5 | from kaggle_imgclassif.cassava.data import CassavaDataModule 6 | from kaggle_imgclassif.cassava.models import LitCassava 7 | from tests import _ROOT_DATA 8 | 9 | PATH_DATA = os.path.join(_ROOT_DATA, "cassava") 10 | 11 | 12 | def test_devel_run(tmpdir, path_data=PATH_DATA): 13 | """Sample fast dev run...""" 14 | dm = CassavaDataModule( 15 | path_csv=os.path.join(path_data, "train.csv"), 16 | path_img_dir=os.path.join(path_data, "train_images"), 17 | batch_size=1, 18 | split=0.6, 19 | ) 20 | model = LitCassava(model="resnet18") 21 | 22 | trainer = Trainer( 23 | default_root_dir=tmpdir, 24 | fast_dev_run=True, 25 | gpus=0, 26 | ) 27 | dm.setup() 28 | trainer.fit(model, datamodule=dm) 29 | -------------------------------------------------------------------------------- /tests/imet_collect/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Borda/kaggle_image-classify/a2ce04098b04aa7a7f4c8e0e9a8005ee60e951b6/tests/imet_collect/__init__.py -------------------------------------------------------------------------------- /tests/imet_collect/test_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pytest 4 | import torch 5 | from PIL import Image 6 | from torch import tensor 7 | 8 | from kaggle_imgclassif.imet_collect.data import IMetDataset, IMetDM 9 | from tests import _ROOT_DATA 10 | 11 | PATH_DATA = os.path.join(_ROOT_DATA, "imet-collect") 12 | _TEST_IMAGE_NAMES = ( 13 | "1cc66a822733a3c3a1ce66fe4be60a6f", 14 | "09fe6ff247881b37779bcb386c26d7bb", 15 | "258e4a904729119efd85faaba80c965a", 16 | "11a87738861970a67249592db12f2da1", 17 | "12c80004e34f9102cad72c7312133529", 18 | "0d5b8274de10cd73836c858c101266ea", 19 | "14f3fa3b620d46be00696eacda9df583", 20 | ) 21 | _TEST_UNIQUE_LABELS = ( 22 | "124", 23 | "1660", 24 | "2281", 25 | "233", 26 | "2362", 27 | "262", 28 | "2941", 29 | "3192", 30 | "3193", 31 | "3235", 32 | "3334", 33 | "341", 34 | "3465", 35 | "370", 36 | "507", 37 | "782", 38 | "783", 39 | "784", 40 | "792", 41 | "946", 42 | "96", 43 | ) 44 | _TEST_LABELS_BINARY = [ 45 | tensor([1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1]), 46 | tensor([0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0]), 47 | tensor([0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0]), 48 | tensor([0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0]), 49 | tensor([0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0]), 50 | tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 1, 0]), 51 | tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0]), 52 | ] 53 | 54 | 55 | @pytest.mark.parametrize("phase", ["train", "valid"]) 56 | def test_dataset(phase, path_data=PATH_DATA): 57 | dataset = IMetDataset( 58 | df_data=os.path.join(path_data, "train-from-kaggle.csv"), 59 | path_img_dir=os.path.join(path_data, "train-1", "train-1"), 60 | split=1.0 if phase == "train" else 0.0, 61 | mode=phase, 62 | random_state=42, 63 | ) 64 | assert len(dataset) == 7 65 | img, _ = dataset[0] 66 | assert isinstance(img, Image.Image) 67 | _img_names = [os.path.splitext(im)[0] for im in dataset.img_names] 68 | assert tuple(_img_names) == tuple(dataset.data["id"]) == _TEST_IMAGE_NAMES 69 | assert dataset.labels_unique == _TEST_UNIQUE_LABELS 70 | lbs = [tensor(dataset[i][1]) for i in range(len(dataset))] 71 | # mm = lambda lb: np.array([i for i, l in enumerate(lb) if l]) 72 | # lb_names = [np.array(dataset.labels_unique)[mm(lb)] for lb in lbs] 73 | assert all(torch.equal(a, b) for a, b in zip(_TEST_LABELS_BINARY, lbs)) 74 | 75 | 76 | def test_datamodule(path_data=PATH_DATA): 77 | dm = IMetDM( 78 | path_csv="train-from-kaggle.csv", 79 | base_path=path_data, 80 | batch_size=2, 81 | split=0.6, 82 | ) 83 | dm.setup() 84 | assert dm.num_classes == len(_TEST_UNIQUE_LABELS) 85 | assert dm.labels_unique == _TEST_UNIQUE_LABELS 86 | assert len(dm.lut_label) == len(_TEST_UNIQUE_LABELS) 87 | # assert isinstance(dm.label_histogram, Tensor) 88 | 89 | for imgs, lbs in dm.train_dataloader(): 90 | assert len(imgs) 91 | assert len(lbs) 92 | break 93 | 94 | for imgs, lbs in dm.val_dataloader(): 95 | assert len(imgs) 96 | assert len(lbs) 97 | break 98 | 99 | for imgs, names in dm.test_dataloader(): 100 | assert len(imgs) 101 | assert len(names) 102 | assert isinstance(names[0], str) 103 | break 104 | -------------------------------------------------------------------------------- /tests/imet_collect/test_models.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import timm 4 | from pytorch_lightning import Trainer 5 | 6 | from kaggle_imgclassif.imet_collect.data import IMetDM 7 | from kaggle_imgclassif.imet_collect.models import LitMet 8 | from tests import _ROOT_DATA 9 | 10 | PATH_DATA = os.path.join(_ROOT_DATA, "imet-collect") 11 | 12 | 13 | def test_create_model(): 14 | net = timm.create_model("resnet34", pretrained=False, num_classes=5) 15 | LitMet(model=net, num_classes=5) 16 | 17 | 18 | def test_devel_run(tmpdir, path_data=PATH_DATA): 19 | """Sample fast dev run...""" 20 | dm = IMetDM( 21 | path_csv=os.path.join(path_data, "train-from-kaggle.csv"), 22 | base_path=path_data, 23 | batch_size=2, 24 | split=0.6, 25 | ) 26 | dm.setup() 27 | net = timm.create_model("resnet18", num_classes=dm.num_classes) 28 | model = LitMet(model=net, num_classes=dm.num_classes) 29 | 30 | # smoke run 31 | trainer = Trainer( 32 | default_root_dir=tmpdir, 33 | fast_dev_run=True, 34 | ) 35 | trainer.fit(model, datamodule=dm) 36 | 37 | # test predictions 38 | for imgs, names in dm.test_dataloader(): 39 | encode = model(imgs) 40 | # it has only batch size 1 41 | for oh, name in zip(encode, names): 42 | dm.binary_encoding_to_labels(oh) 43 | -------------------------------------------------------------------------------- /tests/plant_pathology/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Borda/kaggle_image-classify/a2ce04098b04aa7a7f4c8e0e9a8005ee60e951b6/tests/plant_pathology/__init__.py -------------------------------------------------------------------------------- /tests/plant_pathology/test_augment.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from kaggle_imgclassif.plant_pathology.augment import LitAugmenter 5 | 6 | 7 | @pytest.mark.parametrize("img_shape", [(1, 3, 192, 192), (2, 3, 224, 224)]) 8 | def test_augmenter(img_shape): 9 | B, C, H, W = img_shape 10 | img = torch.rand(img_shape) 11 | aug = LitAugmenter() 12 | 13 | assert aug(img).shape == (B, C, 224, 224) 14 | -------------------------------------------------------------------------------- /tests/plant_pathology/test_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy 4 | import pytest 5 | from torch import Tensor 6 | 7 | from kaggle_imgclassif.plant_pathology.data import PlantPathologyDataset, PlantPathologyDM, PlantPathologySimpleDataset 8 | from tests import _ROOT_DATA 9 | 10 | PATH_DATA = os.path.join(_ROOT_DATA, "plant-pathology") 11 | 12 | _TEST_IMAGE_NAMES = ( 13 | "800113bb65efe69e.jpg", 14 | "8002cb321f8bfcdf.jpg", 15 | "8a2d598f2ec436e6.jpg", 16 | "800f85dc5f407aef.jpg", 17 | "8a1a97abda0b4a7a.jpg", 18 | "8a0be55d81f4bf0c.jpg", 19 | "8a954b82bf81f2bc.jpg", 20 | ) 21 | _TEST_UNIQUE_LABELS = ( 22 | "cider_apple_rust", 23 | "complex", 24 | "frog_eye_leaf_spot", 25 | "healthy", 26 | "powdery_mildew", 27 | "rust", 28 | "scab", 29 | ) 30 | _TEST_LABELS_BINARY = [ 31 | [0, 0, 0, 1, 0, 0, 0], 32 | [0, 1, 1, 0, 0, 0, 1], 33 | [0, 0, 0, 0, 1, 0, 0], 34 | [1, 0, 0, 0, 0, 0, 0], 35 | [0, 0, 1, 0, 0, 0, 0], 36 | [0, 0, 0, 1, 0, 0, 0], 37 | [0, 1, 0, 0, 0, 1, 0], 38 | ] 39 | 40 | 41 | @pytest.mark.parametrize( 42 | ("data_cls", "labels"), 43 | [ 44 | (PlantPathologyDataset, _TEST_LABELS_BINARY), 45 | (PlantPathologySimpleDataset, [3, 1, 4, 0, 2, 3, 1]), 46 | ], 47 | ) 48 | @pytest.mark.parametrize("phase", ["train", "valid"]) 49 | def test_dataset(data_cls, labels, phase, path_data=PATH_DATA): 50 | dataset = data_cls( 51 | df_data=os.path.join(path_data, "train.csv"), 52 | path_img_dir=os.path.join(path_data, "train_images"), 53 | split=1.0 if phase == "train" else 0.0, 54 | mode=phase, 55 | ) 56 | assert len(dataset) == 7 57 | img, _ = dataset[0] 58 | assert isinstance(img, numpy.ndarray) 59 | assert _TEST_IMAGE_NAMES == tuple(dataset.img_names) == tuple(dataset.data["image"]) 60 | assert dataset.labels_unique == _TEST_UNIQUE_LABELS 61 | lbs = [dataset[i][1] for i in range(len(dataset))] 62 | if isinstance(lbs[0], Tensor): 63 | lbs = [list(lb.numpy()) for lb in lbs] 64 | # mm = lambda lb: np.array([i for i, l in enumerate(lb) if l]) 65 | # lb_names = [np.array(dataset.labels_unique)[mm(lb)] for lb in lbs] 66 | assert labels == lbs 67 | 68 | 69 | @pytest.mark.parametrize("simple", [True, False]) 70 | @pytest.mark.parametrize("balance", [True, False]) 71 | def test_datamodule(simple, balance, path_data=PATH_DATA): 72 | dm = PlantPathologyDM( 73 | path_csv="train.csv", 74 | base_path=path_data, 75 | simple=simple, 76 | split=0.6, 77 | balancing=balance, 78 | ) 79 | dm.setup() 80 | assert dm.num_classes > 0 81 | assert dm.labels_unique 82 | assert dm.lut_label 83 | assert isinstance(dm.label_histogram, Tensor) 84 | 85 | for imgs, lbs in dm.train_dataloader(): 86 | assert len(imgs) 87 | assert len(lbs) 88 | break 89 | 90 | for imgs, lbs in dm.val_dataloader(): 91 | assert len(imgs) 92 | assert len(lbs) 93 | break 94 | 95 | for imgs, names in dm.test_dataloader(): 96 | assert len(imgs) 97 | assert len(names) 98 | assert isinstance(names[0], str) 99 | break 100 | -------------------------------------------------------------------------------- /tests/plant_pathology/test_models.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pytest 4 | from pytorch_lightning import Trainer 5 | 6 | from kaggle_imgclassif.plant_pathology.data import PlantPathologyDM 7 | from kaggle_imgclassif.plant_pathology.models import LitPlantPathology, MultiPlantPathology 8 | from tests import _ROOT_DATA 9 | 10 | PATH_DATA = os.path.join(_ROOT_DATA, "plant-pathology") 11 | 12 | 13 | @pytest.mark.parametrize("model_cls", [LitPlantPathology, MultiPlantPathology]) 14 | def test_create_model(model_cls, net: str = "resnet18"): 15 | model_cls(model=net) 16 | 17 | 18 | @pytest.mark.parametrize( 19 | ("ds_simple", "model_cls"), 20 | [ 21 | (True, LitPlantPathology), 22 | (False, MultiPlantPathology), 23 | ], 24 | ) 25 | def test_devel_run(tmpdir, ds_simple, model_cls, path_data=PATH_DATA): 26 | """Sample fast dev run...""" 27 | dm = PlantPathologyDM( 28 | path_csv=os.path.join(path_data, "train.csv"), 29 | base_path=path_data, 30 | simple=ds_simple, 31 | batch_size=2, 32 | split=0.6, 33 | ) 34 | dm.setup() 35 | model = model_cls(model="resnet18", num_classes=dm.num_classes) 36 | 37 | # smoke run 38 | trainer = Trainer( 39 | default_root_dir=tmpdir, 40 | fast_dev_run=True, 41 | ) 42 | trainer.fit(model, datamodule=dm) 43 | 44 | # test predictions 45 | for imgs, names in dm.test_dataloader(): 46 | encode = model(imgs) 47 | # it has only batch size 1 48 | for oh, name in zip(encode, names): 49 | dm.binary_encoding_to_labels(oh) 50 | --------------------------------------------------------------------------------