├── .devcontainer └── devcontainer.json ├── .github ├── FUNDING.yml ├── dependabot.yml └── workflows │ ├── pypi.yml │ ├── stale.yaml │ └── tests.yml ├── .gitignore ├── .readthedocs.yaml ├── HALLOFFAME.md ├── LICENSE ├── Makefile ├── README.md ├── docs ├── Makefile ├── conf.py ├── encoders.rst ├── encoders_dpt.rst ├── encoders_timm.rst ├── index.rst ├── insights.rst ├── install.rst ├── logo.png ├── losses.rst ├── make.bat ├── metrics.rst ├── models.rst ├── quickstart.rst └── save_load.rst ├── examples ├── binary_segmentation_buildings.py ├── binary_segmentation_intro.ipynb ├── camvid_segmentation_multiclass.ipynb ├── cars segmentation (camvid).ipynb ├── convert_to_onnx.ipynb ├── dpt_inference_pretrained.ipynb ├── save_load_model_and_share_with_hf_hub.ipynb ├── segformer_inference_pretrained.ipynb └── upernet_inference_pretrained.ipynb ├── licenses ├── LICENSES.md ├── LICENSE_apache.md ├── LICENSE_apple.md └── LICENSE_nvidia.md ├── misc ├── generate_table.py ├── generate_table_timm.py └── generate_test_models.py ├── pics ├── logo-small-h300.png └── logo-small-w300.png ├── pyproject.toml ├── requirements ├── docs.txt ├── minimum.old ├── required.txt └── test.txt ├── scripts └── models-conversions │ ├── dpt-original-to-smp.py │ ├── segformer-original-decoder-to-smp.py │ └── upernet-hf-to-smp.py ├── segmentation_models_pytorch ├── __init__.py ├── __version__.py ├── base │ ├── __init__.py │ ├── heads.py │ ├── hub_mixin.py │ ├── initialization.py │ ├── model.py │ ├── modules.py │ └── utils.py ├── datasets │ ├── __init__.py │ └── oxford_pet.py ├── decoders │ ├── __init__.py │ ├── deeplabv3 │ │ ├── __init__.py │ │ ├── decoder.py │ │ └── model.py │ ├── dpt │ │ ├── __init__.py │ │ ├── decoder.py │ │ └── model.py │ ├── fpn │ │ ├── __init__.py │ │ ├── decoder.py │ │ └── model.py │ ├── linknet │ │ ├── __init__.py │ │ ├── decoder.py │ │ └── model.py │ ├── manet │ │ ├── __init__.py │ │ ├── decoder.py │ │ └── model.py │ ├── pan │ │ ├── __init__.py │ │ ├── decoder.py │ │ └── model.py │ ├── pspnet │ │ ├── __init__.py │ │ ├── decoder.py │ │ └── model.py │ ├── segformer │ │ ├── __init__.py │ │ ├── decoder.py │ │ └── model.py │ ├── unet │ │ ├── __init__.py │ │ ├── decoder.py │ │ └── model.py │ ├── unetplusplus │ │ ├── __init__.py │ │ ├── decoder.py │ │ └── model.py │ └── upernet │ │ ├── __init__.py │ │ ├── decoder.py │ │ └── model.py ├── encoders │ ├── __init__.py │ ├── _base.py │ ├── _dpn.py │ ├── _efficientnet.py │ ├── _inceptionresnetv2.py │ ├── _inceptionv4.py │ ├── _legacy_pretrained_settings.py │ ├── _preprocessing.py │ ├── _senet.py │ ├── _utils.py │ ├── _xception.py │ ├── densenet.py │ ├── dpn.py │ ├── efficientnet.py │ ├── inceptionresnetv2.py │ ├── inceptionv4.py │ ├── mix_transformer.py │ ├── mobilenet.py │ ├── mobileone.py │ ├── resnet.py │ ├── senet.py │ ├── timm_efficientnet.py │ ├── timm_sknet.py │ ├── timm_universal.py │ ├── timm_vit.py │ ├── vgg.py │ └── xception.py ├── losses │ ├── __init__.py │ ├── _functional.py │ ├── constants.py │ ├── dice.py │ ├── focal.py │ ├── jaccard.py │ ├── lovasz.py │ ├── mcc.py │ ├── soft_bce.py │ ├── soft_ce.py │ └── tversky.py ├── metrics │ ├── __init__.py │ └── functional.py └── utils │ ├── __init__.py │ ├── base.py │ ├── functional.py │ ├── losses.py │ ├── meter.py │ ├── metrics.py │ └── train.py └── tests ├── __init__.py ├── base └── test_modules.py ├── conftest.py ├── encoders ├── __init__.py ├── base.py ├── test_batchnorm_deprecation.py ├── test_common.py ├── test_pretrainedmodels_encoders.py ├── test_smp_encoders.py ├── test_timm_ported_encoders.py ├── test_timm_universal.py ├── test_timm_vit_encoders.py └── test_torchvision_encoders.py ├── models ├── __init__.py ├── base.py ├── test_deeplab.py ├── test_dpt.py ├── test_fpn.py ├── test_linknet.py ├── test_manet.py ├── test_pan.py ├── test_psp.py ├── test_segformer.py ├── test_unet.py ├── test_unetplusplus.py └── test_upernet.py ├── test_base.py ├── test_losses.py ├── test_preprocessing.py └── utils.py /.devcontainer/devcontainer.json: -------------------------------------------------------------------------------- 1 | { 2 | "image": "mcr.microsoft.com/devcontainers/universal:2", 3 | "features": { 4 | "ghcr.io/devcontainers/features/python:1": { 5 | "verison": 3.6 6 | } 7 | } 8 | } 9 | -------------------------------------------------------------------------------- /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | # These are supported funding model platforms 2 | 3 | github: # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2] 4 | patreon: # Replace with a single Patreon username 5 | open_collective: # Replace with a single Open Collective username 6 | ko_fi: qubvel 7 | tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel 8 | community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry 9 | liberapay: qubvel 10 | issuehunt: # Replace with a single IssueHunt username 11 | otechie: # Replace with a single Otechie username 12 | custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2'] 13 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: "github-actions" 4 | directory: "/" 5 | schedule: 6 | interval: "daily" 7 | - package-ecosystem: "pip" 8 | directory: "/requirements" 9 | schedule: 10 | interval: "daily" 11 | groups: 12 | torch: 13 | patterns: 14 | - "torch" 15 | - "torchvision" 16 | ignore: 17 | - dependency-name: "setuptools" 18 | update-types: ["version-update:semver-patch"] 19 | -------------------------------------------------------------------------------- /.github/workflows/pypi.yml: -------------------------------------------------------------------------------- 1 | name: Upload Python Package 2 | 3 | on: 4 | release: 5 | types: [published] 6 | 7 | jobs: 8 | build: 9 | name: build 10 | runs-on: ubuntu-latest 11 | steps: 12 | - name: Clone repo 13 | uses: actions/checkout@v4.2.2 14 | 15 | - name: Set up python 16 | uses: actions/setup-python@v5.6.0 17 | with: 18 | python-version: '3.13' 19 | 20 | - name: Install pip dependencies 21 | run: pip install build 22 | 23 | - name: List pip dependencies 24 | run: pip list 25 | 26 | - name: Build project 27 | run: python3 -m build 28 | 29 | - name: Upload artifacts 30 | uses: actions/upload-artifact@v4.6.2 31 | with: 32 | name: pypi-dist 33 | path: dist/ 34 | 35 | pypi: 36 | name: pypi 37 | needs: 38 | - build 39 | permissions: 40 | id-token: write 41 | runs-on: ubuntu-latest 42 | steps: 43 | - name: Download artifacts 44 | uses: actions/download-artifact@v4.3.0 45 | with: 46 | name: pypi-dist 47 | path: dist/ 48 | 49 | - name: Publish to PyPI 50 | uses: pypa/gh-action-pypi-publish@v1.12.4 51 | -------------------------------------------------------------------------------- /.github/workflows/stale.yaml: -------------------------------------------------------------------------------- 1 | name: 'Close stale issues and PRs' 2 | on: 3 | schedule: 4 | - cron: '30 1 * * *' 5 | 6 | jobs: 7 | stale: 8 | runs-on: ubuntu-latest 9 | steps: 10 | - uses: actions/stale@v9 11 | with: 12 | stale-issue-message: 'This issue is stale because it has been open 60 days with no activity. Remove stale label or comment or this will be closed in 7 days.' 13 | stale-pr-message: 'This PR is stale because it has been open 60 days with no activity. Remove stale label or comment or this will be closed in 15 days.' 14 | close-issue-message: 'This issue was closed because it has been stalled for 7 days with no activity.' 15 | close-pr-message: 'This PR was closed because it has been stalled for 15 days with no activity.' 16 | days-before-issue-stale: 60 17 | days-before-pr-stale: 60 18 | days-before-issue-close: 7 19 | days-before-pr-close: 15 20 | operations-per-run: 100 21 | ascending: true 22 | -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | 2 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions 3 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 4 | 5 | name: CI 6 | 7 | on: 8 | push: 9 | branches: [ main ] 10 | pull_request: 11 | branches: [ main ] 12 | 13 | jobs: 14 | 15 | style: 16 | runs-on: ubuntu-latest 17 | steps: 18 | - uses: actions/checkout@v4 19 | - name: Set up Python 20 | uses: astral-sh/setup-uv@v5 21 | with: 22 | python-version: "3.11" 23 | 24 | - name: Install dependencies 25 | run: uv pip install -r requirements/test.txt 26 | 27 | # Update output format to enable automatic inline annotations. 28 | - name: Run Ruff Linter 29 | run: ruff check --output-format=github 30 | - name: Run Ruff Formatter 31 | run: ruff format --check 32 | 33 | test: 34 | strategy: 35 | matrix: 36 | os: [ubuntu-latest, macos-latest, windows-latest] 37 | python-version: ["3.10", "3.11", "3.12"] 38 | runs-on: ${{ matrix.os }} 39 | steps: 40 | - uses: actions/checkout@v4 41 | 42 | - name: Set up Python ${{ matrix.python-version }} 43 | uses: astral-sh/setup-uv@v5 44 | with: 45 | python-version: ${{ matrix.python-version }} 46 | 47 | - name: Install dependencies 48 | run: uv pip install -r requirements/required.txt -r requirements/test.txt 49 | 50 | - name: Show installed packages 51 | run: uv pip list 52 | 53 | - name: Test with PyTest 54 | run: uv run pytest -v -rsx -n 2 --cov=segmentation_models_pytorch --cov-report=xml --cov-config=pyproject.toml --non-marked-only 55 | 56 | - name: Upload coverage reports to Codecov 57 | uses: codecov/codecov-action@v5 58 | with: 59 | token: ${{ secrets.CODECOV_TOKEN }} 60 | slug: qubvel-org/segmentation_models.pytorch 61 | if: matrix.os == 'macos-latest' && matrix.python-version == '3.12' 62 | 63 | test_logits_match: 64 | runs-on: ubuntu-latest 65 | steps: 66 | - uses: actions/checkout@v4 67 | - name: Set up Python 68 | uses: astral-sh/setup-uv@v5 69 | with: 70 | python-version: "3.10" 71 | - name: Install dependencies 72 | run: uv pip install -r requirements/required.txt -r requirements/test.txt 73 | - name: Show installed packages 74 | run: uv pip list 75 | - name: Test with PyTest 76 | run: RUN_SLOW=1 uv run pytest -v -rsx -n 2 -m "logits_match" 77 | 78 | test_torch_compile: 79 | runs-on: ubuntu-latest 80 | steps: 81 | - uses: actions/checkout@v4 82 | - name: Set up Python 83 | uses: astral-sh/setup-uv@v5 84 | with: 85 | python-version: "3.10" 86 | - name: Install dependencies 87 | run: uv pip install -r requirements/required.txt -r requirements/test.txt 88 | - name: Show installed packages 89 | run: uv pip list 90 | - name: Test with PyTest 91 | run: uv run pytest -v -rsx -n 2 -m "compile" 92 | 93 | test_torch_export: 94 | runs-on: ubuntu-latest 95 | steps: 96 | - uses: actions/checkout@v4 97 | - name: Set up Python 98 | uses: astral-sh/setup-uv@v5 99 | with: 100 | python-version: "3.10" 101 | - name: Install dependencies 102 | run: uv pip install -r requirements/required.txt -r requirements/test.txt 103 | - name: Show installed packages 104 | run: uv pip list 105 | - name: Test with PyTest 106 | run: uv run pytest -v -rsx -n 2 -m "torch_export" 107 | 108 | test_torch_script: 109 | runs-on: ubuntu-latest 110 | steps: 111 | - uses: actions/checkout@v4 112 | - name: Set up Python 113 | uses: astral-sh/setup-uv@v5 114 | with: 115 | python-version: "3.10" 116 | - name: Install dependencies 117 | run: uv pip install -r requirements/required.txt -r requirements/test.txt 118 | - name: Show installed packages 119 | run: uv pip list 120 | - name: Test with PyTest 121 | run: uv run pytest -v -rsx -n 2 -m "torch_script" 122 | 123 | minimum: 124 | runs-on: ubuntu-latest 125 | steps: 126 | - uses: actions/checkout@v4 127 | - name: Set up Python 128 | uses: astral-sh/setup-uv@v5 129 | with: 130 | python-version: "3.9" 131 | - name: Install dependencies 132 | run: uv pip install -r requirements/minimum.old -r requirements/test.txt 133 | - name: Show installed packages 134 | run: uv pip list 135 | - name: Test with pytest 136 | run: uv run pytest -v -rsx -n 2 --non-marked-only 137 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | .idea/ 6 | .venv* 7 | examples/images* 8 | examples/annotations* 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | 63 | # Flask stuff: 64 | instance/ 65 | .webassets-cache 66 | 67 | # Scrapy stuff: 68 | .scrapy 69 | 70 | # Sphinx documentation 71 | docs/_build/ 72 | 73 | # PyBuilder 74 | target/ 75 | 76 | # Jupyter Notebook 77 | .ipynb_checkpoints 78 | 79 | # pyenv 80 | .python-version 81 | 82 | # celery beat schedule file 83 | celerybeat-schedule 84 | 85 | # SageMath parsed files 86 | *.sage.py 87 | 88 | # Environments 89 | .env 90 | .venv 91 | env/ 92 | venv/ 93 | ENV/ 94 | env.bak/ 95 | venv.bak/ 96 | .vscode/ 97 | 98 | # Spyder project settings 99 | .spyderproject 100 | .spyproject 101 | 102 | # Rope project settings 103 | .ropeproject 104 | 105 | # mkdocs documentation 106 | /site 107 | 108 | # mypy 109 | .mypy_cache/ 110 | 111 | # ruff 112 | .ruff_cache/ -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # Read the Docs configuration file for Sphinx projects 2 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 3 | 4 | # Required 5 | version: 2 6 | 7 | # Set the OS, Python version and other tools you might need 8 | build: 9 | os: ubuntu-22.04 10 | tools: 11 | python: "3.12" 12 | 13 | # Build documentation in the "docs/" directory with Sphinx 14 | sphinx: 15 | configuration: docs/conf.py 16 | 17 | # Optional but recommended, declare the Python requirements required 18 | # to build your documentation 19 | # See https://docs.readthedocs.io/en/stable/guides/reproducible-builds.html 20 | python: 21 | install: 22 | - requirements: requirements/docs.txt 23 | -------------------------------------------------------------------------------- /HALLOFFAME.md: -------------------------------------------------------------------------------- 1 | # Hall of Fame 2 | 3 | `Segmentation Models` package is widely used in the image segmentation competitions. 4 | Here you can find competitions, names of the winners and links to their solutions. 5 | 6 | Please, follow these rules, when adding a solution to the "Hall of Fame": 7 | 8 | 1. Solution should be high rated (e.g. for Kaggle gold or silver medal) 9 | 2. There should be a description of the solution (post at the forum / code / blog post / paper / pre-print) 10 | 11 | 12 | ## Kaggle 13 | 14 | ### [Severstal: Steel Defect Detection](https://www.kaggle.com/c/severstal-steel-defect-detection) 15 | 16 | - 1st place. 17 | [Wuxi Jiangsu](https://www.kaggle.com/rguo97), 18 | [Hongbo Zhu](https://www.kaggle.com/zhuhongbo), 19 | [Yizhuo Yu](https://www.kaggle.com/paffpaffyu) 20 | [[description](https://www.kaggle.com/c/severstal-steel-defect-detection/discussion/114254#latest-675874)] 21 | 22 | - 5th place. 23 | [Guanshuo Xu](https://www.kaggle.com/wowfattie) 24 | [[description](https://www.kaggle.com/c/severstal-steel-defect-detection/discussion/117208#latest-675385)] 25 | 26 | - 9th place. 27 | [Jacek Poplawski](https://www.linkedin.com/in/jacekpoplawski/) 28 | [[description](https://www.kaggle.com/c/severstal-steel-defect-detection/discussion/114297#latest-660842)] 29 | 30 | - 10th place. 31 | [Alexey Rozhkov](https://www.linkedin.com/in/alexisrozhkov) 32 | [[description](https://www.kaggle.com/c/severstal-steel-defect-detection/discussion/114465#latest-659615)] 33 | 34 | - 12th place. 35 | [Pavel Iakubovskii](https://www.linkedin.com/in/pavel-iakubovskii/), 36 | [Ilya Dobrynin](https://www.linkedin.com/in/ilya-dobrynin-79a89b106/), 37 | [Denis Kolpakov](https://www.linkedin.com/in/denis-kolpakov-ab3137197/) 38 | [[description](https://www.kaggle.com/c/severstal-steel-defect-detection/discussion/114309#latest-661404)] 39 | 40 | - 31st place. 41 | [Insaf Ashrapov](https://www.linkedin.com/in/iashrapov/), 42 | [Igor Krashenyi](https://www.linkedin.com/in/igor-krashenyi-38b89b98), 43 | [Pavel Pleskov](https://www.linkedin.com/in/ppleskov), 44 | [Anton Zakharenkov](https://www.linkedin.com/in/anton-zakharenkov/), 45 | [Nikolai Popov](https://www.linkedin.com/in/nikolai-popov-b2157370/) 46 | [[description](https://www.kaggle.com/c/severstal-steel-defect-detection/discussion/114383#latest-658438)] 47 | [[code](https://github.com/Diyago/Severstal-Steel-Defect-Detection)] 48 | 49 | - 55th place. 50 | [Karl Hornlund](https://www.linkedin.com/in/karl-hornlund/) 51 | [[description](https://www.kaggle.com/c/severstal-steel-defect-detection/discussion/114410#latest-672682)] 52 | [[code](https://github.com/khornlund/severstal-steel-defect-detection)] 53 | 54 | - Efficiency round 1st place. 55 | [Stefan Stefanov](https://www.linkedin.com/in/stefan-stefanov-63a77b1) 56 | [[description](https://www.kaggle.com/c/severstal-steel-defect-detection/discussion/117486#latest-674229)] 57 | 58 | 59 | ### [Understanding Clouds from Satellite Images](https://www.kaggle.com/c/understanding_cloud_organization) 60 | 61 | - 2nd place. 62 | [Andrey Kiryasov](https://www.kaggle.com/ekydna) 63 | [[description](https://www.kaggle.com/c/understanding_cloud_organization/discussion/118255#latest-678189)] 64 | 65 | - 4th place. 66 | [Ching-Loong Seow](https://www.linkedin.com/in/clseow/) 67 | [[description](https://www.kaggle.com/c/understanding_cloud_organization/discussion/118016#latest-677333)] 68 | 69 | - 34th place. 70 | [Karl Hornlund](https://www.linkedin.com/in/karl-hornlund/) 71 | [[description](https://www.kaggle.com/c/understanding_cloud_organization/discussion/118250#latest-678176)] 72 | [[code](https://github.com/khornlund/understanding-cloud-organization)] 73 | 74 | - 55th place. 75 | [Pavel Iakubovskii](https://www.linkedin.com/in/pavel-iakubovskii/) 76 | [[description](https://www.kaggle.com/c/understanding_cloud_organization/discussion/118019#latest-678626)] 77 | 78 | ## Other platforms 79 | 80 | ### [MICCAI 2020 TN-SCUI challenge](https://tn-scui2020.grand-challenge.org/Home/) 81 | - 1st place. 82 | [Mingyu Wang](https://github.com/WAMAWAMA) 83 | [[description](https://github.com/WAMAWAMA/TNSCUI2020-Seg-Rank1st)] 84 | [[code](https://github.com/WAMAWAMA/TNSCUI2020-Seg-Rank1st)] 85 | 86 | ### [Open Cities AI Challenge: Segmenting Buildings for Disaster Resilience](https://www.drivendata.org/competitions/60/building-segmentation-disaster-resilience/) 87 | - 1st place. 88 | [Pavel Iakubovskii](https://www.linkedin.com/in/pavel-iakubovskii/). 89 | [[code and description](https://github.com/qubvel/open-cities-challenge)] 90 | 91 | ### [Machine Learning based feature extraction of Electrical Substations from Satellite Data ](https://competitions.codalab.org/competitions/32132#learn_the_details) 92 | 93 | - 3rd place. 94 | 95 | [Aarsh chaube](https://github.com/Aarsh2001) 96 | [[code](https://github.com/Aarsh2001/ML_Challenge_NRSC)] 97 | [[Pre-Print](https://github.com/Aarsh2001/ML_Challenge_NRSC/blob/main/3rd%20Rank%20Submission.pdf)] 98 | 99 | ### [NeurIPS2022 Cell Segmentation Challenge](https://neurips22-cellseg.grand-challenge.org/) 100 | 101 | - 1st place. [Gihun Lee](https://github.com/Lee-Gihun), [Sangmook Kim](https://github.com/ElvinKim), [Joonkee Kim](https://github.com/joonkeekim) 102 | - [[code](https://github.com/Lee-Gihun/MEDIAR)] 103 | [[Paper](https://arxiv.org/abs/2212.03465)] 104 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License 2 | 3 | Copyright (c) 2019, Pavel Iakubovskii 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in 13 | all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 | THE SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: test # Declare the 'test' target as phony to avoid conflicts with files named 'test' 2 | 3 | # Variables to store the paths of the python, pip, pytest, and ruff executables 4 | PYTHON := $(shell which python) 5 | PIP := $(shell which pip) 6 | PYTEST := $(shell which pytest) 7 | RUFF := $(shell which ruff) 8 | 9 | # Target to create a Python virtual environment 10 | .venv: 11 | $(PYTHON) -m venv $(shell dirname $(PYTHON)) 12 | 13 | # Target to install development dependencies in the virtual environment 14 | install_dev: .venv 15 | $(PIP) install -e ".[test]" 16 | 17 | # Target to run tests with pytest, using 2 parallel processes and only non-marked tests 18 | test: .venv 19 | $(PYTEST) -v -rsx -n 2 tests/ --non-marked-only 20 | 21 | # Target to run all tests with pytest, including slow tests, using 2 parallel processes 22 | test_all: .venv 23 | RUN_SLOW=1 $(PYTEST) -v -rsx -n 2 tests/ 24 | 25 | # Target to generate a table by running a Python script 26 | table: 27 | $(PYTHON) misc/generate_table.py 28 | 29 | # Target to generate a table for timm by running a Python script 30 | table_timm: 31 | $(PYTHON) misc/generate_table_timm.py 32 | 33 | # Target to fix and format code using ruff 34 | fixup: 35 | $(RUFF) check --fix 36 | $(RUFF) format 37 | 38 | # Target to run code formatting and tests 39 | all: fixup test 40 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | # 13 | # import os 14 | # import sys 15 | # sys.path.insert(0, os.path.abspath('.')) 16 | 17 | import sys 18 | import datetime 19 | # import sphinx_rtd_theme 20 | 21 | sys.path.append("..") 22 | 23 | # -- Project information ----------------------------------------------------- 24 | 25 | project = "Segmentation Models" 26 | copyright = "{}, Pavel Iakubovskii".format(datetime.datetime.now().year) 27 | author = "Pavel Iakubovskii" 28 | 29 | 30 | def get_version(): 31 | sys.path.append("../segmentation_models_pytorch") 32 | from __version__ import __version__ as version 33 | 34 | sys.path.pop(-1) 35 | return version 36 | 37 | 38 | version = get_version() 39 | 40 | # -- General configuration --------------------------------------------------- 41 | 42 | # Add any Sphinx extension module names here, as strings. They can be 43 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 44 | # ones. 45 | 46 | extensions = [ 47 | "sphinx.ext.autodoc", 48 | "sphinx.ext.coverage", 49 | "sphinx.ext.napoleon", 50 | "sphinx.ext.viewcode", 51 | "sphinx.ext.mathjax", 52 | "autodocsumm", 53 | ] 54 | 55 | # Add any paths that contain templates here, relative to this directory. 56 | templates_path = ["_templates"] 57 | 58 | # List of patterns, relative to source directory, that match files and 59 | # directories to ignore when looking for source files. 60 | # This pattern also affects html_static_path and html_extra_path. 61 | exclude_patterns = [] 62 | 63 | 64 | # -- Options for HTML output ------------------------------------------------- 65 | 66 | # The theme to use for HTML and HTML Help pages. See the documentation for 67 | # a list of builtin themes. 68 | # 69 | 70 | # html_theme = "sphinx_rtd_theme" 71 | # html_theme_path = [sphinx_rtd_theme.get_html_theme_path()] 72 | 73 | # import karma_sphinx_theme 74 | # html_theme = "karma_sphinx_theme" 75 | 76 | html_theme = "sphinx_book_theme" 77 | 78 | # import catalyst_sphinx_theme 79 | # html_theme = "catalyst_sphinx_theme" 80 | # html_theme_path = [catalyst_sphinx_theme.get_html_theme_path()] 81 | 82 | html_logo = "logo.png" 83 | 84 | # Add any paths that contain custom static files (such as style sheets) here, 85 | # relative to this directory. They are copied after the builtin static files, 86 | # so a file named "default.css" will overwrite the builtin "default.css". 87 | html_static_path = ["_static"] 88 | 89 | # -- Extension configuration ------------------------------------------------- 90 | 91 | autodoc_inherit_docstrings = False 92 | napoleon_google_docstring = True 93 | napoleon_include_init_with_doc = True 94 | napoleon_numpy_docstring = False 95 | 96 | autodoc_mock_imports = [ 97 | "torch", 98 | "tqdm", 99 | "numpy", 100 | "timm", 101 | "cv2", 102 | "PIL", 103 | "torchvision", 104 | "segmentation_models_pytorch.encoders", 105 | "segmentation_models_pytorch.utils", 106 | # 'segmentation_models_pytorch.base', 107 | ] 108 | 109 | autoclass_content = "both" 110 | autodoc_typehints = "description" 111 | 112 | # --- Work around to make autoclass signatures not (*args, **kwargs) ---------- 113 | 114 | 115 | class FakeSignature: 116 | def __getattribute__(self, *args): 117 | raise ValueError 118 | 119 | 120 | def f(app, obj, bound_method): 121 | if "__new__" in obj.__name__: 122 | obj.__signature__ = FakeSignature() 123 | 124 | 125 | def setup(app): 126 | app.connect("autodoc-before-process-signature", f) 127 | 128 | 129 | # Custom configuration -------------------------------------------------------- 130 | 131 | autodoc_member_order = "bysource" 132 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. Segmentation Models documentation master file, created by 2 | sphinx-quickstart on Fri Nov 27 00:00:20 2020. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Welcome to Segmentation Models's documentation! 7 | =============================================== 8 | 9 | .. toctree:: 10 | :maxdepth: 2 11 | :caption: Contents: 12 | 13 | install 14 | quickstart 15 | models 16 | encoders 17 | encoders_timm 18 | losses 19 | metrics 20 | save_load 21 | insights 22 | 23 | 24 | Indices and tables 25 | ================== 26 | 27 | * :ref:`genindex` 28 | * :ref:`modindex` 29 | * :ref:`search` 30 | -------------------------------------------------------------------------------- /docs/insights.rst: -------------------------------------------------------------------------------- 1 | 💡 Insights 2 | =========== 3 | 4 | 1. Models architecture 5 | ~~~~~~~~~~~~~~~~~~~~~~ 6 | 7 | All segmentation models in SMP (this library short name) are made of: 8 | 9 | - encoder (feature extractor, a.k.a backbone) 10 | - decoder (features fusion block to create segmentation *mask*) 11 | - segmentation head (final head to reduce number of channels from decoder and upsample mask to preserve input-output spatial resolution identity) 12 | - classification head (optional head which build on top of deepest encoder features) 13 | 14 | 15 | 2. Creating your own encoder 16 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 17 | 18 | Encoder is a "classification model" which extract features from image and pass it to decoder. 19 | Each encoder should have following attributes and methods and be inherited from `segmentation_models_pytorch.encoders._base.EncoderMixin` 20 | 21 | .. code-block:: python 22 | 23 | class MyEncoder(torch.nn.Module, EncoderMixin): 24 | 25 | def __init__(self, **kwargs): 26 | super().__init__() 27 | 28 | # A number of channels for each encoder feature tensor, list of integers 29 | self._out_channels: List[int] = [3, 16, 64, 128, 256, 512] 30 | 31 | # A number of stages in decoder (in other words number of downsampling operations), integer 32 | # use in in forward pass to reduce number of returning features 33 | self._depth: int = 5 34 | 35 | # Default number of input channels in first Conv2d layer for encoder (usually 3) 36 | self._in_channels: int = 3 37 | 38 | # Define encoder modules below 39 | ... 40 | 41 | def forward(self, x: torch.Tensor) -> List[torch.Tensor]: 42 | """Produce list of features of different spatial resolutions, each feature is a 4D torch.tensor of 43 | shape NCHW (features should be sorted in descending order according to spatial resolution, starting 44 | with resolution same as input `x` tensor). 45 | 46 | Input: `x` with shape (1, 3, 64, 64) 47 | Output: [f0, f1, f2, f3, f4, f5] - features with corresponding shapes 48 | [(1, 3, 64, 64), (1, 64, 32, 32), (1, 128, 16, 16), (1, 256, 8, 8), 49 | (1, 512, 4, 4), (1, 1024, 2, 2)] (C - dim may differ) 50 | 51 | also should support number of features according to specified depth, e.g. if depth = 5, 52 | number of feature tensors = 6 (one with same resolution as input and 5 downsampled), 53 | depth = 3 -> number of feature tensors = 4 (one with same resolution as input and 3 downsampled). 54 | """ 55 | 56 | return [feat1, feat2, feat3, feat4, feat5, feat6] 57 | 58 | When you write your own Encoder class register its build parameters 59 | 60 | .. code-block:: python 61 | 62 | smp.encoders.encoders["my_awesome_encoder"] = { 63 | "encoder": MyEncoder, # encoder class here 64 | "pretrained_settings": { 65 | "imagenet": { 66 | "mean": [0.485, 0.456, 0.406], 67 | "std": [0.229, 0.224, 0.225], 68 | "url": "https://some-url.com/my-model-weights", 69 | "input_space": "RGB", 70 | "input_range": [0, 1], 71 | }, 72 | }, 73 | "params": { 74 | # init params for encoder if any 75 | }, 76 | }, 77 | 78 | Now you can use your encoder 79 | 80 | .. code-block:: python 81 | 82 | model = smp.Unet(encoder_name="my_awesome_encoder") 83 | 84 | For better understanding see more examples of encoder in smp.encoders module. 85 | 86 | .. note:: 87 | 88 | If it works fine, don`t forget to contribute your work and make a PR to SMP 😉 89 | 90 | 3. Aux classification output 91 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 92 | 93 | All models support ``aux_params`` parameter, which is default set to ``None``. 94 | If ``aux_params = None`` than classification auxiliary output is not created, else 95 | model produce not only ``mask``, but also ``label`` output with shape ``(N, C)``. 96 | 97 | Classification head consist of following layers: 98 | 99 | 1. GlobalPooling 100 | 2. Dropout (optional) 101 | 3. Linear 102 | 4. Activation (optional) 103 | 104 | Example: 105 | 106 | .. code-block:: python 107 | 108 | aux_params=dict( 109 | pooling='avg', # one of 'avg', 'max' 110 | dropout=0.5, # dropout ratio, default is None 111 | activation='sigmoid', # activation function, default is None 112 | classes=4, # define number of output labels 113 | ) 114 | 115 | model = smp.Unet('resnet34', classes=4, aux_params=aux_params) 116 | mask, label = model(x) 117 | 118 | mask.shape, label.shape 119 | # (N, 4, H, W), (N, 4) 120 | -------------------------------------------------------------------------------- /docs/install.rst: -------------------------------------------------------------------------------- 1 | ⚙️ Installation 2 | =============== 3 | 4 | PyPI version: 5 | 6 | .. code-block:: bash 7 | 8 | $ pip install -U segmentation-models-pytorch 9 | 10 | 11 | Latest version from source: 12 | 13 | .. code-block:: bash 14 | 15 | $ pip install -U git+https://github.com/qubvel/segmentation_models.pytorch -------------------------------------------------------------------------------- /docs/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qubvel-org/segmentation_models.pytorch/cf50cd082d35763073a296f6ee6378e24938bed8/docs/logo.png -------------------------------------------------------------------------------- /docs/losses.rst: -------------------------------------------------------------------------------- 1 | 📉 Losses 2 | ========= 3 | 4 | Collection of popular semantic segmentation losses. Adapted from 5 | an awesome repo with pytorch utils https://github.com/BloodAxe/pytorch-toolbelt 6 | 7 | Constants 8 | ~~~~~~~~~ 9 | .. automodule:: segmentation_models_pytorch.losses.constants 10 | :members: 11 | 12 | JaccardLoss 13 | ~~~~~~~~~~~ 14 | .. autoclass:: segmentation_models_pytorch.losses.JaccardLoss 15 | 16 | DiceLoss 17 | ~~~~~~~~ 18 | .. autoclass:: segmentation_models_pytorch.losses.DiceLoss 19 | 20 | TverskyLoss 21 | ~~~~~~~~ 22 | .. autoclass:: segmentation_models_pytorch.losses.TverskyLoss 23 | 24 | FocalLoss 25 | ~~~~~~~~~ 26 | .. autoclass:: segmentation_models_pytorch.losses.FocalLoss 27 | 28 | LovaszLoss 29 | ~~~~~~~~~~ 30 | .. autoclass:: segmentation_models_pytorch.losses.LovaszLoss 31 | 32 | SoftBCEWithLogitsLoss 33 | ~~~~~~~~~~~~~~~~~~~~~ 34 | .. autoclass:: segmentation_models_pytorch.losses.SoftBCEWithLogitsLoss 35 | 36 | SoftCrossEntropyLoss 37 | ~~~~~~~~~~~~~~~~~~~~ 38 | .. autoclass:: segmentation_models_pytorch.losses.SoftCrossEntropyLoss 39 | 40 | MCCLoss 41 | ~~~~~~~~~~~~~~~~~~~~ 42 | .. autoclass:: segmentation_models_pytorch.losses.MCCLoss 43 | :members: forward 44 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/metrics.rst: -------------------------------------------------------------------------------- 1 | 📏 Metrics 2 | ========== 3 | 4 | Functional metrics 5 | ~~~~~~~~~~~~~~~~~~ 6 | .. automodule:: segmentation_models_pytorch.metrics.functional 7 | :members: 8 | :autosummary: 9 | -------------------------------------------------------------------------------- /docs/models.rst: -------------------------------------------------------------------------------- 1 | 🕸️ Segmentation Models 2 | ============================== 3 | 4 | 5 | .. contents:: 6 | :local: 7 | 8 | .. _unet: 9 | 10 | Unet 11 | ~~~~ 12 | .. autoclass:: segmentation_models_pytorch.Unet 13 | 14 | 15 | .. _unetplusplus: 16 | 17 | Unet++ 18 | ~~~~~~ 19 | .. autoclass:: segmentation_models_pytorch.UnetPlusPlus 20 | 21 | 22 | .. _fpn: 23 | 24 | FPN 25 | ~~~ 26 | .. autoclass:: segmentation_models_pytorch.FPN 27 | 28 | 29 | .. _pspnet: 30 | 31 | PSPNet 32 | ~~~~~~ 33 | .. autoclass:: segmentation_models_pytorch.PSPNet 34 | 35 | 36 | .. _deeplabv3: 37 | 38 | DeepLabV3 39 | ~~~~~~~~~ 40 | .. autoclass:: segmentation_models_pytorch.DeepLabV3 41 | 42 | 43 | .. _deeplabv3plus: 44 | 45 | DeepLabV3+ 46 | ~~~~~~~~~~ 47 | .. autoclass:: segmentation_models_pytorch.DeepLabV3Plus 48 | 49 | 50 | .. _linknet: 51 | 52 | Linknet 53 | ~~~~~~~ 54 | .. autoclass:: segmentation_models_pytorch.Linknet 55 | 56 | 57 | .. _manet: 58 | 59 | MAnet 60 | ~~~~~~ 61 | .. autoclass:: segmentation_models_pytorch.MAnet 62 | 63 | 64 | .. _pan: 65 | 66 | PAN 67 | ~~~ 68 | .. autoclass:: segmentation_models_pytorch.PAN 69 | 70 | 71 | .. _upernet: 72 | 73 | UPerNet 74 | ~~~~~~~ 75 | .. autoclass:: segmentation_models_pytorch.UPerNet 76 | 77 | 78 | .. _segformer: 79 | 80 | Segformer 81 | ~~~~~~~~~ 82 | .. autoclass:: segmentation_models_pytorch.Segformer 83 | 84 | 85 | .. _dpt: 86 | 87 | DPT 88 | ~~~ 89 | 90 | .. note:: 91 | 92 | See full list of DPT-compatible timm encoders in :ref:`dpt-encoders`. 93 | 94 | .. note:: 95 | 96 | For some encoders, the model requires ``dynamic_img_size=True`` to be passed in order to work with resolutions different from what the encoder was trained for. 97 | 98 | .. autoclass:: segmentation_models_pytorch.DPT 99 | -------------------------------------------------------------------------------- /docs/quickstart.rst: -------------------------------------------------------------------------------- 1 | 🚀 Quick Start 2 | ============== 3 | 4 | **1. Create segmentation model** 5 | 6 | Segmentation model is just a PyTorch nn.Module, which can be created as easy as: 7 | 8 | .. code-block:: python 9 | 10 | import segmentation_models_pytorch as smp 11 | 12 | model = smp.Unet( 13 | encoder_name="resnet34", # choose encoder, e.g. mobilenet_v2 or efficientnet-b7 14 | encoder_weights="imagenet", # use `imagenet` pre-trained weights for encoder initialization 15 | in_channels=1, # model input channels (1 for gray-scale images, 3 for RGB, etc.) 16 | classes=3, # model output channels (number of classes in your dataset) 17 | ) 18 | 19 | - Check the page with available :doc:`model architectures `. 20 | - Check the table with :doc:`available ported encoders and its corresponding weights `. 21 | - `Pytorch Image Models (timm) `_ encoders are also supported, check it :doc:`here`. 22 | 23 | Alternatively, you can use `smp.create_model` function to create a model by name: 24 | 25 | .. code-block:: python 26 | 27 | model = smp.create_model( 28 | arch="fpn", # name of the architecture, e.g. 'Unet'/ 'FPN' / etc. Case INsensitive! 29 | encoder_name="mit_b0", 30 | encoder_weights="imagenet", 31 | in_channels=1, 32 | classes=3, 33 | ) 34 | 35 | 36 | **2. Configure data preprocessing** 37 | 38 | All encoders have pretrained weights. Preparing your data the same way as during weights pre-training may give your better results (higher metric score and faster convergence). But it is relevant only for 1-2-3-channels images and **not necessary** in case you train the whole model, not only decoder. 39 | 40 | .. code-block:: python 41 | 42 | from segmentation_models_pytorch.encoders import get_preprocessing_fn 43 | 44 | preprocess_input = get_preprocessing_fn('resnet18', pretrained='imagenet') 45 | 46 | 47 | **3. Congratulations!** 🎉 48 | 49 | 50 | You are done! Now you can train your model with your favorite framework, or as simple as: 51 | 52 | .. code-block:: python 53 | 54 | for images, gt_masks in dataloader: 55 | 56 | predicted_mask = model(images) 57 | loss = loss_fn(predicted_mask, gt_masks) 58 | 59 | loss.backward() 60 | optimizer.step() 61 | 62 | Check the following examples: 63 | 64 | .. |colab-badge| image:: https://colab.research.google.com/assets/colab-badge.svg 65 | :target: https://colab.research.google.com/github/qubvel/segmentation_models.pytorch/blob/main/examples/binary_segmentation_intro.ipynb 66 | :alt: Open In Colab 67 | 68 | - Finetuning notebook on Oxford Pet dataset with `PyTorch Lightning `_ |colab-badge| 69 | - Finetuning script for cloth segmentation with `PyTorch Lightning `_ 70 | -------------------------------------------------------------------------------- /docs/save_load.rst: -------------------------------------------------------------------------------- 1 | 📂 Saving and Loading 2 | ===================== 3 | 4 | In this section, we will discuss how to save a trained model, push it to the Hugging Face Hub, and load it back for later use. 5 | 6 | Saving and Sharing a Model 7 | -------------------------- 8 | 9 | Once you have trained your model, you can save it using the `.save_pretrained` method. This method saves the model configuration and weights to a directory of your choice. 10 | And, optionally, you can push the model to the Hugging Face Hub by setting the `push_to_hub` parameter to `True`. 11 | 12 | For example: 13 | 14 | .. code:: python 15 | 16 | import segmentation_models_pytorch as smp 17 | 18 | model = smp.Unet('resnet34', encoder_weights='imagenet') 19 | 20 | # After training your model, save it to a directory 21 | model.save_pretrained('./my_model') 22 | 23 | # Or saved and pushed to the Hub simultaneously 24 | model.save_pretrained('username/my-model', push_to_hub=True) 25 | 26 | Loading Trained Model 27 | --------------------- 28 | 29 | Once your model is saved and pushed to the Hub, you can load it back using the `smp.from_pretrained` method. This method allows you to load the model weights and configuration from a directory or directly from the Hub. 30 | 31 | For example: 32 | 33 | .. code:: python 34 | 35 | import segmentation_models_pytorch as smp 36 | 37 | # Load the model from the local directory 38 | model = smp.from_pretrained('./my_model') 39 | 40 | # Alternatively, load the model directly from the Hugging Face Hub 41 | model = smp.from_pretrained('username/my-model') 42 | 43 | Loading pre-trained model with different number of classes for fine-tuning: 44 | 45 | .. code:: python 46 | 47 | import segmentation_models_pytorch as smp 48 | 49 | model = smp.from_pretrained('', classes=5, strict=False) 50 | 51 | Saving model Metrics and Dataset Name 52 | ------------------------------------- 53 | 54 | You can simply pass the `metrics` and `dataset` parameters to the `save_pretrained` method to save the model metrics and dataset name in Model Card along with the model configuration and weights. 55 | 56 | For example: 57 | 58 | .. code:: python 59 | 60 | import segmentation_models_pytorch as smp 61 | 62 | model = smp.Unet('resnet34', encoder_weights='imagenet') 63 | 64 | # After training your model, save it to a directory 65 | model.save_pretrained('./my_model', metrics={'accuracy': 0.95}, dataset='my_dataset') 66 | 67 | # Or saved and pushed to the Hub simultaneously 68 | model.save_pretrained('username/my-model', push_to_hub=True, metrics={'accuracy': 0.95}, dataset='my_dataset') 69 | 70 | Saving with preprocessing transform (Albumentations) 71 | ---------------------------------------------------- 72 | 73 | You can save the preprocessing transform along with the model and push it to the Hub. 74 | This can be useful when you want to share the model with the preprocessing transform that was used during training, 75 | to make sure that the inference pipeline is consistent with the training pipeline. 76 | 77 | .. code:: python 78 | 79 | import albumentations as A 80 | import segmentation_models_pytorch as smp 81 | 82 | # Define a preprocessing transform for image that would be used during inference 83 | preprocessing_transform = A.Compose([ 84 | A.Resize(256, 256), 85 | A.Normalize() 86 | ]) 87 | 88 | model = smp.Unet() 89 | 90 | directory_or_repo_on_the_hub = "qubvel-hf/unet-with-transform" # / 91 | 92 | # Save the model and transform (and pus ot hub, if needed) 93 | model.save_pretrained(directory_or_repo_on_the_hub, push_to_hub=True) 94 | preprocessing_transform.save_pretrained(directory_or_repo_on_the_hub, push_to_hub=True) 95 | 96 | # Loading transform and model 97 | restored_model = smp.from_pretrained(directory_or_repo_on_the_hub) 98 | restored_transform = A.Compose.from_pretrained(directory_or_repo_on_the_hub) 99 | 100 | print(restored_transform) 101 | 102 | Conclusion 103 | ---------- 104 | 105 | By following these steps, you can easily save, share, and load your models, facilitating collaboration and reproducibility in your projects. Don't forget to replace the placeholders with your actual model paths and names. 106 | 107 | |colab-badge| 108 | 109 | .. |colab-badge| image:: https://colab.research.google.com/assets/colab-badge.svg 110 | :target: https://colab.research.google.com/github/qubvel/segmentation_models.pytorch/blob/main/examples/binary_segmentation_intro.ipynb 111 | :alt: Open In Colab 112 | 113 | .. |colab-badge| image:: https://colab.research.google.com/assets/colab-badge.svg 114 | :target: https://colab.research.google.com/github/qubvel/segmentation_models.pytorch/blob/main/examples/save_load_model_and_share_with_hf_hub.ipynb 115 | :alt: Open In Colab 116 | -------------------------------------------------------------------------------- /licenses/LICENSES.md: -------------------------------------------------------------------------------- 1 | LICENSES for specific files 2 | =========================== 3 | 4 | The majority of the code is licensed under the [MIT License](LICENSE). However, some files are licensed under different terms. Please check each file for file-specific license. 5 | 6 | 7 | **Component-Specific Licenses** 8 | 9 | - NVIDIA License 10 | 11 | * Applies to the Mix Vision Transformer (SegFormer) encoder 12 | * This is for non-commercial use only 13 | * [segmentation_models_pytorch/encoders/mix_transformer.py](https://github.com/qubvel/segmentation_models.pytorch/blob/main/segmentation_models_pytorch/encoders/mix_transformer.py) 14 | * [LICENSE_nvidia](LICENSE_nvidia.md) 15 | 16 | - Apple License 17 | * Applies to the MobileOne encoder 18 | * [segmentation_models_pytorch/encoders/mobileone.py](https://github.com/qubvel/segmentation_models.pytorch/blob/main/segmentation_models_pytorch/encoders/mobileone.py) 19 | * [LICENSE_apple](LICENSE_apple.md) 20 | 21 | - BSD 3-Clause License 22 | * Applies to several encoders and the DeepLabV3 decoder 23 | * [segmentation_models_pytorch/encoders/_dpn.py](https://github.com/qubvel/segmentation_models.pytorch/blob/main/segmentation_models_pytorch/encoders/_dpn.py) 24 | * [segmentation_models_pytorch/encoders/_inceptionresnetv2.py](https://github.com/qubvel/segmentation_models.pytorch/blob/main/segmentation_models_pytorch/encoders/_inceptionresnetv2.py) 25 | * [segmentation_models_pytorch/encoders/_inceptionv4.py](https://github.com/qubvel/segmentation_models.pytorch/blob/main/segmentation_models_pytorch/encoders/_inceptionv4.py) 26 | * [segmentation_models_pytorch/encoders/_senet.py](https://github.com/qubvel/segmentation_models.pytorch/blob/main/segmentation_models_pytorch/encoders/_senet.py) 27 | * [segmentation_models_pytorch/encoders/_xception.py](https://github.com/qubvel/segmentation_models.pytorch/blob/main/segmentation_models_pytorch/encoders/_xception.py) 28 | * [segmentation_models_pytorch/decoders/deeplabv3/decoder.py](https://github.com/qubvel/segmentation_models.pytorch/blob/main/segmentation_models_pytorch/decoders/deeplabv3/decoder.py) 29 | 30 | - Apache-2.0 License 31 | * Applies to the EfficientNet encoder 32 | * [segmentation_models_pytorch/encoders/_efficientnet.py](https://github.com/qubvel/segmentation_models.pytorch/blob/main/segmentation_models_pytorch/encoders/_efficientnet.py) 33 | -------------------------------------------------------------------------------- /licenses/LICENSE_apple.md: -------------------------------------------------------------------------------- 1 | Copyright (C) 2022 Apple Inc. All Rights Reserved. 2 | 3 | IMPORTANT: This Apple software is supplied to you by Apple 4 | Inc. ("Apple") in consideration of your agreement to the following 5 | terms, and your use, installation, modification or redistribution of 6 | this Apple software constitutes acceptance of these terms. If you do 7 | not agree with these terms, please do not use, install, modify or 8 | redistribute this Apple software. 9 | 10 | In consideration of your agreement to abide by the following terms, and 11 | subject to these terms, Apple grants you a personal, non-exclusive 12 | license, under Apple's copyrights in this original Apple software (the 13 | "Apple Software"), to use, reproduce, modify and redistribute the Apple 14 | Software, with or without modifications, in source and/or binary forms; 15 | provided that if you redistribute the Apple Software in its entirety and 16 | without modifications, you must retain this notice and the following 17 | text and disclaimers in all such redistributions of the Apple Software. 18 | Neither the name, trademarks, service marks or logos of Apple Inc. may 19 | be used to endorse or promote products derived from the Apple Software 20 | without specific prior written permission from Apple. Except as 21 | expressly stated in this notice, no other rights or licenses, express or 22 | implied, are granted by Apple herein, including but not limited to any 23 | patent rights that may be infringed by your derivative works or by other 24 | works in which the Apple Software may be incorporated. 25 | 26 | The Apple Software is provided by Apple on an "AS IS" basis. APPLE 27 | MAKES NO WARRANTIES, EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION 28 | THE IMPLIED WARRANTIES OF NON-INFRINGEMENT, MERCHANTABILITY AND FITNESS 29 | FOR A PARTICULAR PURPOSE, REGARDING THE APPLE SOFTWARE OR ITS USE AND 30 | OPERATION ALONE OR IN COMBINATION WITH YOUR PRODUCTS. 31 | 32 | IN NO EVENT SHALL APPLE BE LIABLE FOR ANY SPECIAL, INDIRECT, INCIDENTAL 33 | OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 34 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 35 | INTERRUPTION) ARISING IN ANY WAY OUT OF THE USE, REPRODUCTION, 36 | MODIFICATION AND/OR DISTRIBUTION OF THE APPLE SOFTWARE, HOWEVER CAUSED 37 | AND WHETHER UNDER THEORY OF CONTRACT, TORT (INCLUDING NEGLIGENCE), 38 | STRICT LIABILITY OR OTHERWISE, EVEN IF APPLE HAS BEEN ADVISED OF THE 39 | POSSIBILITY OF SUCH DAMAGE. 40 | 41 | ------------------------------------------------------------------------------- 42 | SOFTWARE DISTRIBUTED WITH ML-MobileOne: 43 | 44 | The ML-MobileOne software includes a number of subcomponents with separate 45 | copyright notices and license terms - please see the file ACKNOWLEDGEMENTS. 46 | ------------------------------------------------------------------------------- 47 | -------------------------------------------------------------------------------- /licenses/LICENSE_nvidia.md: -------------------------------------------------------------------------------- 1 | NVIDIA Source Code License for SegFormer 2 | 3 | 1. Definitions 4 | 5 | “Licensor” means any person or entity that distributes its Work. 6 | 7 | “Software” means the original work of authorship made available under this License. 8 | 9 | “Work” means the Software and any additions to or derivative works of the Software that are made available under 10 | this License. 11 | 12 | The terms “reproduce,” “reproduction,” “derivative works,” and “distribution” have the meaning as provided under 13 | U.S. copyright law; provided, however, that for the purposes of this License, derivative works shall not include 14 | works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work. 15 | 16 | Works, including the Software, are “made available” under this License by including in or with the Work either 17 | (a) a copyright notice referencing the applicability of this License to the Work, or (b) a copy of this License. 18 | 19 | 2. License Grant 20 | 21 | 2.1 Copyright Grant. Subject to the terms and conditions of this License, each Licensor grants to you a perpetual, 22 | worldwide, non-exclusive, royalty-free, copyright license to reproduce, prepare derivative works of, publicly 23 | display, publicly perform, sublicense and distribute its Work and any resulting derivative works in any form. 24 | 25 | 3. Limitations 26 | 27 | 3.1 Redistribution. You may reproduce or distribute the Work only if (a) you do so under this License, (b) you 28 | include a complete copy of this License with your distribution, and (c) you retain without modification any 29 | copyright, patent, trademark, or attribution notices that are present in the Work. 30 | 31 | 3.2 Derivative Works. You may specify that additional or different terms apply to the use, reproduction, and 32 | distribution of your derivative works of the Work (“Your Terms”) only if (a) Your Terms provide that the use 33 | limitation in Section 3.3 applies to your derivative works, and (b) you identify the specific derivative works 34 | that are subject to Your Terms. Notwithstanding Your Terms, this License (including the redistribution 35 | requirements in Section 3.1) will continue to apply to the Work itself. 36 | 37 | 3.3 Use Limitation. The Work and any derivative works thereof only may be used or intended for use 38 | non-commercially. Notwithstanding the foregoing, NVIDIA and its affiliates may use the Work and any derivative 39 | works commercially. As used herein, “non-commercially” means for research or evaluation purposes only. 40 | 41 | 3.4 Patent Claims. If you bring or threaten to bring a patent claim against any Licensor (including any claim, 42 | cross-claim or counterclaim in a lawsuit) to enforce any patents that you allege are infringed by any Work, then 43 | your rights under this License from such Licensor (including the grant in Section 2.1) will terminate immediately. 44 | 45 | 3.5 Trademarks. This License does not grant any rights to use any Licensor’s or its affiliates’ names, logos, 46 | or trademarks, except as necessary to reproduce the notices described in this License. 47 | 48 | 3.6 Termination. If you violate any term of this License, then your rights under this License (including the 49 | grant in Section 2.1) will terminate immediately. 50 | 51 | 4. Disclaimer of Warranty. 52 | 53 | THE WORK IS PROVIDED “AS IS” WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING 54 | WARRANTIES OR CONDITIONS OF M ERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR NON-INFRINGEMENT. YOU 55 | BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER THIS LICENSE. 56 | 57 | 5. Limitation of Liability. 58 | 59 | EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER IN TORT (INCLUDING 60 | NEGLIGENCE), CONTRACT, OR OTHERWISE SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, 61 | INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF OR RELATED TO THIS LICENSE, THE USE OR 62 | INABILITY TO USE THE WORK (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, LOST PROFITS OR 63 | DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER COMM ERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN 64 | ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. 65 | -------------------------------------------------------------------------------- /misc/generate_table.py: -------------------------------------------------------------------------------- 1 | import os 2 | import segmentation_models_pytorch as smp 3 | 4 | from tqdm import tqdm 5 | 6 | encoders = smp.encoders.encoders 7 | 8 | 9 | WIDTH = 32 10 | COLUMNS = ["Encoder", "Pretrained weights", "Params, M", "Script", "Compile", "Export"] 11 | FILE = "encoders_table.md" 12 | 13 | if os.path.exists(FILE): 14 | os.remove(FILE) 15 | 16 | 17 | def wrap_row(r): 18 | return "|{}|".format(r) 19 | 20 | 21 | header = "|".join([column.ljust(WIDTH, " ") for column in COLUMNS]) 22 | separator = "|".join( 23 | ["-" * WIDTH] + [":" + "-" * (WIDTH - 2) + ":"] * (len(COLUMNS) - 1) 24 | ) 25 | 26 | print(wrap_row(header), file=open(FILE, "a")) 27 | print(wrap_row(separator), file=open(FILE, "a")) 28 | 29 | for encoder_name, encoder in tqdm(encoders.items()): 30 | weights = "
".join(encoder["pretrained_settings"].keys()) 31 | 32 | model = encoder["encoder"](**encoder["params"], depth=5) 33 | 34 | script = "✅" if model._is_torch_scriptable else "❌" 35 | compile = "✅" if model._is_torch_compilable else "❌" 36 | export = "✅" if model._is_torch_exportable else "❌" 37 | 38 | params = sum(p.numel() for p in model.parameters()) 39 | params = str(params // 1000000) + "M" 40 | 41 | row = [encoder_name, weights, params, script, compile, export] 42 | row = [str(r).ljust(WIDTH, " ") for r in row] 43 | row = "|".join(row) 44 | 45 | print(wrap_row(row), file=open(FILE, "a")) 46 | -------------------------------------------------------------------------------- /misc/generate_table_timm.py: -------------------------------------------------------------------------------- 1 | import timm 2 | from tqdm import tqdm 3 | 4 | 5 | def check_features_and_reduction(name): 6 | encoder = timm.create_model(name, features_only=True, pretrained=False) 7 | if not encoder.feature_info.reduction() == [2, 4, 8, 16, 32]: 8 | raise ValueError 9 | 10 | 11 | def has_dilation_support(name): 12 | try: 13 | timm.create_model(name, features_only=True, output_stride=8, pretrained=False) 14 | timm.create_model(name, features_only=True, output_stride=16, pretrained=False) 15 | return True 16 | except Exception: 17 | return False 18 | 19 | 20 | def valid_vit_encoder_for_dpt(name): 21 | if "vit" not in name: 22 | return False 23 | encoder = timm.create_model(name) 24 | feature_info = encoder.feature_info 25 | feature_info_obj = timm.models.FeatureInfo( 26 | feature_info=feature_info, out_indices=[0, 1, 2, 3] 27 | ) 28 | reduction_scales = list(feature_info_obj.reduction()) 29 | 30 | if len(set(reduction_scales)) > 1: 31 | return False 32 | 33 | output_stride = reduction_scales[0] 34 | if bin(output_stride).count("1") != 1: 35 | return False 36 | 37 | return True 38 | 39 | 40 | def make_table(data): 41 | names = data.keys() 42 | max_len1 = max([len(x) for x in names]) + 2 43 | max_len2 = len("support dilation") + 2 44 | max_len3 = len("Supported for DPT") + 2 45 | 46 | l1 = "+" + "-" * max_len1 + "+" + "-" * max_len2 + "+" + "-" * max_len3 + "+\n" 47 | l2 = "+" + "=" * max_len1 + "+" + "=" * max_len2 + "+" + "-" * max_len3 + "+\n" 48 | top = ( 49 | "| " 50 | + "Encoder name".ljust(max_len1 - 2) 51 | + " | " 52 | + "Support dilation".center(max_len2 - 2) 53 | + " | " 54 | + "Supported for DPT".center(max_len3 - 2) 55 | + " |\n" 56 | ) 57 | 58 | table = l1 + top + l2 59 | 60 | for k in sorted(data.keys()): 61 | if "has_dilation" in data[k] and data[k]["has_dilation"]: 62 | support = "✅".center(max_len2 - 3) 63 | 64 | else: 65 | support = " ".center(max_len2 - 2) 66 | 67 | if "supported_only_for_dpt" in data[k]: 68 | supported_for_dpt = "✅".center(max_len3 - 3) 69 | 70 | else: 71 | supported_for_dpt = " ".center(max_len3 - 2) 72 | 73 | table += ( 74 | "| " 75 | + k.ljust(max_len1 - 2) 76 | + " | " 77 | + support 78 | + " | " 79 | + supported_for_dpt 80 | + " |\n" 81 | ) 82 | table += l1 83 | 84 | return table 85 | 86 | 87 | if __name__ == "__main__": 88 | supported_models = {} 89 | 90 | with tqdm(timm.list_models()) as names: 91 | for name in names: 92 | try: 93 | check_features_and_reduction(name) 94 | has_dilation = has_dilation_support(name) 95 | supported_models[name] = dict(has_dilation=has_dilation) 96 | 97 | except Exception: 98 | try: 99 | if valid_vit_encoder_for_dpt(name): 100 | supported_models[name] = dict(supported_only_for_dpt=True) 101 | except Exception: 102 | continue 103 | 104 | table = make_table(supported_models) 105 | print(table) 106 | with open("timm_encoders.txt", "w") as f: 107 | print(table, file=f) 108 | print(f"Total encoders: {len(supported_models.keys())}") 109 | -------------------------------------------------------------------------------- /misc/generate_test_models.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import tempfile 4 | import huggingface_hub 5 | import segmentation_models_pytorch as smp 6 | 7 | HUB_REPO = "smp-test-models" 8 | ENCODER_NAME = "tu-resnet18" 9 | 10 | api = huggingface_hub.HfApi(token=os.getenv("HF_TOKEN")) 11 | 12 | 13 | def save_and_push(model, inputs, outputs, model_name, encoder_name): 14 | with tempfile.TemporaryDirectory() as tmpdir: 15 | # save model 16 | model.save_pretrained(f"{tmpdir}") 17 | 18 | # save input and output 19 | torch.save(inputs, f"{tmpdir}/input-tensor.pth") 20 | torch.save(outputs, f"{tmpdir}/output-tensor.pth") 21 | 22 | # create repo 23 | repo_id = f"{HUB_REPO}/{model_name}-{encoder_name}" 24 | if not api.repo_exists(repo_id=repo_id): 25 | api.create_repo(repo_id=repo_id, repo_type="model") 26 | 27 | # upload to hub 28 | api.upload_folder( 29 | folder_path=tmpdir, 30 | repo_id=f"{HUB_REPO}/{model_name}-{encoder_name}", 31 | repo_type="model", 32 | ) 33 | 34 | 35 | for model_name, model_class in smp.MODEL_ARCHITECTURES_MAPPING.items(): 36 | if model_name == "dpt": 37 | encoder_name = "tu-test_vit" 38 | model = smp.DPT( 39 | encoder_name=encoder_name, 40 | decoder_readout="cat", 41 | decoder_intermediate_channels=(16, 32, 64, 64), 42 | decoder_fusion_channels=16, 43 | dynamic_img_size=True, 44 | ) 45 | else: 46 | encoder_name = ENCODER_NAME 47 | model = model_class(encoder_name=encoder_name) 48 | 49 | model = model.eval() 50 | 51 | # generate test sample 52 | torch.manual_seed(423553) 53 | sample = torch.rand(1, 3, 256, 256) 54 | 55 | with torch.no_grad(): 56 | output = model(sample) 57 | 58 | save_and_push(model, sample, output, model_name, encoder_name) 59 | -------------------------------------------------------------------------------- /pics/logo-small-h300.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qubvel-org/segmentation_models.pytorch/cf50cd082d35763073a296f6ee6378e24938bed8/pics/logo-small-h300.png -------------------------------------------------------------------------------- /pics/logo-small-w300.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qubvel-org/segmentation_models.pytorch/cf50cd082d35763073a296f6ee6378e24938bed8/pics/logo-small-w300.png -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ['setuptools>=61'] 3 | build-backend = 'setuptools.build_meta' 4 | 5 | [project] 6 | name = 'segmentation_models_pytorch' 7 | description = 'Image segmentation models with pre-trained backbones. PyTorch.' 8 | readme = 'README.md' 9 | requires-python = '>=3.9' 10 | license = {file = 'LICENSE'} 11 | authors = [{name = 'Pavel Iakubovskii', email = 'qubvel@gmail.com'}] 12 | classifiers = [ 13 | 'License :: OSI Approved :: MIT License', 14 | 'Programming Language :: Python', 15 | 'Programming Language :: Python :: 3', 16 | 'Programming Language :: Python :: Implementation :: CPython', 17 | 'Programming Language :: Python :: Implementation :: PyPy', 18 | ] 19 | dependencies = [ 20 | 'huggingface-hub>=0.24', 21 | 'numpy>=1.19.3', 22 | 'pillow>=8', 23 | 'safetensors>=0.3.1', 24 | 'timm>=0.9', 25 | 'torch>=1.8', 26 | 'torchvision>=0.9', 27 | 'tqdm>=4.42.1', 28 | ] 29 | dynamic = ['version'] 30 | 31 | [project.optional-dependencies] 32 | docs = [ 33 | 'autodocsumm', 34 | 'huggingface-hub', 35 | 'six', 36 | 'sphinx', 37 | 'sphinx-book-theme', 38 | ] 39 | test = [ 40 | 'gitpython', 41 | 'packaging', 42 | 'pytest', 43 | 'pytest-cov', 44 | 'pytest-xdist', 45 | 'ruff>=0.9', 46 | 'setuptools', 47 | ] 48 | 49 | [project.urls] 50 | Homepage = 'https://github.com/qubvel-org/segmentation_models.pytorch' 51 | 52 | [tool.ruff] 53 | extend-include = ['*.ipynb'] 54 | fix = true 55 | 56 | [tool.setuptools.dynamic] 57 | version = {attr = 'segmentation_models_pytorch.__version__.__version__'} 58 | 59 | [tool.setuptools.packages.find] 60 | include = ['segmentation_models_pytorch*'] 61 | 62 | [tool.pytest.ini_options] 63 | markers = [ 64 | "logits_match", 65 | "compile", 66 | "torch_export", 67 | "torch_script", 68 | ] 69 | 70 | [tool.coverage.run] 71 | omit = [ 72 | "segmentation_models_pytorch/utils/*", 73 | "**/convert_*", 74 | ] 75 | -------------------------------------------------------------------------------- /requirements/docs.txt: -------------------------------------------------------------------------------- 1 | autodocsumm==0.2.14 2 | huggingface-hub==0.31.1 3 | six==1.17.0 4 | sphinx==8.2.3 5 | sphinx-book-theme==1.1.4 6 | -------------------------------------------------------------------------------- /requirements/minimum.old: -------------------------------------------------------------------------------- 1 | huggingface-hub==0.24.0 2 | numpy==1.19.3 3 | pillow==8.0.0 4 | safetensors==0.3.1 5 | timm==0.9.0 6 | torch==1.9.0 7 | torchvision==0.10.0 8 | tqdm==4.42.1 9 | Jinja2==3.0.0 10 | -------------------------------------------------------------------------------- /requirements/required.txt: -------------------------------------------------------------------------------- 1 | huggingface_hub==0.31.1 2 | numpy==2.2.4 3 | pillow==11.2.1 4 | safetensors==0.5.3 5 | timm==1.0.15 6 | torch==2.7.0 7 | torchvision==0.22.0 8 | tqdm==4.67.1 9 | -------------------------------------------------------------------------------- /requirements/test.txt: -------------------------------------------------------------------------------- 1 | gitpython==3.1.44 2 | packaging==25.0 3 | pytest==8.3.5 4 | pytest-xdist==3.6.1 5 | pytest-cov==6.1.1 6 | ruff==0.11.9 7 | setuptools==80.4.0 8 | -------------------------------------------------------------------------------- /scripts/models-conversions/dpt-original-to-smp.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import torch 3 | import albumentations as A 4 | import segmentation_models_pytorch as smp 5 | 6 | MODEL_WEIGHTS_PATH = r"dpt_large-ade20k-b12dca68.pt" 7 | HF_HUB_PATH = "qubvel-hf/dpt-large-ade20k" 8 | PUSH_TO_HUB = False 9 | 10 | 11 | def get_transform(): 12 | return A.Compose( 13 | [ 14 | A.LongestMaxSize(max_size=480, interpolation=cv2.INTER_CUBIC), 15 | A.Normalize( 16 | mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), max_pixel_value=255.0 17 | ), 18 | # This is not correct transform, ideally image should resized without padding to multiple of 32, 19 | # but we take there is no such transform in albumentations, here is closest one 20 | A.PadIfNeeded( 21 | min_height=None, 22 | min_width=None, 23 | pad_height_divisor=32, 24 | pad_width_divisor=32, 25 | border_mode=cv2.BORDER_CONSTANT, 26 | value=0, 27 | p=1, 28 | ), 29 | ] 30 | ) 31 | 32 | 33 | if __name__ == "__main__": 34 | # fmt: off 35 | smp_model = smp.DPT(encoder_name="tu-vit_large_patch16_384", classes=150, dynamic_img_size=True) 36 | dpt_model_dict = torch.load(MODEL_WEIGHTS_PATH, weights_only=True) 37 | 38 | for layer_index in range(0, 4): 39 | for param in ["running_mean", "running_var", "num_batches_tracked", "weight", "bias"]: 40 | for block_index in [1, 2]: 41 | for bn_index in [1, 2]: 42 | # Assigning weights of 4th fusion layer of original model to 1st layer of SMP DPT model, 43 | # Assigning weights of 3rd fusion layer of original model to 2nd layer of SMP DPT model ... 44 | # and so on ... 45 | # This is because order of calling fusion layers is reversed in original DPT implementation 46 | dpt_model_dict[f"decoder.fusion_blocks.{layer_index}.residual_conv_block{block_index}.batch_norm_{bn_index}.{param}"] = \ 47 | dpt_model_dict.pop(f"scratch.refinenet{4 - layer_index}.resConfUnit{block_index}.bn{bn_index}.{param}") 48 | 49 | if param in ["weight", "bias"]: 50 | if param == "weight": 51 | for block_index in [1, 2]: 52 | for conv_index in [1, 2]: 53 | dpt_model_dict[f"decoder.fusion_blocks.{layer_index}.residual_conv_block{block_index}.conv_{conv_index}.{param}"] = \ 54 | dpt_model_dict.pop(f"scratch.refinenet{4 - layer_index}.resConfUnit{block_index}.conv{conv_index}.{param}") 55 | 56 | dpt_model_dict[f"decoder.reassemble_blocks.{layer_index}.project_to_feature_dim.{param}"] = \ 57 | dpt_model_dict.pop(f"scratch.layer{layer_index + 1}_rn.{param}") 58 | 59 | dpt_model_dict[f"decoder.fusion_blocks.{layer_index}.project.{param}"] = \ 60 | dpt_model_dict.pop(f"scratch.refinenet{4 - layer_index}.out_conv.{param}") 61 | 62 | dpt_model_dict[f"decoder.projection_blocks.{layer_index}.project.0.{param}"] = \ 63 | dpt_model_dict.pop(f"pretrained.act_postprocess{layer_index + 1}.0.project.0.{param}") 64 | 65 | dpt_model_dict[f"decoder.reassemble_blocks.{layer_index}.project_to_out_channel.{param}"] = \ 66 | dpt_model_dict.pop(f"pretrained.act_postprocess{layer_index + 1}.3.{param}") 67 | 68 | if layer_index != 2: 69 | dpt_model_dict[f"decoder.reassemble_blocks.{layer_index}.upsample.{param}"] = \ 70 | dpt_model_dict.pop(f"pretrained.act_postprocess{layer_index + 1}.4.{param}") 71 | 72 | # Changing state dict keys for segmentation head 73 | dpt_model_dict = { 74 | name.replace("scratch.output_conv", "segmentation_head.head"): parameter 75 | for name, parameter in dpt_model_dict.items() 76 | } 77 | 78 | # Changing state dict keys for encoder layers 79 | dpt_model_dict = { 80 | name.replace("pretrained.model", "encoder.model"): parameter 81 | for name, parameter in dpt_model_dict.items() 82 | } 83 | 84 | # Removing keys, value pairs associated with auxiliary head 85 | dpt_model_dict = { 86 | name: parameter 87 | for name, parameter in dpt_model_dict.items() 88 | if not name.startswith("auxlayer") 89 | } 90 | # fmt: on 91 | 92 | smp_model.load_state_dict(dpt_model_dict, strict=True) 93 | 94 | # ------- DO NOT touch this section ------- 95 | smp_model.eval() 96 | 97 | input_tensor = torch.ones((1, 3, 384, 384)) 98 | output = smp_model(input_tensor) 99 | 100 | print(output.shape) 101 | print(output[0, 0, :3, :3]) 102 | 103 | expected_slice = torch.tensor( 104 | [ 105 | [3.4243, 3.4553, 3.4863], 106 | [3.3332, 3.2876, 3.2419], 107 | [3.2422, 3.1199, 2.9975], 108 | ] 109 | ) 110 | 111 | torch.testing.assert_close( 112 | output[0, 0, :3, :3], expected_slice, atol=1e-4, rtol=1e-4 113 | ) 114 | 115 | # Saving 116 | transform = get_transform() 117 | 118 | transform.save_pretrained(HF_HUB_PATH) 119 | smp_model.save_pretrained(HF_HUB_PATH, push_to_hub=PUSH_TO_HUB) 120 | 121 | # Re-loading to make sure everything is saved correctly 122 | smp_model = smp.from_pretrained(HF_HUB_PATH) 123 | -------------------------------------------------------------------------------- /scripts/models-conversions/segformer-original-decoder-to-smp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | import requests 4 | import numpy as np 5 | import huggingface_hub 6 | import albumentations as A 7 | import matplotlib.pyplot as plt 8 | 9 | from PIL import Image 10 | import segmentation_models_pytorch as smp 11 | 12 | 13 | def convert_state_dict_to_smp(state_dict: dict): 14 | # fmt: off 15 | 16 | if "state_dict" in state_dict: 17 | state_dict = state_dict["state_dict"] 18 | 19 | new_state_dict = {} 20 | 21 | # Map the backbone components to the encoder 22 | keys = list(state_dict.keys()) 23 | for key in keys: 24 | if key.startswith("backbone"): 25 | new_key = key.replace("backbone", "encoder") 26 | new_state_dict[new_key] = state_dict.pop(key) 27 | 28 | 29 | # Map the linear_cX layers to MLP stages 30 | for i in range(4): 31 | base = f"decode_head.linear_c{i+1}.proj" 32 | new_state_dict[f"decoder.mlp_stage.{3-i}.linear.weight"] = state_dict.pop(f"{base}.weight") 33 | new_state_dict[f"decoder.mlp_stage.{3-i}.linear.bias"] = state_dict.pop(f"{base}.bias") 34 | 35 | # Map fuse_stage components 36 | fuse_base = "decode_head.linear_fuse" 37 | fuse_weights = { 38 | "decoder.fuse_stage.0.weight": state_dict.pop(f"{fuse_base}.conv.weight"), 39 | "decoder.fuse_stage.1.weight": state_dict.pop(f"{fuse_base}.bn.weight"), 40 | "decoder.fuse_stage.1.bias": state_dict.pop(f"{fuse_base}.bn.bias"), 41 | "decoder.fuse_stage.1.running_mean": state_dict.pop(f"{fuse_base}.bn.running_mean"), 42 | "decoder.fuse_stage.1.running_var": state_dict.pop(f"{fuse_base}.bn.running_var"), 43 | "decoder.fuse_stage.1.num_batches_tracked": state_dict.pop(f"{fuse_base}.bn.num_batches_tracked"), 44 | } 45 | new_state_dict.update(fuse_weights) 46 | 47 | # Map final layer components 48 | new_state_dict["segmentation_head.0.weight"] = state_dict.pop("decode_head.linear_pred.weight") 49 | new_state_dict["segmentation_head.0.bias"] = state_dict.pop("decode_head.linear_pred.bias") 50 | 51 | del state_dict["decode_head.conv_seg.weight"] 52 | del state_dict["decode_head.conv_seg.bias"] 53 | 54 | assert len(state_dict) == 0, f"Unmapped keys: {state_dict.keys()}" 55 | 56 | # fmt: on 57 | return new_state_dict 58 | 59 | 60 | def get_np_image(): 61 | url = "https://huggingface.co/datasets/hf-internal-testing/fixtures_ade20k/resolve/main/ADE_val_00000001.jpg" 62 | image = Image.open(requests.get(url, stream=True).raw) 63 | return np.array(image) 64 | 65 | 66 | def main(args): 67 | original_checkpoint = torch.load(args.path, map_location="cpu", weights_only=True) 68 | smp_state_dict = convert_state_dict_to_smp(original_checkpoint) 69 | 70 | config = original_checkpoint["meta"]["config"] 71 | num_classes = int(config.split("num_classes=")[1].split(",\n")[0]) 72 | decoder_dims = int(config.split("embed_dim=")[1].split(",\n")[0]) 73 | height, width = [ 74 | int(x) for x in config.split("crop_size=(")[1].split("), ")[0].split(",") 75 | ] 76 | model_size = args.path.split("segformer.")[1][:2] 77 | 78 | # Create the model 79 | model = smp.create_model( 80 | in_channels=3, 81 | classes=num_classes, 82 | arch="segformer", 83 | encoder_name=f"mit_{model_size}", 84 | encoder_weights=None, 85 | decoder_segmentation_channels=decoder_dims, 86 | ).eval() 87 | 88 | # Load the converted state dict 89 | model.load_state_dict(smp_state_dict, strict=True) 90 | 91 | # Preprocessing params 92 | preprocessing = A.Compose( 93 | [ 94 | A.Resize(height, width, p=1), 95 | A.Normalize( 96 | mean=[123.675, 116.28, 103.53], 97 | std=[58.395, 57.12, 57.375], 98 | max_pixel_value=1.0, 99 | p=1, 100 | ), 101 | ] 102 | ) 103 | 104 | # Prepare the input 105 | image = get_np_image() 106 | normalized_image = preprocessing(image=image)["image"] 107 | tensor = torch.tensor(normalized_image).permute(2, 0, 1).unsqueeze(0).float() 108 | 109 | # Forward pass 110 | with torch.no_grad(): 111 | mask = model(tensor) 112 | 113 | # Postprocessing 114 | mask = torch.nn.functional.interpolate( 115 | mask, size=(image.shape[0], image.shape[1]), mode="bilinear" 116 | ) 117 | mask = torch.argmax(mask, dim=1) 118 | mask = mask.squeeze().cpu().numpy() 119 | 120 | model_name = args.path.split("/")[-1].replace(".pth", "").replace(".", "-") 121 | 122 | model.save_pretrained(model_name) 123 | preprocessing.save_pretrained(model_name) 124 | 125 | # fmt: off 126 | plt.subplot(121), plt.axis('off'), plt.imshow(image), plt.title('Input Image') 127 | plt.subplot(122), plt.axis('off'), plt.imshow(mask), plt.title('Output Mask') 128 | plt.savefig(f"{model_name}/example_mask.png") 129 | # fmt: on 130 | 131 | if args.push_to_hub: 132 | repo_id = f"smp-hub/{model_name}" 133 | api = huggingface_hub.HfApi() 134 | api.create_repo(repo_id=repo_id, repo_type="model") 135 | api.upload_folder(folder_path=model_name, repo_id=repo_id) 136 | 137 | 138 | if __name__ == "__main__": 139 | parser = argparse.ArgumentParser() 140 | parser.add_argument( 141 | "--path", 142 | type=str, 143 | default="weights/trained_models/segformer.b2.512x512.ade.160k.pth", 144 | ) 145 | parser.add_argument("--push_to_hub", action="store_true") 146 | args = parser.parse_args() 147 | 148 | main(args) 149 | -------------------------------------------------------------------------------- /segmentation_models_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from . import datasets 2 | from . import encoders 3 | from . import decoders 4 | from . import losses 5 | from . import metrics 6 | 7 | from .decoders.unet import Unet 8 | from .decoders.unetplusplus import UnetPlusPlus 9 | from .decoders.manet import MAnet 10 | from .decoders.linknet import Linknet 11 | from .decoders.fpn import FPN 12 | from .decoders.pspnet import PSPNet 13 | from .decoders.deeplabv3 import DeepLabV3, DeepLabV3Plus 14 | from .decoders.pan import PAN 15 | from .decoders.upernet import UPerNet 16 | from .decoders.segformer import Segformer 17 | from .decoders.dpt import DPT 18 | from .base.hub_mixin import from_pretrained 19 | 20 | from .__version__ import __version__ 21 | 22 | # some private imports for create_model function 23 | from typing import Optional as _Optional 24 | import torch as _torch 25 | 26 | _MODEL_ARCHITECTURES = [ 27 | Unet, 28 | UnetPlusPlus, 29 | MAnet, 30 | Linknet, 31 | FPN, 32 | PSPNet, 33 | DeepLabV3, 34 | DeepLabV3Plus, 35 | PAN, 36 | UPerNet, 37 | Segformer, 38 | DPT, 39 | ] 40 | MODEL_ARCHITECTURES_MAPPING = {a.__name__.lower(): a for a in _MODEL_ARCHITECTURES} 41 | 42 | 43 | def create_model( 44 | arch: str, 45 | encoder_name: str = "resnet34", 46 | encoder_weights: _Optional[str] = "imagenet", 47 | in_channels: int = 3, 48 | classes: int = 1, 49 | **kwargs, 50 | ) -> _torch.nn.Module: 51 | """Models entrypoint, allows to create any model architecture just with 52 | parameters, without using its class 53 | """ 54 | 55 | try: 56 | model_class = MODEL_ARCHITECTURES_MAPPING[arch.lower()] 57 | except KeyError: 58 | raise KeyError( 59 | "Wrong architecture type `{}`. Available options are: {}".format( 60 | arch, list(MODEL_ARCHITECTURES_MAPPING.keys()) 61 | ) 62 | ) 63 | return model_class( 64 | encoder_name=encoder_name, 65 | encoder_weights=encoder_weights, 66 | in_channels=in_channels, 67 | classes=classes, 68 | **kwargs, 69 | ) 70 | 71 | 72 | __all__ = [ 73 | "datasets", 74 | "encoders", 75 | "decoders", 76 | "losses", 77 | "metrics", 78 | "Unet", 79 | "UnetPlusPlus", 80 | "MAnet", 81 | "Linknet", 82 | "FPN", 83 | "PSPNet", 84 | "DeepLabV3", 85 | "DeepLabV3Plus", 86 | "PAN", 87 | "UPerNet", 88 | "Segformer", 89 | "DPT", 90 | "from_pretrained", 91 | "create_model", 92 | "__version__", 93 | ] 94 | -------------------------------------------------------------------------------- /segmentation_models_pytorch/__version__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.5.0" 2 | -------------------------------------------------------------------------------- /segmentation_models_pytorch/base/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import SegmentationModel 2 | 3 | from .modules import Conv2dReLU, Attention 4 | 5 | from .heads import SegmentationHead, ClassificationHead 6 | 7 | __all__ = [ 8 | "SegmentationModel", 9 | "Conv2dReLU", 10 | "Attention", 11 | "SegmentationHead", 12 | "ClassificationHead", 13 | ] 14 | -------------------------------------------------------------------------------- /segmentation_models_pytorch/base/heads.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from .modules import Activation 3 | 4 | 5 | class SegmentationHead(nn.Sequential): 6 | def __init__( 7 | self, in_channels, out_channels, kernel_size=3, activation=None, upsampling=1 8 | ): 9 | conv2d = nn.Conv2d( 10 | in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2 11 | ) 12 | upsampling = ( 13 | nn.UpsamplingBilinear2d(scale_factor=upsampling) 14 | if upsampling > 1 15 | else nn.Identity() 16 | ) 17 | activation = Activation(activation) 18 | super().__init__(conv2d, upsampling, activation) 19 | 20 | 21 | class ClassificationHead(nn.Sequential): 22 | def __init__( 23 | self, in_channels, classes, pooling="avg", dropout=0.2, activation=None 24 | ): 25 | if pooling not in ("max", "avg"): 26 | raise ValueError( 27 | "Pooling should be one of ('max', 'avg'), got {}.".format(pooling) 28 | ) 29 | pool = nn.AdaptiveAvgPool2d(1) if pooling == "avg" else nn.AdaptiveMaxPool2d(1) 30 | flatten = nn.Flatten() 31 | dropout = nn.Dropout(p=dropout, inplace=True) if dropout else nn.Identity() 32 | linear = nn.Linear(in_channels, classes, bias=True) 33 | activation = Activation(activation) 34 | super().__init__(pool, flatten, dropout, linear, activation) 35 | -------------------------------------------------------------------------------- /segmentation_models_pytorch/base/hub_mixin.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import json 3 | from pathlib import Path 4 | from typing import Optional, Union 5 | from functools import wraps 6 | from huggingface_hub import ( 7 | PyTorchModelHubMixin, 8 | ModelCard, 9 | ModelCardData, 10 | hf_hub_download, 11 | ) 12 | 13 | 14 | MODEL_CARD = """ 15 | --- 16 | {{ card_data }} 17 | --- 18 | # {{ model_name }} Model Card 19 | 20 | Table of Contents: 21 | - [Load trained model](#load-trained-model) 22 | - [Model init parameters](#model-init-parameters) 23 | - [Model metrics](#model-metrics) 24 | - [Dataset](#dataset) 25 | 26 | ## Load trained model 27 | ```python 28 | import segmentation_models_pytorch as smp 29 | 30 | model = smp.from_pretrained("") 31 | ``` 32 | 33 | ## Model init parameters 34 | ```python 35 | model_init_params = {{ model_parameters }} 36 | ``` 37 | 38 | ## Model metrics 39 | {{ metrics | default("[More Information Needed]", true) }} 40 | 41 | ## Dataset 42 | Dataset name: {{ dataset | default("[More Information Needed]", true) }} 43 | 44 | ## More Information 45 | - Library: {{ repo_url | default("[More Information Needed]", true) }} 46 | - Docs: {{ docs_url | default("[More Information Needed]", true) }} 47 | 48 | This model has been pushed to the Hub using the [PytorchModelHubMixin](https://huggingface.co/docs/huggingface_hub/package_reference/mixins#huggingface_hub.PyTorchModelHubMixin) 49 | """ 50 | 51 | 52 | def _format_parameters(parameters: dict): 53 | params = {k: v for k, v in parameters.items() if not k.startswith("_")} 54 | params = [ 55 | f'"{k}": {v}' if not isinstance(v, str) else f'"{k}": "{v}"' 56 | for k, v in params.items() 57 | ] 58 | params = ",\n".join([f" {param}" for param in params]) 59 | params = "{\n" + f"{params}" + "\n}" 60 | return params 61 | 62 | 63 | class SMPHubMixin(PyTorchModelHubMixin): 64 | def generate_model_card(self, *args, **kwargs) -> ModelCard: 65 | model_parameters_json = _format_parameters(self.config) 66 | metrics = kwargs.get("metrics", None) 67 | dataset = kwargs.get("dataset", None) 68 | 69 | if metrics is not None: 70 | metrics = json.dumps(metrics, indent=4) 71 | metrics = f"```json\n{metrics}\n```" 72 | 73 | tags = self._hub_mixin_info.model_card_data.get("tags", []) or [] 74 | tags.extend(["segmentation-models-pytorch", "semantic-segmentation", "pytorch"]) 75 | 76 | model_card_data = ModelCardData( 77 | languages=["python"], 78 | library_name="segmentation-models-pytorch", 79 | license="mit", 80 | tags=tags, 81 | pipeline_tag="image-segmentation", 82 | ) 83 | model_card = ModelCard.from_template( 84 | card_data=model_card_data, 85 | template_str=MODEL_CARD, 86 | repo_url="https://github.com/qubvel/segmentation_models.pytorch", 87 | docs_url="https://smp.readthedocs.io/en/latest/", 88 | model_parameters=model_parameters_json, 89 | model_name=self.__class__.__name__, 90 | metrics=metrics, 91 | dataset=dataset, 92 | ) 93 | return model_card 94 | 95 | @wraps(PyTorchModelHubMixin.save_pretrained) 96 | def save_pretrained( 97 | self, save_directory: Union[str, Path], *args, **kwargs 98 | ) -> Optional[str]: 99 | model_card_kwargs = kwargs.pop("model_card_kwargs", {}) 100 | if "dataset" in kwargs: 101 | model_card_kwargs["dataset"] = kwargs.pop("dataset") 102 | if "metrics" in kwargs: 103 | model_card_kwargs["metrics"] = kwargs.pop("metrics") 104 | kwargs["model_card_kwargs"] = model_card_kwargs 105 | 106 | # set additional attribute to be able to deserialize the model 107 | self.config["_model_class"] = self.__class__.__name__ 108 | 109 | try: 110 | # call the original save_pretrained 111 | result = super().save_pretrained(save_directory, *args, **kwargs) 112 | finally: 113 | self.config.pop("_model_class", None) 114 | 115 | return result 116 | 117 | @property 118 | @torch.jit.unused 119 | def config(self) -> dict: 120 | return self._hub_mixin_config 121 | 122 | 123 | @wraps(PyTorchModelHubMixin.from_pretrained) 124 | def from_pretrained( 125 | pretrained_model_name_or_path: str, *args, strict: bool = True, **kwargs 126 | ): 127 | config_path = Path(pretrained_model_name_or_path) / "config.json" 128 | if not config_path.exists(): 129 | config_path = hf_hub_download( 130 | pretrained_model_name_or_path, 131 | filename="config.json", 132 | revision=kwargs.get("revision", None), 133 | ) 134 | 135 | with open(config_path, "r") as f: 136 | config = json.load(f) 137 | model_class_name = config.pop("_model_class") 138 | 139 | import segmentation_models_pytorch as smp 140 | 141 | model_class = getattr(smp, model_class_name) 142 | return model_class.from_pretrained( 143 | pretrained_model_name_or_path, *args, **kwargs, strict=strict 144 | ) 145 | 146 | 147 | def supports_config_loading(func): 148 | """Decorator to filter special config kwargs""" 149 | 150 | @wraps(func) 151 | def wrapper(self, *args, **kwargs): 152 | kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")} 153 | return func(self, *args, **kwargs) 154 | 155 | return wrapper 156 | -------------------------------------------------------------------------------- /segmentation_models_pytorch/base/initialization.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | def initialize_decoder(module): 5 | for m in module.modules(): 6 | if isinstance(m, nn.Conv2d): 7 | nn.init.kaiming_uniform_(m.weight, mode="fan_in", nonlinearity="relu") 8 | if m.bias is not None: 9 | nn.init.constant_(m.bias, 0) 10 | 11 | elif isinstance( 12 | m, (nn.BatchNorm2d, nn.LayerNorm, nn.GroupNorm, nn.InstanceNorm2d) 13 | ): 14 | nn.init.constant_(m.weight, 1) 15 | nn.init.constant_(m.bias, 0) 16 | 17 | elif isinstance(m, nn.Linear): 18 | nn.init.xavier_uniform_(m.weight) 19 | if m.bias is not None: 20 | nn.init.constant_(m.bias, 0) 21 | 22 | 23 | def initialize_head(module): 24 | for m in module.modules(): 25 | if isinstance(m, (nn.Linear, nn.Conv2d)): 26 | nn.init.xavier_uniform_(m.weight) 27 | if m.bias is not None: 28 | nn.init.constant_(m.bias, 0) 29 | -------------------------------------------------------------------------------- /segmentation_models_pytorch/base/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | @torch.jit.unused 5 | def is_torch_compiling(): 6 | try: 7 | return torch.compiler.is_compiling() 8 | except Exception: 9 | try: 10 | import torch._dynamo as dynamo # noqa: F401 11 | 12 | return dynamo.is_compiling() 13 | except Exception: 14 | return False 15 | -------------------------------------------------------------------------------- /segmentation_models_pytorch/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .oxford_pet import OxfordPetDataset, SimpleOxfordPetDataset 2 | 3 | __all__ = ["OxfordPetDataset", "SimpleOxfordPetDataset"] 4 | -------------------------------------------------------------------------------- /segmentation_models_pytorch/datasets/oxford_pet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import shutil 4 | import numpy as np 5 | 6 | from PIL import Image 7 | from tqdm import tqdm 8 | from urllib.request import urlretrieve 9 | 10 | 11 | class OxfordPetDataset(torch.utils.data.Dataset): 12 | def __init__(self, root, mode="train", transform=None): 13 | assert mode in {"train", "valid", "test"} 14 | 15 | self.root = root 16 | self.mode = mode 17 | self.transform = transform 18 | 19 | self.images_directory = os.path.join(self.root, "images") 20 | self.masks_directory = os.path.join(self.root, "annotations", "trimaps") 21 | 22 | self.filenames = self._read_split() # read train/valid/test splits 23 | 24 | def __len__(self): 25 | return len(self.filenames) 26 | 27 | def __getitem__(self, idx): 28 | filename = self.filenames[idx] 29 | image_path = os.path.join(self.images_directory, filename + ".jpg") 30 | mask_path = os.path.join(self.masks_directory, filename + ".png") 31 | 32 | image = np.array(Image.open(image_path).convert("RGB")) 33 | 34 | trimap = np.array(Image.open(mask_path)) 35 | mask = self._preprocess_mask(trimap) 36 | 37 | sample = dict(image=image, mask=mask, trimap=trimap) 38 | if self.transform is not None: 39 | sample = self.transform(**sample) 40 | 41 | return sample 42 | 43 | @staticmethod 44 | def _preprocess_mask(mask): 45 | mask = mask.astype(np.float32) 46 | mask[mask == 2.0] = 0.0 47 | mask[(mask == 1.0) | (mask == 3.0)] = 1.0 48 | return mask 49 | 50 | def _read_split(self): 51 | split_filename = "test.txt" if self.mode == "test" else "trainval.txt" 52 | split_filepath = os.path.join(self.root, "annotations", split_filename) 53 | with open(split_filepath) as f: 54 | split_data = f.read().strip("\n").split("\n") 55 | filenames = [x.split(" ")[0] for x in split_data] 56 | if self.mode == "train": # 90% for train 57 | filenames = [x for i, x in enumerate(filenames) if i % 10 != 0] 58 | elif self.mode == "valid": # 10% for validation 59 | filenames = [x for i, x in enumerate(filenames) if i % 10 == 0] 60 | return filenames 61 | 62 | @staticmethod 63 | def download(root): 64 | # load images 65 | filepath = os.path.join(root, "images.tar.gz") 66 | download_url( 67 | url="https://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz", 68 | filepath=filepath, 69 | ) 70 | extract_archive(filepath) 71 | 72 | # load annotations 73 | filepath = os.path.join(root, "annotations.tar.gz") 74 | download_url( 75 | url="https://www.robots.ox.ac.uk/~vgg/data/pets/data/annotations.tar.gz", 76 | filepath=filepath, 77 | ) 78 | extract_archive(filepath) 79 | 80 | 81 | class SimpleOxfordPetDataset(OxfordPetDataset): 82 | def __getitem__(self, *args, **kwargs): 83 | sample = super().__getitem__(*args, **kwargs) 84 | 85 | # resize images 86 | image = np.array( 87 | Image.fromarray(sample["image"]).resize((256, 256), Image.BILINEAR) 88 | ) 89 | mask = np.array( 90 | Image.fromarray(sample["mask"]).resize((256, 256), Image.NEAREST) 91 | ) 92 | trimap = np.array( 93 | Image.fromarray(sample["trimap"]).resize((256, 256), Image.NEAREST) 94 | ) 95 | 96 | # convert to other format HWC -> CHW 97 | sample["image"] = np.moveaxis(image, -1, 0) 98 | sample["mask"] = np.expand_dims(mask, 0) 99 | sample["trimap"] = np.expand_dims(trimap, 0) 100 | 101 | return sample 102 | 103 | 104 | class TqdmUpTo(tqdm): 105 | def update_to(self, b=1, bsize=1, tsize=None): 106 | if tsize is not None: 107 | self.total = tsize 108 | self.update(b * bsize - self.n) 109 | 110 | 111 | def download_url(url, filepath): 112 | directory = os.path.dirname(os.path.abspath(filepath)) 113 | os.makedirs(directory, exist_ok=True) 114 | if os.path.exists(filepath): 115 | return 116 | 117 | with TqdmUpTo( 118 | unit="B", 119 | unit_scale=True, 120 | unit_divisor=1024, 121 | miniters=1, 122 | desc=os.path.basename(filepath), 123 | ) as t: 124 | urlretrieve(url, filename=filepath, reporthook=t.update_to, data=None) 125 | t.total = t.n 126 | 127 | 128 | def extract_archive(filepath): 129 | extract_dir = os.path.dirname(os.path.abspath(filepath)) 130 | dst_dir = os.path.splitext(filepath)[0] 131 | if not os.path.exists(dst_dir): 132 | shutil.unpack_archive(filepath, extract_dir) 133 | -------------------------------------------------------------------------------- /segmentation_models_pytorch/decoders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qubvel-org/segmentation_models.pytorch/cf50cd082d35763073a296f6ee6378e24938bed8/segmentation_models_pytorch/decoders/__init__.py -------------------------------------------------------------------------------- /segmentation_models_pytorch/decoders/deeplabv3/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import DeepLabV3, DeepLabV3Plus 2 | 3 | __all__ = ["DeepLabV3", "DeepLabV3Plus"] 4 | -------------------------------------------------------------------------------- /segmentation_models_pytorch/decoders/dpt/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import DPT 2 | 3 | __all__ = ["DPT"] 4 | -------------------------------------------------------------------------------- /segmentation_models_pytorch/decoders/fpn/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import FPN 2 | 3 | __all__ = ["FPN"] 4 | -------------------------------------------------------------------------------- /segmentation_models_pytorch/decoders/fpn/decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from typing import List, Literal 6 | 7 | 8 | class Conv3x3GNReLU(nn.Module): 9 | def __init__(self, in_channels: int, out_channels: int, upsample: bool = False): 10 | super().__init__() 11 | self.upsample = upsample 12 | self.block = nn.Sequential( 13 | nn.Conv2d( 14 | in_channels, out_channels, (3, 3), stride=1, padding=1, bias=False 15 | ), 16 | nn.GroupNorm(32, out_channels), 17 | nn.ReLU(inplace=True), 18 | ) 19 | 20 | def forward(self, x: torch.Tensor) -> torch.Tensor: 21 | x = self.block(x) 22 | if self.upsample: 23 | x = F.interpolate(x, scale_factor=2.0, mode="bilinear", align_corners=True) 24 | return x 25 | 26 | 27 | class FPNBlock(nn.Module): 28 | def __init__( 29 | self, 30 | pyramid_channels: int, 31 | skip_channels: int, 32 | interpolation_mode: str = "nearest", 33 | ): 34 | super().__init__() 35 | self.skip_conv = nn.Conv2d(skip_channels, pyramid_channels, kernel_size=1) 36 | self.interpolation_mode = interpolation_mode 37 | 38 | def forward(self, x: torch.Tensor, skip: torch.Tensor) -> torch.Tensor: 39 | x = F.interpolate(x, scale_factor=2.0, mode=self.interpolation_mode) 40 | skip = self.skip_conv(skip) 41 | x = x + skip 42 | return x 43 | 44 | 45 | class SegmentationBlock(nn.Module): 46 | def __init__(self, in_channels: int, out_channels: int, n_upsamples: int = 0): 47 | super().__init__() 48 | 49 | blocks = [Conv3x3GNReLU(in_channels, out_channels, upsample=bool(n_upsamples))] 50 | 51 | if n_upsamples > 1: 52 | for _ in range(1, n_upsamples): 53 | blocks.append(Conv3x3GNReLU(out_channels, out_channels, upsample=True)) 54 | 55 | self.block = nn.Sequential(*blocks) 56 | 57 | def forward(self, x): 58 | return self.block(x) 59 | 60 | 61 | class MergeBlock(nn.Module): 62 | def __init__(self, policy: Literal["add", "cat"]): 63 | super().__init__() 64 | if policy not in ["add", "cat"]: 65 | raise ValueError( 66 | "`merge_policy` must be one of: ['add', 'cat'], got {}".format(policy) 67 | ) 68 | self.policy = policy 69 | 70 | def forward(self, x: List[torch.Tensor]) -> torch.Tensor: 71 | if self.policy == "add": 72 | output = torch.stack(x).sum(dim=0) 73 | elif self.policy == "cat": 74 | output = torch.cat(x, dim=1) 75 | else: 76 | raise ValueError( 77 | "`merge_policy` must be one of: ['add', 'cat'], got {}".format( 78 | self.policy 79 | ) 80 | ) 81 | return output 82 | 83 | 84 | class FPNDecoder(nn.Module): 85 | def __init__( 86 | self, 87 | encoder_channels: List[int], 88 | encoder_depth: int = 5, 89 | pyramid_channels: int = 256, 90 | segmentation_channels: int = 128, 91 | dropout: float = 0.2, 92 | merge_policy: Literal["add", "cat"] = "add", 93 | interpolation_mode: str = "nearest", 94 | ): 95 | super().__init__() 96 | 97 | self.out_channels = ( 98 | segmentation_channels 99 | if merge_policy == "add" 100 | else segmentation_channels * 4 101 | ) 102 | if encoder_depth < 3: 103 | raise ValueError( 104 | "Encoder depth for FPN decoder cannot be less than 3, got {}.".format( 105 | encoder_depth 106 | ) 107 | ) 108 | 109 | encoder_channels = encoder_channels[::-1] 110 | encoder_channels = encoder_channels[: encoder_depth + 1] 111 | 112 | self.p5 = nn.Conv2d(encoder_channels[0], pyramid_channels, kernel_size=1) 113 | self.p4 = FPNBlock(pyramid_channels, encoder_channels[1], interpolation_mode) 114 | self.p3 = FPNBlock(pyramid_channels, encoder_channels[2], interpolation_mode) 115 | self.p2 = FPNBlock(pyramid_channels, encoder_channels[3], interpolation_mode) 116 | 117 | self.seg_blocks = nn.ModuleList( 118 | [ 119 | SegmentationBlock( 120 | pyramid_channels, segmentation_channels, n_upsamples=n_upsamples 121 | ) 122 | for n_upsamples in [3, 2, 1, 0] 123 | ] 124 | ) 125 | 126 | self.merge = MergeBlock(merge_policy) 127 | self.dropout = nn.Dropout2d(p=dropout, inplace=True) 128 | 129 | def forward(self, features: List[torch.Tensor]) -> torch.Tensor: 130 | c2, c3, c4, c5 = features[-4:] 131 | 132 | p5 = self.p5(c5) 133 | p4 = self.p4(p5, c4) 134 | p3 = self.p3(p4, c3) 135 | p2 = self.p2(p3, c2) 136 | 137 | s5 = self.seg_blocks[0](p5) 138 | s4 = self.seg_blocks[1](p4) 139 | s3 = self.seg_blocks[2](p3) 140 | s2 = self.seg_blocks[3](p2) 141 | 142 | feature_pyramid = [s5, s4, s3, s2] 143 | x = self.merge(feature_pyramid) 144 | x = self.dropout(x) 145 | 146 | return x 147 | -------------------------------------------------------------------------------- /segmentation_models_pytorch/decoders/linknet/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import Linknet 2 | 3 | __all__ = ["Linknet"] 4 | -------------------------------------------------------------------------------- /segmentation_models_pytorch/decoders/linknet/decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from typing import Any, Dict, List, Optional, Union 5 | from segmentation_models_pytorch.base import modules 6 | 7 | 8 | class TransposeX2(nn.Sequential): 9 | def __init__( 10 | self, 11 | in_channels: int, 12 | out_channels: int, 13 | use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", 14 | ): 15 | super().__init__() 16 | conv = nn.ConvTranspose2d( 17 | in_channels, out_channels, kernel_size=4, stride=2, padding=1 18 | ) 19 | norm = modules.get_norm_layer(use_norm, out_channels) 20 | activation = nn.ReLU(inplace=True) 21 | super().__init__(conv, norm, activation) 22 | 23 | 24 | class DecoderBlock(nn.Module): 25 | def __init__( 26 | self, 27 | in_channels: int, 28 | out_channels: int, 29 | use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", 30 | ): 31 | super().__init__() 32 | 33 | self.block = nn.Sequential( 34 | modules.Conv2dReLU( 35 | in_channels, 36 | in_channels // 4, 37 | kernel_size=1, 38 | use_norm=use_norm, 39 | ), 40 | TransposeX2(in_channels // 4, in_channels // 4, use_norm=use_norm), 41 | modules.Conv2dReLU( 42 | in_channels // 4, 43 | out_channels, 44 | kernel_size=1, 45 | use_norm=use_norm, 46 | ), 47 | ) 48 | 49 | def forward( 50 | self, x: torch.Tensor, skip: Optional[torch.Tensor] = None 51 | ) -> torch.Tensor: 52 | x = self.block(x) 53 | if skip is not None: 54 | x = x + skip 55 | return x 56 | 57 | 58 | class LinknetDecoder(nn.Module): 59 | def __init__( 60 | self, 61 | encoder_channels: List[int], 62 | prefinal_channels: int = 32, 63 | n_blocks: int = 5, 64 | use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", 65 | ): 66 | super().__init__() 67 | 68 | # remove first skip 69 | encoder_channels = encoder_channels[1:] 70 | # reverse channels to start from head of encoder 71 | encoder_channels = encoder_channels[::-1] 72 | 73 | channels = list(encoder_channels) + [prefinal_channels] 74 | 75 | self.blocks = nn.ModuleList( 76 | [ 77 | DecoderBlock( 78 | channels[i], 79 | channels[i + 1], 80 | use_norm=use_norm, 81 | ) 82 | for i in range(n_blocks) 83 | ] 84 | ) 85 | 86 | def forward(self, features: List[torch.Tensor]) -> torch.Tensor: 87 | features = features[1:] # remove first skip 88 | features = features[::-1] # reverse channels to start from head of encoder 89 | 90 | x = features[0] 91 | skips = features[1:] 92 | 93 | for i, decoder_block in enumerate(self.blocks): 94 | skip = skips[i] if i < len(skips) else None 95 | x = decoder_block(x, skip) 96 | 97 | return x 98 | -------------------------------------------------------------------------------- /segmentation_models_pytorch/decoders/manet/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import MAnet 2 | 3 | __all__ = ["MAnet"] 4 | -------------------------------------------------------------------------------- /segmentation_models_pytorch/decoders/pan/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import PAN 2 | 3 | __all__ = ["PAN"] 4 | -------------------------------------------------------------------------------- /segmentation_models_pytorch/decoders/pspnet/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import PSPNet 2 | 3 | __all__ = ["PSPNet"] 4 | -------------------------------------------------------------------------------- /segmentation_models_pytorch/decoders/pspnet/decoder.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Tuple, Union 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from segmentation_models_pytorch.base import modules 8 | 9 | 10 | class PSPBlock(nn.Module): 11 | def __init__( 12 | self, 13 | in_channels: int, 14 | out_channels: int, 15 | pool_size: int, 16 | use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", 17 | ): 18 | super().__init__() 19 | 20 | if pool_size == 1: 21 | use_norm = "identity" # PyTorch does not support BatchNorm for 1x1 shape 22 | 23 | self.pool = nn.Sequential( 24 | nn.AdaptiveAvgPool2d(output_size=(pool_size, pool_size)), 25 | modules.Conv2dReLU( 26 | in_channels, out_channels, kernel_size=1, use_norm=use_norm 27 | ), 28 | ) 29 | 30 | def forward(self, x: torch.Tensor) -> torch.Tensor: 31 | height, width = x.shape[2:] 32 | x = self.pool(x) 33 | x = F.interpolate(x, size=(height, width), mode="bilinear", align_corners=True) 34 | return x 35 | 36 | 37 | class PSPModule(nn.Module): 38 | def __init__( 39 | self, 40 | in_channels: int, 41 | sizes: Tuple[int, ...] = (1, 2, 3, 6), 42 | use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", 43 | ): 44 | super().__init__() 45 | 46 | self.blocks = nn.ModuleList( 47 | [ 48 | PSPBlock( 49 | in_channels, 50 | in_channels // len(sizes), 51 | size, 52 | use_norm=use_norm, 53 | ) 54 | for size in sizes 55 | ] 56 | ) 57 | 58 | def forward(self, x): 59 | xs = [block(x) for block in self.blocks] + [x] 60 | x = torch.cat(xs, dim=1) 61 | return x 62 | 63 | 64 | class PSPDecoder(nn.Module): 65 | def __init__( 66 | self, 67 | encoder_channels: List[int], 68 | use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", 69 | out_channels: int = 512, 70 | dropout: float = 0.2, 71 | ): 72 | super().__init__() 73 | 74 | self.psp = PSPModule( 75 | in_channels=encoder_channels[-1], 76 | sizes=(1, 2, 3, 6), 77 | use_norm=use_norm, 78 | ) 79 | 80 | self.conv = modules.Conv2dReLU( 81 | in_channels=encoder_channels[-1] * 2, 82 | out_channels=out_channels, 83 | kernel_size=1, 84 | use_norm=use_norm, 85 | ) 86 | 87 | self.dropout = nn.Dropout2d(p=dropout) 88 | 89 | def forward(self, features: List[torch.Tensor]) -> torch.Tensor: 90 | x = features[-1] 91 | x = self.psp(x) 92 | x = self.conv(x) 93 | x = self.dropout(x) 94 | 95 | return x 96 | -------------------------------------------------------------------------------- /segmentation_models_pytorch/decoders/segformer/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import Segformer 2 | 3 | __all__ = ["Segformer"] 4 | -------------------------------------------------------------------------------- /segmentation_models_pytorch/decoders/segformer/decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from typing import List 6 | from segmentation_models_pytorch.base import modules as md 7 | 8 | 9 | class MLP(nn.Module): 10 | def __init__(self, skip_channels: int, segmentation_channels: int): 11 | super().__init__() 12 | 13 | self.linear = nn.Linear(skip_channels, segmentation_channels) 14 | 15 | def forward(self, x: torch.Tensor): 16 | batch, _, height, width = x.shape 17 | x = x.flatten(2).transpose(1, 2) 18 | x = self.linear(x) 19 | x = x.transpose(1, 2).reshape(batch, -1, height, width) 20 | return x 21 | 22 | 23 | class SegformerDecoder(nn.Module): 24 | def __init__( 25 | self, 26 | encoder_channels: List[int], 27 | encoder_depth: int = 5, 28 | segmentation_channels: int = 256, 29 | ): 30 | super().__init__() 31 | 32 | if encoder_depth < 3: 33 | raise ValueError( 34 | "Encoder depth for Segformer decoder cannot be less than 3, got {}.".format( 35 | encoder_depth 36 | ) 37 | ) 38 | 39 | if encoder_channels[1] == 0: 40 | encoder_channels = [ 41 | channel for index, channel in enumerate(encoder_channels) if index != 1 42 | ] 43 | encoder_channels = encoder_channels[::-1] 44 | 45 | self.mlp_stage = nn.ModuleList( 46 | [MLP(channel, segmentation_channels) for channel in encoder_channels[:-1]] 47 | ) 48 | 49 | self.fuse_stage = md.Conv2dReLU( 50 | in_channels=(len(encoder_channels) - 1) * segmentation_channels, 51 | out_channels=segmentation_channels, 52 | kernel_size=1, 53 | use_norm="batchnorm", 54 | ) 55 | 56 | def forward(self, features: List[torch.Tensor]) -> torch.Tensor: 57 | # Resize all features to the size of the largest feature 58 | target_size = [dim // 4 for dim in features[0].shape[2:]] 59 | 60 | features = features[2:] if features[1].size(1) == 0 else features[1:] 61 | features = features[::-1] # reverse channels to start from head of encoder 62 | 63 | resized_features = [] 64 | for i, mlp_layer in enumerate(self.mlp_stage): 65 | feature = mlp_layer(features[i]) 66 | resized_feature = F.interpolate( 67 | feature, size=target_size, mode="bilinear", align_corners=False 68 | ) 69 | resized_features.append(resized_feature) 70 | 71 | output = self.fuse_stage(torch.cat(resized_features, dim=1)) 72 | 73 | return output 74 | -------------------------------------------------------------------------------- /segmentation_models_pytorch/decoders/segformer/model.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Optional, Union, Callable 2 | 3 | from segmentation_models_pytorch.base import ( 4 | ClassificationHead, 5 | SegmentationHead, 6 | SegmentationModel, 7 | ) 8 | from segmentation_models_pytorch.encoders import get_encoder 9 | from segmentation_models_pytorch.base.hub_mixin import supports_config_loading 10 | 11 | from .decoder import SegformerDecoder 12 | 13 | 14 | class Segformer(SegmentationModel): 15 | """Segformer is simple and efficient design for semantic segmentation with Transformers 16 | 17 | Args: 18 | encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone) 19 | to extract features of different spatial resolution 20 | encoder_depth: A number of stages used in encoder in range [3, 5]. Each stage generate features 21 | two times smaller in spatial dimensions than previous one (e.g. for depth 0 we will have features 22 | with shapes [(N, C, H, W),], for depth 1 - [(N, C, H, W), (N, C, H // 2, W // 2)] and so on). 23 | Default is 5 24 | encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and 25 | other pretrained weights (see table with available weights for each encoder_name) 26 | decoder_segmentation_channels: A number of convolution filters in segmentation blocks, default is 256 27 | in_channels: A number of input channels for the model, default is 3 (RGB images) 28 | classes: A number of classes for output mask (or you can think as a number of channels of output mask) 29 | activation: An activation function to apply after the final convolution layer. 30 | Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**, 31 | **callable** and **None**. Default is **None**. 32 | upsampling: A number to upsample the output of the model, default is 4 (same size as input) 33 | aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build 34 | on top of encoder if **aux_params** is not **None** (default). Supported params: 35 | - classes (int): A number of classes 36 | - pooling (str): One of "max", "avg". Default is "avg" 37 | - dropout (float): Dropout factor in [0, 1) 38 | - activation (str): An activation function to apply "sigmoid"/"softmax" 39 | (could be **None** to return logits) 40 | kwargs: Arguments passed to the encoder class ``__init__()`` function. Applies only to ``timm`` models. Keys with ``None`` values are pruned before passing. 41 | 42 | Returns: 43 | ``torch.nn.Module``: **Segformer** 44 | 45 | .. _Segformer: 46 | https://arxiv.org/abs/2105.15203 47 | 48 | """ 49 | 50 | @supports_config_loading 51 | def __init__( 52 | self, 53 | encoder_name: str = "resnet34", 54 | encoder_depth: int = 5, 55 | encoder_weights: Optional[str] = "imagenet", 56 | decoder_segmentation_channels: int = 256, 57 | in_channels: int = 3, 58 | classes: int = 1, 59 | activation: Optional[Union[str, Callable]] = None, 60 | upsampling: int = 4, 61 | aux_params: Optional[dict] = None, 62 | **kwargs: dict[str, Any], 63 | ): 64 | super().__init__() 65 | 66 | self.encoder = get_encoder( 67 | encoder_name, 68 | in_channels=in_channels, 69 | depth=encoder_depth, 70 | weights=encoder_weights, 71 | **kwargs, 72 | ) 73 | 74 | self.decoder = SegformerDecoder( 75 | encoder_channels=self.encoder.out_channels, 76 | encoder_depth=encoder_depth, 77 | segmentation_channels=decoder_segmentation_channels, 78 | ) 79 | 80 | self.segmentation_head = SegmentationHead( 81 | in_channels=decoder_segmentation_channels, 82 | out_channels=classes, 83 | activation=activation, 84 | kernel_size=1, 85 | upsampling=upsampling, 86 | ) 87 | 88 | if aux_params is not None: 89 | self.classification_head = ClassificationHead( 90 | in_channels=self.encoder.out_channels[-1], **aux_params 91 | ) 92 | else: 93 | self.classification_head = None 94 | 95 | self.name = "segformer-{}".format(encoder_name) 96 | self.initialize() 97 | -------------------------------------------------------------------------------- /segmentation_models_pytorch/decoders/unet/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import Unet 2 | 3 | __all__ = ["Unet"] 4 | -------------------------------------------------------------------------------- /segmentation_models_pytorch/decoders/unetplusplus/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import UnetPlusPlus 2 | 3 | __all__ = ["UnetPlusPlus"] 4 | -------------------------------------------------------------------------------- /segmentation_models_pytorch/decoders/upernet/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import UPerNet 2 | 3 | __all__ = ["UPerNet"] 4 | -------------------------------------------------------------------------------- /segmentation_models_pytorch/decoders/upernet/model.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Optional, Union, Callable 2 | 3 | from segmentation_models_pytorch.base import ( 4 | ClassificationHead, 5 | SegmentationHead, 6 | SegmentationModel, 7 | ) 8 | from segmentation_models_pytorch.encoders import get_encoder 9 | from segmentation_models_pytorch.base.hub_mixin import supports_config_loading 10 | 11 | from .decoder import UPerNetDecoder 12 | 13 | 14 | class UPerNet(SegmentationModel): 15 | """UPerNet is a unified perceptual parsing network for image segmentation. 16 | 17 | Args: 18 | encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone) 19 | to extract features of different spatial resolution 20 | encoder_depth: A number of stages used in encoder in range [3, 5]. Each stage generate features 21 | two times smaller in spatial dimensions than previous one (e.g. for depth 0 we will have features 22 | with shapes [(N, C, H, W),], for depth 1 - [(N, C, H, W), (N, C, H // 2, W // 2)] and so on). 23 | Default is 5 24 | encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and 25 | other pretrained weights (see table with available weights for each encoder_name) 26 | decoder_pyramid_channels: A number of convolution filters in Feature Pyramid, default is 256 27 | decoder_segmentation_channels: A number of convolution filters in segmentation blocks, default is 64 28 | decoder_use_norm: Specifies normalization between Conv2D and activation. 29 | Accepts the following types: 30 | - **True**: Defaults to `"batchnorm"`. 31 | - **False**: No normalization (`nn.Identity`). 32 | - **str**: Specifies normalization type using default parameters. Available values: 33 | `"batchnorm"`, `"identity"`, `"layernorm"`, `"instancenorm"`, `"inplace"`. 34 | - **dict**: Fully customizable normalization settings. Structure: 35 | ```python 36 | {"type": , **kwargs} 37 | ``` 38 | where `norm_name` corresponds to normalization type (see above), and `kwargs` are passed directly to the normalization layer as defined in PyTorch documentation. 39 | 40 | **Example**: 41 | ```python 42 | use_norm={"type": "layernorm", "eps": 1e-2} 43 | ``` 44 | in_channels: A number of input channels for the model, default is 3 (RGB images) 45 | classes: A number of classes for output mask (or you can think as a number of channels of output mask) 46 | activation: An activation function to apply after the final convolution layer. 47 | Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**, 48 | **callable** and **None**. Default is **None**. 49 | aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build 50 | on top of encoder if **aux_params** is not **None** (default). Supported params: 51 | - classes (int): A number of classes 52 | - pooling (str): One of "max", "avg". Default is "avg" 53 | - dropout (float): Dropout factor in [0, 1) 54 | - activation (str): An activation function to apply "sigmoid"/"softmax" 55 | (could be **None** to return logits) 56 | kwargs: Arguments passed to the encoder class ``__init__()`` function. Applies only to ``timm`` models. Keys with ``None`` values are pruned before passing. 57 | 58 | Returns: 59 | ``torch.nn.Module``: **UPerNet** 60 | 61 | .. _UPerNet: 62 | https://arxiv.org/abs/1807.10221 63 | 64 | """ 65 | 66 | @supports_config_loading 67 | def __init__( 68 | self, 69 | encoder_name: str = "resnet34", 70 | encoder_depth: int = 5, 71 | encoder_weights: Optional[str] = "imagenet", 72 | decoder_channels: int = 256, 73 | decoder_use_norm: Union[bool, str, Dict[str, Any]] = "batchnorm", 74 | in_channels: int = 3, 75 | classes: int = 1, 76 | activation: Optional[Union[str, Callable]] = None, 77 | upsampling: int = 4, 78 | aux_params: Optional[dict] = None, 79 | **kwargs: dict[str, Any], 80 | ): 81 | super().__init__() 82 | 83 | self.encoder = get_encoder( 84 | encoder_name, 85 | in_channels=in_channels, 86 | depth=encoder_depth, 87 | weights=encoder_weights, 88 | **kwargs, 89 | ) 90 | 91 | self.decoder = UPerNetDecoder( 92 | encoder_channels=self.encoder.out_channels, 93 | encoder_depth=encoder_depth, 94 | decoder_channels=decoder_channels, 95 | use_norm=decoder_use_norm, 96 | ) 97 | 98 | self.segmentation_head = SegmentationHead( 99 | in_channels=decoder_channels, 100 | out_channels=classes, 101 | activation=activation, 102 | kernel_size=1, 103 | upsampling=upsampling, 104 | ) 105 | 106 | if aux_params is not None: 107 | self.classification_head = ClassificationHead( 108 | in_channels=self.encoder.out_channels[-1], **aux_params 109 | ) 110 | else: 111 | self.classification_head = None 112 | 113 | self.name = "upernet-{}".format(encoder_name) 114 | self.initialize() 115 | -------------------------------------------------------------------------------- /segmentation_models_pytorch/encoders/_base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Sequence, Dict 3 | 4 | from . import _utils as utils 5 | 6 | 7 | class EncoderMixin: 8 | """Add encoder functionality such as: 9 | - output channels specification of feature tensors (produced by encoder) 10 | - patching first convolution for arbitrary input channels 11 | """ 12 | 13 | _is_torch_scriptable = True 14 | _is_torch_exportable = True 15 | _is_torch_compilable = True 16 | 17 | def __init__(self): 18 | self._depth = 5 19 | self._in_channels = 3 20 | self._output_stride = 32 21 | 22 | @property 23 | def out_channels(self): 24 | """Return channels dimensions for each tensor of forward output of encoder""" 25 | return self._out_channels[: self._depth + 1] 26 | 27 | @property 28 | def output_stride(self): 29 | return min(self._output_stride, 2**self._depth) 30 | 31 | def set_in_channels(self, in_channels, pretrained=True): 32 | """Change first convolution channels""" 33 | if in_channels == 3: 34 | return 35 | 36 | self._in_channels = in_channels 37 | if self._out_channels[0] == 3: 38 | self._out_channels = [in_channels] + self._out_channels[1:] 39 | 40 | utils.patch_first_conv( 41 | model=self, new_in_channels=in_channels, pretrained=pretrained 42 | ) 43 | 44 | def get_stages(self) -> Dict[int, Sequence[torch.nn.Module]]: 45 | """Override it in your implementation, should return a dictionary with keys as 46 | the output stride and values as the list of modules 47 | """ 48 | raise NotImplementedError 49 | 50 | def make_dilated(self, output_stride): 51 | if output_stride not in [8, 16]: 52 | raise ValueError(f"Output stride should be 16 or 8, got {output_stride}.") 53 | 54 | stages = self.get_stages() 55 | for stage_stride, stage_modules in stages.items(): 56 | if stage_stride <= output_stride: 57 | continue 58 | 59 | dilation_rate = stage_stride // output_stride 60 | for module in stage_modules: 61 | utils.replace_strides_with_dilation(module, dilation_rate) 62 | -------------------------------------------------------------------------------- /segmentation_models_pytorch/encoders/_preprocessing.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def preprocess_input( 5 | x, mean=None, std=None, input_space="RGB", input_range=None, **kwargs 6 | ): 7 | if input_space == "BGR": 8 | x = x[..., ::-1].copy() 9 | 10 | if input_range is not None: 11 | if x.max() > 1 and input_range[1] == 1: 12 | x = x / 255.0 13 | 14 | if mean is not None: 15 | mean = np.array(mean) 16 | x = x - mean 17 | 18 | if std is not None: 19 | std = np.array(std) 20 | x = x / std 21 | 22 | return x 23 | -------------------------------------------------------------------------------- /segmentation_models_pytorch/encoders/_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def patch_first_conv(model, new_in_channels, default_in_channels=3, pretrained=True): 6 | """Change first convolution layer input channels. 7 | In case: 8 | in_channels == 1 or in_channels == 2 -> reuse original weights 9 | in_channels > 3 -> make random kaiming normal initialization 10 | """ 11 | 12 | # get first conv 13 | for module in model.modules(): 14 | if isinstance(module, nn.Conv2d) and module.in_channels == default_in_channels: 15 | break 16 | 17 | weight = module.weight.detach() 18 | module.in_channels = new_in_channels 19 | 20 | if not pretrained: 21 | module.weight = nn.parameter.Parameter( 22 | torch.Tensor( 23 | module.out_channels, 24 | new_in_channels // module.groups, 25 | *module.kernel_size, 26 | ) 27 | ) 28 | module.reset_parameters() 29 | 30 | elif new_in_channels == 1: 31 | new_weight = weight.sum(1, keepdim=True) 32 | module.weight = nn.parameter.Parameter(new_weight) 33 | 34 | else: 35 | new_weight = torch.Tensor( 36 | module.out_channels, new_in_channels // module.groups, *module.kernel_size 37 | ) 38 | 39 | for i in range(new_in_channels): 40 | new_weight[:, i] = weight[:, i % default_in_channels] 41 | 42 | new_weight = new_weight * (default_in_channels / new_in_channels) 43 | module.weight = nn.parameter.Parameter(new_weight) 44 | 45 | 46 | def replace_strides_with_dilation(module, dilation_rate): 47 | """Patch Conv2d modules replacing strides with dilation""" 48 | for mod in module.modules(): 49 | if isinstance(mod, nn.Conv2d): 50 | mod.stride = (1, 1) 51 | mod.dilation = (dilation_rate, dilation_rate) 52 | kh, kw = mod.kernel_size 53 | mod.padding = ((kh // 2) * dilation_rate, (kh // 2) * dilation_rate) 54 | 55 | # Kostyl for EfficientNet 56 | if hasattr(mod, "static_padding"): 57 | mod.static_padding = nn.Identity() 58 | -------------------------------------------------------------------------------- /segmentation_models_pytorch/encoders/inceptionresnetv2.py: -------------------------------------------------------------------------------- 1 | """Each encoder should have following attributes and methods and be inherited from `_base.EncoderMixin` 2 | 3 | Attributes: 4 | 5 | _out_channels (list of int): specify number of channels for each encoder feature tensor 6 | _depth (int): specify number of stages in decoder (in other words number of downsampling operations) 7 | _in_channels (int): default number of input channels in first Conv2d layer for encoder (usually 3) 8 | 9 | Methods: 10 | 11 | forward(self, x: torch.Tensor) 12 | produce list of features of different spatial resolutions, each feature is a 4D torch.tensor of 13 | shape NCHW (features should be sorted in descending order according to spatial resolution, starting 14 | with resolution same as input `x` tensor). 15 | 16 | Input: `x` with shape (1, 3, 64, 64) 17 | Output: [f0, f1, f2, f3, f4, f5] - features with corresponding shapes 18 | [(1, 3, 64, 64), (1, 64, 32, 32), (1, 128, 16, 16), (1, 256, 8, 8), 19 | (1, 512, 4, 4), (1, 1024, 2, 2)] (C - dim may differ) 20 | 21 | also should support number of features according to specified depth, e.g. if depth = 5, 22 | number of feature tensors = 6 (one with same resolution as input and 5 downsampled), 23 | depth = 3 -> number of feature tensors = 4 (one with same resolution as input and 3 downsampled). 24 | """ 25 | 26 | import torch 27 | import torch.nn as nn 28 | from typing import List 29 | 30 | from ._base import EncoderMixin 31 | from ._inceptionresnetv2 import InceptionResNetV2 32 | 33 | 34 | class InceptionResNetV2Encoder(InceptionResNetV2, EncoderMixin): 35 | def __init__( 36 | self, 37 | out_channels: List[int], 38 | depth: int = 5, 39 | output_stride: int = 32, 40 | **kwargs, 41 | ): 42 | if depth > 5 or depth < 1: 43 | raise ValueError( 44 | f"{self.__class__.__name__} depth should be in range [1, 5], got {depth}" 45 | ) 46 | 47 | super().__init__(**kwargs) 48 | 49 | self._depth = depth 50 | self._in_channels = 3 51 | self._out_channels = out_channels 52 | self._output_stride = output_stride 53 | 54 | # correct paddings 55 | for m in self.modules(): 56 | if isinstance(m, nn.Conv2d): 57 | if m.kernel_size == (3, 3): 58 | m.padding = (1, 1) 59 | if isinstance(m, nn.MaxPool2d): 60 | m.padding = (1, 1) 61 | 62 | # for torchscript, block8 does not have relu defined 63 | self.block8.relu = nn.Identity() 64 | 65 | # remove linear layers 66 | del self.avgpool_1a 67 | del self.last_linear 68 | 69 | def make_dilated(self, *args, **kwargs): 70 | raise ValueError( 71 | "InceptionResnetV2 encoder does not support dilated mode " 72 | "due to pooling operation for downsampling!" 73 | ) 74 | 75 | def forward(self, x: torch.Tensor) -> List[torch.Tensor]: 76 | features = [x] 77 | 78 | if self._depth >= 1: 79 | x = self.conv2d_1a(x) 80 | x = self.conv2d_2a(x) 81 | x = self.conv2d_2b(x) 82 | features.append(x) 83 | 84 | if self._depth >= 2: 85 | x = self.maxpool_3a(x) 86 | x = self.conv2d_3b(x) 87 | x = self.conv2d_4a(x) 88 | features.append(x) 89 | 90 | if self._depth >= 3: 91 | x = self.maxpool_5a(x) 92 | x = self.mixed_5b(x) 93 | x = self.repeat(x) 94 | features.append(x) 95 | 96 | if self._depth >= 4: 97 | x = self.mixed_6a(x) 98 | x = self.repeat_1(x) 99 | features.append(x) 100 | 101 | if self._depth >= 5: 102 | x = self.mixed_7a(x) 103 | x = self.repeat_2(x) 104 | x = self.block8(x) 105 | x = self.conv2d_7b(x) 106 | features.append(x) 107 | 108 | return features 109 | 110 | def load_state_dict(self, state_dict, **kwargs): 111 | state_dict.pop("last_linear.bias", None) 112 | state_dict.pop("last_linear.weight", None) 113 | super().load_state_dict(state_dict, **kwargs) 114 | 115 | 116 | inceptionresnetv2_encoders = { 117 | "inceptionresnetv2": { 118 | "encoder": InceptionResNetV2Encoder, 119 | "pretrained_settings": { 120 | "imagenet": { 121 | "repo_id": "smp-hub/inceptionresnetv2.imagenet", 122 | "revision": "120c5afdbb80a1c989db0a7423ebb7a9db9b1e6c", 123 | }, 124 | "imagenet+background": { 125 | "repo_id": "smp-hub/inceptionresnetv2.imagenet-background", 126 | "revision": "3ecf3491658dc0f6a76d69c9d1cb36511b1ee56c", 127 | }, 128 | }, 129 | "params": {"out_channels": [3, 64, 192, 320, 1088, 1536], "num_classes": 1000}, 130 | } 131 | } 132 | -------------------------------------------------------------------------------- /segmentation_models_pytorch/encoders/inceptionv4.py: -------------------------------------------------------------------------------- 1 | """Each encoder should have following attributes and methods and be inherited from `_base.EncoderMixin` 2 | 3 | Attributes: 4 | 5 | _out_channels (list of int): specify number of channels for each encoder feature tensor 6 | _depth (int): specify number of stages in decoder (in other words number of downsampling operations) 7 | _in_channels (int): default number of input channels in first Conv2d layer for encoder (usually 3) 8 | 9 | Methods: 10 | 11 | forward(self, x: torch.Tensor) 12 | produce list of features of different spatial resolutions, each feature is a 4D torch.tensor of 13 | shape NCHW (features should be sorted in descending order according to spatial resolution, starting 14 | with resolution same as input `x` tensor). 15 | 16 | Input: `x` with shape (1, 3, 64, 64) 17 | Output: [f0, f1, f2, f3, f4, f5] - features with corresponding shapes 18 | [(1, 3, 64, 64), (1, 64, 32, 32), (1, 128, 16, 16), (1, 256, 8, 8), 19 | (1, 512, 4, 4), (1, 1024, 2, 2)] (C - dim may differ) 20 | 21 | also should support number of features according to specified depth, e.g. if depth = 5, 22 | number of feature tensors = 6 (one with same resolution as input and 5 downsampled), 23 | depth = 3 -> number of feature tensors = 4 (one with same resolution as input and 3 downsampled). 24 | """ 25 | 26 | import torch 27 | import torch.nn as nn 28 | 29 | from typing import List 30 | 31 | from ._base import EncoderMixin 32 | from ._inceptionv4 import InceptionV4 33 | 34 | 35 | class InceptionV4Encoder(InceptionV4, EncoderMixin): 36 | def __init__( 37 | self, 38 | out_channels: List[int], 39 | depth: int = 5, 40 | output_stride: int = 32, 41 | **kwargs, 42 | ): 43 | if depth > 5 or depth < 1: 44 | raise ValueError( 45 | f"{self.__class__.__name__} depth should be in range [1, 5], got {depth}" 46 | ) 47 | super().__init__(**kwargs) 48 | 49 | self._depth = depth 50 | self._in_channels = 3 51 | self._out_channels = out_channels 52 | self._output_stride = output_stride 53 | self._out_indexes = [2, 4, 8, 14, len(self.features) - 1] 54 | 55 | # correct paddings 56 | for m in self.modules(): 57 | if isinstance(m, nn.Conv2d): 58 | if m.kernel_size == (3, 3): 59 | m.padding = (1, 1) 60 | if isinstance(m, nn.MaxPool2d): 61 | m.padding = (1, 1) 62 | 63 | # remove linear layers 64 | del self.last_linear 65 | 66 | def make_dilated(self, *args, **kwargs): 67 | raise ValueError( 68 | "InceptionV4 encoder does not support dilated mode " 69 | "due to pooling operation for downsampling!" 70 | ) 71 | 72 | def forward(self, x: torch.Tensor) -> List[torch.Tensor]: 73 | depth = 0 74 | features = [x] 75 | 76 | for i, module in enumerate(self.features): 77 | x = module(x) 78 | 79 | if i in self._out_indexes: 80 | features.append(x) 81 | depth += 1 82 | 83 | # torchscript does not support break in cycle, so we just 84 | # go over all modules and then slice number of features 85 | if not torch.jit.is_scripting() and depth > self._depth: 86 | break 87 | 88 | features = features[: self._depth + 1] 89 | return features 90 | 91 | def load_state_dict(self, state_dict, **kwargs): 92 | state_dict.pop("last_linear.bias", None) 93 | state_dict.pop("last_linear.weight", None) 94 | super().load_state_dict(state_dict, **kwargs) 95 | 96 | 97 | inceptionv4_encoders = { 98 | "inceptionv4": { 99 | "encoder": InceptionV4Encoder, 100 | "pretrained_settings": { 101 | "imagenet": { 102 | "repo_id": "smp-hub/inceptionv4.imagenet", 103 | "revision": "918fb54f07811d82a4ecde3a51156041d0facba9", 104 | }, 105 | "imagenet+background": { 106 | "repo_id": "smp-hub/inceptionv4.imagenet-background", 107 | "revision": "8c2a48e20d2709ee64f8421c61be309f05bfa536", 108 | }, 109 | }, 110 | "params": { 111 | "out_channels": [3, 64, 192, 384, 1024, 1536], 112 | "num_classes": 1001, 113 | }, 114 | } 115 | } 116 | -------------------------------------------------------------------------------- /segmentation_models_pytorch/encoders/mobilenet.py: -------------------------------------------------------------------------------- 1 | """Each encoder should have following attributes and methods and be inherited from `_base.EncoderMixin` 2 | 3 | Attributes: 4 | 5 | _out_channels (list of int): specify number of channels for each encoder feature tensor 6 | _depth (int): specify number of stages in decoder (in other words number of downsampling operations) 7 | _in_channels (int): default number of input channels in first Conv2d layer for encoder (usually 3) 8 | 9 | Methods: 10 | 11 | forward(self, x: torch.Tensor) 12 | produce list of features of different spatial resolutions, each feature is a 4D torch.tensor of 13 | shape NCHW (features should be sorted in descending order according to spatial resolution, starting 14 | with resolution same as input `x` tensor). 15 | 16 | Input: `x` with shape (1, 3, 64, 64) 17 | Output: [f0, f1, f2, f3, f4, f5] - features with corresponding shapes 18 | [(1, 3, 64, 64), (1, 64, 32, 32), (1, 128, 16, 16), (1, 256, 8, 8), 19 | (1, 512, 4, 4), (1, 1024, 2, 2)] (C - dim may differ) 20 | 21 | also should support number of features according to specified depth, e.g. if depth = 5, 22 | number of feature tensors = 6 (one with same resolution as input and 5 downsampled), 23 | depth = 3 -> number of feature tensors = 4 (one with same resolution as input and 3 downsampled). 24 | """ 25 | 26 | import torch 27 | import torchvision 28 | from typing import Dict, Sequence, List 29 | 30 | from ._base import EncoderMixin 31 | 32 | 33 | class MobileNetV2Encoder(torchvision.models.MobileNetV2, EncoderMixin): 34 | def __init__( 35 | self, out_channels: List[int], depth: int = 5, output_stride: int = 32, **kwargs 36 | ): 37 | if depth > 5 or depth < 1: 38 | raise ValueError( 39 | f"{self.__class__.__name__} depth should be in range [1, 5], got {depth}" 40 | ) 41 | super().__init__(**kwargs) 42 | 43 | self._depth = depth 44 | self._in_channels = 3 45 | self._out_channels = out_channels 46 | self._output_stride = output_stride 47 | self._out_indexes = [1, 3, 6, 13, len(self.features) - 1] 48 | 49 | del self.classifier 50 | 51 | def get_stages(self) -> Dict[int, Sequence[torch.nn.Module]]: 52 | return { 53 | 16: [self.features[7:14]], 54 | 32: [self.features[14:]], 55 | } 56 | 57 | def forward(self, x: torch.Tensor) -> List[torch.Tensor]: 58 | features = [x] 59 | 60 | depth = 0 61 | for i, module in enumerate(self.features): 62 | x = module(x) 63 | 64 | if i in self._out_indexes: 65 | features.append(x) 66 | depth += 1 67 | 68 | # torchscript does not support break in cycle, so we just 69 | # go over all modules and then slice number of features 70 | if not torch.jit.is_scripting() and depth > self._depth: 71 | break 72 | 73 | features = features[: self._depth + 1] 74 | 75 | return features 76 | 77 | def load_state_dict(self, state_dict, **kwargs): 78 | state_dict.pop("classifier.1.bias", None) 79 | state_dict.pop("classifier.1.weight", None) 80 | super().load_state_dict(state_dict, **kwargs) 81 | 82 | 83 | mobilenet_encoders = { 84 | "mobilenet_v2": { 85 | "encoder": MobileNetV2Encoder, 86 | "pretrained_settings": { 87 | "imagenet": { 88 | "repo_id": "smp-hub/mobilenet_v2.imagenet", 89 | "revision": "e67aa804e17f7b404b629127eabbd224c4e0690b", 90 | } 91 | }, 92 | "params": {"out_channels": [3, 16, 24, 32, 96, 1280]}, 93 | } 94 | } 95 | -------------------------------------------------------------------------------- /segmentation_models_pytorch/encoders/timm_sknet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Dict, List, Sequence 3 | from timm.models.resnet import ResNet 4 | from timm.models.sknet import SelectiveKernelBottleneck, SelectiveKernelBasic 5 | 6 | from ._base import EncoderMixin 7 | 8 | 9 | class SkNetEncoder(ResNet, EncoderMixin): 10 | def __init__( 11 | self, 12 | out_channels: List[int], 13 | depth: int = 5, 14 | output_stride: int = 32, 15 | **kwargs, 16 | ): 17 | if depth > 5 or depth < 1: 18 | raise ValueError( 19 | f"{self.__class__.__name__} depth should be in range [1, 5], got {depth}" 20 | ) 21 | super().__init__(**kwargs) 22 | 23 | self._depth = depth 24 | self._in_channels = 3 25 | self._out_channels = out_channels 26 | self._output_stride = output_stride 27 | 28 | del self.fc 29 | del self.global_pool 30 | 31 | def get_stages(self) -> Dict[int, Sequence[torch.nn.Module]]: 32 | return { 33 | 16: [self.layer3], 34 | 32: [self.layer4], 35 | } 36 | 37 | def forward(self, x: torch.Tensor) -> List[torch.Tensor]: 38 | features = [x] 39 | 40 | if self._depth >= 1: 41 | x = self.conv1(x) 42 | x = self.bn1(x) 43 | x = self.act1(x) 44 | features.append(x) 45 | 46 | if self._depth >= 2: 47 | x = self.maxpool(x) 48 | x = self.layer1(x) 49 | features.append(x) 50 | 51 | if self._depth >= 3: 52 | x = self.layer2(x) 53 | features.append(x) 54 | 55 | if self._depth >= 4: 56 | x = self.layer3(x) 57 | features.append(x) 58 | 59 | if self._depth >= 5: 60 | x = self.layer4(x) 61 | features.append(x) 62 | 63 | return features 64 | 65 | def load_state_dict(self, state_dict, **kwargs): 66 | state_dict.pop("fc.bias", None) 67 | state_dict.pop("fc.weight", None) 68 | super().load_state_dict(state_dict, **kwargs) 69 | 70 | 71 | timm_sknet_encoders = { 72 | "timm-skresnet18": { 73 | "encoder": SkNetEncoder, 74 | "pretrained_settings": { 75 | "imagenet": { 76 | "repo_id": "smp-hub/timm-skresnet18.imagenet", 77 | "revision": "6c97652bb744d89177b68274d2fda3923a7d1f95", 78 | }, 79 | }, 80 | "params": { 81 | "out_channels": [3, 64, 64, 128, 256, 512], 82 | "block": SelectiveKernelBasic, 83 | "layers": [2, 2, 2, 2], 84 | "zero_init_last": False, 85 | "block_args": {"sk_kwargs": {"rd_ratio": 1 / 8, "split_input": True}}, 86 | }, 87 | }, 88 | "timm-skresnet34": { 89 | "encoder": SkNetEncoder, 90 | "pretrained_settings": { 91 | "imagenet": { 92 | "repo_id": "smp-hub/timm-skresnet34.imagenet", 93 | "revision": "2367796924a8182cc835ef6b5dc303917f923f99", 94 | }, 95 | }, 96 | "params": { 97 | "out_channels": [3, 64, 64, 128, 256, 512], 98 | "block": SelectiveKernelBasic, 99 | "layers": [3, 4, 6, 3], 100 | "zero_init_last": False, 101 | "block_args": {"sk_kwargs": {"rd_ratio": 1 / 8, "split_input": True}}, 102 | }, 103 | }, 104 | "timm-skresnext50_32x4d": { 105 | "encoder": SkNetEncoder, 106 | "pretrained_settings": { 107 | "imagenet": { 108 | "repo_id": "smp-hub/timm-skresnext50_32x4d.imagenet", 109 | "revision": "50207e407cc4c6ea9e6872963db6844ca7b7b9de", 110 | }, 111 | }, 112 | "params": { 113 | "out_channels": [3, 64, 256, 512, 1024, 2048], 114 | "block": SelectiveKernelBottleneck, 115 | "layers": [3, 4, 6, 3], 116 | "zero_init_last": False, 117 | "cardinality": 32, 118 | "base_width": 4, 119 | }, 120 | }, 121 | } 122 | -------------------------------------------------------------------------------- /segmentation_models_pytorch/encoders/xception.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from ._base import EncoderMixin 4 | from ._xception import Xception 5 | 6 | 7 | class XceptionEncoder(Xception, EncoderMixin): 8 | def __init__( 9 | self, 10 | out_channels: List[int], 11 | *args, 12 | depth: int = 5, 13 | output_stride: int = 32, 14 | **kwargs, 15 | ): 16 | if depth > 5 or depth < 1: 17 | raise ValueError( 18 | f"{self.__class__.__name__} depth should be in range [1, 5], got {depth}" 19 | ) 20 | super().__init__(*args, **kwargs) 21 | 22 | self._depth = depth 23 | self._in_channels = 3 24 | self._out_channels = out_channels 25 | self._output_stride = output_stride 26 | 27 | # modify padding to maintain output shape 28 | self.conv1.padding = (1, 1) 29 | self.conv2.padding = (1, 1) 30 | 31 | del self.fc 32 | 33 | def make_dilated(self, *args, **kwargs): 34 | raise ValueError( 35 | "Xception encoder does not support dilated mode " 36 | "due to pooling operation for downsampling!" 37 | ) 38 | 39 | def forward(self, x): 40 | features = [x] 41 | 42 | if self._depth >= 1: 43 | x = self.conv1(x) 44 | x = self.bn1(x) 45 | x = self.relu1(x) 46 | x = self.conv2(x) 47 | x = self.bn2(x) 48 | x = self.relu2(x) 49 | features.append(x) 50 | 51 | if self._depth >= 2: 52 | x = self.block1(x) 53 | features.append(x) 54 | 55 | if self._depth >= 3: 56 | x = self.block2(x) 57 | features.append(x) 58 | 59 | if self._depth >= 4: 60 | x = self.block3(x) 61 | x = self.block4(x) 62 | x = self.block5(x) 63 | x = self.block6(x) 64 | x = self.block7(x) 65 | x = self.block8(x) 66 | x = self.block9(x) 67 | x = self.block10(x) 68 | x = self.block11(x) 69 | features.append(x) 70 | 71 | if self._depth >= 5: 72 | x = self.block12(x) 73 | x = self.conv3(x) 74 | x = self.bn3(x) 75 | x = self.relu3(x) 76 | x = self.conv4(x) 77 | x = self.bn4(x) 78 | features.append(x) 79 | 80 | return features 81 | 82 | def load_state_dict(self, state_dict): 83 | # remove linear 84 | state_dict.pop("fc.bias", None) 85 | state_dict.pop("fc.weight", None) 86 | 87 | super().load_state_dict(state_dict) 88 | 89 | 90 | xception_encoders = { 91 | "xception": { 92 | "encoder": XceptionEncoder, 93 | "pretrained_settings": { 94 | "imagenet": { 95 | "repo_id": "smp-hub/xception.imagenet", 96 | "revision": "01cfaf27c11353b1f0c578e7e26d2c000ea91049", 97 | }, 98 | }, 99 | "params": {"out_channels": [3, 64, 128, 256, 728, 2048]}, 100 | } 101 | } 102 | -------------------------------------------------------------------------------- /segmentation_models_pytorch/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from .constants import BINARY_MODE, MULTICLASS_MODE, MULTILABEL_MODE 2 | 3 | from .jaccard import JaccardLoss 4 | from .dice import DiceLoss 5 | from .focal import FocalLoss 6 | from .lovasz import LovaszLoss 7 | from .soft_bce import SoftBCEWithLogitsLoss 8 | from .soft_ce import SoftCrossEntropyLoss 9 | from .tversky import TverskyLoss 10 | from .mcc import MCCLoss 11 | 12 | __all__ = [ 13 | "BINARY_MODE", 14 | "MULTICLASS_MODE", 15 | "MULTILABEL_MODE", 16 | "JaccardLoss", 17 | "DiceLoss", 18 | "FocalLoss", 19 | "LovaszLoss", 20 | "SoftBCEWithLogitsLoss", 21 | "SoftCrossEntropyLoss", 22 | "TverskyLoss", 23 | "MCCLoss", 24 | ] 25 | -------------------------------------------------------------------------------- /segmentation_models_pytorch/losses/constants.py: -------------------------------------------------------------------------------- 1 | #: Loss binary mode suppose you are solving binary segmentation task. 2 | #: That mean yor have only one class which pixels are labled as **1**, 3 | #: the rest pixels are background and labeled as **0**. 4 | #: Target mask shape - (N, H, W), model output mask shape (N, 1, H, W). 5 | BINARY_MODE: str = "binary" 6 | 7 | #: Loss multiclass mode suppose you are solving multi-**class** segmentation task. 8 | #: That mean you have *C = 1..N* classes which have unique label values, 9 | #: classes are mutually exclusive and all pixels are labeled with theese values. 10 | #: Target mask shape - (N, H, W), model output mask shape (N, C, H, W). 11 | MULTICLASS_MODE: str = "multiclass" 12 | 13 | #: Loss multilabel mode suppose you are solving multi-**label** segmentation task. 14 | #: That mean you have *C = 1..N* classes which pixels are labeled as **1**, 15 | #: classes are not mutually exclusive and each class have its own *channel*, 16 | #: pixels in each channel which are not belong to class labeled as **0**. 17 | #: Target mask shape - (N, C, H, W), model output mask shape (N, C, H, W). 18 | MULTILABEL_MODE: str = "multilabel" 19 | -------------------------------------------------------------------------------- /segmentation_models_pytorch/losses/dice.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, List 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torch.nn.modules.loss import _Loss 6 | from ._functional import soft_dice_score, to_tensor 7 | from .constants import BINARY_MODE, MULTICLASS_MODE, MULTILABEL_MODE 8 | 9 | __all__ = ["DiceLoss"] 10 | 11 | 12 | class DiceLoss(_Loss): 13 | def __init__( 14 | self, 15 | mode: str, 16 | classes: Optional[List[int]] = None, 17 | log_loss: bool = False, 18 | from_logits: bool = True, 19 | smooth: float = 0.0, 20 | ignore_index: Optional[int] = None, 21 | eps: float = 1e-7, 22 | ): 23 | """Dice loss for image segmentation task. 24 | It supports binary, multiclass and multilabel cases 25 | 26 | Args: 27 | mode: Loss mode 'binary', 'multiclass' or 'multilabel' 28 | classes: List of classes that contribute in loss computation. By default, all channels are included. 29 | log_loss: If True, loss computed as `- log(dice_coeff)`, otherwise `1 - dice_coeff` 30 | from_logits: If True, assumes input is raw logits 31 | smooth: Smoothness constant for dice coefficient (a) 32 | ignore_index: Label that indicates ignored pixels (does not contribute to loss) 33 | eps: A small epsilon for numerical stability to avoid zero division error 34 | (denominator will be always greater or equal to eps) 35 | 36 | Shape 37 | - **y_pred** - torch.Tensor of shape (N, C, H, W) 38 | - **y_true** - torch.Tensor of shape (N, H, W) or (N, C, H, W) 39 | 40 | Reference 41 | https://github.com/BloodAxe/pytorch-toolbelt 42 | """ 43 | assert mode in {BINARY_MODE, MULTILABEL_MODE, MULTICLASS_MODE} 44 | super(DiceLoss, self).__init__() 45 | self.mode = mode 46 | if classes is not None: 47 | assert mode != BINARY_MODE, ( 48 | "Masking classes is not supported with mode=binary" 49 | ) 50 | classes = to_tensor(classes, dtype=torch.long) 51 | 52 | self.classes = classes 53 | self.from_logits = from_logits 54 | self.smooth = smooth 55 | self.eps = eps 56 | self.log_loss = log_loss 57 | self.ignore_index = ignore_index 58 | 59 | def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: 60 | assert y_true.size(0) == y_pred.size(0) 61 | 62 | if self.from_logits: 63 | # Apply activations to get [0..1] class probabilities 64 | # Using Log-Exp as this gives more numerically stable result and does not cause vanishing gradient on 65 | # extreme values 0 and 1 66 | if self.mode == MULTICLASS_MODE: 67 | y_pred = y_pred.log_softmax(dim=1).exp() 68 | else: 69 | y_pred = F.logsigmoid(y_pred).exp() 70 | 71 | bs = y_true.size(0) 72 | num_classes = y_pred.size(1) 73 | dims = (0, 2) 74 | 75 | if self.mode == BINARY_MODE: 76 | y_true = y_true.view(bs, 1, -1) 77 | y_pred = y_pred.view(bs, 1, -1) 78 | 79 | if self.ignore_index is not None: 80 | mask = y_true != self.ignore_index 81 | y_pred = y_pred * mask 82 | y_true = y_true * mask 83 | 84 | if self.mode == MULTICLASS_MODE: 85 | y_true = y_true.view(bs, -1) 86 | y_pred = y_pred.view(bs, num_classes, -1) 87 | 88 | if self.ignore_index is not None: 89 | mask = y_true != self.ignore_index 90 | y_pred = y_pred * mask.unsqueeze(1) 91 | 92 | y_true = F.one_hot( 93 | (y_true * mask).to(torch.long), num_classes 94 | ) # N,H*W -> N,H*W, C 95 | y_true = y_true.permute(0, 2, 1) * mask.unsqueeze(1) # N, C, H*W 96 | else: 97 | y_true = F.one_hot(y_true, num_classes) # N,H*W -> N,H*W, C 98 | y_true = y_true.permute(0, 2, 1) # N, C, H*W 99 | 100 | if self.mode == MULTILABEL_MODE: 101 | y_true = y_true.view(bs, num_classes, -1) 102 | y_pred = y_pred.view(bs, num_classes, -1) 103 | 104 | if self.ignore_index is not None: 105 | mask = y_true != self.ignore_index 106 | y_pred = y_pred * mask 107 | y_true = y_true * mask 108 | 109 | scores = self.compute_score( 110 | y_pred, y_true.type_as(y_pred), smooth=self.smooth, eps=self.eps, dims=dims 111 | ) 112 | 113 | if self.log_loss: 114 | loss = -torch.log(scores.clamp_min(self.eps)) 115 | else: 116 | loss = 1.0 - scores 117 | 118 | # Dice loss is undefined for non-empty classes 119 | # So we zero contribution of channel that does not have true pixels 120 | # NOTE: A better workaround would be to use loss term `mean(y_pred)` 121 | # for this case, however it will be a modified jaccard loss 122 | 123 | mask = y_true.sum(dims) > 0 124 | loss *= mask.to(loss.dtype) 125 | 126 | if self.classes is not None: 127 | loss = loss[self.classes] 128 | 129 | return self.aggregate_loss(loss) 130 | 131 | def aggregate_loss(self, loss): 132 | return loss.mean() 133 | 134 | def compute_score( 135 | self, output, target, smooth=0.0, eps=1e-7, dims=None 136 | ) -> torch.Tensor: 137 | return soft_dice_score(output, target, smooth, eps, dims) 138 | -------------------------------------------------------------------------------- /segmentation_models_pytorch/losses/focal.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | from functools import partial 3 | 4 | import torch 5 | from torch.nn.modules.loss import _Loss 6 | from ._functional import focal_loss_with_logits 7 | from .constants import BINARY_MODE, MULTICLASS_MODE, MULTILABEL_MODE 8 | 9 | __all__ = ["FocalLoss"] 10 | 11 | 12 | class FocalLoss(_Loss): 13 | def __init__( 14 | self, 15 | mode: str, 16 | alpha: Optional[float] = None, 17 | gamma: Optional[float] = 2.0, 18 | ignore_index: Optional[int] = None, 19 | reduction: Optional[str] = "mean", 20 | normalized: bool = False, 21 | reduced_threshold: Optional[float] = None, 22 | ): 23 | """Compute Focal loss 24 | 25 | Args: 26 | mode: Loss mode 'binary', 'multiclass' or 'multilabel' 27 | alpha: Prior probability of having positive value in target. 28 | gamma: Power factor for dampening weight (focal strength). 29 | ignore_index: If not None, targets may contain values to be ignored. 30 | Target values equal to ignore_index will be ignored from loss computation. 31 | normalized: Compute normalized focal loss (https://arxiv.org/pdf/1909.07829.pdf). 32 | reduced_threshold: Switch to reduced focal loss. Note, when using this mode you 33 | should use `reduction="sum"`. 34 | 35 | Shape 36 | - **y_pred** - torch.Tensor of shape (N, C, H, W) 37 | - **y_true** - torch.Tensor of shape (N, H, W) or (N, C, H, W) 38 | 39 | Reference 40 | https://github.com/BloodAxe/pytorch-toolbelt 41 | 42 | """ 43 | assert mode in {BINARY_MODE, MULTILABEL_MODE, MULTICLASS_MODE} 44 | super().__init__() 45 | 46 | self.mode = mode 47 | self.ignore_index = ignore_index 48 | self.focal_loss_fn = partial( 49 | focal_loss_with_logits, 50 | alpha=alpha, 51 | gamma=gamma, 52 | reduced_threshold=reduced_threshold, 53 | reduction=reduction, 54 | normalized=normalized, 55 | ) 56 | 57 | def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: 58 | if self.mode in {BINARY_MODE, MULTILABEL_MODE}: 59 | y_true = y_true.view(-1) 60 | y_pred = y_pred.view(-1) 61 | 62 | if self.ignore_index is not None: 63 | # Filter predictions with ignore label from loss computation 64 | not_ignored = y_true != self.ignore_index 65 | y_pred = y_pred[not_ignored] 66 | y_true = y_true[not_ignored] 67 | 68 | loss = self.focal_loss_fn(y_pred, y_true) 69 | 70 | elif self.mode == MULTICLASS_MODE: 71 | num_classes = y_pred.size(1) 72 | loss = 0 73 | 74 | # Filter anchors with -1 label from loss computation 75 | if self.ignore_index is not None: 76 | not_ignored = y_true != self.ignore_index 77 | 78 | for cls in range(num_classes): 79 | cls_y_true = (y_true == cls).long() 80 | cls_y_pred = y_pred[:, cls, ...] 81 | 82 | if self.ignore_index is not None: 83 | cls_y_true = cls_y_true[not_ignored] 84 | cls_y_pred = cls_y_pred[not_ignored] 85 | 86 | loss += self.focal_loss_fn(cls_y_pred, cls_y_true) 87 | 88 | return loss 89 | -------------------------------------------------------------------------------- /segmentation_models_pytorch/losses/jaccard.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, List 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torch.nn.modules.loss import _Loss 6 | from ._functional import soft_jaccard_score, to_tensor 7 | from .constants import BINARY_MODE, MULTICLASS_MODE, MULTILABEL_MODE 8 | 9 | __all__ = ["JaccardLoss"] 10 | 11 | 12 | class JaccardLoss(_Loss): 13 | def __init__( 14 | self, 15 | mode: str, 16 | classes: Optional[List[int]] = None, 17 | log_loss: bool = False, 18 | from_logits: bool = True, 19 | smooth: float = 0.0, 20 | eps: float = 1e-7, 21 | ): 22 | """Jaccard loss for image segmentation task. 23 | It supports binary, multiclass and multilabel cases 24 | 25 | Args: 26 | mode: Loss mode 'binary', 'multiclass' or 'multilabel' 27 | classes: List of classes that contribute in loss computation. By default, all channels are included. 28 | log_loss: If True, loss computed as `- log(jaccard_coeff)`, otherwise `1 - jaccard_coeff` 29 | from_logits: If True, assumes input is raw logits 30 | smooth: Smoothness constant for dice coefficient 31 | eps: A small epsilon for numerical stability to avoid zero division error 32 | (denominator will be always greater or equal to eps) 33 | 34 | Shape 35 | - **y_pred** - torch.Tensor of shape (N, C, H, W) 36 | - **y_true** - torch.Tensor of shape (N, H, W) or (N, C, H, W) 37 | 38 | Reference 39 | https://github.com/BloodAxe/pytorch-toolbelt 40 | """ 41 | assert mode in {BINARY_MODE, MULTILABEL_MODE, MULTICLASS_MODE} 42 | super(JaccardLoss, self).__init__() 43 | 44 | self.mode = mode 45 | if classes is not None: 46 | assert mode != BINARY_MODE, ( 47 | "Masking classes is not supported with mode=binary" 48 | ) 49 | classes = to_tensor(classes, dtype=torch.long) 50 | 51 | self.classes = classes 52 | self.from_logits = from_logits 53 | self.smooth = smooth 54 | self.eps = eps 55 | self.log_loss = log_loss 56 | 57 | def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: 58 | assert y_true.size(0) == y_pred.size(0) 59 | 60 | if self.from_logits: 61 | # Apply activations to get [0..1] class probabilities 62 | # Using Log-Exp as this gives more numerically stable result and does not cause vanishing gradient on 63 | # extreme values 0 and 1 64 | if self.mode == MULTICLASS_MODE: 65 | y_pred = y_pred.log_softmax(dim=1).exp() 66 | else: 67 | y_pred = F.logsigmoid(y_pred).exp() 68 | 69 | bs = y_true.size(0) 70 | num_classes = y_pred.size(1) 71 | dims = (0, 2) 72 | 73 | if self.mode == BINARY_MODE: 74 | y_true = y_true.view(bs, 1, -1) 75 | y_pred = y_pred.view(bs, 1, -1) 76 | 77 | if self.mode == MULTICLASS_MODE: 78 | y_true = y_true.view(bs, -1) 79 | y_pred = y_pred.view(bs, num_classes, -1) 80 | 81 | y_true = F.one_hot(y_true, num_classes) # N,H*W -> N,H*W, C 82 | y_true = y_true.permute(0, 2, 1) # H, C, H*W 83 | 84 | if self.mode == MULTILABEL_MODE: 85 | y_true = y_true.view(bs, num_classes, -1) 86 | y_pred = y_pred.view(bs, num_classes, -1) 87 | 88 | scores = soft_jaccard_score( 89 | y_pred, 90 | y_true.type(y_pred.dtype), 91 | smooth=self.smooth, 92 | eps=self.eps, 93 | dims=dims, 94 | ) 95 | 96 | if self.log_loss: 97 | loss = -torch.log(scores.clamp_min(self.eps)) 98 | else: 99 | loss = 1.0 - scores 100 | 101 | # IoU loss is defined for non-empty classes 102 | # So we zero contribution of channel that does not have true pixels 103 | # NOTE: A better workaround would be to use loss term `mean(y_pred)` 104 | # for this case, however it will be a modified jaccard loss 105 | 106 | mask = y_true.sum(dims) > 0 107 | loss *= mask.float() 108 | 109 | if self.classes is not None: 110 | loss = loss[self.classes] 111 | 112 | return loss.mean() 113 | -------------------------------------------------------------------------------- /segmentation_models_pytorch/losses/mcc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn.modules.loss import _Loss 3 | 4 | 5 | class MCCLoss(_Loss): 6 | def __init__(self, eps: float = 1e-5): 7 | """Compute Matthews Correlation Coefficient Loss for image segmentation task. 8 | It only supports binary mode. 9 | 10 | Args: 11 | eps (float): Small epsilon to handle situations where all the samples in the dataset belong to one class 12 | 13 | Reference: 14 | https://github.com/kakumarabhishek/MCC-Loss 15 | """ 16 | super().__init__() 17 | self.eps = eps 18 | 19 | def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: 20 | """Compute MCC loss 21 | 22 | Args: 23 | y_pred (torch.Tensor): model prediction of shape (N, H, W) or (N, 1, H, W) 24 | y_true (torch.Tensor): ground truth labels of shape (N, H, W) or (N, 1, H, W) 25 | 26 | Returns: 27 | torch.Tensor: loss value (1 - mcc) 28 | """ 29 | 30 | bs = y_true.shape[0] 31 | 32 | y_true = y_true.view(bs, 1, -1) 33 | y_pred = y_pred.view(bs, 1, -1) 34 | 35 | tp = torch.sum(torch.mul(y_pred, y_true)) + self.eps 36 | tn = torch.sum(torch.mul((1 - y_pred), (1 - y_true))) + self.eps 37 | fp = torch.sum(torch.mul(y_pred, (1 - y_true))) + self.eps 38 | fn = torch.sum(torch.mul((1 - y_pred), y_true)) + self.eps 39 | 40 | numerator = torch.mul(tp, tn) - torch.mul(fp, fn) 41 | denominator = torch.sqrt( 42 | torch.add(tp, fp) 43 | * torch.add(tp, fn) 44 | * torch.add(tn, fp) 45 | * torch.add(tn, fn) 46 | ) 47 | 48 | mcc = torch.div(numerator.sum(), denominator.sum()) 49 | loss = 1.0 - mcc 50 | 51 | return loss 52 | -------------------------------------------------------------------------------- /segmentation_models_pytorch/losses/soft_bce.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import nn 6 | 7 | __all__ = ["SoftBCEWithLogitsLoss"] 8 | 9 | 10 | class SoftBCEWithLogitsLoss(nn.Module): 11 | __constants__ = [ 12 | "weight", 13 | "pos_weight", 14 | "reduction", 15 | "ignore_index", 16 | "smooth_factor", 17 | ] 18 | 19 | def __init__( 20 | self, 21 | weight: Optional[torch.Tensor] = None, 22 | ignore_index: Optional[int] = -100, 23 | reduction: str = "mean", 24 | smooth_factor: Optional[float] = None, 25 | pos_weight: Optional[torch.Tensor] = None, 26 | ): 27 | """Drop-in replacement for torch.nn.BCEWithLogitsLoss with few additions: ignore_index and label_smoothing 28 | 29 | Args: 30 | ignore_index: Specifies a target value that is ignored and does not contribute to the input gradient. 31 | smooth_factor: Factor to smooth target (e.g. if smooth_factor=0.1 then [1, 0, 1] -> [0.9, 0.1, 0.9]) 32 | 33 | Shape 34 | - **y_pred** - torch.Tensor of shape NxCxHxW 35 | - **y_true** - torch.Tensor of shape NxHxW or Nx1xHxW 36 | 37 | Reference 38 | https://github.com/BloodAxe/pytorch-toolbelt 39 | 40 | """ 41 | super().__init__() 42 | self.ignore_index = ignore_index 43 | self.reduction = reduction 44 | self.smooth_factor = smooth_factor 45 | self.register_buffer("weight", weight) 46 | self.register_buffer("pos_weight", pos_weight) 47 | 48 | def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: 49 | """ 50 | Args: 51 | y_pred: torch.Tensor of shape (N, C, H, W) 52 | y_true: torch.Tensor of shape (N, H, W) or (N, 1, H, W) 53 | 54 | Returns: 55 | loss: torch.Tensor 56 | """ 57 | 58 | if self.smooth_factor is not None: 59 | soft_targets = (1 - y_true) * self.smooth_factor + y_true * ( 60 | 1 - self.smooth_factor 61 | ) 62 | else: 63 | soft_targets = y_true 64 | 65 | loss = F.binary_cross_entropy_with_logits( 66 | y_pred, 67 | soft_targets, 68 | self.weight, 69 | pos_weight=self.pos_weight, 70 | reduction="none", 71 | ) 72 | 73 | if self.ignore_index is not None: 74 | not_ignored_mask = y_true != self.ignore_index 75 | loss *= not_ignored_mask.type_as(loss) 76 | 77 | if self.reduction == "mean": 78 | loss = loss.mean() 79 | 80 | if self.reduction == "sum": 81 | loss = loss.sum() 82 | 83 | return loss 84 | -------------------------------------------------------------------------------- /segmentation_models_pytorch/losses/soft_ce.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | from torch import nn 3 | import torch 4 | import torch.nn.functional as F 5 | from ._functional import label_smoothed_nll_loss 6 | 7 | __all__ = ["SoftCrossEntropyLoss"] 8 | 9 | 10 | class SoftCrossEntropyLoss(nn.Module): 11 | __constants__ = ["reduction", "ignore_index", "smooth_factor"] 12 | 13 | def __init__( 14 | self, 15 | reduction: str = "mean", 16 | smooth_factor: Optional[float] = None, 17 | ignore_index: Optional[int] = -100, 18 | dim: int = 1, 19 | ): 20 | """Drop-in replacement for torch.nn.CrossEntropyLoss with label_smoothing 21 | 22 | Args: 23 | smooth_factor: Factor to smooth target (e.g. if smooth_factor=0.1 then [1, 0, 0] -> [0.9, 0.05, 0.05]) 24 | 25 | Shape 26 | - **y_pred** - torch.Tensor of shape (N, C, H, W) 27 | - **y_true** - torch.Tensor of shape (N, H, W) 28 | 29 | Reference 30 | https://github.com/BloodAxe/pytorch-toolbelt 31 | """ 32 | super().__init__() 33 | self.smooth_factor = smooth_factor 34 | self.ignore_index = ignore_index 35 | self.reduction = reduction 36 | self.dim = dim 37 | 38 | def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: 39 | log_prob = F.log_softmax(y_pred, dim=self.dim) 40 | return label_smoothed_nll_loss( 41 | log_prob, 42 | y_true, 43 | epsilon=self.smooth_factor, 44 | ignore_index=self.ignore_index, 45 | reduction=self.reduction, 46 | dim=self.dim, 47 | ) 48 | -------------------------------------------------------------------------------- /segmentation_models_pytorch/losses/tversky.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | import torch 4 | from ._functional import soft_tversky_score 5 | from .constants import BINARY_MODE, MULTICLASS_MODE, MULTILABEL_MODE 6 | from .dice import DiceLoss 7 | 8 | __all__ = ["TverskyLoss"] 9 | 10 | 11 | class TverskyLoss(DiceLoss): 12 | """Tversky loss for image segmentation task. 13 | Where FP and FN is weighted by alpha and beta params. 14 | With alpha == beta == 0.5, this loss becomes equal DiceLoss. 15 | It supports binary, multiclass and multilabel cases 16 | 17 | Args: 18 | mode: Metric mode {'binary', 'multiclass', 'multilabel'} 19 | classes: Optional list of classes that contribute in loss computation; 20 | By default, all channels are included. 21 | log_loss: If True, loss computed as ``-log(tversky)`` otherwise ``1 - tversky`` 22 | from_logits: If True assumes input is raw logits 23 | smooth: 24 | ignore_index: Label that indicates ignored pixels (does not contribute to loss) 25 | eps: Small epsilon for numerical stability 26 | alpha: Weight constant that penalize model for FPs (False Positives) 27 | beta: Weight constant that penalize model for FNs (False Negatives) 28 | gamma: Constant that squares the error function. Defaults to ``1.0`` 29 | 30 | Return: 31 | loss: torch.Tensor 32 | 33 | """ 34 | 35 | def __init__( 36 | self, 37 | mode: str, 38 | classes: List[int] = None, 39 | log_loss: bool = False, 40 | from_logits: bool = True, 41 | smooth: float = 0.0, 42 | ignore_index: Optional[int] = None, 43 | eps: float = 1e-7, 44 | alpha: float = 0.5, 45 | beta: float = 0.5, 46 | gamma: float = 1.0, 47 | ): 48 | assert mode in {BINARY_MODE, MULTILABEL_MODE, MULTICLASS_MODE} 49 | super().__init__( 50 | mode, classes, log_loss, from_logits, smooth, ignore_index, eps 51 | ) 52 | self.alpha = alpha 53 | self.beta = beta 54 | self.gamma = gamma 55 | 56 | def aggregate_loss(self, loss): 57 | return loss.mean() ** self.gamma 58 | 59 | def compute_score( 60 | self, output, target, smooth=0.0, eps=1e-7, dims=None 61 | ) -> torch.Tensor: 62 | return soft_tversky_score( 63 | output, target, self.alpha, self.beta, smooth, eps, dims 64 | ) 65 | -------------------------------------------------------------------------------- /segmentation_models_pytorch/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | # fmt: off 2 | from .functional import ( 3 | get_stats, 4 | fbeta_score, 5 | f1_score, 6 | iou_score, 7 | accuracy, 8 | precision, 9 | recall, 10 | sensitivity, 11 | specificity, 12 | balanced_accuracy, 13 | positive_predictive_value, 14 | negative_predictive_value, 15 | false_negative_rate, 16 | false_positive_rate, 17 | false_discovery_rate, 18 | false_omission_rate, 19 | positive_likelihood_ratio, 20 | negative_likelihood_ratio, 21 | ) 22 | 23 | __all__ = [ 24 | "get_stats", 25 | "fbeta_score", 26 | "f1_score", 27 | "iou_score", 28 | "accuracy", 29 | "precision", 30 | "recall", 31 | "sensitivity", 32 | "specificity", 33 | "balanced_accuracy", 34 | "positive_predictive_value", 35 | "negative_predictive_value", 36 | "false_negative_rate", 37 | "false_positive_rate", 38 | "false_discovery_rate", 39 | "false_omission_rate", 40 | "positive_likelihood_ratio", 41 | "negative_likelihood_ratio", 42 | ] 43 | -------------------------------------------------------------------------------- /segmentation_models_pytorch/utils/__init__.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | from . import train 4 | from . import losses 5 | from . import metrics 6 | 7 | __all__ = ["train", "losses", "metrics"] 8 | 9 | warnings.warn( 10 | "`smp.utils` module is deprecated and will be removed in future releases.", 11 | DeprecationWarning, 12 | ) 13 | -------------------------------------------------------------------------------- /segmentation_models_pytorch/utils/base.py: -------------------------------------------------------------------------------- 1 | import re 2 | import torch.nn as nn 3 | 4 | 5 | class BaseObject(nn.Module): 6 | def __init__(self, name=None): 7 | super().__init__() 8 | self._name = name 9 | 10 | @property 11 | def __name__(self): 12 | if self._name is None: 13 | name = self.__class__.__name__ 14 | s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name) 15 | return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower() 16 | else: 17 | return self._name 18 | 19 | 20 | class Metric(BaseObject): 21 | pass 22 | 23 | 24 | class Loss(BaseObject): 25 | def __add__(self, other): 26 | if isinstance(other, Loss): 27 | return SumOfLosses(self, other) 28 | else: 29 | raise ValueError("Loss should be inherited from `Loss` class") 30 | 31 | def __radd__(self, other): 32 | return self.__add__(other) 33 | 34 | def __mul__(self, value): 35 | if isinstance(value, (int, float)): 36 | return MultipliedLoss(self, value) 37 | else: 38 | raise ValueError("Loss should be inherited from `BaseLoss` class") 39 | 40 | def __rmul__(self, other): 41 | return self.__mul__(other) 42 | 43 | 44 | class SumOfLosses(Loss): 45 | def __init__(self, l1, l2): 46 | name = "{} + {}".format(l1.__name__, l2.__name__) 47 | super().__init__(name=name) 48 | self.l1 = l1 49 | self.l2 = l2 50 | 51 | def __call__(self, *inputs): 52 | return self.l1.forward(*inputs) + self.l2.forward(*inputs) 53 | 54 | 55 | class MultipliedLoss(Loss): 56 | def __init__(self, loss, multiplier): 57 | # resolve name 58 | if len(loss.__name__.split("+")) > 1: 59 | name = "{} * ({})".format(multiplier, loss.__name__) 60 | else: 61 | name = "{} * {}".format(multiplier, loss.__name__) 62 | super().__init__(name=name) 63 | self.loss = loss 64 | self.multiplier = multiplier 65 | 66 | def __call__(self, *inputs): 67 | return self.multiplier * self.loss.forward(*inputs) 68 | -------------------------------------------------------------------------------- /segmentation_models_pytorch/utils/functional.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def _take_channels(*xs, ignore_channels=None): 5 | if ignore_channels is None: 6 | return xs 7 | else: 8 | channels = [ 9 | channel 10 | for channel in range(xs[0].shape[1]) 11 | if channel not in ignore_channels 12 | ] 13 | xs = [ 14 | torch.index_select(x, dim=1, index=torch.tensor(channels).to(x.device)) 15 | for x in xs 16 | ] 17 | return xs 18 | 19 | 20 | def _threshold(x, threshold=None): 21 | if threshold is not None: 22 | return (x > threshold).type(x.dtype) 23 | else: 24 | return x 25 | 26 | 27 | def iou(pr, gt, eps=1e-7, threshold=None, ignore_channels=None): 28 | """Calculate Intersection over Union between ground truth and prediction 29 | Args: 30 | pr (torch.Tensor): predicted tensor 31 | gt (torch.Tensor): ground truth tensor 32 | eps (float): epsilon to avoid zero division 33 | threshold: threshold for outputs binarization 34 | Returns: 35 | float: IoU (Jaccard) score 36 | """ 37 | 38 | pr = _threshold(pr, threshold=threshold) 39 | pr, gt = _take_channels(pr, gt, ignore_channels=ignore_channels) 40 | 41 | intersection = torch.sum(gt * pr) 42 | union = torch.sum(gt) + torch.sum(pr) - intersection + eps 43 | return (intersection + eps) / union 44 | 45 | 46 | jaccard = iou 47 | 48 | 49 | def f_score(pr, gt, beta=1, eps=1e-7, threshold=None, ignore_channels=None): 50 | """Calculate F-score between ground truth and prediction 51 | Args: 52 | pr (torch.Tensor): predicted tensor 53 | gt (torch.Tensor): ground truth tensor 54 | beta (float): positive constant 55 | eps (float): epsilon to avoid zero division 56 | threshold: threshold for outputs binarization 57 | Returns: 58 | float: F score 59 | """ 60 | 61 | pr = _threshold(pr, threshold=threshold) 62 | pr, gt = _take_channels(pr, gt, ignore_channels=ignore_channels) 63 | 64 | tp = torch.sum(gt * pr) 65 | fp = torch.sum(pr) - tp 66 | fn = torch.sum(gt) - tp 67 | 68 | score = ((1 + beta**2) * tp + eps) / ((1 + beta**2) * tp + beta**2 * fn + fp + eps) 69 | 70 | return score 71 | 72 | 73 | def accuracy(pr, gt, threshold=0.5, ignore_channels=None): 74 | """Calculate accuracy score between ground truth and prediction 75 | Args: 76 | pr (torch.Tensor): predicted tensor 77 | gt (torch.Tensor): ground truth tensor 78 | eps (float): epsilon to avoid zero division 79 | threshold: threshold for outputs binarization 80 | Returns: 81 | float: precision score 82 | """ 83 | pr = _threshold(pr, threshold=threshold) 84 | pr, gt = _take_channels(pr, gt, ignore_channels=ignore_channels) 85 | 86 | tp = torch.sum(gt == pr, dtype=pr.dtype) 87 | score = tp / gt.view(-1).shape[0] 88 | return score 89 | 90 | 91 | def precision(pr, gt, eps=1e-7, threshold=None, ignore_channels=None): 92 | """Calculate precision score between ground truth and prediction 93 | Args: 94 | pr (torch.Tensor): predicted tensor 95 | gt (torch.Tensor): ground truth tensor 96 | eps (float): epsilon to avoid zero division 97 | threshold: threshold for outputs binarization 98 | Returns: 99 | float: precision score 100 | """ 101 | 102 | pr = _threshold(pr, threshold=threshold) 103 | pr, gt = _take_channels(pr, gt, ignore_channels=ignore_channels) 104 | 105 | tp = torch.sum(gt * pr) 106 | fp = torch.sum(pr) - tp 107 | 108 | score = (tp + eps) / (tp + fp + eps) 109 | 110 | return score 111 | 112 | 113 | def recall(pr, gt, eps=1e-7, threshold=None, ignore_channels=None): 114 | """Calculate Recall between ground truth and prediction 115 | Args: 116 | pr (torch.Tensor): A list of predicted elements 117 | gt (torch.Tensor): A list of elements that are to be predicted 118 | eps (float): epsilon to avoid zero division 119 | threshold: threshold for outputs binarization 120 | Returns: 121 | float: recall score 122 | """ 123 | 124 | pr = _threshold(pr, threshold=threshold) 125 | pr, gt = _take_channels(pr, gt, ignore_channels=ignore_channels) 126 | 127 | tp = torch.sum(gt * pr) 128 | fn = torch.sum(gt) - tp 129 | 130 | score = (tp + eps) / (tp + fn + eps) 131 | 132 | return score 133 | -------------------------------------------------------------------------------- /segmentation_models_pytorch/utils/losses.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from . import base 4 | from . import functional as F 5 | from ..base.modules import Activation 6 | 7 | 8 | class JaccardLoss(base.Loss): 9 | def __init__(self, eps=1.0, activation=None, ignore_channels=None, **kwargs): 10 | super().__init__(**kwargs) 11 | self.eps = eps 12 | self.activation = Activation(activation) 13 | self.ignore_channels = ignore_channels 14 | 15 | def forward(self, y_pr, y_gt): 16 | y_pr = self.activation(y_pr) 17 | return 1 - F.jaccard( 18 | y_pr, 19 | y_gt, 20 | eps=self.eps, 21 | threshold=None, 22 | ignore_channels=self.ignore_channels, 23 | ) 24 | 25 | 26 | class DiceLoss(base.Loss): 27 | def __init__( 28 | self, eps=1.0, beta=1.0, activation=None, ignore_channels=None, **kwargs 29 | ): 30 | super().__init__(**kwargs) 31 | self.eps = eps 32 | self.beta = beta 33 | self.activation = Activation(activation) 34 | self.ignore_channels = ignore_channels 35 | 36 | def forward(self, y_pr, y_gt): 37 | y_pr = self.activation(y_pr) 38 | return 1 - F.f_score( 39 | y_pr, 40 | y_gt, 41 | beta=self.beta, 42 | eps=self.eps, 43 | threshold=None, 44 | ignore_channels=self.ignore_channels, 45 | ) 46 | 47 | 48 | class L1Loss(nn.L1Loss, base.Loss): 49 | pass 50 | 51 | 52 | class MSELoss(nn.MSELoss, base.Loss): 53 | pass 54 | 55 | 56 | class CrossEntropyLoss(nn.CrossEntropyLoss, base.Loss): 57 | pass 58 | 59 | 60 | class NLLLoss(nn.NLLLoss, base.Loss): 61 | pass 62 | 63 | 64 | class BCELoss(nn.BCELoss, base.Loss): 65 | pass 66 | 67 | 68 | class BCEWithLogitsLoss(nn.BCEWithLogitsLoss, base.Loss): 69 | pass 70 | -------------------------------------------------------------------------------- /segmentation_models_pytorch/utils/meter.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class Meter(object): 5 | """Meters provide a way to keep track of important statistics in an online manner. 6 | This class is abstract, but provides a standard interface for all meters to follow. 7 | """ 8 | 9 | def reset(self): 10 | """Reset the meter to default settings.""" 11 | pass 12 | 13 | def add(self, value): 14 | """Log a new value to the meter 15 | Args: 16 | value: Next result to include. 17 | """ 18 | pass 19 | 20 | def value(self): 21 | """Get the value of the meter in the current state.""" 22 | pass 23 | 24 | 25 | class AverageValueMeter(Meter): 26 | def __init__(self): 27 | super(AverageValueMeter, self).__init__() 28 | self.reset() 29 | self.val = 0 30 | 31 | def add(self, value, n=1): 32 | self.val = value 33 | self.sum += value 34 | self.var += value * value 35 | self.n += n 36 | 37 | if self.n == 0: 38 | self.mean, self.std = np.nan, np.nan 39 | elif self.n == 1: 40 | self.mean = 0.0 + self.sum # This is to force a copy in torch/numpy 41 | self.std = np.inf 42 | self.mean_old = self.mean 43 | self.m_s = 0.0 44 | else: 45 | self.mean = self.mean_old + (value - n * self.mean_old) / float(self.n) 46 | self.m_s += (value - self.mean_old) * (value - self.mean) 47 | self.mean_old = self.mean 48 | self.std = np.sqrt(self.m_s / (self.n - 1.0)) 49 | 50 | def value(self): 51 | return self.mean, self.std 52 | 53 | def reset(self): 54 | self.n = 0 55 | self.sum = 0.0 56 | self.var = 0.0 57 | self.val = 0.0 58 | self.mean = np.nan 59 | self.mean_old = 0.0 60 | self.m_s = 0.0 61 | self.std = np.nan 62 | -------------------------------------------------------------------------------- /segmentation_models_pytorch/utils/metrics.py: -------------------------------------------------------------------------------- 1 | from . import base 2 | from . import functional as F 3 | from ..base.modules import Activation 4 | 5 | 6 | class IoU(base.Metric): 7 | __name__ = "iou_score" 8 | 9 | def __init__( 10 | self, eps=1e-7, threshold=0.5, activation=None, ignore_channels=None, **kwargs 11 | ): 12 | super().__init__(**kwargs) 13 | self.eps = eps 14 | self.threshold = threshold 15 | self.activation = Activation(activation) 16 | self.ignore_channels = ignore_channels 17 | 18 | def forward(self, y_pr, y_gt): 19 | y_pr = self.activation(y_pr) 20 | return F.iou( 21 | y_pr, 22 | y_gt, 23 | eps=self.eps, 24 | threshold=self.threshold, 25 | ignore_channels=self.ignore_channels, 26 | ) 27 | 28 | 29 | class Fscore(base.Metric): 30 | def __init__( 31 | self, 32 | beta=1, 33 | eps=1e-7, 34 | threshold=0.5, 35 | activation=None, 36 | ignore_channels=None, 37 | **kwargs, 38 | ): 39 | super().__init__(**kwargs) 40 | self.eps = eps 41 | self.beta = beta 42 | self.threshold = threshold 43 | self.activation = Activation(activation) 44 | self.ignore_channels = ignore_channels 45 | 46 | def forward(self, y_pr, y_gt): 47 | y_pr = self.activation(y_pr) 48 | return F.f_score( 49 | y_pr, 50 | y_gt, 51 | eps=self.eps, 52 | beta=self.beta, 53 | threshold=self.threshold, 54 | ignore_channels=self.ignore_channels, 55 | ) 56 | 57 | 58 | class Accuracy(base.Metric): 59 | def __init__(self, threshold=0.5, activation=None, ignore_channels=None, **kwargs): 60 | super().__init__(**kwargs) 61 | self.threshold = threshold 62 | self.activation = Activation(activation) 63 | self.ignore_channels = ignore_channels 64 | 65 | def forward(self, y_pr, y_gt): 66 | y_pr = self.activation(y_pr) 67 | return F.accuracy( 68 | y_pr, y_gt, threshold=self.threshold, ignore_channels=self.ignore_channels 69 | ) 70 | 71 | 72 | class Recall(base.Metric): 73 | def __init__( 74 | self, eps=1e-7, threshold=0.5, activation=None, ignore_channels=None, **kwargs 75 | ): 76 | super().__init__(**kwargs) 77 | self.eps = eps 78 | self.threshold = threshold 79 | self.activation = Activation(activation) 80 | self.ignore_channels = ignore_channels 81 | 82 | def forward(self, y_pr, y_gt): 83 | y_pr = self.activation(y_pr) 84 | return F.recall( 85 | y_pr, 86 | y_gt, 87 | eps=self.eps, 88 | threshold=self.threshold, 89 | ignore_channels=self.ignore_channels, 90 | ) 91 | 92 | 93 | class Precision(base.Metric): 94 | def __init__( 95 | self, eps=1e-7, threshold=0.5, activation=None, ignore_channels=None, **kwargs 96 | ): 97 | super().__init__(**kwargs) 98 | self.eps = eps 99 | self.threshold = threshold 100 | self.activation = Activation(activation) 101 | self.ignore_channels = ignore_channels 102 | 103 | def forward(self, y_pr, y_gt): 104 | y_pr = self.activation(y_pr) 105 | return F.precision( 106 | y_pr, 107 | y_gt, 108 | eps=self.eps, 109 | threshold=self.threshold, 110 | ignore_channels=self.ignore_channels, 111 | ) 112 | -------------------------------------------------------------------------------- /segmentation_models_pytorch/utils/train.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | from tqdm import tqdm as tqdm 4 | from .meter import AverageValueMeter 5 | 6 | 7 | class Epoch: 8 | def __init__(self, model, loss, metrics, stage_name, device="cpu", verbose=True): 9 | self.model = model 10 | self.loss = loss 11 | self.metrics = metrics 12 | self.stage_name = stage_name 13 | self.verbose = verbose 14 | self.device = device 15 | 16 | self._to_device() 17 | 18 | def _to_device(self): 19 | self.model.to(self.device) 20 | self.loss.to(self.device) 21 | for metric in self.metrics: 22 | metric.to(self.device) 23 | 24 | def _format_logs(self, logs): 25 | str_logs = ["{} - {:.4}".format(k, v) for k, v in logs.items()] 26 | s = ", ".join(str_logs) 27 | return s 28 | 29 | def batch_update(self, x, y): 30 | raise NotImplementedError 31 | 32 | def on_epoch_start(self): 33 | pass 34 | 35 | def run(self, dataloader): 36 | self.on_epoch_start() 37 | 38 | logs = {} 39 | loss_meter = AverageValueMeter() 40 | metrics_meters = { 41 | metric.__name__: AverageValueMeter() for metric in self.metrics 42 | } 43 | 44 | with tqdm( 45 | dataloader, 46 | desc=self.stage_name, 47 | file=sys.stdout, 48 | disable=not (self.verbose), 49 | ) as iterator: 50 | for x, y in iterator: 51 | x, y = x.to(self.device), y.to(self.device) 52 | loss, y_pred = self.batch_update(x, y) 53 | 54 | # update loss logs 55 | loss_value = loss.cpu().detach().numpy() 56 | loss_meter.add(loss_value) 57 | loss_logs = {self.loss.__name__: loss_meter.mean} 58 | logs.update(loss_logs) 59 | 60 | # update metrics logs 61 | for metric_fn in self.metrics: 62 | metric_value = metric_fn(y_pred, y).cpu().detach().numpy() 63 | metrics_meters[metric_fn.__name__].add(metric_value) 64 | metrics_logs = {k: v.mean for k, v in metrics_meters.items()} 65 | logs.update(metrics_logs) 66 | 67 | if self.verbose: 68 | s = self._format_logs(logs) 69 | iterator.set_postfix_str(s) 70 | 71 | return logs 72 | 73 | 74 | class TrainEpoch(Epoch): 75 | def __init__(self, model, loss, metrics, optimizer, device="cpu", verbose=True): 76 | super().__init__( 77 | model=model, 78 | loss=loss, 79 | metrics=metrics, 80 | stage_name="train", 81 | device=device, 82 | verbose=verbose, 83 | ) 84 | self.optimizer = optimizer 85 | 86 | def on_epoch_start(self): 87 | self.model.train() 88 | 89 | def batch_update(self, x, y): 90 | self.optimizer.zero_grad() 91 | prediction = self.model.forward(x) 92 | loss = self.loss(prediction, y) 93 | loss.backward() 94 | self.optimizer.step() 95 | return loss, prediction 96 | 97 | 98 | class ValidEpoch(Epoch): 99 | def __init__(self, model, loss, metrics, device="cpu", verbose=True): 100 | super().__init__( 101 | model=model, 102 | loss=loss, 103 | metrics=metrics, 104 | stage_name="valid", 105 | device=device, 106 | verbose=verbose, 107 | ) 108 | 109 | def on_epoch_start(self): 110 | self.model.eval() 111 | 112 | def batch_update(self, x, y): 113 | with torch.no_grad(): 114 | prediction = self.model.forward(x) 115 | loss = self.loss(prediction, y) 116 | return loss, prediction 117 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qubvel-org/segmentation_models.pytorch/cf50cd082d35763073a296f6ee6378e24938bed8/tests/__init__.py -------------------------------------------------------------------------------- /tests/base/test_modules.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from torch import nn 3 | from segmentation_models_pytorch.base.modules import Conv2dReLU 4 | 5 | 6 | def test_conv2drelu_batchnorm(): 7 | module = Conv2dReLU(3, 16, kernel_size=3, padding=1, use_norm="batchnorm") 8 | 9 | assert isinstance(module[0], nn.Conv2d) 10 | assert isinstance(module[1], nn.BatchNorm2d) 11 | assert isinstance(module[2], nn.ReLU) 12 | 13 | 14 | def test_conv2drelu_batchnorm_with_keywords(): 15 | module = Conv2dReLU( 16 | 3, 17 | 16, 18 | kernel_size=3, 19 | padding=1, 20 | use_norm={"type": "batchnorm", "momentum": 1e-4, "affine": False}, 21 | ) 22 | 23 | assert isinstance(module[0], nn.Conv2d) 24 | assert isinstance(module[1], nn.BatchNorm2d) 25 | assert module[1].momentum == 1e-4 and module[1].affine is False 26 | assert isinstance(module[2], nn.ReLU) 27 | 28 | 29 | def test_conv2drelu_identity(): 30 | module = Conv2dReLU(3, 16, kernel_size=3, padding=1, use_norm="identity") 31 | 32 | assert isinstance(module[0], nn.Conv2d) 33 | assert isinstance(module[1], nn.Identity) 34 | assert isinstance(module[2], nn.ReLU) 35 | 36 | 37 | def test_conv2drelu_layernorm(): 38 | module = Conv2dReLU(3, 16, kernel_size=3, padding=1, use_norm="layernorm") 39 | 40 | assert isinstance(module[0], nn.Conv2d) 41 | assert isinstance(module[1], nn.LayerNorm) 42 | assert isinstance(module[2], nn.ReLU) 43 | 44 | 45 | def test_conv2drelu_instancenorm(): 46 | module = Conv2dReLU(3, 16, kernel_size=3, padding=1, use_norm="instancenorm") 47 | 48 | assert isinstance(module[0], nn.Conv2d) 49 | assert isinstance(module[1], nn.InstanceNorm2d) 50 | assert isinstance(module[2], nn.ReLU) 51 | 52 | 53 | def test_conv2drelu_inplace(): 54 | try: 55 | from inplace_abn import InPlaceABN 56 | except ImportError: 57 | pytest.skip("InPlaceABN is not installed") 58 | 59 | module = Conv2dReLU(3, 16, kernel_size=3, padding=1, use_norm="inplace") 60 | 61 | assert len(module) == 3 62 | assert isinstance(module[0], nn.Conv2d) 63 | assert isinstance(module[1], InPlaceABN) 64 | assert isinstance(module[2], nn.Identity) 65 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | def pytest_addoption(parser): 2 | parser.addoption( 3 | "--non-marked-only", action="store_true", help="Run only non-marked tests" 4 | ) 5 | 6 | 7 | def pytest_collection_modifyitems(config, items): 8 | if config.getoption("--non-marked-only"): 9 | non_marked_items = [] 10 | for item in items: 11 | # Check if the test has no marks 12 | if not item.own_markers: 13 | non_marked_items.append(item) 14 | 15 | # Update the test collection to only include non-marked tests 16 | items[:] = non_marked_items 17 | -------------------------------------------------------------------------------- /tests/encoders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qubvel-org/segmentation_models.pytorch/cf50cd082d35763073a296f6ee6378e24938bed8/tests/encoders/__init__.py -------------------------------------------------------------------------------- /tests/encoders/test_batchnorm_deprecation.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import torch 4 | 5 | import segmentation_models_pytorch as smp 6 | from tests.utils import check_two_models_strictly_equal 7 | 8 | 9 | @pytest.mark.parametrize("model_name", ["unet", "unetplusplus", "linknet", "manet"]) 10 | @pytest.mark.parametrize("decoder_option", [True, False, "inplace"]) 11 | def test_seg_models_before_after_use_norm(model_name, decoder_option): 12 | torch.manual_seed(42) 13 | with pytest.warns(DeprecationWarning): 14 | model_decoder_batchnorm = smp.create_model( 15 | model_name, 16 | "mobilenet_v2", 17 | encoder_weights=None, 18 | decoder_use_batchnorm=decoder_option, 19 | ) 20 | model_decoder_norm = smp.create_model( 21 | model_name, 22 | "mobilenet_v2", 23 | encoder_weights=None, 24 | decoder_use_norm=decoder_option, 25 | ) 26 | 27 | model_decoder_norm.load_state_dict(model_decoder_batchnorm.state_dict()) 28 | 29 | check_two_models_strictly_equal( 30 | model_decoder_batchnorm, model_decoder_norm, torch.rand(1, 3, 224, 224) 31 | ) 32 | 33 | 34 | @pytest.mark.parametrize("decoder_option", [True, False, "inplace"]) 35 | def test_pspnet_before_after_use_norm(decoder_option): 36 | torch.manual_seed(42) 37 | with pytest.warns(DeprecationWarning): 38 | model_decoder_batchnorm = smp.create_model( 39 | "pspnet", 40 | "mobilenet_v2", 41 | encoder_weights=None, 42 | psp_use_batchnorm=decoder_option, 43 | ) 44 | model_decoder_norm = smp.create_model( 45 | "pspnet", 46 | "mobilenet_v2", 47 | encoder_weights=None, 48 | decoder_use_norm=decoder_option, 49 | ) 50 | model_decoder_norm.load_state_dict(model_decoder_batchnorm.state_dict()) 51 | 52 | check_two_models_strictly_equal( 53 | model_decoder_batchnorm, model_decoder_norm, torch.rand(1, 3, 224, 224) 54 | ) 55 | -------------------------------------------------------------------------------- /tests/encoders/test_common.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import segmentation_models_pytorch as smp 3 | from tests.utils import slow_test 4 | 5 | 6 | @pytest.mark.parametrize( 7 | "encoder_name_and_weights", 8 | [ 9 | ("resnet18", "imagenet"), 10 | ], 11 | ) 12 | @slow_test 13 | def test_load_encoder_from_hub(encoder_name_and_weights): 14 | encoder_name, weights = encoder_name_and_weights 15 | smp.encoders.get_encoder(encoder_name, weights=weights) 16 | -------------------------------------------------------------------------------- /tests/encoders/test_pretrainedmodels_encoders.py: -------------------------------------------------------------------------------- 1 | import segmentation_models_pytorch as smp 2 | 3 | from tests.encoders import base 4 | from tests.utils import RUN_ALL_ENCODERS 5 | 6 | 7 | class TestDPNEncoder(base.BaseEncoderTester): 8 | encoder_names = ( 9 | ["dpn68"] 10 | if not RUN_ALL_ENCODERS 11 | else ["dpn68", "dpn68b", "dpn92", "dpn98", "dpn107", "dpn131"] 12 | ) 13 | files_for_diff = ["encoders/dpn.py"] 14 | 15 | def get_tiny_encoder(self): 16 | params = { 17 | "stage_idxs": [2, 3, 4, 6], 18 | "out_channels": [3, 2, 70, 134, 262, 518], 19 | "groups": 2, 20 | "inc_sec": (2, 2, 2, 2), 21 | "k_r": 2, 22 | "k_sec": (1, 1, 1, 1), 23 | "num_classes": 1000, 24 | "num_init_features": 2, 25 | "small": True, 26 | "test_time_pool": True, 27 | } 28 | return smp.encoders.dpn.DPNEncoder(**params) 29 | 30 | 31 | class TestInceptionResNetV2Encoder(base.BaseEncoderTester): 32 | encoder_names = ( 33 | ["inceptionresnetv2"] if not RUN_ALL_ENCODERS else ["inceptionresnetv2"] 34 | ) 35 | files_for_diff = ["encoders/inceptionresnetv2.py"] 36 | supports_dilated = False 37 | 38 | 39 | class TestInceptionV4Encoder(base.BaseEncoderTester): 40 | encoder_names = ["inceptionv4"] if not RUN_ALL_ENCODERS else ["inceptionv4"] 41 | files_for_diff = ["encoders/inceptionv4.py"] 42 | supports_dilated = False 43 | 44 | 45 | class TestSeNetEncoder(base.BaseEncoderTester): 46 | encoder_names = ( 47 | ["se_resnet50"] 48 | if not RUN_ALL_ENCODERS 49 | else [ 50 | "se_resnet50", 51 | "se_resnet101", 52 | "se_resnet152", 53 | "se_resnext50_32x4d", 54 | "se_resnext101_32x4d", 55 | # "senet154", # extra large model 56 | ] 57 | ) 58 | files_for_diff = ["encoders/senet.py"] 59 | 60 | def get_tiny_encoder(self): 61 | params = { 62 | "out_channels": [3, 2, 256, 512, 1024, 2048], 63 | "block": smp.encoders.senet.SEResNetBottleneck, 64 | "layers": [1, 1, 1, 1], 65 | "downsample_kernel_size": 1, 66 | "downsample_padding": 0, 67 | "dropout_p": None, 68 | "groups": 1, 69 | "inplanes": 2, 70 | "input_3x3": False, 71 | "num_classes": 1000, 72 | "reduction": 2, 73 | } 74 | return smp.encoders.senet.SENetEncoder(**params) 75 | 76 | 77 | class TestXceptionEncoder(base.BaseEncoderTester): 78 | supports_dilated = False 79 | encoder_names = ["xception"] if not RUN_ALL_ENCODERS else ["xception"] 80 | files_for_diff = ["encoders/xception.py"] 81 | -------------------------------------------------------------------------------- /tests/encoders/test_smp_encoders.py: -------------------------------------------------------------------------------- 1 | import segmentation_models_pytorch as smp 2 | from functools import partial 3 | 4 | from tests.encoders import base 5 | from tests.utils import RUN_ALL_ENCODERS 6 | 7 | 8 | class TestMobileoneEncoder(base.BaseEncoderTester): 9 | encoder_names = ( 10 | ["mobileone_s0"] 11 | if not RUN_ALL_ENCODERS 12 | else [ 13 | "mobileone_s0", 14 | "mobileone_s1", 15 | "mobileone_s2", 16 | "mobileone_s3", 17 | "mobileone_s4", 18 | ] 19 | ) 20 | files_for_diff = ["encoders/mobileone.py"] 21 | 22 | 23 | class TestMixTransformerEncoder(base.BaseEncoderTester): 24 | encoder_names = ( 25 | ["mit_b0"] 26 | if not RUN_ALL_ENCODERS 27 | else ["mit_b0", "mit_b1", "mit_b2", "mit_b3", "mit_b4", "mit_b5"] 28 | ) 29 | files_for_diff = ["encoders/mix_transformer.py"] 30 | 31 | def get_tiny_encoder(self): 32 | params = { 33 | "out_channels": [3, 0, 4, 4, 4, 4], 34 | "patch_size": 4, 35 | "embed_dims": [4, 4, 4, 4], 36 | "num_heads": [1, 1, 1, 1], 37 | "mlp_ratios": [1, 1, 1, 1], 38 | "qkv_bias": True, 39 | "norm_layer": partial(smp.encoders.mix_transformer.LayerNorm, eps=1e-6), 40 | "depths": [1, 1, 1, 1], 41 | "sr_ratios": [8, 4, 2, 1], 42 | "drop_rate": 0.0, 43 | "drop_path_rate": 0.1, 44 | } 45 | 46 | return smp.encoders.mix_transformer.MixVisionTransformerEncoder(**params) 47 | 48 | 49 | class TestEfficientNetEncoder(base.BaseEncoderTester): 50 | encoder_names = ( 51 | ["efficientnet-b0"] 52 | if not RUN_ALL_ENCODERS 53 | else [ 54 | "efficientnet-b0", 55 | "efficientnet-b1", 56 | "efficientnet-b2", 57 | "efficientnet-b3", 58 | "efficientnet-b4", 59 | "efficientnet-b5", 60 | "efficientnet-b6", 61 | # "efficientnet-b7", # extra large model 62 | ] 63 | ) 64 | files_for_diff = ["encoders/efficientnet.py"] 65 | -------------------------------------------------------------------------------- /tests/encoders/test_timm_ported_encoders.py: -------------------------------------------------------------------------------- 1 | from tests.encoders import base 2 | from tests.utils import RUN_ALL_ENCODERS 3 | 4 | 5 | class TestTimmEfficientNetEncoder(base.BaseEncoderTester): 6 | encoder_names = ( 7 | ["timm-efficientnet-b0"] 8 | if not RUN_ALL_ENCODERS 9 | else [ 10 | "timm-efficientnet-b0", 11 | "timm-efficientnet-b1", 12 | "timm-efficientnet-b2", 13 | "timm-efficientnet-b3", 14 | "timm-efficientnet-b4", 15 | "timm-efficientnet-b5", 16 | "timm-efficientnet-b6", 17 | "timm-efficientnet-b7", 18 | "timm-efficientnet-b8", 19 | "timm-efficientnet-l2", 20 | "timm-tf_efficientnet_lite0", 21 | "timm-tf_efficientnet_lite1", 22 | "timm-tf_efficientnet_lite2", 23 | "timm-tf_efficientnet_lite3", 24 | "timm-tf_efficientnet_lite4", 25 | ] 26 | ) 27 | files_for_diff = ["encoders/timm_efficientnet.py"] 28 | 29 | 30 | class TestTimmGERNetEncoder(base.BaseEncoderTester): 31 | encoder_names = ( 32 | ["timm-gernet_s"] 33 | if not RUN_ALL_ENCODERS 34 | else ["timm-gernet_s", "timm-gernet_m", "timm-gernet_l"] 35 | ) 36 | 37 | def test_compile(self): 38 | self.skipTest("Test to be removed") 39 | 40 | 41 | class TestTimmMobileNetV3Encoder(base.BaseEncoderTester): 42 | encoder_names = ( 43 | ["timm-mobilenetv3_small_100"] 44 | if not RUN_ALL_ENCODERS 45 | else [ 46 | "timm-mobilenetv3_large_075", 47 | "timm-mobilenetv3_large_100", 48 | "timm-mobilenetv3_large_minimal_100", 49 | "timm-mobilenetv3_small_075", 50 | "timm-mobilenetv3_small_100", 51 | "timm-mobilenetv3_small_minimal_100", 52 | ] 53 | ) 54 | 55 | def test_compile(self): 56 | self.skipTest("Test to be removed") 57 | 58 | 59 | class TestTimmRegNetEncoder(base.BaseEncoderTester): 60 | encoder_names = ( 61 | ["timm-regnetx_002", "timm-regnety_002"] 62 | if not RUN_ALL_ENCODERS 63 | else [ 64 | "timm-regnetx_002", 65 | "timm-regnetx_004", 66 | "timm-regnetx_006", 67 | "timm-regnetx_008", 68 | "timm-regnetx_016", 69 | "timm-regnetx_032", 70 | "timm-regnetx_040", 71 | "timm-regnetx_064", 72 | "timm-regnetx_080", 73 | "timm-regnetx_120", 74 | "timm-regnetx_160", 75 | "timm-regnetx_320", 76 | "timm-regnety_002", 77 | "timm-regnety_004", 78 | "timm-regnety_006", 79 | "timm-regnety_008", 80 | "timm-regnety_016", 81 | "timm-regnety_032", 82 | "timm-regnety_040", 83 | "timm-regnety_064", 84 | "timm-regnety_080", 85 | "timm-regnety_120", 86 | "timm-regnety_160", 87 | "timm-regnety_320", 88 | ] 89 | ) 90 | 91 | def test_compile(self): 92 | self.skipTest("Test to be removed") 93 | 94 | 95 | class TestTimmRes2NetEncoder(base.BaseEncoderTester): 96 | encoder_names = ( 97 | ["timm-res2net50_26w_4s"] 98 | if not RUN_ALL_ENCODERS 99 | else [ 100 | "timm-res2net50_26w_4s", 101 | "timm-res2net101_26w_4s", 102 | "timm-res2net50_26w_6s", 103 | "timm-res2net50_26w_8s", 104 | "timm-res2net50_48w_2s", 105 | "timm-res2net50_14w_8s", 106 | "timm-res2next50", 107 | ] 108 | ) 109 | 110 | def test_compile(self): 111 | self.skipTest("Test to be removed") 112 | 113 | 114 | class TestTimmResnestEncoder(base.BaseEncoderTester): 115 | default_batch_size = 2 116 | encoder_names = ( 117 | ["timm-resnest14d"] 118 | if not RUN_ALL_ENCODERS 119 | else [ 120 | "timm-resnest14d", 121 | "timm-resnest26d", 122 | "timm-resnest50d", 123 | "timm-resnest101e", 124 | "timm-resnest200e", 125 | "timm-resnest269e", 126 | "timm-resnest50d_4s2x40d", 127 | "timm-resnest50d_1s4x24d", 128 | ] 129 | ) 130 | 131 | def test_compile(self): 132 | self.skipTest("Test to be removed") 133 | 134 | 135 | class TestTimmSkNetEncoder(base.BaseEncoderTester): 136 | default_batch_size = 2 137 | encoder_names = ( 138 | ["timm-skresnet18"] 139 | if not RUN_ALL_ENCODERS 140 | else [ 141 | "timm-skresnet18", 142 | "timm-skresnet34", 143 | "timm-skresnext50_32x4d", 144 | ] 145 | ) 146 | files_for_diff = ["encoders/timm_sknet.py"] 147 | -------------------------------------------------------------------------------- /tests/encoders/test_timm_universal.py: -------------------------------------------------------------------------------- 1 | from tests.encoders import base 2 | from tests.utils import has_timm_test_models 3 | 4 | # check if timm >= 1.0.12 5 | timm_encoders = [ 6 | "tu-resnet18", # for timm universal traditional-like encoder 7 | "tu-convnext_atto", # for timm universal transformer-like encoder 8 | "tu-darknet17", # for timm universal vgg-like encoder 9 | ] 10 | 11 | if has_timm_test_models: 12 | timm_encoders.insert(0, "tu-test_resnet.r160_in1k") 13 | 14 | 15 | class TestTimmUniversalEncoder(base.BaseEncoderTester): 16 | encoder_names = timm_encoders 17 | files_for_diff = ["encoders/timm_universal.py"] 18 | -------------------------------------------------------------------------------- /tests/encoders/test_torchvision_encoders.py: -------------------------------------------------------------------------------- 1 | import segmentation_models_pytorch as smp 2 | 3 | from tests.encoders import base 4 | from tests.utils import RUN_ALL_ENCODERS 5 | 6 | 7 | class TestResNetEncoder(base.BaseEncoderTester): 8 | encoder_names = ( 9 | ["resnet18"] 10 | if not RUN_ALL_ENCODERS 11 | else [ 12 | "resnet18", 13 | "resnet34", 14 | "resnet50", 15 | "resnet101", 16 | "resnet152", 17 | "resnext50_32x4d", 18 | "resnext101_32x4d", 19 | "resnext101_32x8d", 20 | "resnext101_32x16d", 21 | "resnext101_32x32d", 22 | "resnext101_32x48d", 23 | ] 24 | ) 25 | files_for_diff = ["encoders/resnet.py"] 26 | 27 | def get_tiny_encoder(self): 28 | params = { 29 | "out_channels": [3, 64, 64, 128, 256, 512], 30 | "block": smp.encoders.resnet.BasicBlock, 31 | "layers": [1, 1, 1, 1], 32 | } 33 | return smp.encoders.resnet.ResNetEncoder(**params) 34 | 35 | 36 | class TestDenseNetEncoder(base.BaseEncoderTester): 37 | supports_dilated = False 38 | encoder_names = ( 39 | ["densenet121"] 40 | if not RUN_ALL_ENCODERS 41 | else ["densenet121", "densenet169", "densenet161"] 42 | ) 43 | files_for_diff = ["encoders/densenet.py"] 44 | 45 | def get_tiny_encoder(self): 46 | params = { 47 | "out_channels": [3, 2, 3, 2, 2, 2], 48 | "num_init_features": 2, 49 | "growth_rate": 1, 50 | "block_config": (1, 1, 1, 1), 51 | } 52 | return smp.encoders.densenet.DenseNetEncoder(**params) 53 | 54 | 55 | class TestMobileNetEncoder(base.BaseEncoderTester): 56 | encoder_names = ["mobilenet_v2"] if not RUN_ALL_ENCODERS else ["mobilenet_v2"] 57 | files_for_diff = ["encoders/mobilenet.py"] 58 | 59 | 60 | class TestVggEncoder(base.BaseEncoderTester): 61 | supports_dilated = False 62 | encoder_names = ( 63 | ["vgg11"] 64 | if not RUN_ALL_ENCODERS 65 | else [ 66 | "vgg11", 67 | "vgg11_bn", 68 | "vgg13", 69 | "vgg13_bn", 70 | "vgg16", 71 | "vgg16_bn", 72 | "vgg19", 73 | "vgg19_bn", 74 | ] 75 | ) 76 | files_for_diff = ["encoders/vgg.py"] 77 | 78 | def get_tiny_encoder(self): 79 | params = { 80 | "out_channels": [4, 4, 4, 4, 4, 4], 81 | "config": [4, "M", 4, "M", 4, "M", 4, "M", 4, "M"], 82 | "batch_norm": False, 83 | } 84 | return smp.encoders.vgg.VGGEncoder(**params) 85 | -------------------------------------------------------------------------------- /tests/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qubvel-org/segmentation_models.pytorch/cf50cd082d35763073a296f6ee6378e24938bed8/tests/models/__init__.py -------------------------------------------------------------------------------- /tests/models/test_deeplab.py: -------------------------------------------------------------------------------- 1 | from tests.models import base 2 | 3 | 4 | class TestDeeplabV3Model(base.BaseModelTester): 5 | test_model_type = "deeplabv3" 6 | files_for_diff = [r"decoders/deeplabv3/", r"base/"] 7 | 8 | default_batch_size = 2 9 | 10 | 11 | class TestDeeplabV3PlusModel(base.BaseModelTester): 12 | test_model_type = "deeplabv3plus" 13 | files_for_diff = [r"decoders/deeplabv3plus/", r"base/"] 14 | 15 | default_batch_size = 2 16 | -------------------------------------------------------------------------------- /tests/models/test_dpt.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import inspect 3 | import torch 4 | import segmentation_models_pytorch as smp 5 | 6 | from tests.models import base 7 | from tests.utils import ( 8 | slow_test, 9 | default_device, 10 | requires_torch_greater_or_equal, 11 | ) 12 | 13 | 14 | class TestDPTModel(base.BaseModelTester): 15 | test_encoder_name = "tu-vit_tiny_patch16_224" 16 | files_for_diff = [r"decoders/dpt/", r"base/"] 17 | 18 | default_height = 224 19 | default_width = 224 20 | 21 | # should be overriden 22 | test_model_type = "dpt" 23 | 24 | compile_dynamic = False 25 | 26 | @property 27 | def decoder_channels(self): 28 | signature = inspect.signature(self.model_class) 29 | return signature.parameters["decoder_intermediate_channels"].default 30 | 31 | @property 32 | def hub_checkpoint(self): 33 | return "smp-test-models/dpt-tu-test_vit" 34 | 35 | @slow_test 36 | @requires_torch_greater_or_equal("2.0.1") 37 | @pytest.mark.logits_match 38 | def test_load_pretrained(self): 39 | hub_checkpoint = "smp-hub/dpt-large-ade20k" 40 | 41 | model = smp.from_pretrained(hub_checkpoint) 42 | model = model.eval().to(default_device) 43 | 44 | input_tensor = torch.ones((1, 3, 384, 384)) 45 | input_tensor = input_tensor.to(default_device) 46 | 47 | expected_logits_slice = torch.tensor( 48 | [3.4166, 3.4422, 3.4677, 3.2784, 3.0880, 2.9497] 49 | ) 50 | with torch.inference_mode(): 51 | output = model(input_tensor) 52 | 53 | resulted_logits_slice = output[0, 0, 0, 0:6].cpu() 54 | 55 | self.assertEqual(expected_logits_slice.shape, resulted_logits_slice.shape) 56 | is_close = torch.allclose( 57 | expected_logits_slice, resulted_logits_slice, atol=5e-2 58 | ) 59 | max_diff = torch.max(torch.abs(expected_logits_slice - resulted_logits_slice)) 60 | self.assertTrue(is_close, f"Max diff: {max_diff}") 61 | -------------------------------------------------------------------------------- /tests/models/test_fpn.py: -------------------------------------------------------------------------------- 1 | import segmentation_models_pytorch as smp 2 | 3 | from tests.models import base 4 | 5 | 6 | class TestFpnModel(base.BaseModelTester): 7 | test_model_type = "fpn" 8 | files_for_diff = [r"decoders/fpn/", r"base/"] 9 | 10 | def test_interpolation(self): 11 | # test bilinear 12 | model_1 = smp.create_model( 13 | self.test_model_type, 14 | self.test_encoder_name, 15 | decoder_interpolation="bilinear", 16 | ) 17 | assert model_1.decoder.p2.interpolation_mode == "bilinear" 18 | assert model_1.decoder.p3.interpolation_mode == "bilinear" 19 | assert model_1.decoder.p4.interpolation_mode == "bilinear" 20 | 21 | # test bicubic 22 | model_2 = smp.create_model( 23 | self.test_model_type, 24 | self.test_encoder_name, 25 | decoder_interpolation="bicubic", 26 | ) 27 | assert model_2.decoder.p2.interpolation_mode == "bicubic" 28 | assert model_2.decoder.p3.interpolation_mode == "bicubic" 29 | assert model_2.decoder.p4.interpolation_mode == "bicubic" 30 | -------------------------------------------------------------------------------- /tests/models/test_linknet.py: -------------------------------------------------------------------------------- 1 | from tests.models import base 2 | 3 | 4 | class TestLinknetModel(base.BaseModelTester): 5 | test_model_type = "linknet" 6 | files_for_diff = [r"decoders/linknet/", r"base/"] 7 | -------------------------------------------------------------------------------- /tests/models/test_manet.py: -------------------------------------------------------------------------------- 1 | import segmentation_models_pytorch as smp 2 | 3 | from tests.models import base 4 | 5 | 6 | class TestManetModel(base.BaseModelTester): 7 | test_model_type = "manet" 8 | files_for_diff = [r"decoders/manet/", r"base/"] 9 | 10 | def test_interpolation(self): 11 | # test bilinear 12 | model_1 = smp.create_model( 13 | self.test_model_type, 14 | self.test_encoder_name, 15 | decoder_interpolation="bilinear", 16 | ) 17 | for block in model_1.decoder.blocks: 18 | assert block.interpolation_mode == "bilinear" 19 | 20 | # test bicubic 21 | model_2 = smp.create_model( 22 | self.test_model_type, 23 | self.test_encoder_name, 24 | decoder_interpolation="bicubic", 25 | ) 26 | for block in model_2.decoder.blocks: 27 | assert block.interpolation_mode == "bicubic" 28 | -------------------------------------------------------------------------------- /tests/models/test_pan.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import segmentation_models_pytorch as smp 3 | 4 | from tests.models import base 5 | 6 | 7 | class TestPanModel(base.BaseModelTester): 8 | test_model_type = "pan" 9 | files_for_diff = [r"decoders/pan/", r"base/"] 10 | 11 | default_batch_size = 2 12 | default_height = 128 13 | default_width = 128 14 | 15 | def test_interpolation(self): 16 | # test bilinear 17 | model_1 = smp.create_model( 18 | self.test_model_type, 19 | self.test_encoder_name, 20 | decoder_interpolation="bilinear", 21 | ) 22 | assert model_1.decoder.gau1.interpolation_mode == "bilinear" 23 | assert model_1.decoder.gau1.align_corners is True 24 | assert model_1.decoder.gau2.interpolation_mode == "bilinear" 25 | assert model_1.decoder.gau2.align_corners is True 26 | assert model_1.decoder.gau3.interpolation_mode == "bilinear" 27 | assert model_1.decoder.gau3.align_corners is True 28 | 29 | # test bicubic 30 | model_2 = smp.create_model( 31 | self.test_model_type, 32 | self.test_encoder_name, 33 | decoder_interpolation="bicubic", 34 | ) 35 | assert model_2.decoder.gau1.interpolation_mode == "bicubic" 36 | assert model_2.decoder.gau1.align_corners is None 37 | assert model_2.decoder.gau2.interpolation_mode == "bicubic" 38 | assert model_2.decoder.gau2.align_corners is None 39 | assert model_2.decoder.gau3.interpolation_mode == "bicubic" 40 | assert model_2.decoder.gau3.align_corners is None 41 | 42 | with pytest.warns(DeprecationWarning): 43 | smp.create_model( 44 | self.test_model_type, 45 | self.test_encoder_name, 46 | upscale_mode="bicubic", 47 | ) 48 | assert model_2.decoder.gau1.interpolation_mode == "bicubic" 49 | -------------------------------------------------------------------------------- /tests/models/test_psp.py: -------------------------------------------------------------------------------- 1 | from tests.models import base 2 | 3 | 4 | class TestPspModel(base.BaseModelTester): 5 | test_model_type = "pspnet" 6 | files_for_diff = [r"decoders/pspnet/", r"base/"] 7 | 8 | default_batch_size = 2 9 | -------------------------------------------------------------------------------- /tests/models/test_segformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytest 3 | import segmentation_models_pytorch as smp 4 | 5 | from tests.models import base 6 | from tests.utils import slow_test, default_device, requires_torch_greater_or_equal 7 | 8 | 9 | class TestSegformerModel(base.BaseModelTester): 10 | test_model_type = "segformer" 11 | files_for_diff = [r"decoders/segformer/", r"base/"] 12 | 13 | @slow_test 14 | @requires_torch_greater_or_equal("2.0.1") 15 | @pytest.mark.logits_match 16 | def test_load_pretrained(self): 17 | hub_checkpoint = "smp-hub/segformer-b0-512x512-ade-160k" 18 | 19 | model = smp.from_pretrained(hub_checkpoint) 20 | model = model.eval().to(default_device) 21 | 22 | sample = torch.ones([1, 3, 512, 512]).to(default_device) 23 | 24 | with torch.inference_mode(): 25 | output = model(sample) 26 | 27 | self.assertEqual(output.shape, (1, 150, 512, 512)) 28 | 29 | expected_logits_slice = torch.tensor( 30 | [-4.4172, -4.4723, -4.5273, -4.5824, -4.6375, -4.7157] 31 | ) 32 | resulted_logits_slice = output[0, 0, 256, :6].cpu() 33 | is_equal = torch.allclose( 34 | expected_logits_slice, resulted_logits_slice, atol=1e-2 35 | ) 36 | max_diff = torch.max(torch.abs(expected_logits_slice - resulted_logits_slice)) 37 | self.assertTrue( 38 | is_equal, 39 | f"Expected logits slice and resulted logits slice are not equal.\n" 40 | f"Max diff: {max_diff}\n" 41 | f"Expected: {expected_logits_slice}\n" 42 | f"Resulted: {resulted_logits_slice}\n", 43 | ) 44 | -------------------------------------------------------------------------------- /tests/models/test_unet.py: -------------------------------------------------------------------------------- 1 | import segmentation_models_pytorch as smp 2 | from tests.models import base 3 | 4 | 5 | class TestUnetModel(base.BaseModelTester): 6 | test_model_type = "unet" 7 | files_for_diff = [r"decoders/unet/", r"base/"] 8 | 9 | def test_interpolation(self): 10 | # test bilinear 11 | model_1 = smp.create_model( 12 | self.test_model_type, 13 | self.test_encoder_name, 14 | decoder_interpolation="bilinear", 15 | ) 16 | for block in model_1.decoder.blocks: 17 | assert block.interpolation_mode == "bilinear" 18 | 19 | # test bicubic 20 | model_2 = smp.create_model( 21 | self.test_model_type, 22 | self.test_encoder_name, 23 | decoder_interpolation="bicubic", 24 | ) 25 | for block in model_2.decoder.blocks: 26 | assert block.interpolation_mode == "bicubic" 27 | -------------------------------------------------------------------------------- /tests/models/test_unetplusplus.py: -------------------------------------------------------------------------------- 1 | import segmentation_models_pytorch as smp 2 | 3 | from tests.models import base 4 | 5 | 6 | class TestUnetPlusPlusModel(base.BaseModelTester): 7 | test_model_type = "unetplusplus" 8 | files_for_diff = [r"decoders/unetplusplus/", r"base/"] 9 | 10 | def test_interpolation(self): 11 | # test bilinear 12 | model_1 = smp.create_model( 13 | self.test_model_type, 14 | self.test_encoder_name, 15 | decoder_interpolation="bilinear", 16 | ) 17 | is_tested = False 18 | for module in model_1.decoder.modules(): 19 | if module.__class__.__name__ == "DecoderBlock": 20 | assert module.interpolation_mode == "bilinear" 21 | is_tested = True 22 | assert is_tested 23 | 24 | # test bicubic 25 | model_2 = smp.create_model( 26 | self.test_model_type, 27 | self.test_encoder_name, 28 | decoder_interpolation="bicubic", 29 | ) 30 | is_tested = False 31 | for module in model_2.decoder.modules(): 32 | if module.__class__.__name__ == "DecoderBlock": 33 | assert module.interpolation_mode == "bicubic" 34 | is_tested = True 35 | assert is_tested 36 | -------------------------------------------------------------------------------- /tests/models/test_upernet.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from tests.models import base 4 | 5 | 6 | class TestUnetModel(base.BaseModelTester): 7 | test_model_type = "upernet" 8 | files_for_diff = [r"decoders/upernet/", r"base/"] 9 | 10 | default_batch_size = 2 11 | 12 | @pytest.mark.torch_export 13 | def test_torch_export(self): 14 | super().test_torch_export(eps=1e-3) 15 | -------------------------------------------------------------------------------- /tests/test_base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import tempfile 3 | import segmentation_models_pytorch as smp 4 | 5 | import pytest 6 | 7 | 8 | def test_from_pretrained_with_mismatched_keys(): 9 | original_model = smp.Unet(classes=1) 10 | 11 | with tempfile.TemporaryDirectory() as temp_dir: 12 | original_model.save_pretrained(temp_dir) 13 | 14 | # we should catch warning here and check if there specific keys there 15 | with pytest.warns(UserWarning): 16 | restored_model = smp.from_pretrained(temp_dir, classes=2, strict=False) 17 | 18 | assert restored_model.segmentation_head[0].out_channels == 2 19 | 20 | # verify all the weight are the same expect mismatched ones 21 | original_state_dict = original_model.state_dict() 22 | restored_state_dict = restored_model.state_dict() 23 | 24 | expected_mismatched_keys = [ 25 | "segmentation_head.0.weight", 26 | "segmentation_head.0.bias", 27 | ] 28 | mismatched_keys = [] 29 | for key in original_state_dict: 30 | if key not in expected_mismatched_keys: 31 | assert torch.allclose(original_state_dict[key], restored_state_dict[key]) 32 | else: 33 | mismatched_keys.append(key) 34 | 35 | assert len(mismatched_keys) == 2 36 | assert sorted(mismatched_keys) == sorted(expected_mismatched_keys) 37 | -------------------------------------------------------------------------------- /tests/test_preprocessing.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import segmentation_models_pytorch as smp # noqa 4 | 5 | 6 | def _test_preprocessing(inp, out, **params): 7 | preprocessed_output = smp.encoders.preprocess_input(inp, **params) 8 | assert np.allclose(preprocessed_output, out) 9 | 10 | 11 | def test_mean(): 12 | inp = np.ones((32, 32, 3)) 13 | out = np.zeros((32, 32, 3)) 14 | mean = (1, 1, 1) 15 | _test_preprocessing(inp, out, mean=mean) 16 | 17 | 18 | def test_std(): 19 | inp = np.ones((32, 32, 3)) * 255 20 | out = np.ones((32, 32, 3)) 21 | std = (255, 255, 255) 22 | _test_preprocessing(inp, out, std=std) 23 | 24 | 25 | def test_input_range(): 26 | inp = np.ones((32, 32, 3)) 27 | out = np.ones((32, 32, 3)) 28 | _test_preprocessing(inp, out, input_range=(0, 1)) 29 | _test_preprocessing(inp * 255, out, input_range=(0, 1)) 30 | _test_preprocessing(inp * 255, out * 255, input_range=(0, 255)) 31 | 32 | 33 | def test_input_space(): 34 | inp = np.stack([np.ones((32, 32)), np.zeros((32, 32))], axis=-1) 35 | out = np.stack([np.zeros((32, 32)), np.ones((32, 32))], axis=-1) 36 | _test_preprocessing(inp, out, input_space="BGR") 37 | 38 | 39 | def test_preprocessing_params(): 40 | # check default encoder params 41 | params = smp.encoders.get_preprocessing_params("resnet18") 42 | assert params["mean"] == [0.485, 0.456, 0.406] 43 | assert params["std"] == [0.229, 0.224, 0.225] 44 | assert params["input_range"] == [0, 1] 45 | assert params["input_space"] == "RGB" 46 | 47 | # check timm params 48 | params = smp.encoders.get_preprocessing_params("tu-resnet18") 49 | assert params["mean"] == [0.485, 0.456, 0.406] 50 | assert params["std"] == [0.229, 0.224, 0.225] 51 | assert params["input_range"] == [0, 1] 52 | assert params["input_space"] == "RGB" 53 | -------------------------------------------------------------------------------- /tests/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import timm 4 | import torch 5 | import unittest 6 | 7 | from git import Repo 8 | from typing import List 9 | from packaging.version import Version 10 | 11 | 12 | has_timm_test_models = Version(timm.__version__) >= Version("1.0.12") 13 | default_device = "cuda" if torch.cuda.is_available() else "cpu" 14 | 15 | YES_LIST = ["true", "1", "y", "yes"] 16 | RUN_ALL_ENCODERS = os.getenv("RUN_ALL_ENCODERS", "false").lower() in YES_LIST 17 | RUN_SLOW = os.getenv("RUN_SLOW", "false").lower() in YES_LIST 18 | RUN_ALL = os.getenv("RUN_ALL", "false").lower() in YES_LIST 19 | 20 | 21 | def slow_test(test_case): 22 | """ 23 | Decorator marking a test as slow. 24 | 25 | Slow tests are skipped by default. Set the RUN_SLOW environment variable to a truthy value to run them. 26 | 27 | """ 28 | return unittest.skipUnless(RUN_SLOW, "test is slow")(test_case) 29 | 30 | 31 | def requires_timm_greater_or_equal(version: str): 32 | timm_version = Version(timm.__version__) 33 | provided_version = Version(version) 34 | return unittest.skipUnless( 35 | timm_version >= provided_version, 36 | f"timm version {timm_version} is less than {provided_version}", 37 | ) 38 | 39 | 40 | def requires_torch_greater_or_equal(version: str): 41 | torch_version = Version(torch.__version__) 42 | provided_version = Version(version) 43 | return unittest.skipUnless( 44 | torch_version >= provided_version, 45 | f"torch version {torch_version} is less than {provided_version}", 46 | ) 47 | 48 | 49 | def check_run_test_on_diff_or_main(filepath_patterns: List[str]): 50 | if RUN_ALL: 51 | return True 52 | 53 | try: 54 | repo = Repo(".") 55 | current_branch = repo.active_branch.name 56 | diff_files = repo.git.diff("main", name_only=True).splitlines() 57 | 58 | except Exception: 59 | return True 60 | 61 | if current_branch == "main": 62 | return True 63 | 64 | for pattern in filepath_patterns: 65 | for file_path in diff_files: 66 | if re.search(pattern, file_path): 67 | return True 68 | 69 | return False 70 | 71 | 72 | def check_two_models_strictly_equal( 73 | model_a: torch.nn.Module, model_b: torch.nn.Module, input_data: torch.Tensor 74 | ) -> None: 75 | for (k1, v1), (k2, v2) in zip( 76 | model_a.state_dict().items(), model_b.state_dict().items() 77 | ): 78 | assert k1 == k2, f"Key mismatch: {k1} != {k2}" 79 | torch.testing.assert_close( 80 | v1, v2, msg=f"Tensor mismatch at key '{k1}':\n{v1} !=\n{v2}" 81 | ) 82 | 83 | model_a.eval() 84 | model_b.eval() 85 | with torch.inference_mode(): 86 | output_a = model_a(input_data) 87 | output_b = model_b(input_data) 88 | 89 | torch.testing.assert_close(output_a, output_b) 90 | --------------------------------------------------------------------------------