├── .dockerignore ├── .editorconfig ├── .gitattributes ├── .github ├── FUNDING.yml ├── ISSUE_TEMPLATE │ ├── bug_report.md │ └── feature_request.md └── workflows │ ├── close_inactive_issues.yml │ ├── lint_python.yml │ ├── publish_docker.yml │ ├── publish_pypi.yml │ ├── test_install.yml │ └── windows_installer.yml ├── .gitignore ├── .markdownlint.yaml ├── .python-version ├── CITATION.cff ├── Dockerfile ├── Dockerfile_nvidia_cuda_cudnn_gpu ├── LICENSE.txt ├── MANIFEST.in ├── README.md ├── USAGE.md ├── _build-exe.ps1 ├── _modpath.iss ├── _setup.iss ├── docker-compose.yml ├── examples ├── animal-1.jpg ├── animal-1.out.png ├── animal-2.jpg ├── animal-2.out.png ├── animal-3.jpg ├── animal-3.out.png ├── anime-girl-1.jpg ├── anime-girl-1.out.png ├── anime-girl-2.jpg ├── anime-girl-2.out.png ├── anime-girl-3.jpg ├── anime-girl-3.out.png ├── car-1.jpg ├── car-1.out.png ├── car-2.jpg ├── car-2.out.png ├── car-3.jpg ├── car-3.out.png ├── food-1.jpg ├── food-1.out.alpha.jpg ├── food-1.out.jpg ├── girl-1.jpg ├── girl-1.out.png ├── girl-2.jpg ├── girl-2.out.png ├── girl-3.jpg ├── girl-3.out.png ├── plants-1.jpg └── plants-1.out.png ├── onnxruntime-installation-matrix.png ├── pyproject.toml ├── pytest.ini ├── rembg.ipynb ├── rembg.py ├── rembg.spec ├── rembg ├── __init__.py ├── _version.py ├── bg.py ├── cli.py ├── commands │ ├── __init__.py │ ├── b_command.py │ ├── d_command.py │ ├── i_command.py │ ├── p_command.py │ └── s_command.py ├── session_factory.py └── sessions │ ├── __init__.py │ ├── base.py │ ├── birefnet_cod.py │ ├── birefnet_dis.py │ ├── birefnet_general.py │ ├── birefnet_general_lite.py │ ├── birefnet_hrsod.py │ ├── birefnet_massive.py │ ├── birefnet_portrait.py │ ├── bria_rmbg.py │ ├── dis_anime.py │ ├── dis_general_use.py │ ├── sam.py │ ├── silueta.py │ ├── u2net.py │ ├── u2net_cloth_seg.py │ ├── u2net_custom.py │ ├── u2net_human_seg.py │ └── u2netp.py ├── setup.cfg ├── setup.py ├── tests ├── fixtures │ ├── anime-girl-1.jpg │ ├── car-1.jpg │ ├── cloth-1.jpg │ └── plants-1.jpg ├── results │ ├── anime-girl-1.birefnet-cod.png │ ├── anime-girl-1.birefnet-dis.png │ ├── anime-girl-1.birefnet-general-lite.png │ ├── anime-girl-1.birefnet-general.png │ ├── anime-girl-1.birefnet-hrsod.png │ ├── anime-girl-1.birefnet-massive.png │ ├── anime-girl-1.birefnet-portrait.png │ ├── anime-girl-1.isnet-anime.png │ ├── anime-girl-1.isnet-general-use.png │ ├── anime-girl-1.sam.png │ ├── anime-girl-1.silueta.png │ ├── anime-girl-1.u2net.png │ ├── anime-girl-1.u2net_cloth_seg.png │ ├── anime-girl-1.u2net_human_seg.png │ ├── anime-girl-1.u2netp.png │ ├── car-1.birefnet-cod.png │ ├── car-1.birefnet-dis.png │ ├── car-1.birefnet-general-lite.png │ ├── car-1.birefnet-general.png │ ├── car-1.birefnet-hrsod.png │ ├── car-1.birefnet-massive.png │ ├── car-1.birefnet-portrait.png │ ├── car-1.isnet-anime.png │ ├── car-1.isnet-general-use.png │ ├── car-1.sam.png │ ├── car-1.silueta.png │ ├── car-1.u2net.png │ ├── car-1.u2net_cloth_seg.png │ ├── car-1.u2net_human_seg.png │ ├── car-1.u2netp.png │ ├── cloth-1.birefnet-cod.png │ ├── cloth-1.birefnet-dis.png │ ├── cloth-1.birefnet-general-lite.png │ ├── cloth-1.birefnet-general.png │ ├── cloth-1.birefnet-hrsod.png │ ├── cloth-1.birefnet-massive.png │ ├── cloth-1.birefnet-portrait.png │ ├── cloth-1.isnet-anime.png │ ├── cloth-1.isnet-general-use.png │ ├── cloth-1.sam.png │ ├── cloth-1.silueta.png │ ├── cloth-1.u2net.png │ ├── cloth-1.u2net_cloth_seg.png │ ├── cloth-1.u2net_human_seg.png │ ├── cloth-1.u2netp.png │ ├── plants-1.birefnet-cod.png │ ├── plants-1.birefnet-dis.png │ ├── plants-1.birefnet-general-lite.png │ ├── plants-1.birefnet-general.png │ ├── plants-1.birefnet-hrsod.png │ ├── plants-1.birefnet-massive.png │ ├── plants-1.birefnet-portrait.png │ ├── plants-1.isnet-anime.png │ ├── plants-1.isnet-general-use.png │ ├── plants-1.sam.png │ ├── plants-1.silueta.png │ ├── plants-1.u2net.png │ ├── plants-1.u2net_cloth_seg.png │ ├── plants-1.u2net_human_seg.png │ └── plants-1.u2netp.png └── test_remove.py └── versioneer.py /.dockerignore: -------------------------------------------------------------------------------- 1 | * 2 | !rembg 3 | !setup.py 4 | !setup.cfg 5 | !requirements.txt 6 | !requirements-cpu.txt 7 | !requirements-gpu.txt 8 | !versioneer.py 9 | !README.md 10 | .env 11 | -------------------------------------------------------------------------------- /.editorconfig: -------------------------------------------------------------------------------- 1 | # https://editorconfig.org/ 2 | 3 | root = true 4 | 5 | [*] 6 | indent_style = space 7 | indent_size = 4 8 | insert_final_newline = true 9 | trim_trailing_whitespace = true 10 | end_of_line = lf 11 | charset = utf-8 12 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | rembg/_version.py export-subst 2 | -------------------------------------------------------------------------------- /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | github: [danielgatis] 2 | custom: ["https://www.buymeacoffee.com/danielgatis"] 3 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: "[BUG] ..." 5 | labels: bug 6 | assignees: "" 7 | --- 8 | 9 | **Describe the bug** 10 | A clear and concise description of what the bug is. 11 | 12 | **To Reproduce** 13 | Steps to reproduce the behavior: 14 | 15 | 1. Go to '...' 16 | 2. Click on '....' 17 | 3. Scroll down to '....' 18 | 4. See error 19 | 20 | **Expected behavior** 21 | A clear and concise description of what you expected to happen. 22 | 23 | **Images** 24 | Input images to reproduce. 25 | 26 | **OS Version:** 27 | iOS 22 28 | 29 | **Rembg version:** 30 | v2.0.21 31 | 32 | **Additional context** 33 | Add any other context about the problem here. 34 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: "[FEATURE] ..." 5 | labels: enhancement 6 | assignees: "" 7 | --- 8 | 9 | **Is your feature request related to a problem? Please describe.** 10 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 11 | 12 | **Describe the solution you'd like** 13 | A clear and concise description of what you want to happen. 14 | 15 | **Describe alternatives you've considered** 16 | A clear and concise description of any alternative solutions or features you've considered. 17 | 18 | **Additional context** 19 | Add any other context or screenshots about the feature request here. 20 | -------------------------------------------------------------------------------- /.github/workflows/close_inactive_issues.yml: -------------------------------------------------------------------------------- 1 | name: Close inactive issues 2 | 3 | on: 4 | schedule: 5 | - cron: "30 1 * * *" 6 | 7 | jobs: 8 | close_inactive_issues: 9 | runs-on: ubuntu-latest 10 | permissions: 11 | issues: write 12 | pull-requests: write 13 | steps: 14 | - uses: actions/stale@v9 15 | with: 16 | days-before-issue-stale: 30 17 | days-before-issue-close: 14 18 | stale-issue-label: "stale" 19 | stale-issue-message: "This issue is stale because it has been open for 30 days with no activity." 20 | close-issue-message: "This issue was closed because it has been inactive for 14 days since being marked as stale." 21 | days-before-pr-stale: -1 22 | days-before-pr-close: -1 23 | repo-token: ${{ secrets.GITHUB_TOKEN }} 24 | -------------------------------------------------------------------------------- /.github/workflows/lint_python.yml: -------------------------------------------------------------------------------- 1 | name: Lint 2 | 3 | on: [pull_request, push] 4 | 5 | jobs: 6 | lint_python: 7 | runs-on: ubuntu-latest 8 | steps: 9 | - uses: actions/checkout@v4 10 | - uses: actions/setup-python@v5 11 | - name: Install dependencies 12 | run: pip install .[cpu,cli,dev] 13 | - run: mypy --install-types --non-interactive --ignore-missing-imports ./rembg 14 | - run: bandit --recursive --skip B101,B104,B310,B311,B303,B110 --exclude ./rembg/_version.py ./rembg 15 | - run: black --force-exclude rembg/_version.py --check --diff ./rembg 16 | - run: flake8 ./rembg --count --ignore=B008,C901,E203,E266,E731,F401,F811,F841,W503,E501,E402 --show-source --statistics --exclude ./rembg/_version.py 17 | - run: isort --check-only --profile black ./rembg 18 | -------------------------------------------------------------------------------- /.github/workflows/publish_docker.yml: -------------------------------------------------------------------------------- 1 | name: Publish Docker image 2 | 3 | on: 4 | push: 5 | tags: 6 | - "v*.*.*" 7 | 8 | jobs: 9 | publish_docker: 10 | name: Push Docker image to Docker Hub 11 | runs-on: ubuntu-24.04 12 | steps: 13 | - name: Checkout 14 | uses: actions/checkout@v4 15 | 16 | - name: Docker meta 17 | id: meta 18 | uses: docker/metadata-action@v5 19 | with: 20 | # list of Docker images to use as base name for tags 21 | images: | 22 | ${{ secrets.DOCKER_HUB_USERNAME }}/rembg 23 | # generate Docker tags based on the following events/attributes 24 | tags: | 25 | type=ref,event=branch 26 | type=ref,event=branch 27 | type=ref,event=pr 28 | type=semver,pattern={{version}} 29 | type=semver,pattern={{major}}.{{minor}} 30 | type=semver,pattern={{major}} 31 | type=sha 32 | 33 | - name: Set up QEMU 34 | uses: docker/setup-qemu-action@v3 35 | 36 | - name: Set up Docker Buildx 37 | uses: docker/setup-buildx-action@v3 38 | 39 | - name: Login to Docker Hub 40 | uses: docker/login-action@v3 41 | with: 42 | username: ${{ secrets.DOCKER_HUB_USERNAME }} 43 | password: ${{ secrets.DOCKER_HUB_ACCESS_TOKEN }} 44 | 45 | - name: Build and push 46 | uses: docker/build-push-action@v6 47 | with: 48 | context: . 49 | platforms: linux/amd64 50 | push: ${{ github.event_name != 'pull_request' }} 51 | tags: ${{ steps.meta.outputs.tags }} 52 | labels: ${{ steps.meta.outputs.labels }} 53 | cache-from: type=registry,ref=${{ secrets.DOCKER_HUB_USERNAME }}/rembg:buildcache 54 | cache-to: type=registry,ref=${{ secrets.DOCKER_HUB_USERNAME }}/rembg:buildcache,mode=max 55 | -------------------------------------------------------------------------------- /.github/workflows/publish_pypi.yml: -------------------------------------------------------------------------------- 1 | name: Publish to Pypi 2 | 3 | on: 4 | push: 5 | tags: 6 | - "v*.*.*" 7 | 8 | jobs: 9 | publish_pypi: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - uses: actions/checkout@v4 13 | - uses: actions/setup-python@v5 14 | - name: Install dependencies 15 | run: pip install .[cpu,cli,dev] 16 | - name: Builds and uploads to PyPI 17 | run: | 18 | python3 setup.py sdist bdist_wheel 19 | python3 -m twine upload dist/* 20 | env: 21 | TWINE_USERNAME: __token__ 22 | TWINE_PASSWORD: ${{ secrets.PIPY_PASSWORD }} 23 | -------------------------------------------------------------------------------- /.github/workflows/test_install.yml: -------------------------------------------------------------------------------- 1 | name: Test installation 2 | 3 | on: [push] 4 | 5 | jobs: 6 | test_install: 7 | runs-on: ubuntu-latest 8 | strategy: 9 | matrix: 10 | python-version: ["3.10", "3.11", "3.12", "3.13"] 11 | 12 | steps: 13 | - uses: actions/checkout@v4 14 | - name: Set up Python ${{ matrix.python-version }} 15 | uses: actions/setup-python@v5 16 | with: 17 | python-version: ${{ matrix.python-version }} 18 | - name: Install dependencies 19 | run: pip install .[cpu,cli,dev] 20 | - name: Test installation with pytest 21 | run: | 22 | attempt=0 23 | until rembg d || [ $attempt -eq 5 ]; do 24 | attempt=$((attempt+1)) 25 | echo "Attempt $attempt to download the models..." 26 | done 27 | if [ $attempt -eq 5 ]; then 28 | echo "downloading the models failed 5 times, exiting..." 29 | exit 1 30 | fi 31 | pytest 32 | -------------------------------------------------------------------------------- /.github/workflows/windows_installer.yml: -------------------------------------------------------------------------------- 1 | name: Build Windows Installer 2 | 3 | on: 4 | push: 5 | tags: 6 | - "v*.*.*" 7 | jobs: 8 | windows_installer: 9 | name: Build the Inno Setup Installer 10 | runs-on: windows-latest 11 | steps: 12 | - uses: actions/setup-python@v5 13 | - uses: actions/checkout@v4 14 | - shell: pwsh 15 | run: ./_build-exe.ps1 16 | - name: Compile .ISS to .EXE Installer 17 | uses: Minionguyjpro/Inno-Setup-Action@v1.2.2 18 | with: 19 | path: _setup.iss 20 | options: /O+ 21 | - name: Upload binaries to release 22 | uses: svenstaro/upload-release-action@v2 23 | with: 24 | repo_token: ${{ secrets.GITHUB_TOKEN }} 25 | file: dist/rembg-cli-installer.exe 26 | asset_name: rembg-cli-installer.exe 27 | tag: ${{ github.ref }} 28 | overwrite: true 29 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # general things to ignore 2 | build/ 3 | dist/ 4 | .venv/ 5 | .direnv/ 6 | *.egg-info/ 7 | *.egg 8 | *.py[cod] 9 | __pycache__/ 10 | *.so 11 | *~≈ 12 | .env 13 | .envrc 14 | .idea 15 | .pytest_cache 16 | 17 | # due to using tox and pytest 18 | .tox 19 | .cache 20 | .mypy_cache 21 | -------------------------------------------------------------------------------- /.markdownlint.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | default: true 3 | MD013: false # line-length 4 | MD033: false # no-inline-html 5 | -------------------------------------------------------------------------------- /.python-version: -------------------------------------------------------------------------------- 1 | 3.12.4 2 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | title: rembg 3 | message: Rembg is a tool to remove images background 4 | type: software 5 | authors: 6 | - given-names: Daniel 7 | family-names: Gatis 8 | email: danielgatis@gmail.com 9 | identifiers: 10 | - type: url 11 | value: 'https://github.com/danielgatis' 12 | repository-code: 'https://github.com/danielgatis/rembg' 13 | url: 'https://github.com/danielgatis/rembg' 14 | abstract: Rembg is a tool to remove images background. 15 | license: MIT 16 | commit: 9079508935ae55d6eefa0fd75f870599640e8593 17 | version: 2.0.66 18 | date-released: '2025-02-21' 19 | 20 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.10-slim 2 | 3 | WORKDIR /rembg 4 | 5 | RUN pip install --upgrade pip 6 | 7 | RUN apt-get update && apt-get install -y curl && apt-get clean && rm -rf /var/lib/apt/lists/* 8 | 9 | COPY . . 10 | 11 | RUN python -m pip install ".[cpu,cli]" 12 | RUN rembg d u2net 13 | 14 | EXPOSE 7000 15 | ENTRYPOINT ["rembg"] 16 | CMD ["--help"] 17 | -------------------------------------------------------------------------------- /Dockerfile_nvidia_cuda_cudnn_gpu: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:12.4.1-cudnn-devel-ubuntu22.04 2 | 3 | WORKDIR /rembg 4 | 5 | RUN apt-get update && apt-get install -y --no-install-recommends python3-pip python-is-python3 curl && apt-get clean && rm -rf /var/lib/apt/lists/* 6 | 7 | COPY . . 8 | 9 | RUN python -m pip install ".[gpu,cli]" --break-system-packages 10 | RUN rembg d u2net 11 | 12 | EXPOSE 7000 13 | ENTRYPOINT ["rembg"] 14 | CMD ["--help"] 15 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Daniel Gatis 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include MANIFEST.in 2 | include LICENSE.txt 3 | include README.md 4 | include setup.py 5 | include pyproject.toml 6 | include requirements.txt 7 | include requirements-gpu.txt 8 | 9 | include versioneer.py 10 | include rembg/_version.py 11 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Rembg 2 | 3 | [![Downloads](https://img.shields.io/pypi/dm/rembg.svg)](https://img.shields.io/pypi/dm/rembg.svg) 4 | [![License](https://img.shields.io/badge/License-MIT-blue.svg)](https://img.shields.io/badge/License-MIT-blue.svg) 5 | [![Hugging Face Spaces](https://img.shields.io/badge/🤗%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/KenjieDec/RemBG) 6 | [![Streamlit App](https://img.shields.io/badge/🎈%20Streamlit%20Community-Cloud-blue)](https://bgremoval.streamlit.app/) 7 | [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/danielgatis/rembg/blob/main/rembg.ipynb) 8 | 9 | 10 | Rembg is a tool to remove images background. 11 | 12 |

13 | example car-1 14 | example car-1.out 15 | example car-2 16 | example car-2.out 17 | example car-3 18 | example car-3.out 19 |

20 | 21 |

22 | example animal-1 23 | example animal-1.out 24 | example animal-2 25 | example animal-2.out 26 | example animal-3 27 | example animal-3.out 28 |

29 | 30 |

31 | example girl-1 32 | example girl-1.out 33 | example girl-2 34 | example girl-2.out 35 | example girl-3 36 | example girl-3.out 37 |

38 | 39 |

40 | example anime-girl-1 41 | example anime-girl-1.out 42 | example anime-girl-2 43 | example anime-girl-2.out 44 | example anime-girl-3 45 | example anime-girl-3.out 46 |

47 | 48 | **If this project has helped you, please consider making a [donation](https://www.buymeacoffee.com/danielgatis).** 49 | 50 | ## Sponsors 51 | 52 | 53 | 54 | 59 | 69 | 70 | 71 | 76 | 85 | 86 |
55 | 56 | withoutBG API Logo 57 | 58 | 60 | withoutBG API 61 |
62 | https://withoutbg.com 63 |
64 |

65 | High-quality background removal API at affordable rates 66 |
67 |

68 |
72 | 73 | Unsplash 74 | 75 | 77 | PhotoRoom Remove Background API 78 |
79 | https://photoroom.com/api 80 |
81 |

82 | Fast and accurate background remover API
83 |

84 |
87 | 88 | ## Requirements 89 | 90 | ```text 91 | python: >=3.10, <3.14 92 | ``` 93 | 94 | ## Installation 95 | 96 | If you have `onnxruntime` already installed, just install `rembg`: 97 | 98 | ```bash 99 | pip install rembg # for library 100 | pip install "rembg[cli]" # for library + cli 101 | ``` 102 | 103 | Otherwise, install `rembg` with explicit CPU/GPU support. 104 | 105 | ### CPU support: 106 | 107 | ```bash 108 | pip install rembg[cpu] # for library 109 | pip install "rembg[cpu,cli]" # for library + cli 110 | ``` 111 | 112 | ### GPU support (NVidia/Cuda): 113 | 114 | First of all, you need to check if your system supports the `onnxruntime-gpu`. 115 | 116 | Go to [onnxruntime.ai]() and check the installation matrix. 117 | 118 |

119 | onnxruntime-installation-matrix 120 |

121 | 122 | If yes, just run: 123 | 124 | ```bash 125 | pip install "rembg[gpu]" # for library 126 | pip install "rembg[gpu,cli]" # for library + cli 127 | ``` 128 | 129 | Nvidia GPU may require onnxruntime-gpu, cuda, and cudnn-devel. [#668](https://github.com/danielgatis/rembg/issues/668#issuecomment-2689830314) . If rembg[gpu] doesn't work and you can't install cuda or cudnn-devel, use rembg[cpu] and onnxruntime instead. 130 | 131 | ### GPU support (AMD/ROCM): 132 | 133 | ROCM support requires the `onnxruntime-rocm` package. Install it following 134 | [AMD's documentation](https://rocm.docs.amd.com/projects/radeon/en/latest/docs/install/native_linux/install-onnx.html). 135 | 136 | If `onnxruntime-rocm` is installed and working, install the `rembg[rocm]` 137 | version of rembg: 138 | 139 | ```bash 140 | pip install "rembg[rocm]" # for library 141 | pip install "rembg[rocm,cli]" # for library + cli 142 | ``` 143 | 144 | ## Usage as a cli 145 | 146 | After the installation step you can use rembg just typing `rembg` in your terminal window. 147 | 148 | The `rembg` command has 4 subcommands, one for each input type: 149 | 150 | - `i` for files 151 | - `p` for folders 152 | - `s` for http server 153 | - `b` for RGB24 pixel binary stream 154 | 155 | You can get help about the main command using: 156 | 157 | ```shell 158 | rembg --help 159 | ``` 160 | 161 | As well, about all the subcommands using: 162 | 163 | ```shell 164 | rembg --help 165 | ``` 166 | 167 | ### rembg `i` 168 | 169 | Used when input and output are files. 170 | 171 | Remove the background from a remote image 172 | 173 | ```shell 174 | curl -s http://input.png | rembg i > output.png 175 | ``` 176 | 177 | Remove the background from a local file 178 | 179 | ```shell 180 | rembg i path/to/input.png path/to/output.png 181 | ``` 182 | 183 | Remove the background specifying a model 184 | 185 | ```shell 186 | rembg i -m u2netp path/to/input.png path/to/output.png 187 | ``` 188 | 189 | Remove the background returning only the mask 190 | 191 | ```shell 192 | rembg i -om path/to/input.png path/to/output.png 193 | ``` 194 | 195 | Remove the background applying an alpha matting 196 | 197 | ```shell 198 | rembg i -a path/to/input.png path/to/output.png 199 | ``` 200 | 201 | Passing extras parameters 202 | 203 | ```shell 204 | SAM example 205 | 206 | rembg i -m sam -x '{ "sam_prompt": [{"type": "point", "data": [724, 740], "label": 1}] }' examples/plants-1.jpg examples/plants-1.out.png 207 | ``` 208 | 209 | ```shell 210 | Custom model example 211 | 212 | rembg i -m u2net_custom -x '{"model_path": "~/.u2net/u2net.onnx"}' path/to/input.png path/to/output.png 213 | ``` 214 | 215 | ### rembg `p` 216 | 217 | Used when input and output are folders. 218 | 219 | Remove the background from all images in a folder 220 | 221 | ```shell 222 | rembg p path/to/input path/to/output 223 | ``` 224 | 225 | Same as before, but watching for new/changed files to process 226 | 227 | ```shell 228 | rembg p -w path/to/input path/to/output 229 | ``` 230 | 231 | ### rembg `s` 232 | 233 | Used to start http server. 234 | 235 | ```shell 236 | rembg s --host 0.0.0.0 --port 7000 --log_level info 237 | ``` 238 | 239 | To see the complete endpoints documentation, go to: `http://localhost:7000/api`. 240 | 241 | Remove the background from an image url 242 | 243 | ```shell 244 | curl -s "http://localhost:7000/api/remove?url=http://input.png" -o output.png 245 | ``` 246 | 247 | Remove the background from an uploaded image 248 | 249 | ```shell 250 | curl -s -F file=@/path/to/input.jpg "http://localhost:7000/api/remove" -o output.png 251 | ``` 252 | 253 | ### rembg `b` 254 | 255 | Process a sequence of RGB24 images from stdin. This is intended to be used with another program, such as FFMPEG, that outputs RGB24 pixel data to stdout, which is piped into the stdin of this program, although nothing prevents you from manually typing in images at stdin. 256 | 257 | ```shell 258 | rembg b image_width image_height -o output_specifier 259 | ``` 260 | 261 | Arguments: 262 | 263 | - image_width : width of input image(s) 264 | - image_height : height of input image(s) 265 | - output_specifier: printf-style specifier for output filenames, for example if `output-%03u.png`, then output files will be named `output-000.png`, `output-001.png`, `output-002.png`, etc. Output files will be saved in PNG format regardless of the extension specified. You can omit it to write results to stdout. 266 | 267 | Example usage with FFMPEG: 268 | 269 | ```shell 270 | ffmpeg -i input.mp4 -ss 10 -an -f rawvideo -pix_fmt rgb24 pipe:1 | rembg b 1280 720 -o folder/output-%03u.png 271 | ``` 272 | 273 | The width and height values must match the dimension of output images from FFMPEG. Note for FFMPEG, the "`-an -f rawvideo -pix_fmt rgb24 pipe:1`" part is required for the whole thing to work. 274 | 275 | ## Usage as a library 276 | 277 | Input and output as bytes 278 | 279 | ```python 280 | from rembg import remove 281 | 282 | input_path = 'input.png' 283 | output_path = 'output.png' 284 | 285 | with open(input_path, 'rb') as i: 286 | with open(output_path, 'wb') as o: 287 | input = i.read() 288 | output = remove(input) 289 | o.write(output) 290 | ``` 291 | 292 | Input and output as a PIL image 293 | 294 | ```python 295 | from rembg import remove 296 | from PIL import Image 297 | 298 | input_path = 'input.png' 299 | output_path = 'output.png' 300 | 301 | input = Image.open(input_path) 302 | output = remove(input) 303 | output.save(output_path) 304 | ``` 305 | 306 | Input and output as a numpy array 307 | 308 | ```python 309 | from rembg import remove 310 | import cv2 311 | 312 | input_path = 'input.png' 313 | output_path = 'output.png' 314 | 315 | input = cv2.imread(input_path) 316 | output = remove(input) 317 | cv2.imwrite(output_path, output) 318 | ``` 319 | 320 | Force output as bytes 321 | 322 | ```python 323 | from rembg import remove 324 | 325 | input_path = 'input.png' 326 | output_path = 'output.png' 327 | 328 | with open(input_path, 'rb') as i: 329 | with open(output_path, 'wb') as o: 330 | input = i.read() 331 | output = remove(input, force_return_bytes=True) 332 | o.write(output) 333 | ``` 334 | 335 | How to iterate over files in a performatic way 336 | 337 | ```python 338 | from pathlib import Path 339 | from rembg import remove, new_session 340 | 341 | session = new_session() 342 | 343 | for file in Path('path/to/folder').glob('*.png'): 344 | input_path = str(file) 345 | output_path = str(file.parent / (file.stem + ".out.png")) 346 | 347 | with open(input_path, 'rb') as i: 348 | with open(output_path, 'wb') as o: 349 | input = i.read() 350 | output = remove(input, session=session) 351 | o.write(output) 352 | ``` 353 | 354 | To see a full list of examples on how to use rembg, go to the [examples](USAGE.md) page. 355 | 356 | ## Usage as a docker 357 | 358 | ### Only CPU 359 | 360 | Just replace the `rembg` command for `docker run danielgatis/rembg`. 361 | 362 | Try this: 363 | 364 | ```shell 365 | docker run -v path/to/input:/rembg danielgatis/rembg i input.png path/to/output/output.png 366 | ``` 367 | 368 | ### Nvidia CUDA Hardware Acceleration 369 | 370 | Requirement: using CUDA in docker needs your **host** has **NVIDIA Container Toolkit** installed. [NVIDIA Container Toolkit Install Guide](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html) 371 | 372 | **Nvidia CUDA Hardware Acceleration** needs cudnn-devel so you need to build the docker image by yourself. [#668](https://github.com/danielgatis/rembg/issues/668#issuecomment-2689914205) 373 | 374 | Here is a example shows you how to build an image and name it *rembg-nvidia-cuda-cudnn-gpu* 375 | ```shell 376 | docker build -t rembg-nvidia-cuda-cudnn-gpu -f Dockerfile_nvidia_cuda_cudnn_gpu . 377 | ``` 378 | Be aware: It would take 11GB of your disk space. (The cpu version only takes about 1.6GB). Models didn't included. 379 | 380 | After you build the image, run it like this as a cli 381 | ```shell 382 | sudo docker run --rm -it --gpus all -v /dev/dri:/dev/dri -v $PWD:/rembg rembg-nvidia-cuda-cudnn-gpu i -m birefnet-general input.png output.png 383 | ``` 384 | 385 | - Trick 1: Actually you can also make up a nvidia-cuda-cudnn-gpu image and install rembg[gpu, cli] in it. 386 | - Trick 2: Try param `-v /somewhereYouStoresModelFiles/:/root/.u2net` so to download/store model files out of docker images. You can even comment the line `RUN rembg d u2net` so when builing the image, it download will no models, so you can download the specific model you want even without the default u2net model. 387 | 388 | ## Models 389 | 390 | All models are downloaded and saved in the user home folder in the `.u2net` directory. 391 | 392 | The available models are: 393 | 394 | - u2net ([download](https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net.onnx), [source](https://github.com/xuebinqin/U-2-Net)): A pre-trained model for general use cases. 395 | - u2netp ([download](https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2netp.onnx), [source](https://github.com/xuebinqin/U-2-Net)): A lightweight version of u2net model. 396 | - u2net_human_seg ([download](https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net_human_seg.onnx), [source](https://github.com/xuebinqin/U-2-Net)): A pre-trained model for human segmentation. 397 | - u2net_cloth_seg ([download](https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net_cloth_seg.onnx), [source](https://github.com/levindabhi/cloth-segmentation)): A pre-trained model for Cloths Parsing from human portrait. Here clothes are parsed into 3 category: Upper body, Lower body and Full body. 398 | - silueta ([download](https://github.com/danielgatis/rembg/releases/download/v0.0.0/silueta.onnx), [source](https://github.com/xuebinqin/U-2-Net/issues/295)): Same as u2net but the size is reduced to 43Mb. 399 | - isnet-general-use ([download](https://github.com/danielgatis/rembg/releases/download/v0.0.0/isnet-general-use.onnx), [source](https://github.com/xuebinqin/DIS)): A new pre-trained model for general use cases. 400 | - isnet-anime ([download](https://github.com/danielgatis/rembg/releases/download/v0.0.0/isnet-anime.onnx), [source](https://github.com/SkyTNT/anime-segmentation)): A high-accuracy segmentation for anime character. 401 | - sam ([download encoder](https://github.com/danielgatis/rembg/releases/download/v0.0.0/vit_b-encoder-quant.onnx), [download decoder](https://github.com/danielgatis/rembg/releases/download/v0.0.0/vit_b-decoder-quant.onnx), [source](https://github.com/facebookresearch/segment-anything)): A pre-trained model for any use cases. 402 | - birefnet-general ([download](https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-general-epoch_244.onnx), [source](https://github.com/ZhengPeng7/BiRefNet)): A pre-trained model for general use cases. 403 | - birefnet-general-lite ([download](https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-general-bb_swin_v1_tiny-epoch_232.onnx), [source](https://github.com/ZhengPeng7/BiRefNet)): A light pre-trained model for general use cases. 404 | - birefnet-portrait ([download](https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-portrait-epoch_150.onnx), [source](https://github.com/ZhengPeng7/BiRefNet)): A pre-trained model for human portraits. 405 | - birefnet-dis ([download](https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-DIS-epoch_590.onnx), [source](https://github.com/ZhengPeng7/BiRefNet)): A pre-trained model for dichotomous image segmentation (DIS). 406 | - birefnet-hrsod ([download](https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-HRSOD_DHU-epoch_115.onnx), [source](https://github.com/ZhengPeng7/BiRefNet)): A pre-trained model for high-resolution salient object detection (HRSOD). 407 | - birefnet-cod ([download](https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-COD-epoch_125.onnx), [source](https://github.com/ZhengPeng7/BiRefNet)): A pre-trained model for concealed object detection (COD). 408 | - birefnet-massive ([download](https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-massive-TR_DIS5K_TR_TEs-epoch_420.onnx), [source](https://github.com/ZhengPeng7/BiRefNet)): A pre-trained model with massive dataset. 409 | 410 | ### How to train your own model 411 | 412 | If You need more fine tuned models try this: 413 | 414 | 415 | ## Some video tutorials 416 | 417 | - 418 | - 419 | - 420 | - 421 | 422 | ## References 423 | 424 | - 425 | - 426 | - 427 | 428 | ## FAQ 429 | 430 | ### When will this library provide support for Python version 3.xx? 431 | 432 | This library directly depends on the [onnxruntime](https://pypi.org/project/onnxruntime) library. Therefore, we can only update the Python version when [onnxruntime](https://pypi.org/project/onnxruntime) provides support for that specific version. 433 | 434 | ## Buy me a coffee 435 | 436 | Liked some of my work? Buy me a coffee (or more likely a beer) 437 | 438 | Buy Me A Coffee 439 | 440 | ## Star History 441 | 442 | [![Star History Chart](https://api.star-history.com/svg?repos=danielgatis/rembg&type=Date)](https://star-history.com/#danielgatis/rembg&Date) 443 | 444 | ## License 445 | 446 | Copyright (c) 2020-present [Daniel Gatis](https://github.com/danielgatis) 447 | 448 | Licensed under [MIT License](./LICENSE.txt) 449 | -------------------------------------------------------------------------------- /USAGE.md: -------------------------------------------------------------------------------- 1 | # How to use the remove function 2 | 3 | ## Load the Image 4 | 5 | ```python 6 | from PIL import Image 7 | from rembg import new_session, remove 8 | 9 | input_path = 'input.png' 10 | output_path = 'output.png' 11 | 12 | input = Image.open(input_path) 13 | ``` 14 | 15 | ## Removing the background 16 | 17 | ### Without additional arguments 18 | 19 | This defaults to the `u2net` model. 20 | 21 | ```python 22 | output = remove(input) 23 | output.save(output_path) 24 | ``` 25 | 26 | ### With a specific model 27 | 28 | You can use the `new_session` function to create a session with a specific model. 29 | 30 | ```python 31 | model_name = "isnet-general-use" 32 | session = new_session(model_name) 33 | output = remove(input, session=session) 34 | ``` 35 | 36 | ### For processing multiple image files 37 | 38 | By default, `remove` initialises a new session every call. This can be a large bottleneck if you're having to process multiple images. Initialise a session and pass it in to the `remove` function for fast multi-image support 39 | 40 | ```python 41 | model_name = "unet" 42 | rembg_session = new_session(model_name) 43 | for img in images: 44 | output = remove(img, session=rembg_session) 45 | ``` 46 | 47 | ### With alpha matting 48 | 49 | Alpha matting is a post processing step that can be used to improve the quality of the output. 50 | 51 | ```python 52 | output = remove(input, alpha_matting=True, alpha_matting_foreground_threshold=270,alpha_matting_background_threshold=20, alpha_matting_erode_size=11) 53 | ``` 54 | 55 | ### Only mask 56 | 57 | If you only want the mask, you can use the `only_mask` argument. 58 | 59 | ```python 60 | output = remove(input, only_mask=True) 61 | ``` 62 | 63 | ### With post processing 64 | 65 | You can use the `post_process_mask` argument to post process the mask to get better results. 66 | 67 | ```python 68 | output = remove(input, post_process_mask=True) 69 | ``` 70 | 71 | ### Replacing the background color 72 | 73 | You can use the `bgcolor` argument to replace the background color. 74 | 75 | ```python 76 | output = remove(input, bgcolor=(255, 255, 255, 255)) 77 | ``` 78 | 79 | ### Using input points 80 | 81 | You can use the `input_points` and `input_labels` arguments to specify the points that should be used for the masks. This only works with the `sam` model. 82 | 83 | ```python 84 | import numpy as np 85 | # Define the points and labels 86 | # The points are defined as [y, x] 87 | input_points = np.array([[400, 350], [700, 400], [200, 400]]) 88 | input_labels = np.array([1, 1, 2]) 89 | 90 | image = remove(image,session=session, input_points=input_points, input_labels=input_labels) 91 | ``` 92 | 93 | ## Save the image 94 | 95 | ```python 96 | output.save(output_path) 97 | ``` 98 | -------------------------------------------------------------------------------- /_build-exe.ps1: -------------------------------------------------------------------------------- 1 | # Install required packages 2 | pip install pyinstaller 3 | pip install -e ".[cli]" 4 | 5 | # Create PyInstaller spec file with specified data collections 6 | # pyi-makespec --collect-data=gradio_client --collect-data=gradio rembg.py 7 | 8 | # Run PyInstaller with the generated spec file 9 | pyinstaller rembg.spec 10 | -------------------------------------------------------------------------------- /_modpath.iss: -------------------------------------------------------------------------------- 1 | // ---------------------------------------------------------------------------- 2 | // 3 | // Inno Setup Ver: 5.4.2 4 | // Script Version: 1.4.2 5 | // Author: Jared Breland 6 | // Homepage: http://www.legroom.net/software 7 | // License: GNU Lesser General Public License (LGPL), version 3 8 | // http://www.gnu.org/licenses/lgpl.html 9 | // 10 | // Script Function: 11 | // Allow modification of environmental path directly from Inno Setup installers 12 | // 13 | // Instructions: 14 | // Copy modpath.iss to the same directory as your setup script 15 | // 16 | // Add this statement to your [Setup] section 17 | // ChangesEnvironment=true 18 | // 19 | // Add this statement to your [Tasks] section 20 | // You can change the Description or Flags 21 | // You can change the Name, but it must match the ModPathName setting below 22 | // Name: modifypath; Description: &Add application directory to your environmental path; Flags: unchecked 23 | // 24 | // Add the following to the end of your [Code] section 25 | // ModPathName defines the name of the task defined above 26 | // ModPathType defines whether the 'user' or 'system' path will be modified; 27 | // this will default to user if anything other than system is set 28 | // setArrayLength must specify the total number of dirs to be added 29 | // Result[0] contains first directory, Result[1] contains second, etc. 30 | // const 31 | // ModPathName = 'modifypath'; 32 | // ModPathType = 'user'; 33 | // 34 | // function ModPathDir(): TArrayOfString; 35 | // begin 36 | // setArrayLength(Result, 1); 37 | // Result[0] := ExpandConstant('{app}'); 38 | // end; 39 | // #include "modpath.iss" 40 | // ---------------------------------------------------------------------------- 41 | 42 | procedure ModPath(); 43 | var 44 | oldpath: String; 45 | newpath: String; 46 | updatepath: Boolean; 47 | pathArr: TArrayOfString; 48 | aExecFile: String; 49 | aExecArr: TArrayOfString; 50 | i, d: Integer; 51 | pathdir: TArrayOfString; 52 | regroot: Integer; 53 | regpath: String; 54 | 55 | begin 56 | // Get constants from main script and adjust behavior accordingly 57 | // ModPathType MUST be 'system' or 'user'; force 'user' if invalid 58 | if ModPathType = 'system' then begin 59 | regroot := HKEY_LOCAL_MACHINE; 60 | regpath := 'SYSTEM\CurrentControlSet\Control\Session Manager\Environment'; 61 | end else begin 62 | regroot := HKEY_CURRENT_USER; 63 | regpath := 'Environment'; 64 | end; 65 | 66 | // Get array of new directories and act on each individually 67 | pathdir := ModPathDir(); 68 | for d := 0 to GetArrayLength(pathdir)-1 do begin 69 | updatepath := true; 70 | 71 | // Modify WinNT path 72 | if UsingWinNT() = true then begin 73 | 74 | // Get current path, split into an array 75 | RegQueryStringValue(regroot, regpath, 'Path', oldpath); 76 | oldpath := oldpath + ';'; 77 | i := 0; 78 | 79 | while (Pos(';', oldpath) > 0) do begin 80 | SetArrayLength(pathArr, i+1); 81 | pathArr[i] := Copy(oldpath, 0, Pos(';', oldpath)-1); 82 | oldpath := Copy(oldpath, Pos(';', oldpath)+1, Length(oldpath)); 83 | i := i + 1; 84 | 85 | // Check if current directory matches app dir 86 | if pathdir[d] = pathArr[i-1] then begin 87 | // if uninstalling, remove dir from path 88 | if IsUninstaller() = true then begin 89 | continue; 90 | // if installing, flag that dir already exists in path 91 | end else begin 92 | updatepath := false; 93 | end; 94 | end; 95 | 96 | // Add current directory to new path 97 | if i = 1 then begin 98 | newpath := pathArr[i-1]; 99 | end else begin 100 | newpath := newpath + ';' + pathArr[i-1]; 101 | end; 102 | end; 103 | 104 | // Append app dir to path if not already included 105 | if (IsUninstaller() = false) AND (updatepath = true) then 106 | newpath := newpath + ';' + pathdir[d]; 107 | 108 | // Write new path 109 | RegWriteStringValue(regroot, regpath, 'Path', newpath); 110 | 111 | // Modify Win9x path 112 | end else begin 113 | 114 | // Convert to shortened dirname 115 | pathdir[d] := GetShortName(pathdir[d]); 116 | 117 | // If autoexec.bat exists, check if app dir already exists in path 118 | aExecFile := 'C:\AUTOEXEC.BAT'; 119 | if FileExists(aExecFile) then begin 120 | LoadStringsFromFile(aExecFile, aExecArr); 121 | for i := 0 to GetArrayLength(aExecArr)-1 do begin 122 | if IsUninstaller() = false then begin 123 | // If app dir already exists while installing, skip add 124 | if (Pos(pathdir[d], aExecArr[i]) > 0) then 125 | updatepath := false; 126 | break; 127 | end else begin 128 | // If app dir exists and = what we originally set, then delete at uninstall 129 | if aExecArr[i] = 'SET PATH=%PATH%;' + pathdir[d] then 130 | aExecArr[i] := ''; 131 | end; 132 | end; 133 | end; 134 | 135 | // If app dir not found, or autoexec.bat didn't exist, then (create and) append to current path 136 | if (IsUninstaller() = false) AND (updatepath = true) then begin 137 | SaveStringToFile(aExecFile, #13#10 + 'SET PATH=%PATH%;' + pathdir[d], True); 138 | 139 | // If uninstalling, write the full autoexec out 140 | end else begin 141 | SaveStringsToFile(aExecFile, aExecArr, False); 142 | end; 143 | end; 144 | end; 145 | end; 146 | 147 | // Split a string into an array using passed delimeter 148 | procedure MPExplode(var Dest: TArrayOfString; Text: String; Separator: String); 149 | var 150 | i: Integer; 151 | begin 152 | i := 0; 153 | repeat 154 | SetArrayLength(Dest, i+1); 155 | if Pos(Separator,Text) > 0 then begin 156 | Dest[i] := Copy(Text, 1, Pos(Separator, Text)-1); 157 | Text := Copy(Text, Pos(Separator,Text) + Length(Separator), Length(Text)); 158 | i := i + 1; 159 | end else begin 160 | Dest[i] := Text; 161 | Text := ''; 162 | end; 163 | until Length(Text)=0; 164 | end; 165 | 166 | 167 | procedure CurStepChanged(CurStep: TSetupStep); 168 | var 169 | taskname: String; 170 | begin 171 | taskname := ModPathName; 172 | if CurStep = ssPostInstall then 173 | if IsTaskSelected(taskname) then 174 | ModPath(); 175 | end; 176 | 177 | procedure CurUninstallStepChanged(CurUninstallStep: TUninstallStep); 178 | var 179 | aSelectedTasks: TArrayOfString; 180 | i: Integer; 181 | taskname: String; 182 | regpath: String; 183 | regstring: String; 184 | appid: String; 185 | begin 186 | // only run during actual uninstall 187 | if CurUninstallStep = usUninstall then begin 188 | // get list of selected tasks saved in registry at install time 189 | appid := '{#emit SetupSetting("AppId")}'; 190 | if appid = '' then appid := '{#emit SetupSetting("AppName")}'; 191 | regpath := ExpandConstant('Software\Microsoft\Windows\CurrentVersion\Uninstall\'+appid+'_is1'); 192 | RegQueryStringValue(HKLM, regpath, 'Inno Setup: Selected Tasks', regstring); 193 | if regstring = '' then RegQueryStringValue(HKCU, regpath, 'Inno Setup: Selected Tasks', regstring); 194 | 195 | // check each task; if matches modpath taskname, trigger patch removal 196 | if regstring <> '' then begin 197 | taskname := ModPathName; 198 | MPExplode(aSelectedTasks, regstring, ','); 199 | if GetArrayLength(aSelectedTasks) > 0 then begin 200 | for i := 0 to GetArrayLength(aSelectedTasks)-1 do begin 201 | if comparetext(aSelectedTasks[i], taskname) = 0 then 202 | ModPath(); 203 | end; 204 | end; 205 | end; 206 | end; 207 | end; 208 | 209 | function NeedRestart(): Boolean; 210 | var 211 | taskname: String; 212 | begin 213 | taskname := ModPathName; 214 | if IsTaskSelected(taskname) and not UsingWinNT() then begin 215 | Result := True; 216 | end else begin 217 | Result := False; 218 | end; 219 | end; 220 | -------------------------------------------------------------------------------- /_setup.iss: -------------------------------------------------------------------------------- 1 | #define MyAppName "Rembg" 2 | #define MyAppVersion "STABLE" 3 | #define MyAppPublisher "danielgatis" 4 | #define MyAppURL "https://github.com/danielgatis/rembg" 5 | #define MyAppExeName "rembg.exe" 6 | #define MyAppId "49AB7484-212F-4B31-A49F-533A480F3FD4" 7 | 8 | [Setup] 9 | AppId={#MyAppId} 10 | AppName={#MyAppName} 11 | AppVersion={#MyAppVersion} 12 | AppPublisher={#MyAppPublisher} 13 | AppPublisherURL={#MyAppURL} 14 | AppSupportURL={#MyAppURL} 15 | AppUpdatesURL={#MyAppURL} 16 | DefaultDirName={autopf}\{#MyAppName} 17 | DefaultGroupName={#MyAppName} 18 | DisableProgramGroupPage=yes 19 | OutputBaseFilename=rembg-cli-installer 20 | Compression=lzma 21 | SolidCompression=yes 22 | WizardStyle=modern 23 | OutputDir=dist 24 | ChangesEnvironment=yes 25 | 26 | [Languages] 27 | Name: "english"; MessagesFile: "compiler:Default.isl" 28 | 29 | [Files] 30 | Source: "{#SourcePath}dist\rembg\{#MyAppExeName}"; DestDir: "{app}"; Flags: ignoreversion 31 | Source: "{#SourcePath}dist\rembg\*"; DestDir: "{app}"; Flags: ignoreversion recursesubdirs createallsubdirs 32 | 33 | [Tasks] 34 | Name: modifypath; Description: "Add to PATH variable" 35 | 36 | [Icons] 37 | Name: "{group}\{#MyAppName}"; Filename: "{app}\{#MyAppExeName}" 38 | 39 | [Code] 40 | const 41 | ModPathName = 'modifypath'; 42 | ModPathType = 'user'; 43 | 44 | function ModPathDir(): TArrayOfString; 45 | begin 46 | setArrayLength(Result, 1) 47 | Result[0] := ExpandConstant('{app}'); 48 | end; 49 | #include "_modpath.iss" 50 | -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | --- 2 | # You can set variables in .env file in root folder 3 | # 4 | # PUBLIC_PORT=7000:7000 5 | # REPLICAS_COUNT=1 6 | 7 | services: 8 | app: 9 | build: . 10 | command: ["s"] 11 | deploy: 12 | replicas: ${REPLICAS_COUNT:-1} 13 | ports: 14 | - ${PUBLIC_PORT:-7000:7000} 15 | version: '3' 16 | -------------------------------------------------------------------------------- /examples/animal-1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielgatis/rembg/bc1436cad8dd2c94aa396604f9afdc2dde95cf55/examples/animal-1.jpg -------------------------------------------------------------------------------- /examples/animal-1.out.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielgatis/rembg/bc1436cad8dd2c94aa396604f9afdc2dde95cf55/examples/animal-1.out.png -------------------------------------------------------------------------------- /examples/animal-2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielgatis/rembg/bc1436cad8dd2c94aa396604f9afdc2dde95cf55/examples/animal-2.jpg -------------------------------------------------------------------------------- /examples/animal-2.out.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielgatis/rembg/bc1436cad8dd2c94aa396604f9afdc2dde95cf55/examples/animal-2.out.png -------------------------------------------------------------------------------- /examples/animal-3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielgatis/rembg/bc1436cad8dd2c94aa396604f9afdc2dde95cf55/examples/animal-3.jpg -------------------------------------------------------------------------------- /examples/animal-3.out.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielgatis/rembg/bc1436cad8dd2c94aa396604f9afdc2dde95cf55/examples/animal-3.out.png -------------------------------------------------------------------------------- /examples/anime-girl-1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielgatis/rembg/bc1436cad8dd2c94aa396604f9afdc2dde95cf55/examples/anime-girl-1.jpg -------------------------------------------------------------------------------- /examples/anime-girl-1.out.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielgatis/rembg/bc1436cad8dd2c94aa396604f9afdc2dde95cf55/examples/anime-girl-1.out.png -------------------------------------------------------------------------------- /examples/anime-girl-2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielgatis/rembg/bc1436cad8dd2c94aa396604f9afdc2dde95cf55/examples/anime-girl-2.jpg -------------------------------------------------------------------------------- /examples/anime-girl-2.out.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielgatis/rembg/bc1436cad8dd2c94aa396604f9afdc2dde95cf55/examples/anime-girl-2.out.png -------------------------------------------------------------------------------- /examples/anime-girl-3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielgatis/rembg/bc1436cad8dd2c94aa396604f9afdc2dde95cf55/examples/anime-girl-3.jpg -------------------------------------------------------------------------------- /examples/anime-girl-3.out.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielgatis/rembg/bc1436cad8dd2c94aa396604f9afdc2dde95cf55/examples/anime-girl-3.out.png -------------------------------------------------------------------------------- /examples/car-1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielgatis/rembg/bc1436cad8dd2c94aa396604f9afdc2dde95cf55/examples/car-1.jpg -------------------------------------------------------------------------------- /examples/car-1.out.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielgatis/rembg/bc1436cad8dd2c94aa396604f9afdc2dde95cf55/examples/car-1.out.png -------------------------------------------------------------------------------- /examples/car-2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielgatis/rembg/bc1436cad8dd2c94aa396604f9afdc2dde95cf55/examples/car-2.jpg -------------------------------------------------------------------------------- /examples/car-2.out.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielgatis/rembg/bc1436cad8dd2c94aa396604f9afdc2dde95cf55/examples/car-2.out.png -------------------------------------------------------------------------------- /examples/car-3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielgatis/rembg/bc1436cad8dd2c94aa396604f9afdc2dde95cf55/examples/car-3.jpg -------------------------------------------------------------------------------- /examples/car-3.out.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielgatis/rembg/bc1436cad8dd2c94aa396604f9afdc2dde95cf55/examples/car-3.out.png -------------------------------------------------------------------------------- /examples/food-1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielgatis/rembg/bc1436cad8dd2c94aa396604f9afdc2dde95cf55/examples/food-1.jpg -------------------------------------------------------------------------------- /examples/food-1.out.alpha.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielgatis/rembg/bc1436cad8dd2c94aa396604f9afdc2dde95cf55/examples/food-1.out.alpha.jpg -------------------------------------------------------------------------------- /examples/food-1.out.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielgatis/rembg/bc1436cad8dd2c94aa396604f9afdc2dde95cf55/examples/food-1.out.jpg -------------------------------------------------------------------------------- /examples/girl-1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielgatis/rembg/bc1436cad8dd2c94aa396604f9afdc2dde95cf55/examples/girl-1.jpg -------------------------------------------------------------------------------- /examples/girl-1.out.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielgatis/rembg/bc1436cad8dd2c94aa396604f9afdc2dde95cf55/examples/girl-1.out.png -------------------------------------------------------------------------------- /examples/girl-2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielgatis/rembg/bc1436cad8dd2c94aa396604f9afdc2dde95cf55/examples/girl-2.jpg -------------------------------------------------------------------------------- /examples/girl-2.out.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielgatis/rembg/bc1436cad8dd2c94aa396604f9afdc2dde95cf55/examples/girl-2.out.png -------------------------------------------------------------------------------- /examples/girl-3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielgatis/rembg/bc1436cad8dd2c94aa396604f9afdc2dde95cf55/examples/girl-3.jpg -------------------------------------------------------------------------------- /examples/girl-3.out.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielgatis/rembg/bc1436cad8dd2c94aa396604f9afdc2dde95cf55/examples/girl-3.out.png -------------------------------------------------------------------------------- /examples/plants-1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielgatis/rembg/bc1436cad8dd2c94aa396604f9afdc2dde95cf55/examples/plants-1.jpg -------------------------------------------------------------------------------- /examples/plants-1.out.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielgatis/rembg/bc1436cad8dd2c94aa396604f9afdc2dde95cf55/examples/plants-1.out.png -------------------------------------------------------------------------------- /onnxruntime-installation-matrix.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielgatis/rembg/bc1436cad8dd2c94aa396604f9afdc2dde95cf55/onnxruntime-installation-matrix.png -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | # These are the assumed default build requirements from pip: 3 | # https://pip.pypa.io/en/stable/reference/pip/#pep-517-and-518-support 4 | requires = ["setuptools>=65.5.1", "wheel"] 5 | build-backend = "setuptools.build_meta" 6 | 7 | [versioneer] 8 | VCS = "git" 9 | style = "pep440" 10 | versionfile_source = "rembg/_version.py" 11 | versionfile_build = "rembg/_version.py" 12 | tag_prefix = "v" 13 | parentdir_prefix = "rembg-" 14 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | filterwarnings = 3 | ignore::DeprecationWarning 4 | -------------------------------------------------------------------------------- /rembg.py: -------------------------------------------------------------------------------- 1 | from rembg.cli import main 2 | 3 | if __name__ == "__main__": 4 | main() 5 | -------------------------------------------------------------------------------- /rembg.spec: -------------------------------------------------------------------------------- 1 | # -*- mode: python ; coding: utf-8 -*- 2 | from PyInstaller.utils.hooks import collect_data_files 3 | 4 | datas = [] 5 | datas += collect_data_files('gradio_client') 6 | datas += collect_data_files('gradio') 7 | datas += collect_data_files('onnxruntime') 8 | 9 | a = Analysis( 10 | ['rembg.py'], 11 | pathex=[], 12 | binaries=[], 13 | datas=datas, 14 | hiddenimports=[], 15 | hookspath=[], 16 | hooksconfig={}, 17 | runtime_hooks=[], 18 | excludes=[], 19 | noarchive=False, 20 | module_collection_mode={ 21 | 'gradio': 'py', 22 | }, 23 | ) 24 | pyz = PYZ(a.pure) 25 | 26 | exe = EXE( 27 | pyz, 28 | a.scripts, 29 | [], 30 | exclude_binaries=True, 31 | name='rembg', 32 | debug=False, 33 | bootloader_ignore_signals=False, 34 | strip=False, 35 | upx=True, 36 | console=True, 37 | disable_windowed_traceback=False, 38 | argv_emulation=False, 39 | target_arch=None, 40 | codesign_identity=None, 41 | entitlements_file=None, 42 | ) 43 | coll = COLLECT( 44 | exe, 45 | a.binaries, 46 | a.datas, 47 | strip=False, 48 | upx=True, 49 | upx_exclude=[], 50 | name='rembg', 51 | ) 52 | -------------------------------------------------------------------------------- /rembg/__init__.py: -------------------------------------------------------------------------------- 1 | from . import _version 2 | 3 | __version__ = _version.get_versions()["version"] 4 | 5 | from .bg import remove 6 | from .session_factory import new_session 7 | -------------------------------------------------------------------------------- /rembg/_version.py: -------------------------------------------------------------------------------- 1 | # This file helps to compute a version number in source trees obtained from 2 | # git-archive tarball (such as those provided by githubs download-from-tag 3 | # feature). Distribution tarballs (built by setup.py sdist) and build 4 | # directories (produced by setup.py build) will contain a much shorter file 5 | # that just contains the computed version number. 6 | 7 | # This file is released into the public domain. Generated by 8 | # versioneer-0.21 (https://github.com/python-versioneer/python-versioneer) 9 | 10 | """Git implementation of _version.py.""" 11 | 12 | import errno 13 | import os 14 | import re 15 | import subprocess 16 | import sys 17 | from typing import Callable, Dict 18 | 19 | 20 | def get_keywords(): 21 | """Get the keywords needed to look up the version information.""" 22 | # these strings will be replaced by git during git-archive. 23 | # setup.py/versioneer.py will grep for the variable names, so they must 24 | # each be defined on a line of their own. _version.py will just call 25 | # get_keywords(). 26 | git_refnames = " (HEAD -> main)" 27 | git_full = "bc1436cad8dd2c94aa396604f9afdc2dde95cf55" 28 | git_date = "2025-05-17 18:48:36 -0300" 29 | keywords = {"refnames": git_refnames, "full": git_full, "date": git_date} 30 | return keywords 31 | 32 | 33 | class VersioneerConfig: 34 | """Container for Versioneer configuration parameters.""" 35 | 36 | 37 | def get_config(): 38 | """Create, populate and return the VersioneerConfig() object.""" 39 | # these strings are filled in when 'setup.py versioneer' creates 40 | # _version.py 41 | cfg = VersioneerConfig() 42 | cfg.VCS = "git" 43 | cfg.style = "pep440" 44 | cfg.tag_prefix = "v" 45 | cfg.parentdir_prefix = "rembg-" 46 | cfg.versionfile_source = "rembg/_version.py" 47 | cfg.verbose = False 48 | return cfg 49 | 50 | 51 | class NotThisMethod(Exception): 52 | """Exception raised if a method is not valid for the current scenario.""" 53 | 54 | 55 | LONG_VERSION_PY: Dict[str, str] = {} 56 | HANDLERS: Dict[str, Dict[str, Callable]] = {} 57 | 58 | 59 | def register_vcs_handler(vcs, method): # decorator 60 | """Create decorator to mark a method as the handler of a VCS.""" 61 | 62 | def decorate(f): 63 | """Store f in HANDLERS[vcs][method].""" 64 | if vcs not in HANDLERS: 65 | HANDLERS[vcs] = {} 66 | HANDLERS[vcs][method] = f 67 | return f 68 | 69 | return decorate 70 | 71 | 72 | def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, env=None): 73 | """Call the given command(s).""" 74 | assert isinstance(commands, list) 75 | process = None 76 | for command in commands: 77 | try: 78 | dispcmd = str([command] + args) 79 | # remember shell=False, so use git.cmd on windows, not just git 80 | process = subprocess.Popen( 81 | [command] + args, 82 | cwd=cwd, 83 | env=env, 84 | stdout=subprocess.PIPE, 85 | stderr=(subprocess.PIPE if hide_stderr else None), 86 | ) 87 | break 88 | except OSError: 89 | e = sys.exc_info()[1] 90 | if e.errno == errno.ENOENT: 91 | continue 92 | if verbose: 93 | print("unable to run %s" % dispcmd) 94 | print(e) 95 | return None, None 96 | else: 97 | if verbose: 98 | print("unable to find command, tried %s" % (commands,)) 99 | return None, None 100 | stdout = process.communicate()[0].strip().decode() 101 | if process.returncode != 0: 102 | if verbose: 103 | print("unable to run %s (error)" % dispcmd) 104 | print("stdout was %s" % stdout) 105 | return None, process.returncode 106 | return stdout, process.returncode 107 | 108 | 109 | def versions_from_parentdir(parentdir_prefix, root, verbose): 110 | """Try to determine the version from the parent directory name. 111 | 112 | Source tarballs conventionally unpack into a directory that includes both 113 | the project name and a version string. We will also support searching up 114 | two directory levels for an appropriately named parent directory 115 | """ 116 | rootdirs = [] 117 | 118 | for _ in range(3): 119 | dirname = os.path.basename(root) 120 | if dirname.startswith(parentdir_prefix): 121 | return { 122 | "version": dirname[len(parentdir_prefix) :], 123 | "full-revisionid": None, 124 | "dirty": False, 125 | "error": None, 126 | "date": None, 127 | } 128 | rootdirs.append(root) 129 | root = os.path.dirname(root) # up a level 130 | 131 | if verbose: 132 | print( 133 | "Tried directories %s but none started with prefix %s" 134 | % (str(rootdirs), parentdir_prefix) 135 | ) 136 | raise NotThisMethod("rootdir doesn't start with parentdir_prefix") 137 | 138 | 139 | @register_vcs_handler("git", "get_keywords") 140 | def git_get_keywords(versionfile_abs): 141 | """Extract version information from the given file.""" 142 | # the code embedded in _version.py can just fetch the value of these 143 | # keywords. When used from setup.py, we don't want to import _version.py, 144 | # so we do it with a regexp instead. This function is not used from 145 | # _version.py. 146 | keywords = {} 147 | try: 148 | with open(versionfile_abs, "r") as fobj: 149 | for line in fobj: 150 | if line.strip().startswith("git_refnames ="): 151 | mo = re.search(r'=\s*"(.*)"', line) 152 | if mo: 153 | keywords["refnames"] = mo.group(1) 154 | if line.strip().startswith("git_full ="): 155 | mo = re.search(r'=\s*"(.*)"', line) 156 | if mo: 157 | keywords["full"] = mo.group(1) 158 | if line.strip().startswith("git_date ="): 159 | mo = re.search(r'=\s*"(.*)"', line) 160 | if mo: 161 | keywords["date"] = mo.group(1) 162 | except OSError: 163 | pass 164 | return keywords 165 | 166 | 167 | @register_vcs_handler("git", "keywords") 168 | def git_versions_from_keywords(keywords, tag_prefix, verbose): 169 | """Get version information from git keywords.""" 170 | if "refnames" not in keywords: 171 | raise NotThisMethod("Short version file found") 172 | date = keywords.get("date") 173 | if date is not None: 174 | # Use only the last line. Previous lines may contain GPG signature 175 | # information. 176 | date = date.splitlines()[-1] 177 | 178 | # git-2.2.0 added "%cI", which expands to an ISO-8601 -compliant 179 | # datestamp. However we prefer "%ci" (which expands to an "ISO-8601 180 | # -like" string, which we must then edit to make compliant), because 181 | # it's been around since git-1.5.3, and it's too difficult to 182 | # discover which version we're using, or to work around using an 183 | # older one. 184 | date = date.strip().replace(" ", "T", 1).replace(" ", "", 1) 185 | refnames = keywords["refnames"].strip() 186 | if refnames.startswith("$Format"): 187 | if verbose: 188 | print("keywords are unexpanded, not using") 189 | raise NotThisMethod("unexpanded keywords, not a git-archive tarball") 190 | refs = {r.strip() for r in refnames.strip("()").split(",")} 191 | # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of 192 | # just "foo-1.0". If we see a "tag: " prefix, prefer those. 193 | TAG = "tag: " 194 | tags = {r[len(TAG) :] for r in refs if r.startswith(TAG)} 195 | if not tags: 196 | # Either we're using git < 1.8.3, or there really are no tags. We use 197 | # a heuristic: assume all version tags have a digit. The old git %d 198 | # expansion behaves like git log --decorate=short and strips out the 199 | # refs/heads/ and refs/tags/ prefixes that would let us distinguish 200 | # between branches and tags. By ignoring refnames without digits, we 201 | # filter out many common branch names like "release" and 202 | # "stabilization", as well as "HEAD" and "master". 203 | tags = {r for r in refs if re.search(r"\d", r)} 204 | if verbose: 205 | print("discarding '%s', no digits" % ",".join(refs - tags)) 206 | if verbose: 207 | print("likely tags: %s" % ",".join(sorted(tags))) 208 | for ref in sorted(tags): 209 | # sorting will prefer e.g. "2.0" over "2.0rc1" 210 | if ref.startswith(tag_prefix): 211 | r = ref[len(tag_prefix) :] 212 | # Filter out refs that exactly match prefix or that don't start 213 | # with a number once the prefix is stripped (mostly a concern 214 | # when prefix is '') 215 | if not re.match(r"\d", r): 216 | continue 217 | if verbose: 218 | print("picking %s" % r) 219 | return { 220 | "version": r, 221 | "full-revisionid": keywords["full"].strip(), 222 | "dirty": False, 223 | "error": None, 224 | "date": date, 225 | } 226 | # no suitable tags, so version is "0+unknown", but full hex is still there 227 | if verbose: 228 | print("no suitable tags, using unknown + full revision id") 229 | return { 230 | "version": "0+unknown", 231 | "full-revisionid": keywords["full"].strip(), 232 | "dirty": False, 233 | "error": "no suitable tags", 234 | "date": None, 235 | } 236 | 237 | 238 | @register_vcs_handler("git", "pieces_from_vcs") 239 | def git_pieces_from_vcs(tag_prefix, root, verbose, runner=run_command): 240 | """Get version from 'git describe' in the root of the source tree. 241 | 242 | This only gets called if the git-archive 'subst' keywords were *not* 243 | expanded, and _version.py hasn't already been rewritten with a short 244 | version string, meaning we're inside a checked out source tree. 245 | """ 246 | GITS = ["git"] 247 | TAG_PREFIX_REGEX = "*" 248 | if sys.platform == "win32": 249 | GITS = ["git.cmd", "git.exe"] 250 | TAG_PREFIX_REGEX = r"\*" 251 | 252 | _, rc = runner(GITS, ["rev-parse", "--git-dir"], cwd=root, hide_stderr=True) 253 | if rc != 0: 254 | if verbose: 255 | print("Directory %s not under git control" % root) 256 | raise NotThisMethod("'git rev-parse --git-dir' returned error") 257 | 258 | # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] 259 | # if there isn't one, this yields HEX[-dirty] (no NUM) 260 | describe_out, rc = runner( 261 | GITS, 262 | [ 263 | "describe", 264 | "--tags", 265 | "--dirty", 266 | "--always", 267 | "--long", 268 | "--match", 269 | "%s%s" % (tag_prefix, TAG_PREFIX_REGEX), 270 | ], 271 | cwd=root, 272 | ) 273 | # --long was added in git-1.5.5 274 | if describe_out is None: 275 | raise NotThisMethod("'git describe' failed") 276 | describe_out = describe_out.strip() 277 | full_out, rc = runner(GITS, ["rev-parse", "HEAD"], cwd=root) 278 | if full_out is None: 279 | raise NotThisMethod("'git rev-parse' failed") 280 | full_out = full_out.strip() 281 | 282 | pieces = {} 283 | pieces["long"] = full_out 284 | pieces["short"] = full_out[:7] # maybe improved later 285 | pieces["error"] = None 286 | 287 | branch_name, rc = runner(GITS, ["rev-parse", "--abbrev-ref", "HEAD"], cwd=root) 288 | # --abbrev-ref was added in git-1.6.3 289 | if rc != 0 or branch_name is None: 290 | raise NotThisMethod("'git rev-parse --abbrev-ref' returned error") 291 | branch_name = branch_name.strip() 292 | 293 | if branch_name == "HEAD": 294 | # If we aren't exactly on a branch, pick a branch which represents 295 | # the current commit. If all else fails, we are on a branchless 296 | # commit. 297 | branches, rc = runner(GITS, ["branch", "--contains"], cwd=root) 298 | # --contains was added in git-1.5.4 299 | if rc != 0 or branches is None: 300 | raise NotThisMethod("'git branch --contains' returned error") 301 | branches = branches.split("\n") 302 | 303 | # Remove the first line if we're running detached 304 | if "(" in branches[0]: 305 | branches.pop(0) 306 | 307 | # Strip off the leading "* " from the list of branches. 308 | branches = [branch[2:] for branch in branches] 309 | if "master" in branches: 310 | branch_name = "master" 311 | elif not branches: 312 | branch_name = None 313 | else: 314 | # Pick the first branch that is returned. Good or bad. 315 | branch_name = branches[0] 316 | 317 | pieces["branch"] = branch_name 318 | 319 | # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty] 320 | # TAG might have hyphens. 321 | git_describe = describe_out 322 | 323 | # look for -dirty suffix 324 | dirty = git_describe.endswith("-dirty") 325 | pieces["dirty"] = dirty 326 | if dirty: 327 | git_describe = git_describe[: git_describe.rindex("-dirty")] 328 | 329 | # now we have TAG-NUM-gHEX or HEX 330 | 331 | if "-" in git_describe: 332 | # TAG-NUM-gHEX 333 | mo = re.search(r"^(.+)-(\d+)-g([0-9a-f]+)$", git_describe) 334 | if not mo: 335 | # unparsable. Maybe git-describe is misbehaving? 336 | pieces["error"] = "unable to parse git-describe output: '%s'" % describe_out 337 | return pieces 338 | 339 | # tag 340 | full_tag = mo.group(1) 341 | if not full_tag.startswith(tag_prefix): 342 | if verbose: 343 | fmt = "tag '%s' doesn't start with prefix '%s'" 344 | print(fmt % (full_tag, tag_prefix)) 345 | pieces["error"] = "tag '%s' doesn't start with prefix '%s'" % ( 346 | full_tag, 347 | tag_prefix, 348 | ) 349 | return pieces 350 | pieces["closest-tag"] = full_tag[len(tag_prefix) :] 351 | 352 | # distance: number of commits since tag 353 | pieces["distance"] = int(mo.group(2)) 354 | 355 | # commit: short hex revision ID 356 | pieces["short"] = mo.group(3) 357 | 358 | else: 359 | # HEX: no tags 360 | pieces["closest-tag"] = None 361 | count_out, rc = runner(GITS, ["rev-list", "HEAD", "--count"], cwd=root) 362 | pieces["distance"] = int(count_out) # total number of commits 363 | 364 | # commit date: see ISO-8601 comment in git_versions_from_keywords() 365 | date = runner(GITS, ["show", "-s", "--format=%ci", "HEAD"], cwd=root)[0].strip() 366 | # Use only the last line. Previous lines may contain GPG signature 367 | # information. 368 | date = date.splitlines()[-1] 369 | pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1) 370 | 371 | return pieces 372 | 373 | 374 | def plus_or_dot(pieces): 375 | """Return a + if we don't already have one, else return a .""" 376 | if "+" in pieces.get("closest-tag", ""): 377 | return "." 378 | return "+" 379 | 380 | 381 | def render_pep440(pieces): 382 | """Build up version string, with post-release "local version identifier". 383 | 384 | Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you 385 | get a tagged build and then dirty it, you'll get TAG+0.gHEX.dirty 386 | 387 | Exceptions: 388 | 1: no tags. git_describe was just HEX. 0+untagged.DISTANCE.gHEX[.dirty] 389 | """ 390 | if pieces["closest-tag"]: 391 | rendered = pieces["closest-tag"] 392 | if pieces["distance"] or pieces["dirty"]: 393 | rendered += plus_or_dot(pieces) 394 | rendered += "%d.g%s" % (pieces["distance"], pieces["short"]) 395 | if pieces["dirty"]: 396 | rendered += ".dirty" 397 | else: 398 | # exception #1 399 | rendered = "0+untagged.%d.g%s" % (pieces["distance"], pieces["short"]) 400 | if pieces["dirty"]: 401 | rendered += ".dirty" 402 | return rendered 403 | 404 | 405 | def render_pep440_branch(pieces): 406 | """TAG[[.dev0]+DISTANCE.gHEX[.dirty]] . 407 | 408 | The ".dev0" means not master branch. Note that .dev0 sorts backwards 409 | (a feature branch will appear "older" than the master branch). 410 | 411 | Exceptions: 412 | 1: no tags. 0[.dev0]+untagged.DISTANCE.gHEX[.dirty] 413 | """ 414 | if pieces["closest-tag"]: 415 | rendered = pieces["closest-tag"] 416 | if pieces["distance"] or pieces["dirty"]: 417 | if pieces["branch"] != "master": 418 | rendered += ".dev0" 419 | rendered += plus_or_dot(pieces) 420 | rendered += "%d.g%s" % (pieces["distance"], pieces["short"]) 421 | if pieces["dirty"]: 422 | rendered += ".dirty" 423 | else: 424 | # exception #1 425 | rendered = "0" 426 | if pieces["branch"] != "master": 427 | rendered += ".dev0" 428 | rendered += "+untagged.%d.g%s" % (pieces["distance"], pieces["short"]) 429 | if pieces["dirty"]: 430 | rendered += ".dirty" 431 | return rendered 432 | 433 | 434 | def pep440_split_post(ver): 435 | """Split pep440 version string at the post-release segment. 436 | 437 | Returns the release segments before the post-release and the 438 | post-release version number (or -1 if no post-release segment is present). 439 | """ 440 | vc = str.split(ver, ".post") 441 | return vc[0], int(vc[1] or 0) if len(vc) == 2 else None 442 | 443 | 444 | def render_pep440_pre(pieces): 445 | """TAG[.postN.devDISTANCE] -- No -dirty. 446 | 447 | Exceptions: 448 | 1: no tags. 0.post0.devDISTANCE 449 | """ 450 | if pieces["closest-tag"]: 451 | if pieces["distance"]: 452 | # update the post release segment 453 | tag_version, post_version = pep440_split_post(pieces["closest-tag"]) 454 | rendered = tag_version 455 | if post_version is not None: 456 | rendered += ".post%d.dev%d" % (post_version + 1, pieces["distance"]) 457 | else: 458 | rendered += ".post0.dev%d" % (pieces["distance"]) 459 | else: 460 | # no commits, use the tag as the version 461 | rendered = pieces["closest-tag"] 462 | else: 463 | # exception #1 464 | rendered = "0.post0.dev%d" % pieces["distance"] 465 | return rendered 466 | 467 | 468 | def render_pep440_post(pieces): 469 | """TAG[.postDISTANCE[.dev0]+gHEX] . 470 | 471 | The ".dev0" means dirty. Note that .dev0 sorts backwards 472 | (a dirty tree will appear "older" than the corresponding clean one), 473 | but you shouldn't be releasing software with -dirty anyways. 474 | 475 | Exceptions: 476 | 1: no tags. 0.postDISTANCE[.dev0] 477 | """ 478 | if pieces["closest-tag"]: 479 | rendered = pieces["closest-tag"] 480 | if pieces["distance"] or pieces["dirty"]: 481 | rendered += ".post%d" % pieces["distance"] 482 | if pieces["dirty"]: 483 | rendered += ".dev0" 484 | rendered += plus_or_dot(pieces) 485 | rendered += "g%s" % pieces["short"] 486 | else: 487 | # exception #1 488 | rendered = "0.post%d" % pieces["distance"] 489 | if pieces["dirty"]: 490 | rendered += ".dev0" 491 | rendered += "+g%s" % pieces["short"] 492 | return rendered 493 | 494 | 495 | def render_pep440_post_branch(pieces): 496 | """TAG[.postDISTANCE[.dev0]+gHEX[.dirty]] . 497 | 498 | The ".dev0" means not master branch. 499 | 500 | Exceptions: 501 | 1: no tags. 0.postDISTANCE[.dev0]+gHEX[.dirty] 502 | """ 503 | if pieces["closest-tag"]: 504 | rendered = pieces["closest-tag"] 505 | if pieces["distance"] or pieces["dirty"]: 506 | rendered += ".post%d" % pieces["distance"] 507 | if pieces["branch"] != "master": 508 | rendered += ".dev0" 509 | rendered += plus_or_dot(pieces) 510 | rendered += "g%s" % pieces["short"] 511 | if pieces["dirty"]: 512 | rendered += ".dirty" 513 | else: 514 | # exception #1 515 | rendered = "0.post%d" % pieces["distance"] 516 | if pieces["branch"] != "master": 517 | rendered += ".dev0" 518 | rendered += "+g%s" % pieces["short"] 519 | if pieces["dirty"]: 520 | rendered += ".dirty" 521 | return rendered 522 | 523 | 524 | def render_pep440_old(pieces): 525 | """TAG[.postDISTANCE[.dev0]] . 526 | 527 | The ".dev0" means dirty. 528 | 529 | Exceptions: 530 | 1: no tags. 0.postDISTANCE[.dev0] 531 | """ 532 | if pieces["closest-tag"]: 533 | rendered = pieces["closest-tag"] 534 | if pieces["distance"] or pieces["dirty"]: 535 | rendered += ".post%d" % pieces["distance"] 536 | if pieces["dirty"]: 537 | rendered += ".dev0" 538 | else: 539 | # exception #1 540 | rendered = "0.post%d" % pieces["distance"] 541 | if pieces["dirty"]: 542 | rendered += ".dev0" 543 | return rendered 544 | 545 | 546 | def render_git_describe(pieces): 547 | """TAG[-DISTANCE-gHEX][-dirty]. 548 | 549 | Like 'git describe --tags --dirty --always'. 550 | 551 | Exceptions: 552 | 1: no tags. HEX[-dirty] (note: no 'g' prefix) 553 | """ 554 | if pieces["closest-tag"]: 555 | rendered = pieces["closest-tag"] 556 | if pieces["distance"]: 557 | rendered += "-%d-g%s" % (pieces["distance"], pieces["short"]) 558 | else: 559 | # exception #1 560 | rendered = pieces["short"] 561 | if pieces["dirty"]: 562 | rendered += "-dirty" 563 | return rendered 564 | 565 | 566 | def render_git_describe_long(pieces): 567 | """TAG-DISTANCE-gHEX[-dirty]. 568 | 569 | Like 'git describe --tags --dirty --always -long'. 570 | The distance/hash is unconditional. 571 | 572 | Exceptions: 573 | 1: no tags. HEX[-dirty] (note: no 'g' prefix) 574 | """ 575 | if pieces["closest-tag"]: 576 | rendered = pieces["closest-tag"] 577 | rendered += "-%d-g%s" % (pieces["distance"], pieces["short"]) 578 | else: 579 | # exception #1 580 | rendered = pieces["short"] 581 | if pieces["dirty"]: 582 | rendered += "-dirty" 583 | return rendered 584 | 585 | 586 | def render(pieces, style): 587 | """Render the given version pieces into the requested style.""" 588 | if pieces["error"]: 589 | return { 590 | "version": "unknown", 591 | "full-revisionid": pieces.get("long"), 592 | "dirty": None, 593 | "error": pieces["error"], 594 | "date": None, 595 | } 596 | 597 | if not style or style == "default": 598 | style = "pep440" # the default 599 | 600 | if style == "pep440": 601 | rendered = render_pep440(pieces) 602 | elif style == "pep440-branch": 603 | rendered = render_pep440_branch(pieces) 604 | elif style == "pep440-pre": 605 | rendered = render_pep440_pre(pieces) 606 | elif style == "pep440-post": 607 | rendered = render_pep440_post(pieces) 608 | elif style == "pep440-post-branch": 609 | rendered = render_pep440_post_branch(pieces) 610 | elif style == "pep440-old": 611 | rendered = render_pep440_old(pieces) 612 | elif style == "git-describe": 613 | rendered = render_git_describe(pieces) 614 | elif style == "git-describe-long": 615 | rendered = render_git_describe_long(pieces) 616 | else: 617 | raise ValueError("unknown style '%s'" % style) 618 | 619 | return { 620 | "version": rendered, 621 | "full-revisionid": pieces["long"], 622 | "dirty": pieces["dirty"], 623 | "error": None, 624 | "date": pieces.get("date"), 625 | } 626 | 627 | 628 | def get_versions(): 629 | """Get version information or return default if unable to do so.""" 630 | # I am in _version.py, which lives at ROOT/VERSIONFILE_SOURCE. If we have 631 | # __file__, we can work backwards from there to the root. Some 632 | # py2exe/bbfreeze/non-CPython implementations don't do __file__, in which 633 | # case we can only use expanded keywords. 634 | 635 | cfg = get_config() 636 | verbose = cfg.verbose 637 | 638 | try: 639 | return git_versions_from_keywords(get_keywords(), cfg.tag_prefix, verbose) 640 | except NotThisMethod: 641 | pass 642 | 643 | try: 644 | root = os.path.realpath(__file__) 645 | # versionfile_source is the relative path from the top of the source 646 | # tree (where the .git directory might live) to this file. Invert 647 | # this to find the root from __file__. 648 | for _ in cfg.versionfile_source.split("/"): 649 | root = os.path.dirname(root) 650 | except NameError: 651 | return { 652 | "version": "0+unknown", 653 | "full-revisionid": None, 654 | "dirty": None, 655 | "error": "unable to find root of source tree", 656 | "date": None, 657 | } 658 | 659 | try: 660 | pieces = git_pieces_from_vcs(cfg.tag_prefix, root, verbose) 661 | return render(pieces, cfg.style) 662 | except NotThisMethod: 663 | pass 664 | 665 | try: 666 | if cfg.parentdir_prefix: 667 | return versions_from_parentdir(cfg.parentdir_prefix, root, verbose) 668 | except NotThisMethod: 669 | pass 670 | 671 | return { 672 | "version": "0+unknown", 673 | "full-revisionid": None, 674 | "dirty": None, 675 | "error": "unable to compute version", 676 | "date": None, 677 | } 678 | -------------------------------------------------------------------------------- /rembg/bg.py: -------------------------------------------------------------------------------- 1 | import io 2 | import sys 3 | from enum import Enum 4 | from typing import Any, List, Optional, Tuple, Union, cast 5 | 6 | import numpy as np 7 | import onnxruntime as ort 8 | from cv2 import ( 9 | BORDER_DEFAULT, 10 | MORPH_ELLIPSE, 11 | MORPH_OPEN, 12 | GaussianBlur, 13 | getStructuringElement, 14 | morphologyEx, 15 | ) 16 | from PIL import Image, ImageOps 17 | from PIL.Image import Image as PILImage 18 | from pymatting.alpha.estimate_alpha_cf import estimate_alpha_cf 19 | from pymatting.foreground.estimate_foreground_ml import estimate_foreground_ml 20 | from pymatting.util.util import stack_images 21 | from scipy.ndimage import binary_erosion 22 | 23 | from .session_factory import new_session 24 | from .sessions import sessions, sessions_names 25 | from .sessions.base import BaseSession 26 | 27 | ort.set_default_logger_severity(3) 28 | 29 | kernel = getStructuringElement(MORPH_ELLIPSE, (3, 3)) 30 | 31 | 32 | class ReturnType(Enum): 33 | BYTES = 0 34 | PILLOW = 1 35 | NDARRAY = 2 36 | 37 | 38 | def alpha_matting_cutout( 39 | img: PILImage, 40 | mask: PILImage, 41 | foreground_threshold: int, 42 | background_threshold: int, 43 | erode_structure_size: int, 44 | ) -> PILImage: 45 | """ 46 | Perform alpha matting on an image using a given mask and threshold values. 47 | 48 | This function takes a PIL image `img` and a PIL image `mask` as input, along with 49 | the `foreground_threshold` and `background_threshold` values used to determine 50 | foreground and background pixels. The `erode_structure_size` parameter specifies 51 | the size of the erosion structure to be applied to the mask. 52 | 53 | The function returns a PIL image representing the cutout of the foreground object 54 | from the original image. 55 | """ 56 | if img.mode == "RGBA" or img.mode == "CMYK": 57 | img = img.convert("RGB") 58 | 59 | img_array = np.asarray(img) 60 | mask_array = np.asarray(mask) 61 | 62 | is_foreground = mask_array > foreground_threshold 63 | is_background = mask_array < background_threshold 64 | 65 | structure = None 66 | if erode_structure_size > 0: 67 | structure = np.ones( 68 | (erode_structure_size, erode_structure_size), dtype=np.uint8 69 | ) 70 | 71 | is_foreground = binary_erosion(is_foreground, structure=structure) 72 | is_background = binary_erosion(is_background, structure=structure, border_value=1) 73 | 74 | trimap = np.full(mask_array.shape, dtype=np.uint8, fill_value=128) 75 | trimap[is_foreground] = 255 76 | trimap[is_background] = 0 77 | 78 | img_normalized = img_array / 255.0 79 | trimap_normalized = trimap / 255.0 80 | 81 | alpha = estimate_alpha_cf(img_normalized, trimap_normalized) 82 | foreground = estimate_foreground_ml(img_normalized, alpha) 83 | cutout = stack_images(foreground, alpha) 84 | 85 | cutout = np.clip(cutout * 255, 0, 255).astype(np.uint8) 86 | cutout = Image.fromarray(cutout) 87 | 88 | return cutout 89 | 90 | 91 | def naive_cutout(img: PILImage, mask: PILImage) -> PILImage: 92 | """ 93 | Perform a simple cutout operation on an image using a mask. 94 | 95 | This function takes a PIL image `img` and a PIL image `mask` as input. 96 | It uses the mask to create a new image where the pixels from `img` are 97 | cut out based on the mask. 98 | 99 | The function returns a PIL image representing the cutout of the original 100 | image using the mask. 101 | """ 102 | empty = Image.new("RGBA", (img.size), 0) 103 | cutout = Image.composite(img, empty, mask) 104 | return cutout 105 | 106 | 107 | def putalpha_cutout(img: PILImage, mask: PILImage) -> PILImage: 108 | """ 109 | Apply the specified mask to the image as an alpha cutout. 110 | 111 | Args: 112 | img (PILImage): The image to be modified. 113 | mask (PILImage): The mask to be applied. 114 | 115 | Returns: 116 | PILImage: The modified image with the alpha cutout applied. 117 | """ 118 | img.putalpha(mask) 119 | return img 120 | 121 | 122 | def get_concat_v_multi(imgs: List[PILImage]) -> PILImage: 123 | """ 124 | Concatenate multiple images vertically. 125 | 126 | Args: 127 | imgs (List[PILImage]): The list of images to be concatenated. 128 | 129 | Returns: 130 | PILImage: The concatenated image. 131 | """ 132 | pivot = imgs.pop(0) 133 | for im in imgs: 134 | pivot = get_concat_v(pivot, im) 135 | return pivot 136 | 137 | 138 | def get_concat_v(img1: PILImage, img2: PILImage) -> PILImage: 139 | """ 140 | Concatenate two images vertically. 141 | 142 | Args: 143 | img1 (PILImage): The first image. 144 | img2 (PILImage): The second image to be concatenated below the first image. 145 | 146 | Returns: 147 | PILImage: The concatenated image. 148 | """ 149 | dst = Image.new("RGBA", (img1.width, img1.height + img2.height)) 150 | dst.paste(img1, (0, 0)) 151 | dst.paste(img2, (0, img1.height)) 152 | return dst 153 | 154 | 155 | def post_process(mask: np.ndarray) -> np.ndarray: 156 | """ 157 | Post Process the mask for a smooth boundary by applying Morphological Operations 158 | Research based on paper: https://www.sciencedirect.com/science/article/pii/S2352914821000757 159 | args: 160 | mask: Binary Numpy Mask 161 | """ 162 | mask = morphologyEx(mask, MORPH_OPEN, kernel) 163 | mask = GaussianBlur(mask, (5, 5), sigmaX=2, sigmaY=2, borderType=BORDER_DEFAULT) 164 | mask = np.where(mask < 127, 0, 255).astype(np.uint8) # type: ignore 165 | return mask 166 | 167 | 168 | def apply_background_color(img: PILImage, color: Tuple[int, int, int, int]) -> PILImage: 169 | """ 170 | Apply the specified background color to the image. 171 | 172 | Args: 173 | img (PILImage): The image to be modified. 174 | color (Tuple[int, int, int, int]): The RGBA color to be applied. 175 | 176 | Returns: 177 | PILImage: The modified image with the background color applied. 178 | """ 179 | background = Image.new("RGBA", img.size, tuple(color)) 180 | colored_image = Image.alpha_composite(background, img) 181 | 182 | return colored_image 183 | 184 | 185 | def fix_image_orientation(img: PILImage) -> PILImage: 186 | """ 187 | Fix the orientation of the image based on its EXIF data. 188 | 189 | Args: 190 | img (PILImage): The image to be fixed. 191 | 192 | Returns: 193 | PILImage: The fixed image. 194 | """ 195 | return cast(PILImage, ImageOps.exif_transpose(img)) 196 | 197 | 198 | def download_models(models: tuple[str, ...]) -> None: 199 | """ 200 | Download models for image processing. 201 | """ 202 | if len(models) == 0: 203 | print("No models specified, downloading all models") 204 | models = tuple(sessions_names) 205 | 206 | for model in models: 207 | session = sessions.get(model) 208 | if session is None: 209 | print(f"Error: no model found: {model}") 210 | sys.exit(1) 211 | else: 212 | print(f"Downloading model: {model}") 213 | session.download_models() 214 | 215 | 216 | def remove( 217 | data: Union[bytes, PILImage, np.ndarray], 218 | alpha_matting: bool = False, 219 | alpha_matting_foreground_threshold: int = 240, 220 | alpha_matting_background_threshold: int = 10, 221 | alpha_matting_erode_size: int = 10, 222 | session: Optional[BaseSession] = None, 223 | only_mask: bool = False, 224 | post_process_mask: bool = False, 225 | bgcolor: Optional[Tuple[int, int, int, int]] = None, 226 | force_return_bytes: bool = False, 227 | *args: Optional[Any], 228 | **kwargs: Optional[Any], 229 | ) -> Union[bytes, PILImage, np.ndarray]: 230 | """ 231 | Remove the background from an input image. 232 | 233 | This function takes in various parameters and returns a modified version of the input image with the background removed. The function can handle input data in the form of bytes, a PIL image, or a numpy array. The function first checks the type of the input data and converts it to a PIL image if necessary. It then fixes the orientation of the image and proceeds to perform background removal using the 'u2net' model. The result is a list of binary masks representing the foreground objects in the image. These masks are post-processed and combined to create a final cutout image. If a background color is provided, it is applied to the cutout image. The function returns the resulting cutout image in the format specified by the input 'return_type' parameter or as python bytes if force_return_bytes is true. 234 | 235 | Parameters: 236 | data (Union[bytes, PILImage, np.ndarray]): The input image data. 237 | alpha_matting (bool, optional): Flag indicating whether to use alpha matting. Defaults to False. 238 | alpha_matting_foreground_threshold (int, optional): Foreground threshold for alpha matting. Defaults to 240. 239 | alpha_matting_background_threshold (int, optional): Background threshold for alpha matting. Defaults to 10. 240 | alpha_matting_erode_size (int, optional): Erosion size for alpha matting. Defaults to 10. 241 | session (Optional[BaseSession], optional): A session object for the 'u2net' model. Defaults to None. 242 | only_mask (bool, optional): Flag indicating whether to return only the binary masks. Defaults to False. 243 | post_process_mask (bool, optional): Flag indicating whether to post-process the masks. Defaults to False. 244 | bgcolor (Optional[Tuple[int, int, int, int]], optional): Background color for the cutout image. Defaults to None. 245 | force_return_bytes (bool, optional): Flag indicating whether to return the cutout image as bytes. Defaults to False. 246 | *args (Optional[Any]): Additional positional arguments. 247 | **kwargs (Optional[Any]): Additional keyword arguments. 248 | 249 | Returns: 250 | Union[bytes, PILImage, np.ndarray]: The cutout image with the background removed. 251 | """ 252 | if isinstance(data, bytes) or force_return_bytes: 253 | return_type = ReturnType.BYTES 254 | img = cast(PILImage, Image.open(io.BytesIO(cast(bytes, data)))) 255 | elif isinstance(data, PILImage): 256 | return_type = ReturnType.PILLOW 257 | img = cast(PILImage, data) 258 | elif isinstance(data, np.ndarray): 259 | return_type = ReturnType.NDARRAY 260 | img = cast(PILImage, Image.fromarray(data)) 261 | else: 262 | raise ValueError( 263 | "Input type {} is not supported. Try using force_return_bytes=True to force python bytes output".format( 264 | type(data) 265 | ) 266 | ) 267 | 268 | putalpha = kwargs.pop("putalpha", False) 269 | 270 | # Fix image orientation 271 | img = fix_image_orientation(img) 272 | 273 | if session is None: 274 | session = new_session("u2net", *args, **kwargs) 275 | 276 | masks = session.predict(img, *args, **kwargs) 277 | cutouts = [] 278 | 279 | for mask in masks: 280 | if post_process_mask: 281 | mask = Image.fromarray(post_process(np.array(mask))) 282 | 283 | if only_mask: 284 | cutout = mask 285 | 286 | elif alpha_matting: 287 | try: 288 | cutout = alpha_matting_cutout( 289 | img, 290 | mask, 291 | alpha_matting_foreground_threshold, 292 | alpha_matting_background_threshold, 293 | alpha_matting_erode_size, 294 | ) 295 | except ValueError: 296 | if putalpha: 297 | cutout = putalpha_cutout(img, mask) 298 | else: 299 | cutout = naive_cutout(img, mask) 300 | else: 301 | if putalpha: 302 | cutout = putalpha_cutout(img, mask) 303 | else: 304 | cutout = naive_cutout(img, mask) 305 | 306 | cutouts.append(cutout) 307 | 308 | cutout = img 309 | if len(cutouts) > 0: 310 | cutout = get_concat_v_multi(cutouts) 311 | 312 | if bgcolor is not None and not only_mask: 313 | cutout = apply_background_color(cutout, bgcolor) 314 | 315 | if ReturnType.PILLOW == return_type: 316 | return cutout 317 | 318 | if ReturnType.NDARRAY == return_type: 319 | return np.asarray(cutout) 320 | 321 | bio = io.BytesIO() 322 | cutout.save(bio, "PNG") 323 | bio.seek(0) 324 | 325 | return bio.read() 326 | -------------------------------------------------------------------------------- /rembg/cli.py: -------------------------------------------------------------------------------- 1 | import click 2 | 3 | from . import _version 4 | from .commands import command_functions 5 | 6 | 7 | @click.group() 8 | @click.version_option(version=_version.get_versions()["version"]) 9 | def _main() -> None: 10 | pass 11 | 12 | 13 | for command in command_functions: 14 | _main.add_command(command) 15 | 16 | _main() 17 | -------------------------------------------------------------------------------- /rembg/commands/__init__.py: -------------------------------------------------------------------------------- 1 | command_functions = [] 2 | 3 | from .b_command import b_command 4 | from .d_command import d_command 5 | from .i_command import i_command 6 | from .p_command import p_command 7 | from .s_command import s_command 8 | 9 | command_functions.append(b_command) 10 | command_functions.append(d_command) 11 | command_functions.append(i_command) 12 | command_functions.append(p_command) 13 | command_functions.append(s_command) 14 | -------------------------------------------------------------------------------- /rembg/commands/b_command.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import io 3 | import json 4 | import os 5 | import sys 6 | from typing import IO 7 | 8 | import click 9 | import PIL 10 | 11 | from ..bg import remove 12 | from ..session_factory import new_session 13 | from ..sessions import sessions_names 14 | 15 | 16 | @click.command( # type: ignore 17 | name="b", 18 | help="for a byte stream as input", 19 | ) 20 | @click.option( 21 | "-m", 22 | "--model", 23 | default="u2net", 24 | type=click.Choice(sessions_names), 25 | show_default=True, 26 | show_choices=True, 27 | help="model name", 28 | ) 29 | @click.option( 30 | "-a", 31 | "--alpha-matting", 32 | is_flag=True, 33 | show_default=True, 34 | help="use alpha matting", 35 | ) 36 | @click.option( 37 | "-af", 38 | "--alpha-matting-foreground-threshold", 39 | default=240, 40 | type=int, 41 | show_default=True, 42 | help="trimap fg threshold", 43 | ) 44 | @click.option( 45 | "-ab", 46 | "--alpha-matting-background-threshold", 47 | default=10, 48 | type=int, 49 | show_default=True, 50 | help="trimap bg threshold", 51 | ) 52 | @click.option( 53 | "-ae", 54 | "--alpha-matting-erode-size", 55 | default=10, 56 | type=int, 57 | show_default=True, 58 | help="erode size", 59 | ) 60 | @click.option( 61 | "-om", 62 | "--only-mask", 63 | is_flag=True, 64 | show_default=True, 65 | help="output only the mask", 66 | ) 67 | @click.option( 68 | "-ppm", 69 | "--post-process-mask", 70 | is_flag=True, 71 | show_default=True, 72 | help="post process the mask", 73 | ) 74 | @click.option( 75 | "-bgc", 76 | "--bgcolor", 77 | default=(0, 0, 0, 0), 78 | type=(int, int, int, int), 79 | nargs=4, 80 | help="Background color (R G B A) to replace the removed background with", 81 | ) 82 | @click.option("-x", "--extras", type=str) 83 | @click.option( 84 | "-o", 85 | "--output_specifier", 86 | type=str, 87 | help="printf-style specifier for output filenames (e.g. 'output-%d.png'))", 88 | ) 89 | @click.argument( 90 | "image_width", 91 | type=int, 92 | ) 93 | @click.argument( 94 | "image_height", 95 | type=int, 96 | ) 97 | def b_command( 98 | model: str, 99 | extras: str, 100 | image_width: int, 101 | image_height: int, 102 | output_specifier: str, 103 | **kwargs 104 | ) -> None: 105 | """ 106 | Command-line interface for processing images by removing the background using a specified model and generating a mask. 107 | 108 | This CLI command takes several options and arguments to configure the background removal process and save the processed images. 109 | 110 | Parameters: 111 | model (str): The name of the model to use for background removal. 112 | extras (str): Additional options in JSON format that can be passed to customize the background removal process. 113 | image_width (int): The width of the input images in pixels. 114 | image_height (int): The height of the input images in pixels. 115 | output_specifier (str): A printf-style specifier for the output filenames. If specified, the processed images will be saved to the specified output directory with filenames generated using the specifier. 116 | **kwargs: Additional keyword arguments that can be used to customize the background removal process. 117 | 118 | Returns: 119 | None 120 | """ 121 | if extras: 122 | try: 123 | kwargs.update(json.loads(extras)) 124 | except Exception: 125 | raise click.BadParameter("extras must be a valid JSON string") 126 | 127 | session = new_session(model, **kwargs) 128 | bytes_per_img = image_width * image_height * 3 129 | 130 | if output_specifier: 131 | output_dir = os.path.dirname( 132 | os.path.abspath(os.path.expanduser(output_specifier)) 133 | ) 134 | 135 | if not os.path.isdir(output_dir): 136 | os.makedirs(output_dir, exist_ok=True) 137 | 138 | def img_to_byte_array(img: PIL.Image.Image) -> bytes: 139 | buff = io.BytesIO() 140 | img.save(buff, format="PNG") 141 | return buff.getvalue() 142 | 143 | async def connect_stdin_stdout(): 144 | loop = asyncio.get_event_loop() 145 | reader = asyncio.StreamReader() 146 | protocol = asyncio.StreamReaderProtocol(reader) 147 | 148 | await loop.connect_read_pipe(lambda: protocol, sys.stdin) 149 | w_transport, w_protocol = await loop.connect_write_pipe( 150 | asyncio.streams.FlowControlMixin, sys.stdout 151 | ) 152 | 153 | writer = asyncio.StreamWriter(w_transport, w_protocol, reader, loop) 154 | return reader, writer 155 | 156 | async def main(): 157 | reader, writer = await connect_stdin_stdout() 158 | 159 | idx = 0 160 | while True: 161 | try: 162 | img_bytes = await reader.readexactly(bytes_per_img) 163 | if not img_bytes: 164 | break 165 | 166 | img = PIL.Image.frombytes("RGB", (image_width, image_height), img_bytes) 167 | output = remove(img, session=session, **kwargs) 168 | 169 | if output_specifier: 170 | output.save((output_specifier % idx), format="PNG") 171 | else: 172 | writer.write(img_to_byte_array(output)) 173 | 174 | idx += 1 175 | except asyncio.IncompleteReadError: 176 | break 177 | 178 | asyncio.run(main()) 179 | -------------------------------------------------------------------------------- /rembg/commands/d_command.py: -------------------------------------------------------------------------------- 1 | import click 2 | 3 | from ..bg import download_models 4 | 5 | 6 | @click.command( # type: ignore 7 | name="d", 8 | help="download models", 9 | ) 10 | @click.argument("models", nargs=-1) 11 | def d_command(models: tuple[str, ...]) -> None: 12 | """ 13 | Download models 14 | """ 15 | download_models(models) 16 | -------------------------------------------------------------------------------- /rembg/commands/i_command.py: -------------------------------------------------------------------------------- 1 | import json 2 | import sys 3 | from typing import IO 4 | 5 | import click 6 | 7 | from ..bg import remove 8 | from ..session_factory import new_session 9 | from ..sessions import sessions_names 10 | 11 | 12 | @click.command( # type: ignore 13 | name="i", 14 | help="for a file as input", 15 | ) 16 | @click.option( 17 | "-m", 18 | "--model", 19 | default="u2net", 20 | type=click.Choice(sessions_names), 21 | show_default=True, 22 | show_choices=True, 23 | help="model name", 24 | ) 25 | @click.option( 26 | "-a", 27 | "--alpha-matting", 28 | is_flag=True, 29 | show_default=True, 30 | help="use alpha matting", 31 | ) 32 | @click.option( 33 | "-af", 34 | "--alpha-matting-foreground-threshold", 35 | default=240, 36 | type=int, 37 | show_default=True, 38 | help="trimap fg threshold", 39 | ) 40 | @click.option( 41 | "-ab", 42 | "--alpha-matting-background-threshold", 43 | default=10, 44 | type=int, 45 | show_default=True, 46 | help="trimap bg threshold", 47 | ) 48 | @click.option( 49 | "-ae", 50 | "--alpha-matting-erode-size", 51 | default=10, 52 | type=int, 53 | show_default=True, 54 | help="erode size", 55 | ) 56 | @click.option( 57 | "-om", 58 | "--only-mask", 59 | is_flag=True, 60 | show_default=True, 61 | help="output only the mask", 62 | ) 63 | @click.option( 64 | "-ppm", 65 | "--post-process-mask", 66 | is_flag=True, 67 | show_default=True, 68 | help="post process the mask", 69 | ) 70 | @click.option( 71 | "-bgc", 72 | "--bgcolor", 73 | default=(0, 0, 0, 0), 74 | type=(int, int, int, int), 75 | nargs=4, 76 | help="Background color (R G B A) to replace the removed background with", 77 | ) 78 | @click.option("-x", "--extras", type=str) 79 | @click.argument( 80 | "input", default=(None if sys.stdin.isatty() else "-"), type=click.File("rb") 81 | ) 82 | @click.argument( 83 | "output", 84 | default=(None if sys.stdin.isatty() else "-"), 85 | type=click.File("wb", lazy=True), 86 | ) 87 | def i_command(model: str, extras: str, input: IO, output: IO, **kwargs) -> None: 88 | """ 89 | Click command line interface function to process an input file based on the provided options. 90 | 91 | This function is the entry point for the CLI program. It reads an input file, applies image processing operations based on the provided options, and writes the output to a file. 92 | 93 | Parameters: 94 | model (str): The name of the model to use for image processing. 95 | extras (str): Additional options in JSON format. 96 | input: The input file to process. 97 | output: The output file to write the processed image to. 98 | **kwargs: Additional keyword arguments corresponding to the command line options. 99 | 100 | Returns: 101 | None 102 | """ 103 | try: 104 | kwargs.update(json.loads(extras)) 105 | except Exception: 106 | pass 107 | 108 | output.write(remove(input.read(), session=new_session(model, **kwargs), **kwargs)) 109 | -------------------------------------------------------------------------------- /rembg/commands/p_command.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pathlib 3 | import time 4 | from typing import cast 5 | 6 | import click 7 | import filetype 8 | from tqdm import tqdm 9 | from watchdog.events import FileSystemEvent, FileSystemEventHandler 10 | from watchdog.observers import Observer 11 | 12 | from ..bg import remove 13 | from ..session_factory import new_session 14 | from ..sessions import sessions_names 15 | 16 | 17 | @click.command( # type: ignore 18 | name="p", 19 | help="for a folder as input", 20 | ) 21 | @click.option( 22 | "-m", 23 | "--model", 24 | default="u2net", 25 | type=click.Choice(sessions_names), 26 | show_default=True, 27 | show_choices=True, 28 | help="model name", 29 | ) 30 | @click.option( 31 | "-a", 32 | "--alpha-matting", 33 | is_flag=True, 34 | show_default=True, 35 | help="use alpha matting", 36 | ) 37 | @click.option( 38 | "-af", 39 | "--alpha-matting-foreground-threshold", 40 | default=240, 41 | type=int, 42 | show_default=True, 43 | help="trimap fg threshold", 44 | ) 45 | @click.option( 46 | "-ab", 47 | "--alpha-matting-background-threshold", 48 | default=10, 49 | type=int, 50 | show_default=True, 51 | help="trimap bg threshold", 52 | ) 53 | @click.option( 54 | "-ae", 55 | "--alpha-matting-erode-size", 56 | default=10, 57 | type=int, 58 | show_default=True, 59 | help="erode size", 60 | ) 61 | @click.option( 62 | "-om", 63 | "--only-mask", 64 | is_flag=True, 65 | show_default=True, 66 | help="output only the mask", 67 | ) 68 | @click.option( 69 | "-ppm", 70 | "--post-process-mask", 71 | is_flag=True, 72 | show_default=True, 73 | help="post process the mask", 74 | ) 75 | @click.option( 76 | "-w", 77 | "--watch", 78 | default=False, 79 | is_flag=True, 80 | show_default=True, 81 | help="watches a folder for changes", 82 | ) 83 | @click.option( 84 | "-d", 85 | "--delete_input", 86 | default=False, 87 | is_flag=True, 88 | show_default=True, 89 | help="delete input file after processing", 90 | ) 91 | @click.option( 92 | "-bgc", 93 | "--bgcolor", 94 | default=(0, 0, 0, 0), 95 | type=(int, int, int, int), 96 | nargs=4, 97 | help="Background color (R G B A) to replace the removed background with", 98 | ) 99 | @click.option("-x", "--extras", type=str) 100 | @click.argument( 101 | "input", 102 | type=click.Path( 103 | exists=True, 104 | path_type=pathlib.Path, 105 | file_okay=False, 106 | dir_okay=True, 107 | readable=True, 108 | ), 109 | ) 110 | @click.argument( 111 | "output", 112 | type=click.Path( 113 | exists=False, 114 | path_type=pathlib.Path, 115 | file_okay=False, 116 | dir_okay=True, 117 | writable=True, 118 | ), 119 | ) 120 | def p_command( 121 | model: str, 122 | extras: str, 123 | input: pathlib.Path, 124 | output: pathlib.Path, 125 | watch: bool, 126 | delete_input: bool, 127 | **kwargs, 128 | ) -> None: 129 | """ 130 | Command-line interface (CLI) program for performing background removal on images in a folder. 131 | 132 | This program takes a folder as input and uses a specified model to remove the background from the images in the folder. 133 | It provides various options for configuration, such as choosing the model, enabling alpha matting, setting trimap thresholds, erode size, etc. 134 | Additional options include outputting only the mask and post-processing the mask. 135 | The program can also watch the input folder for changes and automatically process new images. 136 | The resulting images with the background removed are saved in the specified output folder. 137 | 138 | Parameters: 139 | model (str): The name of the model to use for background removal. 140 | extras (str): Additional options in JSON format. 141 | input (pathlib.Path): The path to the input folder. 142 | output (pathlib.Path): The path to the output folder. 143 | watch (bool): Whether to watch the input folder for changes. 144 | delete_input (bool): Whether to delete the input file after processing. 145 | **kwargs: Additional keyword arguments. 146 | 147 | Returns: 148 | None 149 | """ 150 | try: 151 | kwargs.update(json.loads(extras)) 152 | except Exception: 153 | pass 154 | 155 | session = new_session(model, **kwargs) 156 | 157 | def process(each_input: pathlib.Path) -> None: 158 | try: 159 | mimetype = filetype.guess(each_input) 160 | if mimetype is None: 161 | return 162 | if mimetype.mime.find("image") < 0: 163 | return 164 | 165 | each_output = (output / each_input.name).with_suffix(".png") 166 | each_output.parents[0].mkdir(parents=True, exist_ok=True) 167 | 168 | if not each_output.exists(): 169 | each_output.write_bytes( 170 | cast( 171 | bytes, 172 | remove(each_input.read_bytes(), session=session, **kwargs), 173 | ) 174 | ) 175 | 176 | if watch: 177 | print( 178 | f"processed: {each_input.absolute()} -> {each_output.absolute()}" 179 | ) 180 | 181 | if delete_input: 182 | each_input.unlink() 183 | 184 | except Exception as e: 185 | print(e) 186 | 187 | inputs = list(input.glob("**/*")) 188 | if not watch: 189 | inputs_tqdm = tqdm(inputs) 190 | 191 | for each_input in inputs_tqdm: 192 | if not each_input.is_dir(): 193 | process(each_input) 194 | 195 | if watch: 196 | should_watch = True 197 | observer = Observer() 198 | 199 | class EventHandler(FileSystemEventHandler): 200 | def on_any_event(self, event: FileSystemEvent) -> None: 201 | src_path = cast(str, event.src_path) 202 | if ( 203 | not ( 204 | event.is_directory or event.event_type in ["deleted", "closed"] 205 | ) 206 | and pathlib.Path(src_path).exists() 207 | ): 208 | if src_path.endswith("stop.txt"): 209 | nonlocal should_watch 210 | should_watch = False 211 | pathlib.Path(src_path).unlink() 212 | return 213 | 214 | process(pathlib.Path(src_path)) 215 | 216 | event_handler = EventHandler() 217 | observer.schedule(event_handler, str(input), recursive=False) 218 | observer.start() 219 | 220 | try: 221 | while should_watch: 222 | time.sleep(1) 223 | 224 | finally: 225 | observer.stop() 226 | observer.join() 227 | -------------------------------------------------------------------------------- /rembg/commands/s_command.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import webbrowser 4 | from typing import Optional, Tuple, cast 5 | 6 | import aiohttp 7 | import click 8 | import gradio as gr 9 | import uvicorn 10 | from asyncer import asyncify 11 | from fastapi import Depends, FastAPI, File, Form, Query 12 | from fastapi.middleware.cors import CORSMiddleware 13 | from starlette.responses import Response 14 | 15 | from .._version import get_versions 16 | from ..bg import remove 17 | from ..session_factory import new_session 18 | from ..sessions import sessions_names 19 | from ..sessions.base import BaseSession 20 | 21 | 22 | @click.command( # type: ignore 23 | name="s", 24 | help="for a http server", 25 | ) 26 | @click.option( 27 | "-p", 28 | "--port", 29 | default=7000, 30 | type=int, 31 | show_default=True, 32 | help="port", 33 | ) 34 | @click.option( 35 | "-h", 36 | "--host", 37 | default="0.0.0.0", 38 | type=str, 39 | show_default=True, 40 | help="host", 41 | ) 42 | @click.option( 43 | "-l", 44 | "--log_level", 45 | default="info", 46 | type=str, 47 | show_default=True, 48 | help="log level", 49 | ) 50 | @click.option( 51 | "-t", 52 | "--threads", 53 | default=None, 54 | type=int, 55 | show_default=True, 56 | help="number of worker threads", 57 | ) 58 | def s_command(port: int, host: str, log_level: str, threads: int) -> None: 59 | """ 60 | Command-line interface for running the FastAPI web server. 61 | 62 | This function starts the FastAPI web server with the specified port and log level. 63 | If the number of worker threads is specified, it sets the thread limiter accordingly. 64 | """ 65 | sessions: dict[str, BaseSession] = {} 66 | tags_metadata = [ 67 | { 68 | "name": "Background Removal", 69 | "description": "Endpoints that perform background removal with different image sources.", 70 | "externalDocs": { 71 | "description": "GitHub Source", 72 | "url": "https://github.com/danielgatis/rembg", 73 | }, 74 | }, 75 | ] 76 | app = FastAPI( 77 | title="Rembg", 78 | description="Rembg is a tool to remove images background. That is it.", 79 | version=get_versions()["version"], 80 | contact={ 81 | "name": "Daniel Gatis", 82 | "url": "https://github.com/danielgatis", 83 | "email": "danielgatis@gmail.com", 84 | }, 85 | license_info={ 86 | "name": "MIT License", 87 | "url": "https://github.com/danielgatis/rembg/blob/main/LICENSE.txt", 88 | }, 89 | openapi_tags=tags_metadata, 90 | docs_url="/api", 91 | ) 92 | 93 | app.add_middleware( 94 | CORSMiddleware, 95 | allow_credentials=True, 96 | allow_origins=["*"], 97 | allow_methods=["*"], 98 | allow_headers=["*"], 99 | ) 100 | 101 | class CommonQueryParams: 102 | def __init__( 103 | self, 104 | model: str = Query( 105 | description="Model to use when processing image", 106 | regex=r"(" + "|".join(sessions_names) + ")", 107 | default="u2net", 108 | ), 109 | a: bool = Query(default=False, description="Enable Alpha Matting"), 110 | af: int = Query( 111 | default=240, 112 | ge=0, 113 | le=255, 114 | description="Alpha Matting (Foreground Threshold)", 115 | ), 116 | ab: int = Query( 117 | default=10, 118 | ge=0, 119 | le=255, 120 | description="Alpha Matting (Background Threshold)", 121 | ), 122 | ae: int = Query( 123 | default=10, ge=0, description="Alpha Matting (Erode Structure Size)" 124 | ), 125 | om: bool = Query(default=False, description="Only Mask"), 126 | ppm: bool = Query(default=False, description="Post Process Mask"), 127 | bgc: Optional[str] = Query(default=None, description="Background Color"), 128 | extras: Optional[str] = Query( 129 | default=None, description="Extra parameters as JSON" 130 | ), 131 | ): 132 | self.model = model 133 | self.a = a 134 | self.af = af 135 | self.ab = ab 136 | self.ae = ae 137 | self.om = om 138 | self.ppm = ppm 139 | self.extras = extras 140 | self.bgc = ( 141 | cast(Tuple[int, int, int, int], tuple(map(int, bgc.split(",")))) 142 | if bgc 143 | else None 144 | ) 145 | 146 | class CommonQueryPostParams: 147 | def __init__( 148 | self, 149 | model: str = Form( 150 | description="Model to use when processing image", 151 | regex=r"(" + "|".join(sessions_names) + ")", 152 | default="u2net", 153 | ), 154 | a: bool = Form(default=False, description="Enable Alpha Matting"), 155 | af: int = Form( 156 | default=240, 157 | ge=0, 158 | le=255, 159 | description="Alpha Matting (Foreground Threshold)", 160 | ), 161 | ab: int = Form( 162 | default=10, 163 | ge=0, 164 | le=255, 165 | description="Alpha Matting (Background Threshold)", 166 | ), 167 | ae: int = Form( 168 | default=10, ge=0, description="Alpha Matting (Erode Structure Size)" 169 | ), 170 | om: bool = Form(default=False, description="Only Mask"), 171 | ppm: bool = Form(default=False, description="Post Process Mask"), 172 | bgc: Optional[str] = Query(default=None, description="Background Color"), 173 | extras: Optional[str] = Query( 174 | default=None, description="Extra parameters as JSON" 175 | ), 176 | ): 177 | self.model = model 178 | self.a = a 179 | self.af = af 180 | self.ab = ab 181 | self.ae = ae 182 | self.om = om 183 | self.ppm = ppm 184 | self.extras = extras 185 | self.bgc = ( 186 | cast(Tuple[int, int, int, int], tuple(map(int, bgc.split(",")))) 187 | if bgc 188 | else None 189 | ) 190 | 191 | def im_without_bg(content: bytes, commons: CommonQueryParams) -> Response: 192 | kwargs = {} 193 | 194 | if commons.extras: 195 | try: 196 | kwargs.update(json.loads(commons.extras)) 197 | except Exception: 198 | pass 199 | 200 | return Response( 201 | remove( 202 | content, 203 | session=sessions.setdefault( 204 | commons.model, new_session(commons.model, **kwargs) 205 | ), 206 | alpha_matting=commons.a, 207 | alpha_matting_foreground_threshold=commons.af, 208 | alpha_matting_background_threshold=commons.ab, 209 | alpha_matting_erode_size=commons.ae, 210 | only_mask=commons.om, 211 | post_process_mask=commons.ppm, 212 | bgcolor=commons.bgc, 213 | **kwargs, 214 | ), 215 | media_type="image/png", 216 | ) 217 | 218 | @app.on_event("startup") 219 | def startup(): 220 | try: 221 | webbrowser.open(f"http://localhost:{port}") 222 | except Exception: 223 | pass 224 | 225 | if threads is not None: 226 | from anyio import CapacityLimiter 227 | from anyio.lowlevel import RunVar 228 | 229 | RunVar("_default_thread_limiter").set(CapacityLimiter(threads)) 230 | 231 | @app.get( 232 | path="/api/remove", 233 | tags=["Background Removal"], 234 | summary="Remove from URL", 235 | description="Removes the background from an image obtained by retrieving an URL.", 236 | ) 237 | async def get_index( 238 | url: str = Query( 239 | default=..., description="URL of the image that has to be processed." 240 | ), 241 | commons: CommonQueryParams = Depends(), 242 | ): 243 | async with aiohttp.ClientSession() as session: 244 | async with session.get(url) as response: 245 | file = await response.read() 246 | return await asyncify(im_without_bg)(file, commons) 247 | 248 | @app.post( 249 | path="/api/remove", 250 | tags=["Background Removal"], 251 | summary="Remove from Stream", 252 | description="Removes the background from an image sent within the request itself.", 253 | ) 254 | async def post_index( 255 | file: bytes = File( 256 | default=..., 257 | description="Image file (byte stream) that has to be processed.", 258 | ), 259 | commons: CommonQueryPostParams = Depends(), 260 | ): 261 | return await asyncify(im_without_bg)(file, commons) # type: ignore 262 | 263 | def gr_app(app): 264 | def inference(input_path, model, *args): 265 | output_path = "output.png" 266 | a, af, ab, ae, om, ppm, cmd_args = args 267 | 268 | kwargs = { 269 | "alpha_matting": a, 270 | "alpha_matting_foreground_threshold": af, 271 | "alpha_matting_background_threshold": ab, 272 | "alpha_matting_erode_size": ae, 273 | "only_mask": om, 274 | "post_process_mask": ppm, 275 | } 276 | 277 | if cmd_args: 278 | kwargs.update(json.loads(cmd_args)) 279 | kwargs["session"] = new_session(model, **kwargs) 280 | 281 | with open(input_path, "rb") as i: 282 | with open(output_path, "wb") as o: 283 | input = i.read() 284 | output = remove(input, **kwargs) 285 | o.write(output) 286 | return os.path.join(output_path) 287 | 288 | interface = gr.Interface( 289 | inference, 290 | [ 291 | gr.components.Image(type="filepath", label="Input"), 292 | gr.components.Dropdown(sessions_names, value="u2net", label="Models"), 293 | gr.components.Checkbox(value=True, label="Alpha matting"), 294 | gr.components.Slider( 295 | value=240, minimum=0, maximum=255, label="Foreground threshold" 296 | ), 297 | gr.components.Slider( 298 | value=10, minimum=0, maximum=255, label="Background threshold" 299 | ), 300 | gr.components.Slider( 301 | value=40, minimum=0, maximum=255, label="Erosion size" 302 | ), 303 | gr.components.Checkbox(value=False, label="Only mask"), 304 | gr.components.Checkbox(value=True, label="Post process mask"), 305 | gr.components.Textbox(label="Arguments"), 306 | ], 307 | gr.components.Image(type="filepath", label="Output"), 308 | concurrency_limit=3, 309 | analytics_enabled=False, 310 | ) 311 | 312 | app = gr.mount_gradio_app(app, interface, path="/") 313 | return app 314 | 315 | print( 316 | f"To access the API documentation, go to http://{'localhost' if host == '0.0.0.0' else host}:{port}/api" 317 | ) 318 | print( 319 | f"To access the UI, go to http://{'localhost' if host == '0.0.0.0' else host}:{port}" 320 | ) 321 | 322 | uvicorn.run(gr_app(app), host=host, port=port, log_level=log_level) 323 | -------------------------------------------------------------------------------- /rembg/session_factory.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Optional, Type 3 | 4 | import onnxruntime as ort 5 | 6 | from .sessions import sessions_class 7 | from .sessions.base import BaseSession 8 | from .sessions.u2net import U2netSession 9 | 10 | 11 | def new_session(model_name: str = "u2net", *args, **kwargs) -> BaseSession: 12 | """ 13 | Create a new session object based on the specified model name. 14 | 15 | This function searches for the session class based on the model name in the 'sessions_class' list. 16 | It then creates an instance of the session class with the provided arguments. 17 | The 'sess_opts' object is created using the 'ort.SessionOptions()' constructor. 18 | If the 'OMP_NUM_THREADS' environment variable is set, the 'inter_op_num_threads' option of 'sess_opts' is set to its value. 19 | 20 | Parameters: 21 | model_name (str): The name of the model. 22 | *args: Additional positional arguments. 23 | **kwargs: Additional keyword arguments. 24 | 25 | Raises: 26 | ValueError: If no session class with the given `model_name` is found. 27 | 28 | Returns: 29 | BaseSession: The created session object. 30 | """ 31 | session_class: Optional[Type[BaseSession]] = None 32 | 33 | for sc in sessions_class: 34 | if sc.name() == model_name: 35 | session_class = sc 36 | break 37 | 38 | if session_class is None: 39 | raise ValueError(f"No session class found for model '{model_name}'") 40 | 41 | sess_opts = ort.SessionOptions() 42 | 43 | if "OMP_NUM_THREADS" in os.environ: 44 | threads = int(os.environ["OMP_NUM_THREADS"]) 45 | sess_opts.inter_op_num_threads = threads 46 | sess_opts.intra_op_num_threads = threads 47 | 48 | return session_class(model_name, sess_opts, *args, **kwargs) 49 | -------------------------------------------------------------------------------- /rembg/sessions/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Dict, List 4 | 5 | from .base import BaseSession 6 | 7 | sessions: Dict[str, type[BaseSession]] = {} 8 | 9 | from .birefnet_general import BiRefNetSessionGeneral 10 | 11 | sessions[BiRefNetSessionGeneral.name()] = BiRefNetSessionGeneral 12 | 13 | from .birefnet_general_lite import BiRefNetSessionGeneralLite 14 | 15 | sessions[BiRefNetSessionGeneralLite.name()] = BiRefNetSessionGeneralLite 16 | 17 | from .birefnet_portrait import BiRefNetSessionPortrait 18 | 19 | sessions[BiRefNetSessionPortrait.name()] = BiRefNetSessionPortrait 20 | 21 | from .birefnet_dis import BiRefNetSessionDIS 22 | 23 | sessions[BiRefNetSessionDIS.name()] = BiRefNetSessionDIS 24 | 25 | from .birefnet_hrsod import BiRefNetSessionHRSOD 26 | 27 | sessions[BiRefNetSessionHRSOD.name()] = BiRefNetSessionHRSOD 28 | 29 | from .birefnet_cod import BiRefNetSessionCOD 30 | 31 | sessions[BiRefNetSessionCOD.name()] = BiRefNetSessionCOD 32 | 33 | from .birefnet_massive import BiRefNetSessionMassive 34 | 35 | sessions[BiRefNetSessionMassive.name()] = BiRefNetSessionMassive 36 | 37 | from .dis_anime import DisSession 38 | 39 | sessions[DisSession.name()] = DisSession 40 | 41 | from .dis_general_use import DisSession as DisSessionGeneralUse 42 | 43 | sessions[DisSessionGeneralUse.name()] = DisSessionGeneralUse 44 | 45 | from .sam import SamSession 46 | 47 | sessions[SamSession.name()] = SamSession 48 | 49 | from .silueta import SiluetaSession 50 | 51 | sessions[SiluetaSession.name()] = SiluetaSession 52 | 53 | from .u2net_cloth_seg import Unet2ClothSession 54 | 55 | sessions[Unet2ClothSession.name()] = Unet2ClothSession 56 | 57 | from .u2net_custom import U2netCustomSession 58 | 59 | sessions[U2netCustomSession.name()] = U2netCustomSession 60 | 61 | from .u2net_human_seg import U2netHumanSegSession 62 | 63 | sessions[U2netHumanSegSession.name()] = U2netHumanSegSession 64 | 65 | from .u2net import U2netSession 66 | 67 | sessions[U2netSession.name()] = U2netSession 68 | 69 | from .u2netp import U2netpSession 70 | 71 | sessions[U2netpSession.name()] = U2netpSession 72 | 73 | from .bria_rmbg import BriaRmBgSession 74 | 75 | sessions[BriaRmBgSession.name()] = BriaRmBgSession 76 | 77 | sessions_names = list(sessions.keys()) 78 | sessions_class = list(sessions.values()) 79 | -------------------------------------------------------------------------------- /rembg/sessions/base.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Dict, List, Tuple 3 | 4 | import numpy as np 5 | import onnxruntime as ort 6 | from PIL import Image 7 | from PIL.Image import Image as PILImage 8 | 9 | 10 | class BaseSession: 11 | """This is a base class for managing a session with a machine learning model.""" 12 | 13 | def __init__(self, model_name: str, sess_opts: ort.SessionOptions, *args, **kwargs): 14 | """Initialize an instance of the BaseSession class.""" 15 | self.model_name = model_name 16 | 17 | device_type = ort.get_device() 18 | if ( 19 | device_type == "GPU" 20 | and "CUDAExecutionProvider" in ort.get_available_providers() 21 | ): 22 | providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] 23 | elif ( 24 | device_type[0:3] == "GPU" 25 | and "ROCMExecutionProvider" in ort.get_available_providers() 26 | ): 27 | providers = ["ROCMExecutionProvider", "CPUExecutionProvider"] 28 | else: 29 | providers = ["CPUExecutionProvider"] 30 | 31 | self.inner_session = ort.InferenceSession( 32 | str(self.__class__.download_models(*args, **kwargs)), 33 | sess_options=sess_opts, 34 | providers=providers, 35 | ) 36 | 37 | def normalize( 38 | self, 39 | img: PILImage, 40 | mean: Tuple[float, float, float], 41 | std: Tuple[float, float, float], 42 | size: Tuple[int, int], 43 | *args, 44 | **kwargs 45 | ) -> Dict[str, np.ndarray]: 46 | im = img.convert("RGB").resize(size, Image.Resampling.LANCZOS) 47 | 48 | im_ary = np.array(im) 49 | im_ary = im_ary / max(np.max(im_ary), 1e-6) 50 | 51 | tmpImg = np.zeros((im_ary.shape[0], im_ary.shape[1], 3)) 52 | tmpImg[:, :, 0] = (im_ary[:, :, 0] - mean[0]) / std[0] 53 | tmpImg[:, :, 1] = (im_ary[:, :, 1] - mean[1]) / std[1] 54 | tmpImg[:, :, 2] = (im_ary[:, :, 2] - mean[2]) / std[2] 55 | 56 | tmpImg = tmpImg.transpose((2, 0, 1)) 57 | 58 | return { 59 | self.inner_session.get_inputs()[0] 60 | .name: np.expand_dims(tmpImg, 0) 61 | .astype(np.float32) 62 | } 63 | 64 | def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]: 65 | raise NotImplementedError 66 | 67 | @classmethod 68 | def checksum_disabled(cls, *args, **kwargs): 69 | return os.getenv("MODEL_CHECKSUM_DISABLED", None) is not None 70 | 71 | @classmethod 72 | def u2net_home(cls, *args, **kwargs): 73 | return os.path.expanduser( 74 | os.getenv( 75 | "U2NET_HOME", os.path.join(os.getenv("XDG_DATA_HOME", "~"), ".u2net") 76 | ) 77 | ) 78 | 79 | @classmethod 80 | def download_models(cls, *args, **kwargs): 81 | raise NotImplementedError 82 | 83 | @classmethod 84 | def name(cls, *args, **kwargs): 85 | raise NotImplementedError 86 | -------------------------------------------------------------------------------- /rembg/sessions/birefnet_cod.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pooch 4 | 5 | from . import BiRefNetSessionGeneral 6 | 7 | 8 | class BiRefNetSessionCOD(BiRefNetSessionGeneral): 9 | """ 10 | This class represents a BiRefNet-COD session, which is a subclass of BiRefNetSessionGeneral. 11 | """ 12 | 13 | @classmethod 14 | def download_models(cls, *args, **kwargs): 15 | """ 16 | Downloads the BiRefNet-COD model file from a specific URL and saves it. 17 | 18 | Parameters: 19 | *args: Additional positional arguments. 20 | **kwargs: Additional keyword arguments. 21 | 22 | Returns: 23 | str: The path to the downloaded model file. 24 | """ 25 | fname = f"{cls.name(*args, **kwargs)}.onnx" 26 | pooch.retrieve( 27 | "https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-COD-epoch_125.onnx", 28 | ( 29 | None 30 | if cls.checksum_disabled(*args, **kwargs) 31 | else "md5:f6d0d21ca89d287f17e7afe9f5fd3b45" 32 | ), 33 | fname=fname, 34 | path=cls.u2net_home(*args, **kwargs), 35 | progressbar=True, 36 | ) 37 | 38 | return os.path.join(cls.u2net_home(*args, **kwargs), fname) 39 | 40 | @classmethod 41 | def name(cls, *args, **kwargs): 42 | """ 43 | Returns the name of the BiRefNet-COD session. 44 | 45 | Parameters: 46 | *args: Additional positional arguments. 47 | **kwargs: Additional keyword arguments. 48 | 49 | Returns: 50 | str: The name of the session. 51 | """ 52 | return "birefnet-cod" 53 | -------------------------------------------------------------------------------- /rembg/sessions/birefnet_dis.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pooch 4 | 5 | from . import BiRefNetSessionGeneral 6 | 7 | 8 | class BiRefNetSessionDIS(BiRefNetSessionGeneral): 9 | """ 10 | This class represents a BiRefNet-DIS session, which is a subclass of BiRefNetSessionGeneral. 11 | """ 12 | 13 | @classmethod 14 | def download_models(cls, *args, **kwargs): 15 | """ 16 | Downloads the BiRefNet-DIS model file from a specific URL and saves it. 17 | 18 | Parameters: 19 | *args: Additional positional arguments. 20 | **kwargs: Additional keyword arguments. 21 | 22 | Returns: 23 | str: The path to the downloaded model file. 24 | """ 25 | fname = f"{cls.name(*args, **kwargs)}.onnx" 26 | pooch.retrieve( 27 | "https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-DIS-epoch_590.onnx", 28 | ( 29 | None 30 | if cls.checksum_disabled(*args, **kwargs) 31 | else "md5:2d4d44102b446f33a4ebb2e56c051f2b" 32 | ), 33 | fname=fname, 34 | path=cls.u2net_home(*args, **kwargs), 35 | progressbar=True, 36 | ) 37 | 38 | return os.path.join(cls.u2net_home(*args, **kwargs), fname) 39 | 40 | @classmethod 41 | def name(cls, *args, **kwargs): 42 | """ 43 | Returns the name of the BiRefNet-DIS session. 44 | 45 | Parameters: 46 | *args: Additional positional arguments. 47 | **kwargs: Additional keyword arguments. 48 | 49 | Returns: 50 | str: The name of the session. 51 | """ 52 | return "birefnet-dis" 53 | -------------------------------------------------------------------------------- /rembg/sessions/birefnet_general.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List 3 | 4 | import numpy as np 5 | import pooch 6 | from PIL import Image 7 | from PIL.Image import Image as PILImage 8 | 9 | from .base import BaseSession 10 | 11 | 12 | class BiRefNetSessionGeneral(BaseSession): 13 | """ 14 | This class represents a BiRefNet-General session, which is a subclass of BaseSession. 15 | """ 16 | 17 | def sigmoid(self, mat): 18 | return 1 / (1 + np.exp(-mat)) 19 | 20 | def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]: 21 | """ 22 | Predicts the output masks for the input image using the inner session. 23 | 24 | Parameters: 25 | img (PILImage): The input image. 26 | *args: Additional positional arguments. 27 | **kwargs: Additional keyword arguments. 28 | 29 | Returns: 30 | List[PILImage]: The list of output masks. 31 | """ 32 | ort_outs = self.inner_session.run( 33 | None, 34 | self.normalize( 35 | img, (0.485, 0.456, 0.406), (0.229, 0.224, 0.225), (1024, 1024) 36 | ), 37 | ) 38 | 39 | pred = self.sigmoid(ort_outs[0][:, 0, :, :]) 40 | 41 | ma = np.max(pred) 42 | mi = np.min(pred) 43 | 44 | pred = (pred - mi) / (ma - mi) 45 | pred = np.squeeze(pred) 46 | 47 | mask = Image.fromarray((pred * 255).astype("uint8"), mode="L") 48 | mask = mask.resize(img.size, Image.Resampling.LANCZOS) 49 | 50 | return [mask] 51 | 52 | @classmethod 53 | def download_models(cls, *args, **kwargs): 54 | """ 55 | Downloads the BiRefNet-General model file from a specific URL and saves it. 56 | 57 | Parameters: 58 | *args: Additional positional arguments. 59 | **kwargs: Additional keyword arguments. 60 | 61 | Returns: 62 | str: The path to the downloaded model file. 63 | """ 64 | fname = f"{cls.name(*args, **kwargs)}.onnx" 65 | pooch.retrieve( 66 | "https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-general-epoch_244.onnx", 67 | ( 68 | None 69 | if cls.checksum_disabled(*args, **kwargs) 70 | else "md5:7a35a0141cbbc80de11d9c9a28f52697" 71 | ), 72 | fname=fname, 73 | path=cls.u2net_home(*args, **kwargs), 74 | progressbar=True, 75 | ) 76 | 77 | return os.path.join(cls.u2net_home(*args, **kwargs), fname) 78 | 79 | @classmethod 80 | def name(cls, *args, **kwargs): 81 | """ 82 | Returns the name of the BiRefNet-General session. 83 | 84 | Parameters: 85 | *args: Additional positional arguments. 86 | **kwargs: Additional keyword arguments. 87 | 88 | Returns: 89 | str: The name of the session. 90 | """ 91 | return "birefnet-general" 92 | -------------------------------------------------------------------------------- /rembg/sessions/birefnet_general_lite.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pooch 4 | 5 | from . import BiRefNetSessionGeneral 6 | 7 | 8 | class BiRefNetSessionGeneralLite(BiRefNetSessionGeneral): 9 | """ 10 | This class represents a BiRefNet-General-Lite session, which is a subclass of BiRefNetSessionGeneral. 11 | """ 12 | 13 | @classmethod 14 | def download_models(cls, *args, **kwargs): 15 | """ 16 | Downloads the BiRefNet-General-Lite model file from a specific URL and saves it. 17 | 18 | Parameters: 19 | *args: Additional positional arguments. 20 | **kwargs: Additional keyword arguments. 21 | 22 | Returns: 23 | str: The path to the downloaded model file. 24 | """ 25 | fname = f"{cls.name(*args, **kwargs)}.onnx" 26 | pooch.retrieve( 27 | "https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-general-bb_swin_v1_tiny-epoch_232.onnx", 28 | ( 29 | None 30 | if cls.checksum_disabled(*args, **kwargs) 31 | else "md5:4fab47adc4ff364be1713e97b7e66334" 32 | ), 33 | fname=fname, 34 | path=cls.u2net_home(*args, **kwargs), 35 | progressbar=True, 36 | ) 37 | 38 | return os.path.join(cls.u2net_home(*args, **kwargs), fname) 39 | 40 | @classmethod 41 | def name(cls, *args, **kwargs): 42 | """ 43 | Returns the name of the BiRefNet-General-Lite session. 44 | 45 | Parameters: 46 | *args: Additional positional arguments. 47 | **kwargs: Additional keyword arguments. 48 | 49 | Returns: 50 | str: The name of the session. 51 | """ 52 | return "birefnet-general-lite" 53 | -------------------------------------------------------------------------------- /rembg/sessions/birefnet_hrsod.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pooch 4 | 5 | from . import BiRefNetSessionGeneral 6 | 7 | 8 | class BiRefNetSessionHRSOD(BiRefNetSessionGeneral): 9 | """ 10 | This class represents a BiRefNet-HRSOD session, which is a subclass of BiRefNetSessionGeneral. 11 | """ 12 | 13 | @classmethod 14 | def download_models(cls, *args, **kwargs): 15 | """ 16 | Downloads the BiRefNet-HRSOD model file from a specific URL and saves it. 17 | 18 | Parameters: 19 | *args: Additional positional arguments. 20 | **kwargs: Additional keyword arguments. 21 | 22 | Returns: 23 | str: The path to the downloaded model file. 24 | """ 25 | fname = f"{cls.name(*args, **kwargs)}.onnx" 26 | pooch.retrieve( 27 | "https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-HRSOD_DHU-epoch_115.onnx", 28 | ( 29 | None 30 | if cls.checksum_disabled(*args, **kwargs) 31 | else "md5:c017ade5de8a50ff0fd74d790d268dda" 32 | ), 33 | fname=fname, 34 | path=cls.u2net_home(*args, **kwargs), 35 | progressbar=True, 36 | ) 37 | 38 | return os.path.join(cls.u2net_home(*args, **kwargs), fname) 39 | 40 | @classmethod 41 | def name(cls, *args, **kwargs): 42 | """ 43 | Returns the name of the BiRefNet-HRSOD session. 44 | 45 | Parameters: 46 | *args: Additional positional arguments. 47 | **kwargs: Additional keyword arguments. 48 | 49 | Returns: 50 | str: The name of the session. 51 | """ 52 | return "birefnet-hrsod" 53 | -------------------------------------------------------------------------------- /rembg/sessions/birefnet_massive.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pooch 4 | 5 | from . import BiRefNetSessionGeneral 6 | 7 | 8 | class BiRefNetSessionMassive(BiRefNetSessionGeneral): 9 | """ 10 | This class represents a BiRefNet-Massive session, which is a subclass of BiRefNetSessionGeneral. 11 | """ 12 | 13 | @classmethod 14 | def download_models(cls, *args, **kwargs): 15 | """ 16 | Downloads the BiRefNet-Massive model file from a specific URL and saves it. 17 | 18 | Parameters: 19 | *args: Additional positional arguments. 20 | **kwargs: Additional keyword arguments. 21 | 22 | Returns: 23 | str: The path to the downloaded model file. 24 | """ 25 | fname = f"{cls.name(*args, **kwargs)}.onnx" 26 | pooch.retrieve( 27 | "https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-massive-TR_DIS5K_TR_TEs-epoch_420.onnx", 28 | ( 29 | None 30 | if cls.checksum_disabled(*args, **kwargs) 31 | else "md5:33e726a2136a3d59eb0fdf613e31e3e9" 32 | ), 33 | fname=fname, 34 | path=cls.u2net_home(*args, **kwargs), 35 | progressbar=True, 36 | ) 37 | 38 | return os.path.join(cls.u2net_home(*args, **kwargs), fname) 39 | 40 | @classmethod 41 | def name(cls, *args, **kwargs): 42 | """ 43 | Returns the name of the BiRefNet-Massive session. 44 | 45 | Parameters: 46 | *args: Additional positional arguments. 47 | **kwargs: Additional keyword arguments. 48 | 49 | Returns: 50 | str: The name of the session. 51 | """ 52 | return "birefnet-massive" 53 | -------------------------------------------------------------------------------- /rembg/sessions/birefnet_portrait.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pooch 4 | 5 | from . import BiRefNetSessionGeneral 6 | 7 | 8 | class BiRefNetSessionPortrait(BiRefNetSessionGeneral): 9 | """ 10 | This class represents a BiRefNet-Portrait session, which is a subclass of BiRefNetSessionGeneral. 11 | """ 12 | 13 | @classmethod 14 | def download_models(cls, *args, **kwargs): 15 | """ 16 | Downloads the BiRefNet-Portrait model file from a specific URL and saves it. 17 | 18 | Parameters: 19 | *args: Additional positional arguments. 20 | **kwargs: Additional keyword arguments. 21 | 22 | Returns: 23 | str: The path to the downloaded model file. 24 | """ 25 | fname = f"{cls.name(*args, **kwargs)}.onnx" 26 | pooch.retrieve( 27 | "https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-portrait-epoch_150.onnx", 28 | ( 29 | None 30 | if cls.checksum_disabled(*args, **kwargs) 31 | else "md5:c3a64a6abf20250d090cd055f12a3b67" 32 | ), 33 | fname=fname, 34 | path=cls.u2net_home(*args, **kwargs), 35 | progressbar=True, 36 | ) 37 | 38 | return os.path.join(cls.u2net_home(*args, **kwargs), fname) 39 | 40 | @classmethod 41 | def name(cls, *args, **kwargs): 42 | """ 43 | Returns the name of the BiRefNet-Portrait session. 44 | 45 | Parameters: 46 | *args: Additional positional arguments. 47 | **kwargs: Additional keyword arguments. 48 | 49 | Returns: 50 | str: The name of the session. 51 | """ 52 | return "birefnet-portrait" 53 | -------------------------------------------------------------------------------- /rembg/sessions/bria_rmbg.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List 3 | 4 | import numpy as np 5 | import pooch 6 | from PIL import Image 7 | from PIL.Image import Image as PILImage 8 | 9 | from .base import BaseSession 10 | 11 | 12 | class BriaRmBgSession(BaseSession): 13 | """ 14 | This class represents a Bria-rmbg-2.0 session, which is a subclass of BaseSession. 15 | """ 16 | 17 | def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]: 18 | """ 19 | Predicts the output masks for the input image using the inner session. 20 | 21 | Parameters: 22 | img (PILImage): The input image. 23 | *args: Additional positional arguments. 24 | **kwargs: Additional keyword arguments. 25 | 26 | Returns: 27 | List[PILImage]: The list of output masks. 28 | """ 29 | ort_outs = self.inner_session.run( 30 | None, 31 | self.normalize( 32 | img, (0.485, 0.456, 0.406), (0.229, 0.224, 0.225), (1024, 1024) 33 | ), 34 | ) 35 | 36 | pred = ort_outs[0][:, 0, :, :] 37 | 38 | ma = np.max(pred) 39 | mi = np.min(pred) 40 | 41 | pred = (pred - mi) / (ma - mi) 42 | pred = np.squeeze(pred) 43 | 44 | mask = Image.fromarray((pred * 255).astype("uint8"), mode="L") 45 | mask = mask.resize(img.size, Image.Resampling.LANCZOS) 46 | 47 | return [mask] 48 | 49 | @classmethod 50 | def download_models(cls, *args, **kwargs): 51 | """ 52 | Downloads the BRIA-RMBG 2.0 model file from a specific URL and saves it. 53 | 54 | Parameters: 55 | *args: Additional positional arguments. 56 | **kwargs: Additional keyword arguments. 57 | 58 | Returns: 59 | str: The path to the downloaded model file. 60 | """ 61 | fname = f"{cls.name(*args, **kwargs)}.onnx" 62 | pooch.retrieve( 63 | "https://github.com/danielgatis/rembg/releases/download/v0.0.0/bria-rmbg-2.0.onnx", 64 | ( 65 | None 66 | if cls.checksum_disabled(*args, **kwargs) 67 | else "sha256:5b486f08200f513f460da46dd701db5fbb47d79b4be4b708a19444bcd4e79958" 68 | ), 69 | fname=fname, 70 | path=cls.u2net_home(*args, **kwargs), 71 | progressbar=True, 72 | ) 73 | 74 | return os.path.join(cls.u2net_home(*args, **kwargs), fname) 75 | 76 | @classmethod 77 | def name(cls, *args, **kwargs): 78 | """ 79 | Returns the name of the Bria-rmbg session. 80 | 81 | Parameters: 82 | *args: Additional positional arguments. 83 | **kwargs: Additional keyword arguments. 84 | 85 | Returns: 86 | str: The name of the session. 87 | """ 88 | return "bria-rmbg" 89 | -------------------------------------------------------------------------------- /rembg/sessions/dis_anime.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List 3 | 4 | import numpy as np 5 | import pooch 6 | from PIL import Image 7 | from PIL.Image import Image as PILImage 8 | 9 | from .base import BaseSession 10 | 11 | 12 | class DisSession(BaseSession): 13 | """ 14 | This class represents a session for object detection. 15 | """ 16 | 17 | def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]: 18 | """ 19 | Use a pre-trained model to predict the object in the given image. 20 | 21 | Parameters: 22 | img (PILImage): The input image. 23 | *args: Variable length argument list. 24 | **kwargs: Arbitrary keyword arguments. 25 | 26 | Returns: 27 | List[PILImage]: A list of predicted mask images. 28 | """ 29 | ort_outs = self.inner_session.run( 30 | None, 31 | self.normalize(img, (0.485, 0.456, 0.406), (1.0, 1.0, 1.0), (1024, 1024)), 32 | ) 33 | 34 | pred = ort_outs[0][:, 0, :, :] 35 | 36 | ma = np.max(pred) 37 | mi = np.min(pred) 38 | 39 | pred = (pred - mi) / (ma - mi) 40 | pred = np.squeeze(pred) 41 | 42 | mask = Image.fromarray((pred * 255).astype("uint8"), mode="L") 43 | mask = mask.resize(img.size, Image.Resampling.LANCZOS) 44 | 45 | return [mask] 46 | 47 | @classmethod 48 | def download_models(cls, *args, **kwargs): 49 | """ 50 | Download the pre-trained models. 51 | 52 | Parameters: 53 | *args: Variable length argument list. 54 | **kwargs: Arbitrary keyword arguments. 55 | 56 | Returns: 57 | str: The path of the downloaded model file. 58 | """ 59 | fname = f"{cls.name(*args, **kwargs)}.onnx" 60 | pooch.retrieve( 61 | "https://github.com/danielgatis/rembg/releases/download/v0.0.0/isnet-anime.onnx", 62 | ( 63 | None 64 | if cls.checksum_disabled(*args, **kwargs) 65 | else "md5:6f184e756bb3bd901c8849220a83e38e" 66 | ), 67 | fname=fname, 68 | path=cls.u2net_home(*args, **kwargs), 69 | progressbar=True, 70 | ) 71 | 72 | return os.path.join(cls.u2net_home(*args, **kwargs), fname) 73 | 74 | @classmethod 75 | def name(cls, *args, **kwargs): 76 | """ 77 | Get the name of the pre-trained model. 78 | 79 | Parameters: 80 | *args: Variable length argument list. 81 | **kwargs: Arbitrary keyword arguments. 82 | 83 | Returns: 84 | str: The name of the pre-trained model. 85 | """ 86 | return "isnet-anime" 87 | -------------------------------------------------------------------------------- /rembg/sessions/dis_general_use.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List 3 | 4 | import numpy as np 5 | import pooch 6 | from PIL import Image 7 | from PIL.Image import Image as PILImage 8 | 9 | from .base import BaseSession 10 | 11 | 12 | class DisSession(BaseSession): 13 | def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]: 14 | """ 15 | Predicts the mask image for the input image. 16 | 17 | This method takes a PILImage object as input and returns a list of PILImage objects as output. It performs several image processing operations to generate the mask image. 18 | 19 | Parameters: 20 | img (PILImage): The input image. 21 | 22 | Returns: 23 | List[PILImage]: A list of PILImage objects representing the generated mask image. 24 | """ 25 | ort_outs = self.inner_session.run( 26 | None, 27 | self.normalize(img, (0.5, 0.5, 0.5), (1.0, 1.0, 1.0), (1024, 1024)), 28 | ) 29 | 30 | pred = ort_outs[0][:, 0, :, :] 31 | 32 | ma = np.max(pred) 33 | mi = np.min(pred) 34 | 35 | pred = (pred - mi) / (ma - mi) 36 | pred = np.squeeze(pred) 37 | 38 | mask = Image.fromarray((pred * 255).astype("uint8"), mode="L") 39 | mask = mask.resize(img.size, Image.Resampling.LANCZOS) 40 | 41 | return [mask] 42 | 43 | @classmethod 44 | def download_models(cls, *args, **kwargs): 45 | """ 46 | Downloads the pre-trained model file. 47 | 48 | This class method downloads the pre-trained model file from a specified URL using the pooch library. 49 | 50 | Parameters: 51 | args: Additional positional arguments. 52 | kwargs: Additional keyword arguments. 53 | 54 | Returns: 55 | str: The path to the downloaded model file. 56 | """ 57 | fname = f"{cls.name(*args, **kwargs)}.onnx" 58 | pooch.retrieve( 59 | "https://github.com/danielgatis/rembg/releases/download/v0.0.0/isnet-general-use.onnx", 60 | ( 61 | None 62 | if cls.checksum_disabled(*args, **kwargs) 63 | else "md5:fc16ebd8b0c10d971d3513d564d01e29" 64 | ), 65 | fname=fname, 66 | path=cls.u2net_home(*args, **kwargs), 67 | progressbar=True, 68 | ) 69 | 70 | return os.path.join(cls.u2net_home(*args, **kwargs), fname) 71 | 72 | @classmethod 73 | def name(cls, *args, **kwargs): 74 | """ 75 | Returns the name of the model. 76 | 77 | This class method returns the name of the model. 78 | 79 | Parameters: 80 | args: Additional positional arguments. 81 | kwargs: Additional keyword arguments. 82 | 83 | Returns: 84 | str: The name of the model. 85 | """ 86 | return "isnet-general-use" 87 | -------------------------------------------------------------------------------- /rembg/sessions/sam.py: -------------------------------------------------------------------------------- 1 | import os 2 | from copy import deepcopy 3 | from typing import List 4 | 5 | import cv2 6 | import numpy as np 7 | import onnxruntime as ort 8 | import pooch 9 | from jsonschema import validate 10 | from PIL import Image 11 | from PIL.Image import Image as PILImage 12 | 13 | from .base import BaseSession 14 | 15 | 16 | def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int): 17 | scale = long_side_length * 1.0 / max(oldh, oldw) 18 | newh, neww = oldh * scale, oldw * scale 19 | neww = int(neww + 0.5) 20 | newh = int(newh + 0.5) 21 | 22 | return (newh, neww) 23 | 24 | 25 | def apply_coords(coords: np.ndarray, original_size, target_length): 26 | old_h, old_w = original_size 27 | new_h, new_w = get_preprocess_shape( 28 | original_size[0], original_size[1], target_length 29 | ) 30 | 31 | coords = deepcopy(coords).astype(float) 32 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 33 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 34 | 35 | return coords 36 | 37 | 38 | def get_input_points(prompt): 39 | points = [] 40 | labels = [] 41 | 42 | for mark in prompt: 43 | if mark["type"] == "point": 44 | points.append(mark["data"]) 45 | labels.append(mark["label"]) 46 | elif mark["type"] == "rectangle": 47 | points.append([mark["data"][0], mark["data"][1]]) 48 | points.append([mark["data"][2], mark["data"][3]]) 49 | labels.append(2) 50 | labels.append(3) 51 | 52 | points, labels = np.array(points), np.array(labels) 53 | return points, labels 54 | 55 | 56 | def transform_masks(masks, original_size, transform_matrix): 57 | output_masks = [] 58 | 59 | for batch in range(masks.shape[0]): 60 | batch_masks = [] 61 | for mask_id in range(masks.shape[1]): 62 | mask = masks[batch, mask_id] 63 | mask = cv2.warpAffine( 64 | mask, 65 | transform_matrix[:2], 66 | (original_size[1], original_size[0]), 67 | flags=cv2.INTER_LINEAR, 68 | ) 69 | batch_masks.append(mask) 70 | output_masks.append(batch_masks) 71 | 72 | return np.array(output_masks) 73 | 74 | 75 | class SamSession(BaseSession): 76 | """ 77 | This class represents a session for the Sam model. 78 | 79 | Args: 80 | model_name (str): The name of the model. 81 | sess_opts (ort.SessionOptions): The session options. 82 | *args: Variable length argument list. 83 | **kwargs: Arbitrary keyword arguments. 84 | """ 85 | 86 | def __init__( 87 | self, 88 | model_name: str, 89 | sess_opts: ort.SessionOptions, 90 | *args, 91 | **kwargs, 92 | ): 93 | """ 94 | Initialize a new SamSession with the given model name and session options. 95 | 96 | Args: 97 | model_name (str): The name of the model. 98 | sess_opts (ort.SessionOptions): The session options. 99 | *args: Variable length argument list. 100 | **kwargs: Arbitrary keyword arguments. 101 | """ 102 | self.model_name = model_name 103 | 104 | paths = self.__class__.download_models(*args, **kwargs) 105 | self.encoder = ort.InferenceSession( 106 | str(paths[0]), 107 | sess_options=sess_opts, 108 | ) 109 | self.decoder = ort.InferenceSession( 110 | str(paths[1]), 111 | sess_options=sess_opts, 112 | ) 113 | 114 | def predict( 115 | self, 116 | img: PILImage, 117 | *args, 118 | **kwargs, 119 | ) -> List[PILImage]: 120 | """ 121 | Predict masks for an input image. 122 | 123 | This function takes an image as input and performs various preprocessing steps on the image. It then runs the image through an encoder to obtain an image embedding. The function also takes input labels and points as additional arguments. It concatenates the input points and labels with padding and transforms them. It creates an empty mask input and an indicator for no mask. The function then passes the image embedding, point coordinates, point labels, mask input, and has mask input to a decoder. The decoder generates masks based on the input and returns them as a list of images. 124 | 125 | Parameters: 126 | img (PILImage): The input image. 127 | *args: Additional arguments. 128 | **kwargs: Additional keyword arguments. 129 | 130 | Returns: 131 | List[PILImage]: A list of masks generated by the decoder. 132 | """ 133 | prompt = kwargs.get( 134 | "sam_prompt", 135 | [ 136 | { 137 | "type": "point", 138 | "label": 1, 139 | "data": [int(img.width / 2), int(img.height / 2)], 140 | } 141 | ], 142 | ) 143 | schema = { 144 | "type": "array", 145 | "items": { 146 | "type": "object", 147 | "properties": { 148 | "type": {"type": "string"}, 149 | "label": {"type": "integer"}, 150 | "data": { 151 | "type": "array", 152 | "items": {"type": "number"}, 153 | }, 154 | }, 155 | }, 156 | } 157 | 158 | validate(instance=prompt, schema=schema) 159 | 160 | target_size = 1024 161 | input_size = (684, 1024) 162 | encoder_input_name = self.encoder.get_inputs()[0].name 163 | 164 | img = img.convert("RGB") 165 | cv_image = np.array(img) 166 | original_size = cv_image.shape[:2] 167 | 168 | scale_x = input_size[1] / cv_image.shape[1] 169 | scale_y = input_size[0] / cv_image.shape[0] 170 | scale = min(scale_x, scale_y) 171 | 172 | transform_matrix = np.array( 173 | [ 174 | [scale, 0, 0], 175 | [0, scale, 0], 176 | [0, 0, 1], 177 | ] 178 | ) 179 | 180 | cv_image = cv2.warpAffine( 181 | cv_image, 182 | transform_matrix[:2], 183 | (input_size[1], input_size[0]), 184 | flags=cv2.INTER_LINEAR, 185 | ) 186 | 187 | ## encoder 188 | 189 | encoder_inputs = { 190 | encoder_input_name: cv_image.astype(np.float32), 191 | } 192 | 193 | encoder_output = self.encoder.run(None, encoder_inputs) 194 | image_embedding = encoder_output[0] 195 | 196 | embedding = { 197 | "image_embedding": image_embedding, 198 | "original_size": original_size, 199 | "transform_matrix": transform_matrix, 200 | } 201 | 202 | ## decoder 203 | 204 | input_points, input_labels = get_input_points(prompt) 205 | onnx_coord = np.concatenate([input_points, np.array([[0.0, 0.0]])], axis=0)[ 206 | None, :, : 207 | ] 208 | onnx_label = np.concatenate([input_labels, np.array([-1])], axis=0)[ 209 | None, : 210 | ].astype(np.float32) 211 | onnx_coord = apply_coords(onnx_coord, input_size, target_size).astype( 212 | np.float32 213 | ) 214 | 215 | onnx_coord = np.concatenate( 216 | [ 217 | onnx_coord, 218 | np.ones((1, onnx_coord.shape[1], 1), dtype=np.float32), 219 | ], 220 | axis=2, 221 | ) 222 | onnx_coord = np.matmul(onnx_coord, transform_matrix.T) 223 | onnx_coord = onnx_coord[:, :, :2].astype(np.float32) 224 | 225 | onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32) 226 | onnx_has_mask_input = np.zeros(1, dtype=np.float32) 227 | 228 | decoder_inputs = { 229 | "image_embeddings": image_embedding, 230 | "point_coords": onnx_coord, 231 | "point_labels": onnx_label, 232 | "mask_input": onnx_mask_input, 233 | "has_mask_input": onnx_has_mask_input, 234 | "orig_im_size": np.array(input_size, dtype=np.float32), 235 | } 236 | 237 | masks, _, _ = self.decoder.run(None, decoder_inputs) 238 | inv_transform_matrix = np.linalg.inv(transform_matrix) 239 | masks = transform_masks(masks, original_size, inv_transform_matrix) 240 | 241 | mask = np.zeros((masks.shape[2], masks.shape[3], 3), dtype=np.uint8) 242 | for m in masks[0, :, :, :]: 243 | mask[m > 0.0] = [255, 255, 255] 244 | 245 | return [Image.fromarray(mask).convert("L")] 246 | 247 | @classmethod 248 | def download_models(cls, *args, **kwargs): 249 | """ 250 | Class method to download ONNX model files. 251 | 252 | This method is responsible for downloading two ONNX model files from specified URLs and saving them locally. The downloaded files are saved with the naming convention 'name_encoder.onnx' and 'name_decoder.onnx', where 'name' is the value returned by the 'name' method. 253 | 254 | Parameters: 255 | cls: The class object. 256 | *args: Variable length argument list. 257 | **kwargs: Arbitrary keyword arguments. 258 | 259 | Returns: 260 | tuple: A tuple containing the file paths of the downloaded encoder and decoder models. 261 | """ 262 | model_name = kwargs.get("sam_model", "sam_vit_b_01ec64") 263 | quant = kwargs.get("sam_quant", False) 264 | 265 | fname_encoder = f"{model_name}.encoder.onnx" 266 | fname_decoder = f"{model_name}.decoder.onnx" 267 | 268 | if quant: 269 | fname_encoder = f"{model_name}.encoder.quant.onnx" 270 | fname_decoder = f"{model_name}.decoder.quant.onnx" 271 | 272 | pooch.retrieve( 273 | f"https://github.com/danielgatis/rembg/releases/download/v0.0.0/{fname_encoder}", 274 | None, 275 | fname=fname_encoder, 276 | path=cls.u2net_home(*args, **kwargs), 277 | progressbar=True, 278 | ) 279 | 280 | pooch.retrieve( 281 | f"https://github.com/danielgatis/rembg/releases/download/v0.0.0/{fname_decoder}", 282 | None, 283 | fname=fname_decoder, 284 | path=cls.u2net_home(*args, **kwargs), 285 | progressbar=True, 286 | ) 287 | 288 | if fname_encoder == "sam_vit_h_4b8939.encoder.onnx" and not os.path.exists( 289 | os.path.join( 290 | cls.u2net_home(*args, **kwargs), "sam_vit_h_4b8939.encoder_data.bin" 291 | ) 292 | ): 293 | content = bytearray() 294 | 295 | for i in range(1, 4): 296 | pooch.retrieve( 297 | f"https://github.com/danielgatis/rembg/releases/download/v0.0.0/sam_vit_h_4b8939.encoder_data.{i}.bin", 298 | None, 299 | fname=f"sam_vit_h_4b8939.encoder_data.{i}.bin", 300 | path=cls.u2net_home(*args, **kwargs), 301 | progressbar=True, 302 | ) 303 | 304 | fbin = os.path.join( 305 | cls.u2net_home(*args, **kwargs), 306 | f"sam_vit_h_4b8939.encoder_data.{i}.bin", 307 | ) 308 | content.extend(open(fbin, "rb").read()) 309 | os.remove(fbin) 310 | 311 | with open( 312 | os.path.join( 313 | cls.u2net_home(*args, **kwargs), 314 | "sam_vit_h_4b8939.encoder_data.bin", 315 | ), 316 | "wb", 317 | ) as fp: 318 | fp.write(content) 319 | 320 | return ( 321 | os.path.join(cls.u2net_home(*args, **kwargs), fname_encoder), 322 | os.path.join(cls.u2net_home(*args, **kwargs), fname_decoder), 323 | ) 324 | 325 | @classmethod 326 | def name(cls, *args, **kwargs): 327 | """ 328 | Class method to return a string value. 329 | 330 | This method returns the string value 'sam'. 331 | 332 | Parameters: 333 | cls: The class object. 334 | *args: Variable length argument list. 335 | **kwargs: Arbitrary keyword arguments. 336 | 337 | Returns: 338 | str: The string value 'sam'. 339 | """ 340 | return "sam" 341 | -------------------------------------------------------------------------------- /rembg/sessions/silueta.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List 3 | 4 | import numpy as np 5 | import pooch 6 | from PIL import Image 7 | from PIL.Image import Image as PILImage 8 | 9 | from .base import BaseSession 10 | 11 | 12 | class SiluetaSession(BaseSession): 13 | """This is a class representing a SiluetaSession object.""" 14 | 15 | def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]: 16 | """ 17 | Predict the mask of the input image. 18 | 19 | This method takes an image as input, preprocesses it, and performs a prediction to generate a mask. The generated mask is then post-processed and returned as a list of PILImage objects. 20 | 21 | Parameters: 22 | img (PILImage): The input image to be processed. 23 | *args: Variable length argument list. 24 | **kwargs: Arbitrary keyword arguments. 25 | 26 | Returns: 27 | List[PILImage]: A list of post-processed masks. 28 | """ 29 | ort_outs = self.inner_session.run( 30 | None, 31 | self.normalize( 32 | img, (0.485, 0.456, 0.406), (0.229, 0.224, 0.225), (320, 320) 33 | ), 34 | ) 35 | 36 | pred = ort_outs[0][:, 0, :, :] 37 | 38 | ma = np.max(pred) 39 | mi = np.min(pred) 40 | 41 | pred = (pred - mi) / (ma - mi) 42 | pred = np.squeeze(pred) 43 | 44 | mask = Image.fromarray((pred * 255).astype("uint8"), mode="L") 45 | mask = mask.resize(img.size, Image.Resampling.LANCZOS) 46 | 47 | return [mask] 48 | 49 | @classmethod 50 | def download_models(cls, *args, **kwargs): 51 | """ 52 | Download the pre-trained model file. 53 | 54 | This method downloads the pre-trained model file from a specified URL. The file is saved to the U2NET home directory. 55 | 56 | Parameters: 57 | *args: Variable length argument list. 58 | **kwargs: Arbitrary keyword arguments. 59 | 60 | Returns: 61 | str: The path to the downloaded model file. 62 | """ 63 | fname = f"{cls.name()}.onnx" 64 | pooch.retrieve( 65 | "https://github.com/danielgatis/rembg/releases/download/v0.0.0/silueta.onnx", 66 | ( 67 | None 68 | if cls.checksum_disabled(*args, **kwargs) 69 | else "md5:55e59e0d8062d2f5d013f4725ee84782" 70 | ), 71 | fname=fname, 72 | path=cls.u2net_home(*args, **kwargs), 73 | progressbar=True, 74 | ) 75 | 76 | return os.path.join(cls.u2net_home(*args, **kwargs), fname) 77 | 78 | @classmethod 79 | def name(cls, *args, **kwargs): 80 | """ 81 | Return the name of the model. 82 | 83 | This method returns the name of the Silueta model. 84 | 85 | Parameters: 86 | *args: Variable length argument list. 87 | **kwargs: Arbitrary keyword arguments. 88 | 89 | Returns: 90 | str: The name of the model. 91 | """ 92 | return "silueta" 93 | -------------------------------------------------------------------------------- /rembg/sessions/u2net.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List 3 | 4 | import numpy as np 5 | import pooch 6 | from PIL import Image 7 | from PIL.Image import Image as PILImage 8 | 9 | from .base import BaseSession 10 | 11 | 12 | class U2netSession(BaseSession): 13 | """ 14 | This class represents a U2net session, which is a subclass of BaseSession. 15 | """ 16 | 17 | def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]: 18 | """ 19 | Predicts the output masks for the input image using the inner session. 20 | 21 | Parameters: 22 | img (PILImage): The input image. 23 | *args: Additional positional arguments. 24 | **kwargs: Additional keyword arguments. 25 | 26 | Returns: 27 | List[PILImage]: The list of output masks. 28 | """ 29 | ort_outs = self.inner_session.run( 30 | None, 31 | self.normalize( 32 | img, (0.485, 0.456, 0.406), (0.229, 0.224, 0.225), (320, 320) 33 | ), 34 | ) 35 | 36 | pred = ort_outs[0][:, 0, :, :] 37 | 38 | ma = np.max(pred) 39 | mi = np.min(pred) 40 | 41 | pred = (pred - mi) / (ma - mi) 42 | pred = np.squeeze(pred) 43 | 44 | mask = Image.fromarray((pred.clip(0, 1) * 255).astype("uint8"), mode="L") 45 | mask = mask.resize(img.size, Image.Resampling.LANCZOS) 46 | 47 | return [mask] 48 | 49 | @classmethod 50 | def download_models(cls, *args, **kwargs): 51 | """ 52 | Downloads the U2net model file from a specific URL and saves it. 53 | 54 | Parameters: 55 | *args: Additional positional arguments. 56 | **kwargs: Additional keyword arguments. 57 | 58 | Returns: 59 | str: The path to the downloaded model file. 60 | """ 61 | fname = f"{cls.name(*args, **kwargs)}.onnx" 62 | pooch.retrieve( 63 | "https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net.onnx", 64 | ( 65 | None 66 | if cls.checksum_disabled(*args, **kwargs) 67 | else "md5:60024c5c889badc19c04ad937298a77b" 68 | ), 69 | fname=fname, 70 | path=cls.u2net_home(*args, **kwargs), 71 | progressbar=True, 72 | ) 73 | 74 | return os.path.join(cls.u2net_home(*args, **kwargs), fname) 75 | 76 | @classmethod 77 | def name(cls, *args, **kwargs): 78 | """ 79 | Returns the name of the U2net session. 80 | 81 | Parameters: 82 | *args: Additional positional arguments. 83 | **kwargs: Additional keyword arguments. 84 | 85 | Returns: 86 | str: The name of the session. 87 | """ 88 | return "u2net" 89 | -------------------------------------------------------------------------------- /rembg/sessions/u2net_cloth_seg.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List 3 | 4 | import numpy as np 5 | import pooch 6 | from PIL import Image 7 | from PIL.Image import Image as PILImage 8 | from scipy.special import log_softmax 9 | 10 | from .base import BaseSession 11 | 12 | palette1 = [ 13 | 0, 14 | 0, 15 | 0, 16 | 255, 17 | 255, 18 | 255, 19 | 0, 20 | 0, 21 | 0, 22 | 0, 23 | 0, 24 | 0, 25 | ] 26 | 27 | palette2 = [ 28 | 0, 29 | 0, 30 | 0, 31 | 0, 32 | 0, 33 | 0, 34 | 255, 35 | 255, 36 | 255, 37 | 0, 38 | 0, 39 | 0, 40 | ] 41 | 42 | palette3 = [ 43 | 0, 44 | 0, 45 | 0, 46 | 0, 47 | 0, 48 | 0, 49 | 0, 50 | 0, 51 | 0, 52 | 255, 53 | 255, 54 | 255, 55 | ] 56 | 57 | 58 | class Unet2ClothSession(BaseSession): 59 | def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]: 60 | """ 61 | Predict the cloth category of an image. 62 | 63 | This method takes an image as input and predicts the cloth category of the image. 64 | The method uses the inner_session to make predictions using a pre-trained model. 65 | The predicted mask is then converted to an image and resized to match the size of the input image. 66 | Depending on the cloth category specified in the method arguments, the method applies different color palettes to the mask and appends the resulting images to a list. 67 | 68 | Parameters: 69 | img (PILImage): The input image. 70 | *args: Additional positional arguments. 71 | **kwargs: Additional keyword arguments. 72 | 73 | Returns: 74 | List[PILImage]: A list of images representing the predicted masks. 75 | """ 76 | ort_outs = self.inner_session.run( 77 | None, 78 | self.normalize( 79 | img, (0.485, 0.456, 0.406), (0.229, 0.224, 0.225), (768, 768) 80 | ), 81 | ) 82 | 83 | pred = ort_outs 84 | pred = log_softmax(pred[0], 1) 85 | pred = np.argmax(pred, axis=1, keepdims=True) 86 | pred = np.squeeze(pred, 0) 87 | pred = np.squeeze(pred, 0) 88 | 89 | mask = Image.fromarray(pred.astype("uint8"), mode="L") 90 | mask = mask.resize(img.size, Image.Resampling.LANCZOS) 91 | 92 | masks = [] 93 | 94 | cloth_category = kwargs.get("cc") or kwargs.get("cloth_category") 95 | 96 | def upper_cloth(): 97 | mask1 = mask.copy() 98 | mask1.putpalette(palette1) 99 | mask1 = mask1.convert("RGB").convert("L") 100 | masks.append(mask1) 101 | 102 | def lower_cloth(): 103 | mask2 = mask.copy() 104 | mask2.putpalette(palette2) 105 | mask2 = mask2.convert("RGB").convert("L") 106 | masks.append(mask2) 107 | 108 | def full_cloth(): 109 | mask3 = mask.copy() 110 | mask3.putpalette(palette3) 111 | mask3 = mask3.convert("RGB").convert("L") 112 | masks.append(mask3) 113 | 114 | if cloth_category == "upper": 115 | upper_cloth() 116 | elif cloth_category == "lower": 117 | lower_cloth() 118 | elif cloth_category == "full": 119 | full_cloth() 120 | else: 121 | upper_cloth() 122 | lower_cloth() 123 | full_cloth() 124 | 125 | return masks 126 | 127 | @classmethod 128 | def download_models(cls, *args, **kwargs): 129 | fname = f"{cls.name(*args, **kwargs)}.onnx" 130 | pooch.retrieve( 131 | "https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net_cloth_seg.onnx", 132 | ( 133 | None 134 | if cls.checksum_disabled(*args, **kwargs) 135 | else "md5:2434d1f3cb744e0e49386c906e5a08bb" 136 | ), 137 | fname=fname, 138 | path=cls.u2net_home(*args, **kwargs), 139 | progressbar=True, 140 | ) 141 | 142 | return os.path.join(cls.u2net_home(*args, **kwargs), fname) 143 | 144 | @classmethod 145 | def name(cls, *args, **kwargs): 146 | return "u2net_cloth_seg" 147 | -------------------------------------------------------------------------------- /rembg/sessions/u2net_custom.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List 3 | 4 | import numpy as np 5 | import onnxruntime as ort 6 | import pooch 7 | from PIL import Image 8 | from PIL.Image import Image as PILImage 9 | 10 | from .base import BaseSession 11 | 12 | 13 | class U2netCustomSession(BaseSession): 14 | """This is a class representing a custom session for the U2net model.""" 15 | 16 | def __init__(self, model_name: str, sess_opts: ort.SessionOptions, *args, **kwargs): 17 | """ 18 | Initialize a new U2netCustomSession object. 19 | 20 | Parameters: 21 | model_name (str): The name of the model. 22 | sess_opts (ort.SessionOptions): The session options. 23 | *args: Additional positional arguments. 24 | **kwargs: Additional keyword arguments. 25 | 26 | Raises: 27 | ValueError: If model_path is None. 28 | """ 29 | model_path = kwargs.get("model_path") 30 | if model_path is None: 31 | raise ValueError("model_path is required") 32 | 33 | super().__init__(model_name, sess_opts, *args, **kwargs) 34 | 35 | def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]: 36 | """ 37 | Predict the segmentation mask for the input image. 38 | 39 | Parameters: 40 | img (PILImage): The input image. 41 | *args: Additional positional arguments. 42 | **kwargs: Additional keyword arguments. 43 | 44 | Returns: 45 | List[PILImage]: A list of PILImage objects representing the segmentation mask. 46 | """ 47 | ort_outs = self.inner_session.run( 48 | None, 49 | self.normalize( 50 | img, (0.485, 0.456, 0.406), (0.229, 0.224, 0.225), (320, 320) 51 | ), 52 | ) 53 | 54 | pred = ort_outs[0][:, 0, :, :] 55 | 56 | ma = np.max(pred) 57 | mi = np.min(pred) 58 | 59 | pred = (pred - mi) / (ma - mi) 60 | pred = np.squeeze(pred) 61 | 62 | mask = Image.fromarray((pred * 255).astype("uint8"), mode="L") 63 | mask = mask.resize(img.size, Image.Resampling.LANCZOS) 64 | 65 | return [mask] 66 | 67 | @classmethod 68 | def download_models(cls, *args, **kwargs): 69 | """ 70 | Download the model files. 71 | 72 | Parameters: 73 | *args: Additional positional arguments. 74 | **kwargs: Additional keyword arguments. 75 | 76 | Returns: 77 | str: The absolute path to the model files. 78 | """ 79 | model_path = kwargs.get("model_path") 80 | if model_path is None: 81 | return 82 | 83 | return os.path.abspath(os.path.expanduser(model_path)) 84 | 85 | @classmethod 86 | def name(cls, *args, **kwargs): 87 | """ 88 | Get the name of the model. 89 | 90 | Parameters: 91 | *args: Additional positional arguments. 92 | **kwargs: Additional keyword arguments. 93 | 94 | Returns: 95 | str: The name of the model. 96 | """ 97 | return "u2net_custom" 98 | -------------------------------------------------------------------------------- /rembg/sessions/u2net_human_seg.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List 3 | 4 | import numpy as np 5 | import pooch 6 | from PIL import Image 7 | from PIL.Image import Image as PILImage 8 | 9 | from .base import BaseSession 10 | 11 | 12 | class U2netHumanSegSession(BaseSession): 13 | """ 14 | This class represents a session for performing human segmentation using the U2Net model. 15 | """ 16 | 17 | def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]: 18 | """ 19 | Predicts human segmentation masks for the input image. 20 | 21 | Parameters: 22 | img (PILImage): The input image. 23 | *args: Variable length argument list. 24 | **kwargs: Arbitrary keyword arguments. 25 | 26 | Returns: 27 | List[PILImage]: A list of predicted masks. 28 | """ 29 | ort_outs = self.inner_session.run( 30 | None, 31 | self.normalize( 32 | img, (0.485, 0.456, 0.406), (0.229, 0.224, 0.225), (320, 320) 33 | ), 34 | ) 35 | 36 | pred = ort_outs[0][:, 0, :, :] 37 | 38 | ma = np.max(pred) 39 | mi = np.min(pred) 40 | 41 | pred = (pred - mi) / (ma - mi) 42 | pred = np.squeeze(pred) 43 | 44 | mask = Image.fromarray((pred * 255).astype("uint8"), mode="L") 45 | mask = mask.resize(img.size, Image.Resampling.LANCZOS) 46 | 47 | return [mask] 48 | 49 | @classmethod 50 | def download_models(cls, *args, **kwargs): 51 | """ 52 | Downloads the U2Net model weights. 53 | 54 | Parameters: 55 | *args: Variable length argument list. 56 | **kwargs: Arbitrary keyword arguments. 57 | 58 | Returns: 59 | str: The path to the downloaded model weights. 60 | """ 61 | fname = f"{cls.name(*args, **kwargs)}.onnx" 62 | pooch.retrieve( 63 | "https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net_human_seg.onnx", 64 | ( 65 | None 66 | if cls.checksum_disabled(*args, **kwargs) 67 | else "md5:c09ddc2e0104f800e3e1bb4652583d1f" 68 | ), 69 | fname=fname, 70 | path=cls.u2net_home(*args, **kwargs), 71 | progressbar=True, 72 | ) 73 | 74 | return os.path.join(cls.u2net_home(*args, **kwargs), fname) 75 | 76 | @classmethod 77 | def name(cls, *args, **kwargs): 78 | """ 79 | Returns the name of the U2Net model. 80 | 81 | Parameters: 82 | *args: Variable length argument list. 83 | **kwargs: Arbitrary keyword arguments. 84 | 85 | Returns: 86 | str: The name of the model. 87 | """ 88 | return "u2net_human_seg" 89 | -------------------------------------------------------------------------------- /rembg/sessions/u2netp.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List 3 | 4 | import numpy as np 5 | import pooch 6 | from PIL import Image 7 | from PIL.Image import Image as PILImage 8 | 9 | from .base import BaseSession 10 | 11 | 12 | class U2netpSession(BaseSession): 13 | """This class represents a session for using the U2netp model.""" 14 | 15 | def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]: 16 | """ 17 | Predicts the mask for the given image using the U2netp model. 18 | 19 | Parameters: 20 | img (PILImage): The input image. 21 | 22 | Returns: 23 | List[PILImage]: The predicted mask. 24 | """ 25 | ort_outs = self.inner_session.run( 26 | None, 27 | self.normalize( 28 | img, (0.485, 0.456, 0.406), (0.229, 0.224, 0.225), (320, 320) 29 | ), 30 | ) 31 | 32 | pred = ort_outs[0][:, 0, :, :] 33 | 34 | ma = np.max(pred) 35 | mi = np.min(pred) 36 | 37 | pred = (pred - mi) / (ma - mi) 38 | pred = np.squeeze(pred) 39 | 40 | mask = Image.fromarray((pred * 255).astype("uint8"), mode="L") 41 | mask = mask.resize(img.size, Image.Resampling.LANCZOS) 42 | 43 | return [mask] 44 | 45 | @classmethod 46 | def download_models(cls, *args, **kwargs): 47 | """ 48 | Downloads the U2netp model. 49 | 50 | Returns: 51 | str: The path to the downloaded model. 52 | """ 53 | fname = f"{cls.name(*args, **kwargs)}.onnx" 54 | pooch.retrieve( 55 | "https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2netp.onnx", 56 | ( 57 | None 58 | if cls.checksum_disabled(*args, **kwargs) 59 | else "md5:8e83ca70e441ab06c318d82300c84806" 60 | ), 61 | fname=fname, 62 | path=cls.u2net_home(*args, **kwargs), 63 | progressbar=True, 64 | ) 65 | 66 | return os.path.join(cls.u2net_home(*args, **kwargs), fname) 67 | 68 | @classmethod 69 | def name(cls, *args, **kwargs): 70 | """ 71 | Returns the name of the U2netp model. 72 | 73 | Returns: 74 | str: The name of the model. 75 | """ 76 | return "u2netp" 77 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | # This includes the license file(s) in the wheel. 3 | # https://wheel.readthedocs.io/en/stable/user_guide.html#including-license-files-in-the-generated-wheel-file 4 | license_files = LICENSE.txt 5 | 6 | # See the docstring in versioneer.py for instructions. Note that you must 7 | # re-run 'versioneer.py setup' after changing this section, and commit the 8 | # resulting files. 9 | 10 | [versioneer] 11 | VCS = git 12 | style = pep440 13 | versionfile_source = rembg/_version.py 14 | versionfile_build = rembg/_version.py 15 | tag_prefix = v 16 | parentdir_prefix = rembg- 17 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pathlib 3 | import sys 4 | 5 | sys.path.append(os.path.dirname(__file__)) 6 | from setuptools import find_packages, setup 7 | 8 | import versioneer 9 | 10 | here = pathlib.Path(__file__).parent.resolve() 11 | 12 | long_description = (here / "README.md").read_text(encoding="utf-8") 13 | 14 | install_requires = [ 15 | "jsonschema", 16 | "numpy", 17 | "opencv-python-headless", 18 | "pillow", 19 | "pooch", 20 | "pymatting", 21 | "scikit-image", 22 | "scipy", 23 | "tqdm", 24 | ] 25 | 26 | extras_require = { 27 | "dev": [ 28 | "bandit", 29 | "black", 30 | "flake8", 31 | "imagehash", 32 | "isort", 33 | "mypy", 34 | "pytest", 35 | "setuptools", 36 | "twine", 37 | "wheel", 38 | ], 39 | "cpu": ["onnxruntime"], 40 | "gpu": ["onnxruntime-gpu"], 41 | "rocm": ["onnxruntime-rocm"], 42 | "cli": [ 43 | "aiohttp", 44 | "asyncer", 45 | "click", 46 | "fastapi", 47 | "filetype", 48 | "gradio", 49 | "python-multipart", 50 | "uvicorn", 51 | "watchdog", 52 | ], 53 | } 54 | 55 | entry_points = { 56 | "console_scripts": [ 57 | "rembg=rembg.cli:main", 58 | ], 59 | } 60 | 61 | 62 | setup( 63 | name="rembg", 64 | description="Remove image background", 65 | long_description=long_description, 66 | long_description_content_type="text/markdown", 67 | url="https://github.com/danielgatis/rembg", 68 | author="Daniel Gatis", 69 | author_email="danielgatis@gmail.com", 70 | classifiers=[ 71 | "License :: OSI Approved :: MIT License", 72 | "Topic :: Scientific/Engineering", 73 | "Topic :: Scientific/Engineering :: Mathematics", 74 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 75 | "Topic :: Software Development", 76 | "Topic :: Software Development :: Libraries", 77 | "Topic :: Software Development :: Libraries :: Python Modules", 78 | "Programming Language :: Python", 79 | "Programming Language :: Python :: 3 :: Only", 80 | "Programming Language :: Python :: 3.10", 81 | "Programming Language :: Python :: 3.11", 82 | "Programming Language :: Python :: 3.12", 83 | "Programming Language :: Python :: 3.13", 84 | ], 85 | keywords="remove, background, u2net", 86 | python_requires=">=3.10, <3.14", 87 | packages=find_packages(), 88 | install_requires=install_requires, 89 | entry_points=entry_points, 90 | extras_require=extras_require, 91 | version=versioneer.get_version(), 92 | cmdclass=versioneer.get_cmdclass(), 93 | ) 94 | -------------------------------------------------------------------------------- /tests/fixtures/anime-girl-1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielgatis/rembg/bc1436cad8dd2c94aa396604f9afdc2dde95cf55/tests/fixtures/anime-girl-1.jpg -------------------------------------------------------------------------------- /tests/fixtures/car-1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielgatis/rembg/bc1436cad8dd2c94aa396604f9afdc2dde95cf55/tests/fixtures/car-1.jpg -------------------------------------------------------------------------------- /tests/fixtures/cloth-1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielgatis/rembg/bc1436cad8dd2c94aa396604f9afdc2dde95cf55/tests/fixtures/cloth-1.jpg -------------------------------------------------------------------------------- /tests/fixtures/plants-1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielgatis/rembg/bc1436cad8dd2c94aa396604f9afdc2dde95cf55/tests/fixtures/plants-1.jpg -------------------------------------------------------------------------------- /tests/results/anime-girl-1.birefnet-cod.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielgatis/rembg/bc1436cad8dd2c94aa396604f9afdc2dde95cf55/tests/results/anime-girl-1.birefnet-cod.png -------------------------------------------------------------------------------- /tests/results/anime-girl-1.birefnet-dis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielgatis/rembg/bc1436cad8dd2c94aa396604f9afdc2dde95cf55/tests/results/anime-girl-1.birefnet-dis.png -------------------------------------------------------------------------------- /tests/results/anime-girl-1.birefnet-general-lite.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielgatis/rembg/bc1436cad8dd2c94aa396604f9afdc2dde95cf55/tests/results/anime-girl-1.birefnet-general-lite.png -------------------------------------------------------------------------------- /tests/results/anime-girl-1.birefnet-general.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielgatis/rembg/bc1436cad8dd2c94aa396604f9afdc2dde95cf55/tests/results/anime-girl-1.birefnet-general.png -------------------------------------------------------------------------------- /tests/results/anime-girl-1.birefnet-hrsod.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielgatis/rembg/bc1436cad8dd2c94aa396604f9afdc2dde95cf55/tests/results/anime-girl-1.birefnet-hrsod.png -------------------------------------------------------------------------------- /tests/results/anime-girl-1.birefnet-massive.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielgatis/rembg/bc1436cad8dd2c94aa396604f9afdc2dde95cf55/tests/results/anime-girl-1.birefnet-massive.png -------------------------------------------------------------------------------- /tests/results/anime-girl-1.birefnet-portrait.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielgatis/rembg/bc1436cad8dd2c94aa396604f9afdc2dde95cf55/tests/results/anime-girl-1.birefnet-portrait.png -------------------------------------------------------------------------------- /tests/results/anime-girl-1.isnet-anime.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielgatis/rembg/bc1436cad8dd2c94aa396604f9afdc2dde95cf55/tests/results/anime-girl-1.isnet-anime.png -------------------------------------------------------------------------------- /tests/results/anime-girl-1.isnet-general-use.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielgatis/rembg/bc1436cad8dd2c94aa396604f9afdc2dde95cf55/tests/results/anime-girl-1.isnet-general-use.png -------------------------------------------------------------------------------- /tests/results/anime-girl-1.sam.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielgatis/rembg/bc1436cad8dd2c94aa396604f9afdc2dde95cf55/tests/results/anime-girl-1.sam.png -------------------------------------------------------------------------------- /tests/results/anime-girl-1.silueta.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielgatis/rembg/bc1436cad8dd2c94aa396604f9afdc2dde95cf55/tests/results/anime-girl-1.silueta.png -------------------------------------------------------------------------------- /tests/results/anime-girl-1.u2net.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielgatis/rembg/bc1436cad8dd2c94aa396604f9afdc2dde95cf55/tests/results/anime-girl-1.u2net.png -------------------------------------------------------------------------------- /tests/results/anime-girl-1.u2net_cloth_seg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielgatis/rembg/bc1436cad8dd2c94aa396604f9afdc2dde95cf55/tests/results/anime-girl-1.u2net_cloth_seg.png -------------------------------------------------------------------------------- /tests/results/anime-girl-1.u2net_human_seg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielgatis/rembg/bc1436cad8dd2c94aa396604f9afdc2dde95cf55/tests/results/anime-girl-1.u2net_human_seg.png -------------------------------------------------------------------------------- /tests/results/anime-girl-1.u2netp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielgatis/rembg/bc1436cad8dd2c94aa396604f9afdc2dde95cf55/tests/results/anime-girl-1.u2netp.png -------------------------------------------------------------------------------- /tests/results/car-1.birefnet-cod.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielgatis/rembg/bc1436cad8dd2c94aa396604f9afdc2dde95cf55/tests/results/car-1.birefnet-cod.png -------------------------------------------------------------------------------- /tests/results/car-1.birefnet-dis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielgatis/rembg/bc1436cad8dd2c94aa396604f9afdc2dde95cf55/tests/results/car-1.birefnet-dis.png -------------------------------------------------------------------------------- /tests/results/car-1.birefnet-general-lite.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielgatis/rembg/bc1436cad8dd2c94aa396604f9afdc2dde95cf55/tests/results/car-1.birefnet-general-lite.png -------------------------------------------------------------------------------- /tests/results/car-1.birefnet-general.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielgatis/rembg/bc1436cad8dd2c94aa396604f9afdc2dde95cf55/tests/results/car-1.birefnet-general.png -------------------------------------------------------------------------------- /tests/results/car-1.birefnet-hrsod.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielgatis/rembg/bc1436cad8dd2c94aa396604f9afdc2dde95cf55/tests/results/car-1.birefnet-hrsod.png -------------------------------------------------------------------------------- /tests/results/car-1.birefnet-massive.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielgatis/rembg/bc1436cad8dd2c94aa396604f9afdc2dde95cf55/tests/results/car-1.birefnet-massive.png -------------------------------------------------------------------------------- /tests/results/car-1.birefnet-portrait.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielgatis/rembg/bc1436cad8dd2c94aa396604f9afdc2dde95cf55/tests/results/car-1.birefnet-portrait.png -------------------------------------------------------------------------------- /tests/results/car-1.isnet-anime.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielgatis/rembg/bc1436cad8dd2c94aa396604f9afdc2dde95cf55/tests/results/car-1.isnet-anime.png -------------------------------------------------------------------------------- /tests/results/car-1.isnet-general-use.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielgatis/rembg/bc1436cad8dd2c94aa396604f9afdc2dde95cf55/tests/results/car-1.isnet-general-use.png -------------------------------------------------------------------------------- /tests/results/car-1.sam.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielgatis/rembg/bc1436cad8dd2c94aa396604f9afdc2dde95cf55/tests/results/car-1.sam.png -------------------------------------------------------------------------------- /tests/results/car-1.silueta.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielgatis/rembg/bc1436cad8dd2c94aa396604f9afdc2dde95cf55/tests/results/car-1.silueta.png -------------------------------------------------------------------------------- /tests/results/car-1.u2net.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielgatis/rembg/bc1436cad8dd2c94aa396604f9afdc2dde95cf55/tests/results/car-1.u2net.png -------------------------------------------------------------------------------- /tests/results/car-1.u2net_cloth_seg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielgatis/rembg/bc1436cad8dd2c94aa396604f9afdc2dde95cf55/tests/results/car-1.u2net_cloth_seg.png -------------------------------------------------------------------------------- /tests/results/car-1.u2net_human_seg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielgatis/rembg/bc1436cad8dd2c94aa396604f9afdc2dde95cf55/tests/results/car-1.u2net_human_seg.png -------------------------------------------------------------------------------- /tests/results/car-1.u2netp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielgatis/rembg/bc1436cad8dd2c94aa396604f9afdc2dde95cf55/tests/results/car-1.u2netp.png -------------------------------------------------------------------------------- /tests/results/cloth-1.birefnet-cod.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielgatis/rembg/bc1436cad8dd2c94aa396604f9afdc2dde95cf55/tests/results/cloth-1.birefnet-cod.png -------------------------------------------------------------------------------- /tests/results/cloth-1.birefnet-dis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielgatis/rembg/bc1436cad8dd2c94aa396604f9afdc2dde95cf55/tests/results/cloth-1.birefnet-dis.png -------------------------------------------------------------------------------- /tests/results/cloth-1.birefnet-general-lite.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielgatis/rembg/bc1436cad8dd2c94aa396604f9afdc2dde95cf55/tests/results/cloth-1.birefnet-general-lite.png -------------------------------------------------------------------------------- /tests/results/cloth-1.birefnet-general.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielgatis/rembg/bc1436cad8dd2c94aa396604f9afdc2dde95cf55/tests/results/cloth-1.birefnet-general.png -------------------------------------------------------------------------------- /tests/results/cloth-1.birefnet-hrsod.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielgatis/rembg/bc1436cad8dd2c94aa396604f9afdc2dde95cf55/tests/results/cloth-1.birefnet-hrsod.png -------------------------------------------------------------------------------- /tests/results/cloth-1.birefnet-massive.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielgatis/rembg/bc1436cad8dd2c94aa396604f9afdc2dde95cf55/tests/results/cloth-1.birefnet-massive.png -------------------------------------------------------------------------------- /tests/results/cloth-1.birefnet-portrait.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielgatis/rembg/bc1436cad8dd2c94aa396604f9afdc2dde95cf55/tests/results/cloth-1.birefnet-portrait.png -------------------------------------------------------------------------------- /tests/results/cloth-1.isnet-anime.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielgatis/rembg/bc1436cad8dd2c94aa396604f9afdc2dde95cf55/tests/results/cloth-1.isnet-anime.png -------------------------------------------------------------------------------- /tests/results/cloth-1.isnet-general-use.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielgatis/rembg/bc1436cad8dd2c94aa396604f9afdc2dde95cf55/tests/results/cloth-1.isnet-general-use.png -------------------------------------------------------------------------------- /tests/results/cloth-1.sam.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielgatis/rembg/bc1436cad8dd2c94aa396604f9afdc2dde95cf55/tests/results/cloth-1.sam.png -------------------------------------------------------------------------------- /tests/results/cloth-1.silueta.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielgatis/rembg/bc1436cad8dd2c94aa396604f9afdc2dde95cf55/tests/results/cloth-1.silueta.png -------------------------------------------------------------------------------- /tests/results/cloth-1.u2net.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielgatis/rembg/bc1436cad8dd2c94aa396604f9afdc2dde95cf55/tests/results/cloth-1.u2net.png -------------------------------------------------------------------------------- /tests/results/cloth-1.u2net_cloth_seg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielgatis/rembg/bc1436cad8dd2c94aa396604f9afdc2dde95cf55/tests/results/cloth-1.u2net_cloth_seg.png -------------------------------------------------------------------------------- /tests/results/cloth-1.u2net_human_seg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielgatis/rembg/bc1436cad8dd2c94aa396604f9afdc2dde95cf55/tests/results/cloth-1.u2net_human_seg.png -------------------------------------------------------------------------------- /tests/results/cloth-1.u2netp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielgatis/rembg/bc1436cad8dd2c94aa396604f9afdc2dde95cf55/tests/results/cloth-1.u2netp.png -------------------------------------------------------------------------------- /tests/results/plants-1.birefnet-cod.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielgatis/rembg/bc1436cad8dd2c94aa396604f9afdc2dde95cf55/tests/results/plants-1.birefnet-cod.png -------------------------------------------------------------------------------- /tests/results/plants-1.birefnet-dis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielgatis/rembg/bc1436cad8dd2c94aa396604f9afdc2dde95cf55/tests/results/plants-1.birefnet-dis.png -------------------------------------------------------------------------------- /tests/results/plants-1.birefnet-general-lite.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielgatis/rembg/bc1436cad8dd2c94aa396604f9afdc2dde95cf55/tests/results/plants-1.birefnet-general-lite.png -------------------------------------------------------------------------------- /tests/results/plants-1.birefnet-general.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielgatis/rembg/bc1436cad8dd2c94aa396604f9afdc2dde95cf55/tests/results/plants-1.birefnet-general.png -------------------------------------------------------------------------------- /tests/results/plants-1.birefnet-hrsod.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielgatis/rembg/bc1436cad8dd2c94aa396604f9afdc2dde95cf55/tests/results/plants-1.birefnet-hrsod.png -------------------------------------------------------------------------------- /tests/results/plants-1.birefnet-massive.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielgatis/rembg/bc1436cad8dd2c94aa396604f9afdc2dde95cf55/tests/results/plants-1.birefnet-massive.png -------------------------------------------------------------------------------- /tests/results/plants-1.birefnet-portrait.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielgatis/rembg/bc1436cad8dd2c94aa396604f9afdc2dde95cf55/tests/results/plants-1.birefnet-portrait.png -------------------------------------------------------------------------------- /tests/results/plants-1.isnet-anime.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielgatis/rembg/bc1436cad8dd2c94aa396604f9afdc2dde95cf55/tests/results/plants-1.isnet-anime.png -------------------------------------------------------------------------------- /tests/results/plants-1.isnet-general-use.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielgatis/rembg/bc1436cad8dd2c94aa396604f9afdc2dde95cf55/tests/results/plants-1.isnet-general-use.png -------------------------------------------------------------------------------- /tests/results/plants-1.sam.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielgatis/rembg/bc1436cad8dd2c94aa396604f9afdc2dde95cf55/tests/results/plants-1.sam.png -------------------------------------------------------------------------------- /tests/results/plants-1.silueta.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielgatis/rembg/bc1436cad8dd2c94aa396604f9afdc2dde95cf55/tests/results/plants-1.silueta.png -------------------------------------------------------------------------------- /tests/results/plants-1.u2net.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielgatis/rembg/bc1436cad8dd2c94aa396604f9afdc2dde95cf55/tests/results/plants-1.u2net.png -------------------------------------------------------------------------------- /tests/results/plants-1.u2net_cloth_seg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielgatis/rembg/bc1436cad8dd2c94aa396604f9afdc2dde95cf55/tests/results/plants-1.u2net_cloth_seg.png -------------------------------------------------------------------------------- /tests/results/plants-1.u2net_human_seg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielgatis/rembg/bc1436cad8dd2c94aa396604f9afdc2dde95cf55/tests/results/plants-1.u2net_human_seg.png -------------------------------------------------------------------------------- /tests/results/plants-1.u2netp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielgatis/rembg/bc1436cad8dd2c94aa396604f9afdc2dde95cf55/tests/results/plants-1.u2netp.png -------------------------------------------------------------------------------- /tests/test_remove.py: -------------------------------------------------------------------------------- 1 | from io import BytesIO 2 | from pathlib import Path 3 | 4 | from imagehash import phash as hash_img 5 | from PIL import Image 6 | 7 | from rembg import new_session, remove 8 | 9 | here = Path(__file__).parent.resolve() 10 | 11 | def test_remove(): 12 | kwargs = { 13 | "sam": { 14 | "anime-girl-1" : { 15 | "sam_prompt" :[{"type": "point", "data": [400, 165], "label": 1}], 16 | }, 17 | 18 | "car-1" : { 19 | "sam_prompt" :[{"type": "point", "data": [250, 200], "label": 1}], 20 | }, 21 | 22 | "cloth-1" : { 23 | "sam_prompt" :[{"type": "point", "data": [370, 495], "label": 1}], 24 | }, 25 | 26 | "plants-1" : { 27 | "sam_prompt" :[{"type": "point", "data": [724, 740], "label": 1}], 28 | }, 29 | } 30 | } 31 | 32 | for model in [ 33 | "u2net", 34 | "u2netp", 35 | "u2net_human_seg", 36 | "u2net_cloth_seg", 37 | "silueta", 38 | "isnet-general-use", 39 | "isnet-anime", 40 | "sam", 41 | "birefnet-general", 42 | "birefnet-general-lite", 43 | "birefnet-portrait", 44 | "birefnet-dis", 45 | "birefnet-hrsod", 46 | "birefnet-cod", 47 | "birefnet-massive" 48 | ]: 49 | for picture in ["anime-girl-1", "car-1", "cloth-1", "plants-1"]: 50 | image_path = Path(here / "fixtures" / f"{picture}.jpg") 51 | image = image_path.read_bytes() 52 | 53 | actual = remove(image, session=new_session(model), **kwargs.get(model, {}).get(picture, {})) 54 | actual_hash = hash_img(Image.open(BytesIO(actual))) 55 | 56 | expected_path = Path(here / "results" / f"{picture}.{model}.png") 57 | # Uncomment to update the expected results 58 | # f = open(expected_path, "wb") 59 | # f.write(actual) 60 | # f.close() 61 | 62 | expected = expected_path.read_bytes() 63 | expected_hash = hash_img(Image.open(BytesIO(expected))) 64 | 65 | print(f"image_path: {image_path}") 66 | print(f"expected_path: {expected_path}") 67 | print(f"actual_hash: {actual_hash}") 68 | print(f"expected_hash: {expected_hash}") 69 | print(f"actual_hash == expected_hash: {actual_hash == expected_hash}") 70 | print("---\n") 71 | 72 | assert actual_hash == expected_hash 73 | --------------------------------------------------------------------------------