├── .dockerignore ├── .flake8 ├── .gitattributes ├── .github └── workflows │ ├── codeql-analysis.yml │ ├── docker.yaml │ ├── master.yml │ ├── master_docker.yaml │ └── python-publish.yml ├── .gitignore ├── .pre-commit-config.yaml ├── Dockerfile.cpu ├── Dockerfile.cuda ├── LICENSE ├── MANIFEST.in ├── README.md ├── carvekit ├── __init__.py ├── __main__.py ├── api │ ├── __init__.py │ ├── high.py │ └── interface.py ├── ml │ ├── __init__.py │ ├── arch │ │ ├── __init__.py │ │ ├── basnet │ │ │ ├── __init__.py │ │ │ └── basnet.py │ │ ├── fba_matting │ │ │ ├── __init__.py │ │ │ ├── layers_WS.py │ │ │ ├── models.py │ │ │ ├── resnet_GN_WS.py │ │ │ ├── resnet_bn.py │ │ │ └── transforms.py │ │ ├── tracerb7 │ │ │ ├── __init__.py │ │ │ ├── att_modules.py │ │ │ ├── conv_modules.py │ │ │ ├── effi_utils.py │ │ │ ├── efficientnet.py │ │ │ └── tracer.py │ │ └── u2net │ │ │ ├── __init__.py │ │ │ └── u2net.py │ ├── files │ │ ├── __init__.py │ │ └── models_loc.py │ └── wrap │ │ ├── __init__.py │ │ ├── basnet.py │ │ ├── deeplab_v3.py │ │ ├── fba_matting.py │ │ ├── tracer_b7.py │ │ └── u2net.py ├── pipelines │ ├── __init__.py │ ├── postprocessing.py │ └── preprocessing.py ├── trimap │ ├── __init__.py │ ├── add_ops.py │ ├── cv_gen.py │ └── generator.py ├── utils │ ├── __init__.py │ ├── download_models.py │ ├── fs_utils.py │ ├── image_utils.py │ ├── mask_utils.py │ ├── models_utils.py │ └── pool_utils.py └── web │ ├── __init__.py │ ├── app.py │ ├── deps.py │ ├── handlers │ ├── __init__.py │ └── response.py │ ├── other │ ├── __init__.py │ └── removebg.py │ ├── responses │ ├── __init__.py │ └── api.py │ ├── routers │ ├── __init__.py │ └── api_router.py │ ├── schemas │ ├── __init__.py │ ├── config.py │ └── request.py │ ├── static │ ├── css │ │ ├── animate.css │ │ ├── bootstrap.min.css │ │ ├── fancybox_loading.gif │ │ ├── fancybox_overlay.png │ │ ├── fancybox_sprite.png │ │ ├── font-awesome.min.css │ │ ├── jquery.fancybox.css │ │ ├── main.css │ │ ├── media-queries.css │ │ ├── normalize.min.css │ │ └── particles.css │ ├── fonts │ │ ├── FontAwesome.otf │ │ ├── fontawesome-webfont.eot │ │ ├── fontawesome-webfont.svg │ │ ├── fontawesome-webfont.ttf │ │ └── fontawesome-webfont.woff │ ├── img │ │ ├── CarveKit_logo_main.png │ │ ├── art.gif │ │ ├── envelop.png │ │ ├── icon.png │ │ └── preloader.gif │ ├── index.html │ └── js │ │ ├── bootstrap.min.js │ │ ├── custom.js │ │ ├── jquery-1.11.1.min.js │ │ ├── jquery-countTo.js │ │ ├── jquery.appear.js │ │ ├── jquery.easing.min.js │ │ ├── jquery.fancybox.pack.js │ │ ├── jquery.mixitup.min.js │ │ ├── jquery.parallax-1.1.3.js │ │ ├── jquery.singlePageNav.min.js │ │ ├── modernizr-2.6.2.min.js │ │ ├── particles.js │ │ └── wow.min.js │ └── utils │ ├── __init__.py │ ├── init_utils.py │ ├── net_utils.py │ └── task_queue.py ├── conftest.py ├── docker-compose.cpu.yml ├── docker-compose.cuda.yml ├── docs ├── CREDITS.md ├── code_examples │ └── python │ │ ├── http_api_lib.py │ │ ├── http_api_requests.py │ │ └── images │ │ └── 4.jpg ├── imgs │ ├── compare │ │ ├── 1.png │ │ ├── 2.png │ │ ├── 3.png │ │ └── readme.jpg │ ├── input │ │ ├── 1.jpg │ │ ├── 1_bg_removed.png │ │ ├── 2.jpg │ │ ├── 2_bg_removed.png │ │ ├── 3.jpg │ │ ├── 3_bg_removed.png │ │ ├── 4.jpg │ │ └── 4_bg_removed.png │ ├── logo.png │ └── screenshot │ │ ├── docs_fastapi.png │ │ └── frontend.png ├── other │ └── carvekit_try.ipynb └── readme │ └── ru.md ├── requirements.txt ├── requirements_dev.txt ├── requirements_test.txt ├── setup.py └── tests ├── data ├── cat.JPG ├── cat.MP3 ├── cat.jpg ├── cat.mp3 ├── cat_mask.png └── cat_trimap.png ├── test_basnet.py ├── test_deeplabv3.py ├── test_fba.py ├── test_fs_utils.py ├── test_high.py ├── test_image_utils.py ├── test_interface.py ├── test_mask_utils.py ├── test_models_utils.py ├── test_pool_utils.py ├── test_postprocessing.py ├── test_preprocessing.py ├── test_tracer.py ├── test_trimap.py └── test_u2net.py /.dockerignore: -------------------------------------------------------------------------------- 1 | venv 2 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore = E203, E266, E501, W503, F403, F401 3 | max-line-length = 79 4 | max-complexity = 50 5 | select = B,C,E,F,W,T4,B9 6 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | *.html linguist-detectable=false 2 | *.py linguist-detectable=true 3 | -------------------------------------------------------------------------------- /.github/workflows/codeql-analysis.yml: -------------------------------------------------------------------------------- 1 | # For most projects, this workflow file will not need changing; you simply need 2 | # to commit it to your repository. 3 | # 4 | # You may wish to alter this file to override the set of languages analyzed, 5 | # or to provide custom queries or build logic. 6 | # 7 | # ******** NOTE ******** 8 | # We have attempted to detect the languages in your repository. Please check 9 | # the `language` matrix defined below to confirm you have the correct set of 10 | # supported CodeQL languages. 11 | # 12 | name: "CodeQL" 13 | 14 | on: 15 | push: 16 | branches: [ master ] 17 | pull_request: 18 | # The branches below must be a subset of the branches above 19 | branches: [ master ] 20 | schedule: 21 | - cron: '23 5 * * 5' 22 | 23 | jobs: 24 | analyze: 25 | name: Analyze 26 | runs-on: ubuntu-latest 27 | permissions: 28 | actions: read 29 | contents: read 30 | security-events: write 31 | 32 | strategy: 33 | fail-fast: false 34 | matrix: 35 | language: [ 'javascript', 'python' ] 36 | # CodeQL supports [ 'cpp', 'csharp', 'go', 'java', 'javascript', 'python', 'ruby' ] 37 | # Learn more about CodeQL language support at https://aka.ms/codeql-docs/language-support 38 | 39 | steps: 40 | - name: Checkout repository 41 | uses: actions/checkout@v3 42 | 43 | # Initializes the CodeQL tools for scanning. 44 | - name: Initialize CodeQL 45 | uses: github/codeql-action/init@v2 46 | with: 47 | languages: ${{ matrix.language }} 48 | # If you wish to specify custom queries, you can do so here or in a config file. 49 | # By default, queries listed here will override any specified in a config file. 50 | # Prefix the list here with "+" to use these queries and those in the config file. 51 | 52 | # Details on CodeQL's query packs refer to : https://docs.github.com/en/code-security/code-scanning/automatically-scanning-your-code-for-vulnerabilities-and-errors/configuring-code-scanning#using-queries-in-ql-packs 53 | # queries: security-extended,security-and-quality 54 | 55 | 56 | # Autobuild attempts to build any compiled languages (C/C++, C#, or Java). 57 | # If this step fails, then you should remove it and run the build manually (see below) 58 | - name: Autobuild 59 | uses: github/codeql-action/autobuild@v2 60 | 61 | # ℹ️ Command-line programs to run using the OS shell. 62 | # 📚 See https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#jobsjob_idstepsrun 63 | 64 | # If the Autobuild fails above, remove it and uncomment the following three lines. 65 | # modify them (or add more) to build your code if your project, please refer to the EXAMPLE below for guidance. 66 | 67 | # - run: | 68 | # echo "Run, Build Application using script" 69 | # ./location_of_script_within_repo/buildscript.sh 70 | 71 | - name: Perform CodeQL Analysis 72 | uses: github/codeql-action/analyze@v2 73 | -------------------------------------------------------------------------------- /.github/workflows/docker.yaml: -------------------------------------------------------------------------------- 1 | name: Docker images 2 | 3 | on: 4 | release: 5 | types: [ published ] 6 | 7 | permissions: 8 | contents: read 9 | 10 | jobs: 11 | publish: 12 | runs-on: ubuntu-latest 13 | steps: 14 | - name: Checkout 15 | uses: actions/checkout@v3 16 | - name: Set up Docker BuildX 17 | uses: docker/setup-buildx-action@v2 18 | - name: Login to Docker Hub 19 | uses: docker/login-action@v2 20 | with: 21 | username: ${{ secrets.DOCKERHUB_USERNAME }} 22 | password: ${{ secrets.DOCKERHUB_TOKEN }} 23 | - name: Set env 24 | run: echo "RELEASE_VERSION=${GITHUB_REF#refs/*/}" >> $GITHUB_ENV 25 | - name: Build and push 26 | uses: docker/build-push-action@v3 27 | with: 28 | push: true 29 | context: ./ 30 | file: "./Dockerfile.cpu" 31 | tags: | 32 | anodev/carvekit:latest-cpu 33 | anodev/carvekit:${{ env.RELEASE_VERSION }}-cpu 34 | - name: Build and push cuda 35 | uses: docker/build-push-action@v3 36 | with: 37 | push: true 38 | context: ./ 39 | file: "./Dockerfile.cuda" 40 | tags: | 41 | anodev/carvekit:latest-cuda 42 | anodev/carvekit:${{ env.RELEASE_VERSION }}-cuda -------------------------------------------------------------------------------- /.github/workflows/master.yml: -------------------------------------------------------------------------------- 1 | name: Test Master CPU CI/CD 2 | 3 | on: 4 | push: 5 | branches: [ master ] 6 | pull_request: 7 | branches: [ master ] 8 | schedule: 9 | - cron: '23 5 * * 5' 10 | 11 | jobs: 12 | test_package: 13 | strategy: 14 | matrix: 15 | python-version: [3.9, 3.10.11, 3.11.7] 16 | os: [ubuntu-latest, windows-latest, macos-latest] 17 | runs-on: ${{ matrix.os }} 18 | 19 | steps: 20 | - uses: actions/checkout@v2 21 | - name: Set up Python ${{ matrix.python-version }} 22 | uses: actions/setup-python@v2 23 | with: 24 | python-version: ${{ matrix.python-version }} 25 | 26 | - name: Install dependencies 27 | run: | 28 | python -m pip install --upgrade pip 29 | pip install flake8 pytest 30 | pip3 install --no-cache-dir -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cpu 31 | pip3 install --no-cache-dir -r requirements_test.txt 32 | 33 | - name: Download all models 34 | run: | 35 | python3 -c "from carvekit.ml.files.models_loc import download_all; download_all();" 36 | 37 | - name: Lint with flake8 38 | run: | 39 | # stop the build if there are Python syntax errors or undefined names 40 | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics 41 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide 42 | flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics 43 | 44 | - name: Test with pytest 45 | run: | 46 | pytest -s -v 47 | 48 | -------------------------------------------------------------------------------- /.github/workflows/master_docker.yaml: -------------------------------------------------------------------------------- 1 | name: Test Docker CPU CI/CD 2 | 3 | on: 4 | push: 5 | branches: [ master ] 6 | pull_request: 7 | branches: [ master ] 8 | schedule: 9 | - cron: '23 5 * * 5' 10 | 11 | jobs: 12 | build: 13 | runs-on: ubuntu-latest 14 | steps: 15 | - uses: actions/checkout@v3 16 | - name: Build the CPU Docker img 17 | run: docker build . --file Dockerfile.cpu --tag carvekit 18 | - name: Test CPU Docker image 19 | run: docker run --rm carvekit pytest 20 | - name: Build the GPU Docker image 21 | run: docker build . --file Dockerfile.cuda --tag carvekit_gpu 22 | - name: Test GPU Docker image on CPU 23 | run: docker run --rm carvekit_gpu pytest 24 | 25 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | name: Upload Python Package 2 | 3 | on: 4 | release: 5 | types: [published] 6 | 7 | permissions: 8 | contents: read 9 | 10 | jobs: 11 | deploy: 12 | 13 | runs-on: ubuntu-latest 14 | 15 | steps: 16 | - uses: actions/checkout@v3 17 | - name: Set up Python 18 | uses: actions/setup-python@v3 19 | with: 20 | python-version: '3.9' 21 | 22 | - name: Install dependencies 23 | run: | 24 | python -m pip install --upgrade pip 25 | pip install build twine 26 | 27 | - name: Build package 28 | run: python -m build 29 | 30 | - name: Publish package 31 | run: twine upload -u __token__ -p ${{ secrets.PYPIKEY }} dist/* 32 | 33 | - name: Clear dist 34 | run: rm -rf dist 35 | 36 | - name: Build colab ready package 37 | run: python -m build 38 | env: 39 | COLAB_PACKAGE_RELEASE: true 40 | 41 | - name: Publish colab package 42 | run: twine upload -u __token__ -p ${{ secrets.PYPIKEY2 }} dist/* 43 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /.idea/ 2 | __pycache__ 3 | /tests/.pytest_cache/ 4 | /tests/tests_temp/ 5 | /.pytest_cache/ 6 | /carvekit.egg-info/ 7 | venv -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/psf/black 3 | rev: 22.10.0 4 | hooks: 5 | - id: black 6 | language_version: python3.10 7 | - repo: https://gitlab.com/pycqa/flake8 8 | rev: 3.9.2 9 | hooks: 10 | - id: flake8 -------------------------------------------------------------------------------- /Dockerfile.cpu: -------------------------------------------------------------------------------- 1 | FROM python:3.10.4 as builder 2 | WORKDIR /app 3 | 4 | # Models download stage 5 | RUN pip3 install --no-cache-dir tqdm==4.66.1 requests==2.31.0 6 | RUN mkdir -p ./carvekit/utils/ 7 | RUN mkdir -p ./carvekit/ml/files 8 | COPY ./carvekit/__init__.py ./carvekit/__init__.py 9 | RUN touch ./carvekit/ml/__init__.py 10 | RUN touch ./carvekit/utils/__init__.py 11 | COPY ./carvekit/utils/download_models.py ./carvekit/utils/download_models.py 12 | COPY ./carvekit/ml/files/__init__.py ./carvekit/ml/files/__init__.py 13 | COPY ./carvekit/ml/files/models_loc.py ./carvekit/ml/files/models_loc.py 14 | RUN python3 -c "from carvekit.ml.files.models_loc import download_all; download_all();" 15 | RUN rm -rf ./carvekit 16 | 17 | FROM python:3.10.4 18 | WORKDIR /app 19 | 20 | RUN apt-get update && apt-get -y install libgl1 # Install cv2 dep. 21 | COPY --from=builder /root/.cache/carvekit /root/.cache/carvekit 22 | 23 | # Install requirements 24 | COPY requirements.txt ./ 25 | RUN pip3 install --no-cache-dir -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cpu 26 | COPY requirements_test.txt ./ 27 | RUN pip3 install --no-cache-dir -r requirements_test.txt 28 | 29 | # Update code 30 | COPY ./ ./ 31 | 32 | # Install to site-packages to make possible run tests and process images by cli 33 | RUN pip3 install -e ./ 34 | 35 | ENV CARVEKIT_PORT '5000' 36 | ENV CARVEKIT_HOST '0.0.0.0' 37 | ENV CARVEKIT_SEGMENTATION_NETWORK 'tracer_b7' 38 | ENV CARVEKIT_PREPROCESSING_METHOD 'none' 39 | ENV CARVEKIT_POSTPROCESSING_METHOD 'fba' 40 | ENV CARVEKIT_DEVICE 'cpu' 41 | ENV CARVEKIT_BATCH_SIZE_SEG '5' 42 | ENV CARVEKIT_BATCH_SIZE_MATTING '1' 43 | ENV CARVEKIT_SEG_MASK_SIZE '640' 44 | ENV CARVEKIT_MATTING_MASK_SIZE '2048' 45 | ENV CARVEKIT_AUTH_ENABLE '1' 46 | ENV CARVEKIT_FP16 '0' 47 | ENV CARVEKIT_TRIMAP_PROB_THRESHOLD=231 48 | ENV CARVEKIT_TRIMAP_DILATION=30 49 | ENV CARVEKIT_TRIMAP_EROSION=5 50 | 51 | # Tokens will be generated automatically every time the container is restarted if ENV is not set. 52 | 53 | # ENV CARVEKIT_ADMIN_TOKEN 'admin_token' # Do not use this env when creating an image as it is not safe. 54 | # ENV CARVEKIT_ALLOWED_TOKENS 'test_token1,test_token2' # Do not use this env when creating an image as it is not safe. 55 | 56 | EXPOSE 5000 57 | 58 | CMD ["/bin/sh", "-c", "uvicorn carvekit.web.app:app --host $CARVEKIT_HOST --port $CARVEKIT_PORT"] 59 | -------------------------------------------------------------------------------- /Dockerfile.cuda: -------------------------------------------------------------------------------- 1 | FROM python:3.10.4 as builder 2 | WORKDIR /app 3 | 4 | # Models download stage 5 | RUN pip3 install --no-cache-dir tqdm==4.66.1 requests==2.31.0 6 | RUN mkdir -p ./carvekit/utils/ 7 | RUN mkdir -p ./carvekit/ml/files 8 | COPY ./carvekit/__init__.py ./carvekit/__init__.py 9 | RUN touch ./carvekit/ml/__init__.py 10 | RUN touch ./carvekit/utils/__init__.py 11 | COPY ./carvekit/utils/download_models.py ./carvekit/utils/download_models.py 12 | COPY ./carvekit/ml/files/__init__.py ./carvekit/ml/files/__init__.py 13 | COPY ./carvekit/ml/files/models_loc.py ./carvekit/ml/files/models_loc.py 14 | RUN python3 -c "from carvekit.ml.files.models_loc import download_all; download_all();" 15 | RUN rm -rf ./carvekit 16 | 17 | FROM pytorch/pytorch:2.2.2-cuda12.1-cudnn8-runtime 18 | WORKDIR /app 19 | 20 | RUN apt-get update && apt-get -y install libgl1 libglib2.0-0 # Install cv2 dep. 21 | COPY --from=builder /root/.cache/carvekit /root/.cache/carvekit 22 | 23 | # Install requirements 24 | COPY requirements.txt ./ 25 | RUN pip3 install --no-cache-dir -r requirements.txt 26 | COPY requirements_test.txt ./ 27 | RUN pip3 install --no-cache-dir -r requirements_test.txt 28 | 29 | # Update code 30 | COPY ./ ./ 31 | 32 | # Install to site-packages to make possible run tests and process images by cli 33 | RUN pip3 install -e ./ 34 | 35 | ENV CARVEKIT_PORT '5000' 36 | ENV CARVEKIT_HOST '0.0.0.0' 37 | ENV CARVEKIT_SEGMENTATION_NETWORK 'tracer_b7' 38 | ENV CARVEKIT_PREPROCESSING_METHOD 'none' 39 | ENV CARVEKIT_POSTPROCESSING_METHOD 'fba' 40 | ENV CARVEKIT_DEVICE 'cuda' 41 | ENV CARVEKIT_BATCH_SIZE_SEG '5' 42 | ENV CARVEKIT_BATCH_SIZE_MATTING '1' 43 | ENV CARVEKIT_SEG_MASK_SIZE '640' 44 | ENV CARVEKIT_MATTING_MASK_SIZE '2048' 45 | ENV CARVEKIT_AUTH_ENABLE '1' 46 | ENV CARVEKIT_FP16 '0' 47 | ENV CARVEKIT_TRIMAP_PROB_THRESHOLD=231 48 | ENV CARVEKIT_TRIMAP_DILATION=30 49 | ENV CARVEKIT_TRIMAP_EROSION=5 50 | 51 | # Tokens will be generated automatically every time the container is restarted if ENV is not set. 52 | 53 | # ENV CARVEKIT_ADMIN_TOKEN 'admin_token' # Do not use this env when creating an image as it is not safe. 54 | # ENV CARVEKIT_ALLOWED_TOKENS 'test_token1,test_token2' # Do not use this env when creating an image as it is not safe. 55 | 56 | EXPOSE 5000 57 | 58 | CMD ["/bin/sh", "-c", "uvicorn carvekit.web.app:app --host $CARVEKIT_HOST --port $CARVEKIT_PORT"] 59 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include requirements*.txt -------------------------------------------------------------------------------- /carvekit/__init__.py: -------------------------------------------------------------------------------- 1 | version = "4.1.2" 2 | -------------------------------------------------------------------------------- /carvekit/__main__.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import click 4 | import tqdm 5 | 6 | from carvekit.utils.image_utils import ALLOWED_SUFFIXES 7 | from carvekit.utils.pool_utils import batch_generator, thread_pool_processing 8 | from carvekit.web.schemas.config import MLConfig 9 | from carvekit.web.utils.init_utils import init_interface 10 | from carvekit.utils.fs_utils import save_file 11 | 12 | 13 | @click.command( 14 | "removebg", 15 | help="Performs background removal on specified photos using console interface.", 16 | ) 17 | @click.option("-i", required=True, type=str, help="Path to input file or dir") 18 | @click.option("-o", default="none", type=str, help="Path to output file or dir") 19 | @click.option("--pre", default="none", type=str, help="Preprocessing method") 20 | @click.option("--post", default="fba", type=str, help="Postprocessing method.") 21 | @click.option("--net", default="tracer_b7", type=str, help="Segmentation Network") 22 | @click.option( 23 | "--recursive", 24 | default=False, 25 | type=bool, 26 | help="Enables recursive search for images in a folder", 27 | ) 28 | @click.option( 29 | "--batch_size", 30 | default=10, 31 | type=int, 32 | help="Batch Size for list of images to be loaded to RAM", 33 | ) 34 | @click.option( 35 | "--batch_size_seg", 36 | default=5, 37 | type=int, 38 | help="Batch size for list of images to be processed by segmentation " "network", 39 | ) 40 | @click.option( 41 | "--batch_size_mat", 42 | default=1, 43 | type=int, 44 | help="Batch size for list of images to be processed by matting " "network", 45 | ) 46 | @click.option( 47 | "--seg_mask_size", 48 | default=640, 49 | type=int, 50 | help="The size of the input image for the segmentation neural network.", 51 | ) 52 | @click.option( 53 | "--matting_mask_size", 54 | default=2048, 55 | type=int, 56 | help="The size of the input image for the matting neural network.", 57 | ) 58 | @click.option( 59 | "--trimap_dilation", 60 | default=30, 61 | type=int, 62 | help="The size of the offset radius from the object mask in " 63 | "pixels when forming an unknown area", 64 | ) 65 | @click.option( 66 | "--trimap_erosion", 67 | default=5, 68 | type=int, 69 | help="The number of iterations of erosion that the object's " 70 | "mask will be subjected to before forming an unknown area", 71 | ) 72 | @click.option( 73 | "--trimap_prob_threshold", 74 | default=231, 75 | type=int, 76 | help="Probability threshold at which the prob_filter " 77 | "and prob_as_unknown_area operations will be " 78 | "applied", 79 | ) 80 | @click.option("--device", default="cpu", type=str, help="Processing Device.") 81 | @click.option( 82 | "--fp16", default=False, type=bool, help="Enables mixed precision processing." 83 | ) 84 | def removebg( 85 | i: str, 86 | o: str, 87 | pre: str, 88 | post: str, 89 | net: str, 90 | recursive: bool, 91 | batch_size: int, 92 | batch_size_seg: int, 93 | batch_size_mat: int, 94 | seg_mask_size: int, 95 | matting_mask_size: int, 96 | device: str, 97 | fp16: bool, 98 | trimap_dilation: int, 99 | trimap_erosion: int, 100 | trimap_prob_threshold: int, 101 | ): 102 | out_path = Path(o) 103 | input_path = Path(i) 104 | if input_path.is_dir(): 105 | if recursive: 106 | all_images = input_path.rglob("*.*") 107 | else: 108 | all_images = input_path.glob("*.*") 109 | all_images = [ 110 | i 111 | for i in all_images 112 | if i.suffix.lower() in ALLOWED_SUFFIXES and "_bg_removed" not in i.name 113 | ] 114 | else: 115 | all_images = [input_path] 116 | 117 | interface_config = MLConfig( 118 | segmentation_network=net, 119 | preprocessing_method=pre, 120 | postprocessing_method=post, 121 | device=device, 122 | batch_size_seg=batch_size_seg, 123 | batch_size_matting=batch_size_mat, 124 | seg_mask_size=seg_mask_size, 125 | matting_mask_size=matting_mask_size, 126 | fp16=fp16, 127 | trimap_dilation=trimap_dilation, 128 | trimap_erosion=trimap_erosion, 129 | trimap_prob_threshold=trimap_prob_threshold, 130 | ) 131 | 132 | interface = init_interface(interface_config) 133 | 134 | for image_batch in tqdm.tqdm( 135 | batch_generator(all_images, n=batch_size), 136 | total=int(len(all_images) / batch_size), 137 | desc="Removing background", 138 | unit=" image batch", 139 | colour="blue", 140 | ): 141 | images_without_background = interface(image_batch) # Remove background 142 | thread_pool_processing( 143 | lambda x: save_file(out_path, image_batch[x], images_without_background[x]), 144 | range((len(image_batch))), 145 | ) # Drop images to fs 146 | 147 | 148 | if __name__ == "__main__": 149 | removebg() 150 | -------------------------------------------------------------------------------- /carvekit/api/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OPHoperHPO/image-background-remove-tool/f141a311af67fb1da64269c508a6d1f786420801/carvekit/api/__init__.py -------------------------------------------------------------------------------- /carvekit/api/high.py: -------------------------------------------------------------------------------- 1 | """ 2 | Source url: https://github.com/OPHoperHPO/image-background-remove-tool 3 | Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. 4 | License: Apache License 2.0 5 | """ 6 | import warnings 7 | 8 | from carvekit.api.interface import Interface 9 | from carvekit.ml.wrap.fba_matting import FBAMatting 10 | from carvekit.ml.wrap.tracer_b7 import TracerUniversalB7 11 | from carvekit.ml.wrap.u2net import U2NET 12 | from carvekit.pipelines.postprocessing import MattingMethod 13 | from carvekit.trimap.generator import TrimapGenerator 14 | 15 | 16 | class HiInterface(Interface): 17 | def __init__( 18 | self, 19 | object_type: str = "object", 20 | batch_size_seg=2, 21 | batch_size_matting=1, 22 | device="cpu", 23 | seg_mask_size=640, 24 | matting_mask_size=2048, 25 | trimap_prob_threshold=231, 26 | trimap_dilation=30, 27 | trimap_erosion_iters=5, 28 | fp16=False, 29 | ): 30 | """ 31 | Initializes High Level interface. 32 | 33 | Args: 34 | object_type: Interest object type. Can be "object" or "hairs-like". 35 | matting_mask_size: The size of the input image for the matting neural network. 36 | seg_mask_size: The size of the input image for the segmentation neural network. 37 | batch_size_seg: Number of images processed per one segmentation neural network call. 38 | batch_size_matting: Number of images processed per one matting neural network call. 39 | device: Processing device 40 | fp16: Use half precision. Reduce memory usage and increase speed. Experimental support 41 | trimap_prob_threshold: Probability threshold at which the prob_filter and prob_as_unknown_area operations will be applied 42 | trimap_dilation: The size of the offset radius from the object mask in pixels when forming an unknown area 43 | trimap_erosion_iters: The number of iterations of erosion that the object's mask will be subjected to before forming an unknown area 44 | 45 | Notes: 46 | 1. Changing seg_mask_size may cause an out-of-memory error if the value is too large, and it may also 47 | result in reduced precision. I do not recommend changing this value. You can change matting_mask_size in 48 | range from (1024 to 4096) to improve object edge refining quality, but it will cause extra large RAM and 49 | video memory consume. Also, you can change batch size to accelerate background removal, but it also causes 50 | extra large video memory consume, if value is too big. 51 | 52 | 2. Changing trimap_prob_threshold, trimap_kernel_size, trimap_erosion_iters may improve object edge 53 | refining quality, 54 | """ 55 | if object_type == "object": 56 | self.u2net = TracerUniversalB7( 57 | device=device, 58 | batch_size=batch_size_seg, 59 | input_image_size=seg_mask_size, 60 | fp16=fp16, 61 | ) 62 | elif object_type == "hairs-like": 63 | self.u2net = U2NET( 64 | device=device, 65 | batch_size=batch_size_seg, 66 | input_image_size=seg_mask_size, 67 | fp16=fp16, 68 | ) 69 | else: 70 | warnings.warn( 71 | f"Unknown object type: {object_type}. Using default object type: object" 72 | ) 73 | self.u2net = TracerUniversalB7( 74 | device=device, 75 | batch_size=batch_size_seg, 76 | input_image_size=seg_mask_size, 77 | fp16=fp16, 78 | ) 79 | 80 | self.fba = FBAMatting( 81 | batch_size=batch_size_matting, 82 | device=device, 83 | input_tensor_size=matting_mask_size, 84 | fp16=fp16, 85 | ) 86 | self.trimap_generator = TrimapGenerator( 87 | prob_threshold=trimap_prob_threshold, 88 | kernel_size=trimap_dilation, 89 | erosion_iters=trimap_erosion_iters, 90 | ) 91 | super(HiInterface, self).__init__( 92 | pre_pipe=None, 93 | seg_pipe=self.u2net, 94 | post_pipe=MattingMethod( 95 | matting_module=self.fba, 96 | trimap_generator=self.trimap_generator, 97 | device=device, 98 | ), 99 | device=device, 100 | ) 101 | -------------------------------------------------------------------------------- /carvekit/api/interface.py: -------------------------------------------------------------------------------- 1 | """ 2 | Source url: https://github.com/OPHoperHPO/image-background-remove-tool 3 | Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. 4 | License: Apache License 2.0 5 | """ 6 | from pathlib import Path 7 | from typing import Union, List, Optional 8 | 9 | from PIL import Image 10 | 11 | from carvekit.ml.wrap.basnet import BASNET 12 | from carvekit.ml.wrap.deeplab_v3 import DeepLabV3 13 | from carvekit.ml.wrap.u2net import U2NET 14 | from carvekit.ml.wrap.tracer_b7 import TracerUniversalB7 15 | from carvekit.pipelines.preprocessing import PreprocessingStub 16 | from carvekit.pipelines.postprocessing import MattingMethod 17 | from carvekit.utils.image_utils import load_image 18 | from carvekit.utils.mask_utils import apply_mask 19 | from carvekit.utils.pool_utils import thread_pool_processing 20 | 21 | 22 | class Interface: 23 | def __init__( 24 | self, 25 | seg_pipe: Union[U2NET, BASNET, DeepLabV3, TracerUniversalB7], 26 | pre_pipe: Optional[Union[PreprocessingStub]] = None, 27 | post_pipe: Optional[Union[MattingMethod]] = None, 28 | device="cpu", 29 | ): 30 | """ 31 | Initializes an object for interacting with pipelines and other components of the CarveKit framework. 32 | 33 | Args: 34 | pre_pipe: Initialized pre-processing pipeline object 35 | seg_pipe: Initialized segmentation network object 36 | post_pipe: Initialized postprocessing pipeline object 37 | device: The processing device that will be used to apply the masks to the images. 38 | """ 39 | self.device = device 40 | self.preprocessing_pipeline = pre_pipe 41 | self.segmentation_pipeline = seg_pipe 42 | self.postprocessing_pipeline = post_pipe 43 | 44 | def __call__( 45 | self, images: List[Union[str, Path, Image.Image]] 46 | ) -> List[Image.Image]: 47 | """ 48 | Removes the background from the specified images. 49 | 50 | Args: 51 | images: list of input images 52 | 53 | Returns: 54 | List of images without background as PIL.Image.Image instances 55 | """ 56 | images = thread_pool_processing(load_image, images) 57 | if self.preprocessing_pipeline is not None: 58 | masks: List[Image.Image] = self.preprocessing_pipeline( 59 | interface=self, images=images 60 | ) 61 | else: 62 | masks: List[Image.Image] = self.segmentation_pipeline(images=images) 63 | 64 | if self.postprocessing_pipeline is not None: 65 | images: List[Image.Image] = self.postprocessing_pipeline( 66 | images=images, masks=masks 67 | ) 68 | else: 69 | images = list( 70 | map( 71 | lambda x: apply_mask( 72 | image=images[x], mask=masks[x], device=self.device 73 | ), 74 | range(len(images)), 75 | ) 76 | ) 77 | return images 78 | -------------------------------------------------------------------------------- /carvekit/ml/__init__.py: -------------------------------------------------------------------------------- 1 | from carvekit.utils.models_utils import fix_seed, suppress_warnings 2 | 3 | fix_seed() 4 | suppress_warnings() 5 | -------------------------------------------------------------------------------- /carvekit/ml/arch/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OPHoperHPO/image-background-remove-tool/f141a311af67fb1da64269c508a6d1f786420801/carvekit/ml/arch/__init__.py -------------------------------------------------------------------------------- /carvekit/ml/arch/basnet/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OPHoperHPO/image-background-remove-tool/f141a311af67fb1da64269c508a6d1f786420801/carvekit/ml/arch/basnet/__init__.py -------------------------------------------------------------------------------- /carvekit/ml/arch/fba_matting/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OPHoperHPO/image-background-remove-tool/f141a311af67fb1da64269c508a6d1f786420801/carvekit/ml/arch/fba_matting/__init__.py -------------------------------------------------------------------------------- /carvekit/ml/arch/fba_matting/layers_WS.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified by Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. 3 | Source url: https://github.com/MarcoForte/FBA_Matting 4 | License: MIT License 5 | """ 6 | import torch 7 | import torch.nn as nn 8 | from torch.nn import functional as F 9 | 10 | 11 | class Conv2d(nn.Conv2d): 12 | def __init__( 13 | self, 14 | in_channels, 15 | out_channels, 16 | kernel_size, 17 | stride=1, 18 | padding=0, 19 | dilation=1, 20 | groups=1, 21 | bias=True, 22 | ): 23 | super(Conv2d, self).__init__( 24 | in_channels, 25 | out_channels, 26 | kernel_size, 27 | stride, 28 | padding, 29 | dilation, 30 | groups, 31 | bias, 32 | ) 33 | 34 | def forward(self, x): 35 | # return super(Conv2d, self).forward(x) 36 | weight = self.weight 37 | weight_mean = ( 38 | weight.mean(dim=1, keepdim=True) 39 | .mean(dim=2, keepdim=True) 40 | .mean(dim=3, keepdim=True) 41 | ) 42 | weight = weight - weight_mean 43 | # std = (weight).view(weight.size(0), -1).std(dim=1).view(-1, 1, 1, 1) + 1e-5 44 | std = ( 45 | torch.sqrt(torch.var(weight.view(weight.size(0), -1), dim=1) + 1e-12).view( 46 | -1, 1, 1, 1 47 | ) 48 | + 1e-5 49 | ) 50 | weight = weight / std.expand_as(weight) 51 | return F.conv2d( 52 | x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups 53 | ) 54 | 55 | 56 | def BatchNorm2d(num_features): 57 | return nn.GroupNorm(num_channels=num_features, num_groups=32) 58 | -------------------------------------------------------------------------------- /carvekit/ml/arch/fba_matting/resnet_GN_WS.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified by Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. 3 | Source url: https://github.com/MarcoForte/FBA_Matting 4 | License: MIT License 5 | """ 6 | import torch.nn as nn 7 | import carvekit.ml.arch.fba_matting.layers_WS as L 8 | 9 | __all__ = ["ResNet", "l_resnet50"] 10 | 11 | 12 | def conv3x3(in_planes, out_planes, stride=1): 13 | """3x3 convolution with padding""" 14 | return L.Conv2d( 15 | in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False 16 | ) 17 | 18 | 19 | def conv1x1(in_planes, out_planes, stride=1): 20 | """1x1 convolution""" 21 | return L.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 22 | 23 | 24 | class BasicBlock(nn.Module): 25 | expansion = 1 26 | 27 | def __init__(self, inplanes, planes, stride=1, downsample=None): 28 | super(BasicBlock, self).__init__() 29 | self.conv1 = conv3x3(inplanes, planes, stride) 30 | self.bn1 = L.BatchNorm2d(planes) 31 | self.relu = nn.ReLU(inplace=True) 32 | self.conv2 = conv3x3(planes, planes) 33 | self.bn2 = L.BatchNorm2d(planes) 34 | self.downsample = downsample 35 | self.stride = stride 36 | 37 | def forward(self, x): 38 | identity = x 39 | 40 | out = self.conv1(x) 41 | out = self.bn1(out) 42 | out = self.relu(out) 43 | 44 | out = self.conv2(out) 45 | out = self.bn2(out) 46 | 47 | if self.downsample is not None: 48 | identity = self.downsample(x) 49 | 50 | out += identity 51 | out = self.relu(out) 52 | 53 | return out 54 | 55 | 56 | class Bottleneck(nn.Module): 57 | expansion = 4 58 | 59 | def __init__(self, inplanes, planes, stride=1, downsample=None): 60 | super(Bottleneck, self).__init__() 61 | self.conv1 = conv1x1(inplanes, planes) 62 | self.bn1 = L.BatchNorm2d(planes) 63 | self.conv2 = conv3x3(planes, planes, stride) 64 | self.bn2 = L.BatchNorm2d(planes) 65 | self.conv3 = conv1x1(planes, planes * self.expansion) 66 | self.bn3 = L.BatchNorm2d(planes * self.expansion) 67 | self.relu = nn.ReLU(inplace=True) 68 | self.downsample = downsample 69 | self.stride = stride 70 | 71 | def forward(self, x): 72 | identity = x 73 | 74 | out = self.conv1(x) 75 | out = self.bn1(out) 76 | out = self.relu(out) 77 | 78 | out = self.conv2(out) 79 | out = self.bn2(out) 80 | out = self.relu(out) 81 | 82 | out = self.conv3(out) 83 | out = self.bn3(out) 84 | 85 | if self.downsample is not None: 86 | identity = self.downsample(x) 87 | 88 | out += identity 89 | out = self.relu(out) 90 | 91 | return out 92 | 93 | 94 | class ResNet(nn.Module): 95 | def __init__(self, block, layers, num_classes=1000): 96 | super(ResNet, self).__init__() 97 | self.inplanes = 64 98 | self.conv1 = L.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 99 | self.bn1 = L.BatchNorm2d(64) 100 | self.relu = nn.ReLU(inplace=True) 101 | self.maxpool = nn.MaxPool2d( 102 | kernel_size=3, stride=2, padding=1, return_indices=True 103 | ) 104 | self.layer1 = self._make_layer(block, 64, layers[0]) 105 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 106 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 107 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 108 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 109 | self.fc = nn.Linear(512 * block.expansion, num_classes) 110 | 111 | def _make_layer(self, block, planes, blocks, stride=1): 112 | downsample = None 113 | if stride != 1 or self.inplanes != planes * block.expansion: 114 | downsample = nn.Sequential( 115 | conv1x1(self.inplanes, planes * block.expansion, stride), 116 | L.BatchNorm2d(planes * block.expansion), 117 | ) 118 | 119 | layers = [] 120 | layers.append(block(self.inplanes, planes, stride, downsample)) 121 | self.inplanes = planes * block.expansion 122 | for _ in range(1, blocks): 123 | layers.append(block(self.inplanes, planes)) 124 | 125 | return nn.Sequential(*layers) 126 | 127 | def forward(self, x): 128 | x = self.conv1(x) 129 | x = self.bn1(x) 130 | x = self.relu(x) 131 | x = self.maxpool(x) 132 | 133 | x = self.layer1(x) 134 | x = self.layer2(x) 135 | x = self.layer3(x) 136 | x = self.layer4(x) 137 | 138 | x = self.avgpool(x) 139 | x = x.view(x.size(0), -1) 140 | x = self.fc(x) 141 | 142 | return x 143 | 144 | 145 | def l_resnet50(pretrained=False, **kwargs): 146 | """Constructs a ResNet-50 model. 147 | Args: 148 | pretrained (bool): If True, returns a model pre-trained on ImageNet 149 | """ 150 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 151 | return model 152 | -------------------------------------------------------------------------------- /carvekit/ml/arch/fba_matting/resnet_bn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified by Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. 3 | Source url: https://github.com/MarcoForte/FBA_Matting 4 | License: MIT License 5 | """ 6 | import torch.nn as nn 7 | import math 8 | from torch.nn import BatchNorm2d 9 | 10 | __all__ = ["ResNet"] 11 | 12 | 13 | def conv3x3(in_planes, out_planes, stride=1): 14 | "3x3 convolution with padding" 15 | return nn.Conv2d( 16 | in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False 17 | ) 18 | 19 | 20 | class BasicBlock(nn.Module): 21 | expansion = 1 22 | 23 | def __init__(self, inplanes, planes, stride=1, downsample=None): 24 | super(BasicBlock, self).__init__() 25 | self.conv1 = conv3x3(inplanes, planes, stride) 26 | self.bn1 = BatchNorm2d(planes) 27 | self.relu = nn.ReLU(inplace=True) 28 | self.conv2 = conv3x3(planes, planes) 29 | self.bn2 = BatchNorm2d(planes) 30 | self.downsample = downsample 31 | self.stride = stride 32 | 33 | def forward(self, x): 34 | residual = x 35 | 36 | out = self.conv1(x) 37 | out = self.bn1(out) 38 | out = self.relu(out) 39 | 40 | out = self.conv2(out) 41 | out = self.bn2(out) 42 | 43 | if self.downsample is not None: 44 | residual = self.downsample(x) 45 | 46 | out += residual 47 | out = self.relu(out) 48 | 49 | return out 50 | 51 | 52 | class Bottleneck(nn.Module): 53 | expansion = 4 54 | 55 | def __init__(self, inplanes, planes, stride=1, downsample=None): 56 | super(Bottleneck, self).__init__() 57 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 58 | self.bn1 = BatchNorm2d(planes) 59 | self.conv2 = nn.Conv2d( 60 | planes, planes, kernel_size=3, stride=stride, padding=1, bias=False 61 | ) 62 | self.bn2 = BatchNorm2d(planes, momentum=0.01) 63 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 64 | self.bn3 = BatchNorm2d(planes * 4) 65 | self.relu = nn.ReLU(inplace=True) 66 | self.downsample = downsample 67 | self.stride = stride 68 | 69 | def forward(self, x): 70 | residual = x 71 | 72 | out = self.conv1(x) 73 | out = self.bn1(out) 74 | out = self.relu(out) 75 | 76 | out = self.conv2(out) 77 | out = self.bn2(out) 78 | out = self.relu(out) 79 | 80 | out = self.conv3(out) 81 | out = self.bn3(out) 82 | 83 | if self.downsample is not None: 84 | residual = self.downsample(x) 85 | 86 | out += residual 87 | out = self.relu(out) 88 | 89 | return out 90 | 91 | 92 | class ResNet(nn.Module): 93 | def __init__(self, block, layers, num_classes=1000): 94 | self.inplanes = 128 95 | super(ResNet, self).__init__() 96 | self.conv1 = conv3x3(3, 64, stride=2) 97 | self.bn1 = BatchNorm2d(64) 98 | self.relu1 = nn.ReLU(inplace=True) 99 | self.conv2 = conv3x3(64, 64) 100 | self.bn2 = BatchNorm2d(64) 101 | self.relu2 = nn.ReLU(inplace=True) 102 | self.conv3 = conv3x3(64, 128) 103 | self.bn3 = BatchNorm2d(128) 104 | self.relu3 = nn.ReLU(inplace=True) 105 | self.maxpool = nn.MaxPool2d( 106 | kernel_size=3, stride=2, padding=1, return_indices=True 107 | ) 108 | 109 | self.layer1 = self._make_layer(block, 64, layers[0]) 110 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 111 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 112 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 113 | self.avgpool = nn.AvgPool2d(7, stride=1) 114 | self.fc = nn.Linear(512 * block.expansion, num_classes) 115 | 116 | for m in self.modules(): 117 | if isinstance(m, nn.Conv2d): 118 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 119 | m.weight.data.normal_(0, math.sqrt(2.0 / n)) 120 | elif isinstance(m, BatchNorm2d): 121 | m.weight.data.fill_(1) 122 | m.bias.data.zero_() 123 | 124 | def _make_layer(self, block, planes, blocks, stride=1): 125 | downsample = None 126 | if stride != 1 or self.inplanes != planes * block.expansion: 127 | downsample = nn.Sequential( 128 | nn.Conv2d( 129 | self.inplanes, 130 | planes * block.expansion, 131 | kernel_size=1, 132 | stride=stride, 133 | bias=False, 134 | ), 135 | BatchNorm2d(planes * block.expansion), 136 | ) 137 | 138 | layers = [] 139 | layers.append(block(self.inplanes, planes, stride, downsample)) 140 | self.inplanes = planes * block.expansion 141 | for i in range(1, blocks): 142 | layers.append(block(self.inplanes, planes)) 143 | 144 | return nn.Sequential(*layers) 145 | 146 | def forward(self, x): 147 | x = self.relu1(self.bn1(self.conv1(x))) 148 | x = self.relu2(self.bn2(self.conv2(x))) 149 | x = self.relu3(self.bn3(self.conv3(x))) 150 | x, indices = self.maxpool(x) 151 | 152 | x = self.layer1(x) 153 | x = self.layer2(x) 154 | x = self.layer3(x) 155 | x = self.layer4(x) 156 | 157 | x = self.avgpool(x) 158 | x = x.view(x.size(0), -1) 159 | x = self.fc(x) 160 | return x 161 | 162 | 163 | def l_resnet50(): 164 | """Constructs a ResNet-50 model. 165 | Args: 166 | pretrained (bool): If True, returns a model pre-trained on ImageNet 167 | """ 168 | model = ResNet(Bottleneck, [3, 4, 6, 3]) 169 | return model 170 | -------------------------------------------------------------------------------- /carvekit/ml/arch/fba_matting/transforms.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified by Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. 3 | Source url: https://github.com/MarcoForte/FBA_Matting 4 | License: MIT License 5 | """ 6 | import cv2 7 | import numpy as np 8 | 9 | group_norm_std = [0.229, 0.224, 0.225] 10 | group_norm_mean = [0.485, 0.456, 0.406] 11 | 12 | 13 | def dt(a): 14 | return cv2.distanceTransform((a * 255).astype(np.uint8), cv2.DIST_L2, 0) 15 | 16 | 17 | def trimap_transform(trimap): 18 | h, w = trimap.shape[0], trimap.shape[1] 19 | 20 | clicks = np.zeros((h, w, 6)) 21 | for k in range(2): 22 | if np.count_nonzero(trimap[:, :, k]) > 0: 23 | dt_mask = -dt(1 - trimap[:, :, k]) ** 2 24 | L = 320 25 | clicks[:, :, 3 * k] = np.exp(dt_mask / (2 * ((0.02 * L) ** 2))) 26 | clicks[:, :, 3 * k + 1] = np.exp(dt_mask / (2 * ((0.08 * L) ** 2))) 27 | clicks[:, :, 3 * k + 2] = np.exp(dt_mask / (2 * ((0.16 * L) ** 2))) 28 | 29 | return clicks 30 | 31 | 32 | def groupnorm_normalise_image(img, format="nhwc"): 33 | """ 34 | Accept rgb in range 0,1 35 | """ 36 | if format == "nhwc": 37 | for i in range(3): 38 | img[..., i] = (img[..., i] - group_norm_mean[i]) / group_norm_std[i] 39 | else: 40 | for i in range(3): 41 | img[..., i, :, :] = ( 42 | img[..., i, :, :] - group_norm_mean[i] 43 | ) / group_norm_std[i] 44 | 45 | return img 46 | -------------------------------------------------------------------------------- /carvekit/ml/arch/tracerb7/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OPHoperHPO/image-background-remove-tool/f141a311af67fb1da64269c508a6d1f786420801/carvekit/ml/arch/tracerb7/__init__.py -------------------------------------------------------------------------------- /carvekit/ml/arch/tracerb7/conv_modules.py: -------------------------------------------------------------------------------- 1 | """ 2 | Source url: https://github.com/Karel911/TRACER 3 | Author: Min Seok Lee and Wooseok Shin 4 | License: Apache License 2.0 5 | """ 6 | import torch.nn as nn 7 | 8 | 9 | class BasicConv2d(nn.Module): 10 | def __init__( 11 | self, 12 | in_channel, 13 | out_channel, 14 | kernel_size, 15 | stride=(1, 1), 16 | padding=(0, 0), 17 | dilation=(1, 1), 18 | ): 19 | super(BasicConv2d, self).__init__() 20 | self.conv = nn.Conv2d( 21 | in_channel, 22 | out_channel, 23 | kernel_size=kernel_size, 24 | stride=stride, 25 | padding=padding, 26 | dilation=dilation, 27 | bias=False, 28 | ) 29 | self.bn = nn.BatchNorm2d(out_channel) 30 | self.selu = nn.SELU() 31 | 32 | def forward(self, x): 33 | x = self.conv(x) 34 | x = self.bn(x) 35 | x = self.selu(x) 36 | 37 | return x 38 | 39 | 40 | class DWConv(nn.Module): 41 | def __init__(self, in_channel, out_channel, kernel, dilation, padding): 42 | super(DWConv, self).__init__() 43 | self.out_channel = out_channel 44 | self.DWConv = nn.Conv2d( 45 | in_channel, 46 | out_channel, 47 | kernel_size=kernel, 48 | padding=padding, 49 | groups=in_channel, 50 | dilation=dilation, 51 | bias=False, 52 | ) 53 | self.bn = nn.BatchNorm2d(out_channel) 54 | self.selu = nn.SELU() 55 | 56 | def forward(self, x): 57 | x = self.DWConv(x) 58 | out = self.selu(self.bn(x)) 59 | 60 | return out 61 | 62 | 63 | class DWSConv(nn.Module): 64 | def __init__(self, in_channel, out_channel, kernel, padding, kernels_per_layer): 65 | super(DWSConv, self).__init__() 66 | self.out_channel = out_channel 67 | self.DWConv = nn.Conv2d( 68 | in_channel, 69 | in_channel * kernels_per_layer, 70 | kernel_size=kernel, 71 | padding=padding, 72 | groups=in_channel, 73 | bias=False, 74 | ) 75 | self.bn = nn.BatchNorm2d(in_channel * kernels_per_layer) 76 | self.selu = nn.SELU() 77 | self.PWConv = nn.Conv2d( 78 | in_channel * kernels_per_layer, out_channel, kernel_size=1, bias=False 79 | ) 80 | self.bn2 = nn.BatchNorm2d(out_channel) 81 | 82 | def forward(self, x): 83 | x = self.DWConv(x) 84 | x = self.selu(self.bn(x)) 85 | out = self.PWConv(x) 86 | out = self.selu(self.bn2(out)) 87 | 88 | return out 89 | -------------------------------------------------------------------------------- /carvekit/ml/arch/tracerb7/tracer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Source url: https://github.com/Karel911/TRACER 3 | Author: Min Seok Lee and Wooseok Shin 4 | Modified by Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. 5 | License: Apache License 2.0 6 | Changes: 7 | - Refactored code 8 | - Removed unused code 9 | - Added comments 10 | """ 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | from typing import List, Optional, Tuple 16 | 17 | from torch import Tensor 18 | 19 | from carvekit.ml.arch.tracerb7.efficientnet import EfficientEncoderB7 20 | from carvekit.ml.arch.tracerb7.att_modules import ( 21 | RFB_Block, 22 | aggregation, 23 | ObjectAttention, 24 | ) 25 | 26 | 27 | class TracerDecoder(nn.Module): 28 | """Tracer Decoder""" 29 | 30 | def __init__( 31 | self, 32 | encoder: EfficientEncoderB7, 33 | features_channels: Optional[List[int]] = None, 34 | rfb_channel: Optional[List[int]] = None, 35 | ): 36 | """ 37 | Initialize the tracer decoder. 38 | 39 | Args: 40 | encoder: The encoder to use. 41 | features_channels: The channels of the backbone features at different stages. default: [48, 80, 224, 640] 42 | rfb_channel: The channels of the RFB features. default: [32, 64, 128] 43 | """ 44 | super().__init__() 45 | if rfb_channel is None: 46 | rfb_channel = [32, 64, 128] 47 | if features_channels is None: 48 | features_channels = [48, 80, 224, 640] 49 | self.encoder = encoder 50 | self.features_channels = features_channels 51 | 52 | # Receptive Field Blocks 53 | features_channels = rfb_channel 54 | self.rfb2 = RFB_Block(self.features_channels[1], features_channels[0]) 55 | self.rfb3 = RFB_Block(self.features_channels[2], features_channels[1]) 56 | self.rfb4 = RFB_Block(self.features_channels[3], features_channels[2]) 57 | 58 | # Multi-level aggregation 59 | self.agg = aggregation(features_channels) 60 | 61 | # Object Attention 62 | self.ObjectAttention2 = ObjectAttention( 63 | channel=self.features_channels[1], kernel_size=3 64 | ) 65 | self.ObjectAttention1 = ObjectAttention( 66 | channel=self.features_channels[0], kernel_size=3 67 | ) 68 | 69 | def forward(self, inputs: torch.Tensor) -> Tensor: 70 | """ 71 | Forward pass of the tracer decoder. 72 | 73 | Args: 74 | inputs: Preprocessed images. 75 | 76 | Returns: 77 | Tensors of segmentation masks and mask of object edges. 78 | """ 79 | features = self.encoder(inputs) 80 | x3_rfb = self.rfb2(features[1]) 81 | x4_rfb = self.rfb3(features[2]) 82 | x5_rfb = self.rfb4(features[3]) 83 | 84 | D_0 = self.agg(x5_rfb, x4_rfb, x3_rfb) 85 | 86 | ds_map0 = F.interpolate(D_0, scale_factor=8, mode="bilinear") 87 | 88 | D_1 = self.ObjectAttention2(D_0, features[1]) 89 | ds_map1 = F.interpolate(D_1, scale_factor=8, mode="bilinear") 90 | 91 | ds_map = F.interpolate(D_1, scale_factor=2, mode="bilinear") 92 | D_2 = self.ObjectAttention1(ds_map, features[0]) 93 | ds_map2 = F.interpolate(D_2, scale_factor=4, mode="bilinear") 94 | 95 | final_map = (ds_map2 + ds_map1 + ds_map0) / 3 96 | 97 | return torch.sigmoid(final_map) 98 | -------------------------------------------------------------------------------- /carvekit/ml/arch/u2net/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OPHoperHPO/image-background-remove-tool/f141a311af67fb1da64269c508a6d1f786420801/carvekit/ml/arch/u2net/__init__.py -------------------------------------------------------------------------------- /carvekit/ml/files/__init__.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | carvekit_dir = Path.home().joinpath(".cache/carvekit") 4 | 5 | carvekit_dir.mkdir(parents=True, exist_ok=True) 6 | 7 | checkpoints_dir = carvekit_dir.joinpath("checkpoints") 8 | -------------------------------------------------------------------------------- /carvekit/ml/files/models_loc.py: -------------------------------------------------------------------------------- 1 | """ 2 | Source url: https://github.com/OPHoperHPO/image-background-remove-tool 3 | Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. 4 | License: Apache License 2.0 5 | """ 6 | import pathlib 7 | from carvekit.ml.files import checkpoints_dir 8 | from carvekit.utils.download_models import downloader 9 | 10 | 11 | def u2net_full_pretrained() -> pathlib.Path: 12 | """Returns u2net pretrained model location 13 | 14 | Returns: 15 | pathlib.Path to model location 16 | """ 17 | return downloader("u2net.pth") 18 | 19 | 20 | def basnet_pretrained() -> pathlib.Path: 21 | """Returns basnet pretrained model location 22 | 23 | Returns: 24 | pathlib.Path to model location 25 | """ 26 | return downloader("basnet.pth") 27 | 28 | 29 | def deeplab_pretrained() -> pathlib.Path: 30 | """Returns basnet pretrained model location 31 | 32 | Returns: 33 | pathlib.Path to model location 34 | """ 35 | return downloader("deeplab.pth") 36 | 37 | 38 | def fba_pretrained() -> pathlib.Path: 39 | """Returns basnet pretrained model location 40 | 41 | Returns: 42 | pathlib.Path to model location 43 | """ 44 | return downloader("fba_matting.pth") 45 | 46 | 47 | def tracer_b7_pretrained() -> pathlib.Path: 48 | """Returns TRACER with EfficientNet v1 b7 encoder pretrained model location 49 | 50 | Returns: 51 | pathlib.Path to model location 52 | """ 53 | return downloader("tracer_b7.pth") 54 | 55 | 56 | def tracer_hair_pretrained() -> pathlib.Path: 57 | """Returns TRACER with EfficientNet v1 b7 encoder model for hair segmentation location 58 | 59 | Returns: 60 | pathlib.Path to model location 61 | """ 62 | return downloader("tracer_hair.pth") 63 | 64 | 65 | def download_all(): 66 | u2net_full_pretrained() 67 | fba_pretrained() 68 | deeplab_pretrained() 69 | basnet_pretrained() 70 | tracer_b7_pretrained() 71 | -------------------------------------------------------------------------------- /carvekit/ml/wrap/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OPHoperHPO/image-background-remove-tool/f141a311af67fb1da64269c508a6d1f786420801/carvekit/ml/wrap/__init__.py -------------------------------------------------------------------------------- /carvekit/ml/wrap/basnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Source url: https://github.com/OPHoperHPO/image-background-remove-tool 3 | Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. 4 | License: Apache License 2.0 5 | """ 6 | import pathlib 7 | from typing import Union, List 8 | 9 | import PIL 10 | import numpy as np 11 | import torch 12 | from PIL import Image 13 | 14 | from carvekit.ml.arch.basnet.basnet import BASNet 15 | from carvekit.ml.files.models_loc import basnet_pretrained 16 | from carvekit.utils.image_utils import convert_image, load_image 17 | from carvekit.utils.pool_utils import batch_generator, thread_pool_processing 18 | 19 | __all__ = ["BASNET"] 20 | 21 | 22 | class BASNET(BASNet): 23 | """BASNet model interface""" 24 | 25 | def __init__( 26 | self, 27 | device="cpu", 28 | input_image_size: Union[List[int], int] = 320, 29 | batch_size: int = 10, 30 | load_pretrained: bool = True, 31 | fp16: bool = False, 32 | ): 33 | """ 34 | Initialize the BASNET model 35 | 36 | Args: 37 | device: processing device 38 | input_image_size: input image size 39 | batch_size: the number of images that the neural network processes in one run 40 | load_pretrained: loading pretrained model 41 | fp16: use fp16 precision // not supported at this moment 42 | 43 | """ 44 | super(BASNET, self).__init__(n_channels=3, n_classes=1) 45 | self.device = device 46 | self.batch_size = batch_size 47 | if isinstance(input_image_size, list): 48 | self.input_image_size = input_image_size[:2] 49 | else: 50 | self.input_image_size = (input_image_size, input_image_size) 51 | self.to(device) 52 | if load_pretrained: 53 | self.load_state_dict( 54 | torch.load(basnet_pretrained(), map_location=self.device) 55 | ) 56 | self.eval() 57 | 58 | def data_preprocessing(self, data: PIL.Image.Image) -> torch.Tensor: 59 | """ 60 | Transform input image to suitable data format for neural network 61 | 62 | Args: 63 | data: input image 64 | 65 | Returns: 66 | input for neural network 67 | 68 | """ 69 | resized = data.resize(self.input_image_size) 70 | # noinspection PyTypeChecker 71 | resized_arr = np.array(resized, dtype=np.float64) 72 | temp_image = np.zeros((resized_arr.shape[0], resized_arr.shape[1], 3)) 73 | if np.max(resized_arr) != 0: 74 | resized_arr /= np.max(resized_arr) 75 | temp_image[:, :, 0] = (resized_arr[:, :, 0] - 0.485) / 0.229 76 | temp_image[:, :, 1] = (resized_arr[:, :, 1] - 0.456) / 0.224 77 | temp_image[:, :, 2] = (resized_arr[:, :, 2] - 0.406) / 0.225 78 | temp_image = temp_image.transpose((2, 0, 1)) 79 | temp_image = np.expand_dims(temp_image, 0) 80 | return torch.from_numpy(temp_image).type(torch.FloatTensor) 81 | 82 | @staticmethod 83 | def data_postprocessing( 84 | data: torch.tensor, original_image: PIL.Image.Image 85 | ) -> PIL.Image.Image: 86 | """ 87 | Transforms output data from neural network to suitable data 88 | format for using with other components of this framework. 89 | 90 | Args: 91 | data: output data from neural network 92 | original_image: input image which was used for predicted data 93 | 94 | Returns: 95 | Segmentation mask as PIL Image instance 96 | 97 | """ 98 | data = data.unsqueeze(0) 99 | mask = data[:, 0, :, :] 100 | ma = torch.max(mask) # Normalizes prediction 101 | mi = torch.min(mask) 102 | predict = ((mask - mi) / (ma - mi)).squeeze() 103 | predict_np = predict.cpu().data.numpy() * 255 104 | mask = Image.fromarray(predict_np).convert("L") 105 | mask = mask.resize(original_image.size, resample=3) 106 | return mask 107 | 108 | def __call__( 109 | self, images: List[Union[str, pathlib.Path, PIL.Image.Image]] 110 | ) -> List[PIL.Image.Image]: 111 | """ 112 | Passes input images through neural network and returns segmentation masks as PIL.Image.Image instances 113 | 114 | Args: 115 | images: input images 116 | 117 | Returns: 118 | segmentation masks as for input images, as PIL.Image.Image instances 119 | 120 | """ 121 | collect_masks = [] 122 | for image_batch in batch_generator(images, self.batch_size): 123 | images = thread_pool_processing( 124 | lambda x: convert_image(load_image(x)), image_batch 125 | ) 126 | batches = torch.vstack( 127 | thread_pool_processing(self.data_preprocessing, images) 128 | ) 129 | with torch.no_grad(): 130 | batches = batches.to(self.device) 131 | masks, d2, d3, d4, d5, d6, d7, d8 = super(BASNET, self).__call__( 132 | batches 133 | ) 134 | masks_cpu = masks.cpu() 135 | del d2, d3, d4, d5, d6, d7, d8, batches, masks 136 | masks = thread_pool_processing( 137 | lambda x: self.data_postprocessing(masks_cpu[x], images[x]), 138 | range(len(images)), 139 | ) 140 | collect_masks += masks 141 | return collect_masks 142 | -------------------------------------------------------------------------------- /carvekit/ml/wrap/deeplab_v3.py: -------------------------------------------------------------------------------- 1 | """ 2 | Source url: https://github.com/OPHoperHPO/image-background-remove-tool 3 | Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. 4 | License: Apache License 2.0 5 | """ 6 | import pathlib 7 | from typing import List, Union 8 | 9 | import PIL.Image 10 | import torch 11 | from PIL import Image 12 | from torchvision import transforms 13 | from torchvision.models.segmentation import deeplabv3_resnet101 14 | from carvekit.ml.files.models_loc import deeplab_pretrained 15 | from carvekit.utils.image_utils import convert_image, load_image 16 | from carvekit.utils.models_utils import get_precision_autocast, cast_network 17 | from carvekit.utils.pool_utils import batch_generator, thread_pool_processing 18 | 19 | __all__ = ["DeepLabV3"] 20 | 21 | 22 | class DeepLabV3: 23 | def __init__( 24 | self, 25 | device="cpu", 26 | batch_size: int = 10, 27 | input_image_size: Union[List[int], int] = 1024, 28 | load_pretrained: bool = True, 29 | fp16: bool = False, 30 | ): 31 | """ 32 | Initialize the DeepLabV3 model 33 | 34 | Args: 35 | device: processing device 36 | input_image_size: input image size 37 | batch_size: the number of images that the neural network processes in one run 38 | load_pretrained: loading pretrained model 39 | fp16: use half precision 40 | 41 | """ 42 | self.device = device 43 | self.batch_size = batch_size 44 | self.network = deeplabv3_resnet101( 45 | pretrained=False, pretrained_backbone=False, aux_loss=True 46 | ) 47 | self.network.to(self.device) 48 | if load_pretrained: 49 | self.network.load_state_dict( 50 | torch.load(deeplab_pretrained(), map_location=self.device) 51 | ) 52 | if isinstance(input_image_size, list): 53 | self.input_image_size = input_image_size[:2] 54 | else: 55 | self.input_image_size = (input_image_size, input_image_size) 56 | self.network.eval() 57 | self.fp16 = fp16 58 | self.transform = transforms.Compose( 59 | [ 60 | transforms.ToTensor(), 61 | transforms.Normalize( 62 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 63 | ), 64 | ] 65 | ) 66 | 67 | def to(self, device: str): 68 | """ 69 | Moves neural network to specified processing device 70 | 71 | Args: 72 | device (:class:`torch.device`): the desired device. 73 | Returns: 74 | None 75 | 76 | """ 77 | self.network.to(device) 78 | 79 | def data_preprocessing(self, data: PIL.Image.Image) -> torch.Tensor: 80 | """ 81 | Transform input image to suitable data format for neural network 82 | 83 | Args: 84 | data: input image 85 | 86 | Returns: 87 | input for neural network 88 | 89 | """ 90 | copy = data.copy() 91 | copy.thumbnail(self.input_image_size, resample=3) 92 | return self.transform(copy) 93 | 94 | @staticmethod 95 | def data_postprocessing( 96 | data: torch.tensor, original_image: PIL.Image.Image 97 | ) -> PIL.Image.Image: 98 | """ 99 | Transforms output data from neural network to suitable data 100 | format for using with other components of this framework. 101 | 102 | Args: 103 | data: output data from neural network 104 | original_image: input image which was used for predicted data 105 | 106 | Returns: 107 | Segmentation mask as PIL Image instance 108 | 109 | """ 110 | return ( 111 | Image.fromarray(data.numpy() * 255).convert("L").resize(original_image.size) 112 | ) 113 | 114 | def __call__( 115 | self, images: List[Union[str, pathlib.Path, PIL.Image.Image]] 116 | ) -> List[PIL.Image.Image]: 117 | """ 118 | Passes input images though neural network and returns segmentation masks as PIL.Image.Image instances 119 | 120 | Args: 121 | images: input images 122 | 123 | Returns: 124 | segmentation masks as for input images, as PIL.Image.Image instances 125 | 126 | """ 127 | collect_masks = [] 128 | autocast, dtype = get_precision_autocast(device=self.device, fp16=self.fp16) 129 | with autocast: 130 | cast_network(self.network, dtype) 131 | for image_batch in batch_generator(images, self.batch_size): 132 | images = thread_pool_processing( 133 | lambda x: convert_image(load_image(x)), image_batch 134 | ) 135 | batches = thread_pool_processing(self.data_preprocessing, images) 136 | with torch.no_grad(): 137 | masks = [ 138 | self.network(i.to(self.device).unsqueeze(0))["out"][0] 139 | .argmax(0) 140 | .byte() 141 | .cpu() 142 | for i in batches 143 | ] 144 | del batches 145 | masks = thread_pool_processing( 146 | lambda x: self.data_postprocessing(masks[x], images[x]), 147 | range(len(images)), 148 | ) 149 | collect_masks += masks 150 | return collect_masks 151 | -------------------------------------------------------------------------------- /carvekit/ml/wrap/tracer_b7.py: -------------------------------------------------------------------------------- 1 | """ 2 | Source url: https://github.com/OPHoperHPO/image-background-remove-tool 3 | Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. 4 | License: Apache License 2.0 5 | """ 6 | import pathlib 7 | import warnings 8 | from typing import List, Union 9 | import PIL.Image 10 | import numpy as np 11 | import torch 12 | import torchvision.transforms as transforms 13 | from PIL import Image 14 | 15 | from carvekit.ml.arch.tracerb7.tracer import TracerDecoder 16 | from carvekit.ml.arch.tracerb7.efficientnet import EfficientEncoderB7 17 | from carvekit.ml.files.models_loc import tracer_b7_pretrained, tracer_hair_pretrained 18 | from carvekit.utils.models_utils import get_precision_autocast, cast_network 19 | from carvekit.utils.image_utils import load_image, convert_image 20 | from carvekit.utils.pool_utils import thread_pool_processing, batch_generator 21 | 22 | __all__ = ["TracerUniversalB7"] 23 | 24 | 25 | class TracerUniversalB7(TracerDecoder): 26 | """TRACER B7 model interface""" 27 | 28 | def __init__( 29 | self, 30 | device="cpu", 31 | input_image_size: Union[List[int], int] = 640, 32 | batch_size: int = 4, 33 | load_pretrained: bool = True, 34 | fp16: bool = False, 35 | model_path: Union[str, pathlib.Path] = None, 36 | ): 37 | """ 38 | Initialize the U2NET model 39 | 40 | Args: 41 | layers_cfg: neural network layers configuration 42 | device: processing device 43 | input_image_size: input image size 44 | batch_size: the number of images that the neural network processes in one run 45 | load_pretrained: loading pretrained model 46 | fp16: use fp16 precision 47 | 48 | """ 49 | if model_path is None: 50 | model_path = tracer_b7_pretrained() 51 | super(TracerUniversalB7, self).__init__( 52 | encoder=EfficientEncoderB7(), 53 | rfb_channel=[32, 64, 128], 54 | features_channels=[48, 80, 224, 640], 55 | ) 56 | 57 | self.fp16 = fp16 58 | self.device = device 59 | self.batch_size = batch_size 60 | if isinstance(input_image_size, list): 61 | self.input_image_size = input_image_size[:2] 62 | else: 63 | self.input_image_size = (input_image_size, input_image_size) 64 | 65 | self.transform = transforms.Compose( 66 | [ 67 | transforms.ToTensor(), 68 | transforms.Resize(self.input_image_size), 69 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 70 | ] 71 | ) 72 | self.to(device) 73 | if load_pretrained: 74 | # TODO remove edge detector from weights. It doesn't work well with this model! 75 | self.load_state_dict( 76 | torch.load(model_path, map_location=self.device), strict=False 77 | ) 78 | self.eval() 79 | 80 | def data_preprocessing(self, data: PIL.Image.Image) -> torch.FloatTensor: 81 | """ 82 | Transform input image to suitable data format for neural network 83 | 84 | Args: 85 | data: input image 86 | 87 | Returns: 88 | input for neural network 89 | 90 | """ 91 | 92 | return torch.unsqueeze(self.transform(data), 0).type(torch.FloatTensor) 93 | 94 | @staticmethod 95 | def data_postprocessing( 96 | data: torch.tensor, original_image: PIL.Image.Image 97 | ) -> PIL.Image.Image: 98 | """ 99 | Transforms output data from neural network to suitable data 100 | format for using with other components of this framework. 101 | 102 | Args: 103 | data: output data from neural network 104 | original_image: input image which was used for predicted data 105 | 106 | Returns: 107 | Segmentation mask as PIL Image instance 108 | 109 | """ 110 | output = (data.type(torch.FloatTensor).detach().cpu().numpy() * 255.0).astype( 111 | np.uint8 112 | ) 113 | output = output.squeeze(0) 114 | mask = Image.fromarray(output).convert("L") 115 | mask = mask.resize(original_image.size, resample=Image.BILINEAR) 116 | return mask 117 | 118 | def __call__( 119 | self, images: List[Union[str, pathlib.Path, PIL.Image.Image]] 120 | ) -> List[PIL.Image.Image]: 121 | """ 122 | Passes input images though neural network and returns segmentation masks as PIL.Image.Image instances 123 | 124 | Args: 125 | images: input images 126 | 127 | Returns: 128 | segmentation masks as for input images, as PIL.Image.Image instances 129 | 130 | """ 131 | collect_masks = [] 132 | autocast, dtype = get_precision_autocast(device=self.device, fp16=self.fp16) 133 | with autocast: 134 | cast_network(self, dtype) 135 | for image_batch in batch_generator(images, self.batch_size): 136 | images = thread_pool_processing( 137 | lambda x: convert_image(load_image(x)), image_batch 138 | ) 139 | batches = torch.vstack( 140 | thread_pool_processing(self.data_preprocessing, images) 141 | ) 142 | with torch.no_grad(): 143 | batches = batches.to(self.device) 144 | masks = super(TracerDecoder, self).__call__(batches) 145 | masks_cpu = masks.cpu() 146 | del batches, masks 147 | masks = thread_pool_processing( 148 | lambda x: self.data_postprocessing(masks_cpu[x], images[x]), 149 | range(len(images)), 150 | ) 151 | collect_masks += masks 152 | 153 | return collect_masks 154 | 155 | 156 | class TracerHair(TracerUniversalB7): 157 | """TRACER HAIR model interface""" 158 | 159 | def __init__( 160 | self, 161 | device="cpu", 162 | input_image_size: Union[List[int], int] = 640, 163 | batch_size: int = 4, 164 | load_pretrained: bool = True, 165 | fp16: bool = False, 166 | model_path: Union[str, pathlib.Path] = None, 167 | ): 168 | if model_path is None: 169 | model_path = tracer_hair_pretrained() 170 | warnings.warn("TracerHair has not public model yet. Don't use it!", UserWarning) 171 | super(TracerHair, self).__init__( 172 | device=device, 173 | input_image_size=input_image_size, 174 | batch_size=batch_size, 175 | load_pretrained=load_pretrained, 176 | fp16=fp16, 177 | model_path=model_path, 178 | ) 179 | -------------------------------------------------------------------------------- /carvekit/ml/wrap/u2net.py: -------------------------------------------------------------------------------- 1 | """ 2 | Source url: https://github.com/OPHoperHPO/image-background-remove-tool 3 | Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. 4 | License: Apache License 2.0 5 | """ 6 | import pathlib 7 | from typing import List, Union 8 | import PIL.Image 9 | import numpy as np 10 | import torch 11 | from PIL import Image 12 | 13 | from carvekit.ml.arch.u2net.u2net import U2NETArchitecture 14 | from carvekit.ml.files.models_loc import u2net_full_pretrained 15 | from carvekit.utils.image_utils import load_image, convert_image 16 | from carvekit.utils.pool_utils import thread_pool_processing, batch_generator 17 | 18 | __all__ = ["U2NET"] 19 | 20 | 21 | class U2NET(U2NETArchitecture): 22 | """U^2-Net model interface""" 23 | 24 | def __init__( 25 | self, 26 | layers_cfg="full", 27 | device="cpu", 28 | input_image_size: Union[List[int], int] = 320, 29 | batch_size: int = 10, 30 | load_pretrained: bool = True, 31 | fp16: bool = False, 32 | ): 33 | """ 34 | Initialize the U2NET model 35 | 36 | Args: 37 | layers_cfg: neural network layers configuration 38 | device: processing device 39 | input_image_size: input image size 40 | batch_size: the number of images that the neural network processes in one run 41 | load_pretrained: loading pretrained model 42 | fp16: use fp16 precision // not supported at this moment. 43 | 44 | """ 45 | super(U2NET, self).__init__(cfg_type=layers_cfg, out_ch=1) 46 | self.device = device 47 | self.batch_size = batch_size 48 | if isinstance(input_image_size, list): 49 | self.input_image_size = input_image_size[:2] 50 | else: 51 | self.input_image_size = (input_image_size, input_image_size) 52 | self.to(device) 53 | if load_pretrained: 54 | self.load_state_dict( 55 | torch.load(u2net_full_pretrained(), map_location=self.device) 56 | ) 57 | self.eval() 58 | 59 | def data_preprocessing(self, data: PIL.Image.Image) -> torch.FloatTensor: 60 | """ 61 | Transform input image to suitable data format for neural network 62 | 63 | Args: 64 | data: input image 65 | 66 | Returns: 67 | input for neural network 68 | 69 | """ 70 | resized = data.resize(self.input_image_size, resample=3) 71 | # noinspection PyTypeChecker 72 | resized_arr = np.array(resized, dtype=float) 73 | temp_image = np.zeros((resized_arr.shape[0], resized_arr.shape[1], 3)) 74 | if np.max(resized_arr) != 0: 75 | resized_arr /= np.max(resized_arr) 76 | temp_image[:, :, 0] = (resized_arr[:, :, 0] - 0.485) / 0.229 77 | temp_image[:, :, 1] = (resized_arr[:, :, 1] - 0.456) / 0.224 78 | temp_image[:, :, 2] = (resized_arr[:, :, 2] - 0.406) / 0.225 79 | temp_image = temp_image.transpose((2, 0, 1)) 80 | temp_image = np.expand_dims(temp_image, 0) 81 | return torch.from_numpy(temp_image).type(torch.FloatTensor) 82 | 83 | @staticmethod 84 | def data_postprocessing( 85 | data: torch.tensor, original_image: PIL.Image.Image 86 | ) -> PIL.Image.Image: 87 | """ 88 | Transforms output data from neural network to suitable data 89 | format for using with other components of this framework. 90 | 91 | Args: 92 | data: output data from neural network 93 | original_image: input image which was used for predicted data 94 | 95 | Returns: 96 | Segmentation mask as PIL Image instance 97 | 98 | """ 99 | data = data.unsqueeze(0) 100 | mask = data[:, 0, :, :] 101 | ma = torch.max(mask) # Normalizes prediction 102 | mi = torch.min(mask) 103 | predict = ((mask - mi) / (ma - mi)).squeeze() 104 | predict_np = predict.cpu().data.numpy() * 255 105 | mask = Image.fromarray(predict_np).convert("L") 106 | mask = mask.resize(original_image.size, resample=3) 107 | return mask 108 | 109 | def __call__( 110 | self, images: List[Union[str, pathlib.Path, PIL.Image.Image]] 111 | ) -> List[PIL.Image.Image]: 112 | """ 113 | Passes input images though neural network and returns segmentation masks as PIL.Image.Image instances 114 | 115 | Args: 116 | images: input images 117 | 118 | Returns: 119 | segmentation masks as for input images, as PIL.Image.Image instances 120 | 121 | """ 122 | collect_masks = [] 123 | for image_batch in batch_generator(images, self.batch_size): 124 | images = thread_pool_processing( 125 | lambda x: convert_image(load_image(x)), image_batch 126 | ) 127 | batches = torch.vstack( 128 | thread_pool_processing(self.data_preprocessing, images) 129 | ) 130 | with torch.no_grad(): 131 | batches = batches.to(self.device) 132 | masks, d2, d3, d4, d5, d6, d7 = super(U2NET, self).__call__(batches) 133 | masks_cpu = masks.cpu() 134 | del d2, d3, d4, d5, d6, d7, batches, masks 135 | masks = thread_pool_processing( 136 | lambda x: self.data_postprocessing(masks_cpu[x], images[x]), 137 | range(len(images)), 138 | ) 139 | collect_masks += masks 140 | return collect_masks 141 | -------------------------------------------------------------------------------- /carvekit/pipelines/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OPHoperHPO/image-background-remove-tool/f141a311af67fb1da64269c508a6d1f786420801/carvekit/pipelines/__init__.py -------------------------------------------------------------------------------- /carvekit/pipelines/postprocessing.py: -------------------------------------------------------------------------------- 1 | """ 2 | Source url: https://github.com/OPHoperHPO/image-background-remove-tool 3 | Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. 4 | License: Apache License 2.0 5 | """ 6 | from carvekit.ml.wrap.fba_matting import FBAMatting 7 | from typing import Union, List 8 | from PIL import Image 9 | from pathlib import Path 10 | from carvekit.trimap.cv_gen import CV2TrimapGenerator 11 | from carvekit.trimap.generator import TrimapGenerator 12 | from carvekit.utils.mask_utils import apply_mask 13 | from carvekit.utils.pool_utils import thread_pool_processing 14 | from carvekit.utils.image_utils import load_image, convert_image 15 | 16 | __all__ = ["MattingMethod"] 17 | 18 | 19 | class MattingMethod: 20 | """ 21 | Improving the edges of the object mask using neural networks for matting and algorithms for creating trimap. 22 | Neural network for matting performs accurate object edge detection by using a special map called trimap, 23 | with unknown area that we scan for boundary, already known general object area and the background.""" 24 | 25 | def __init__( 26 | self, 27 | matting_module: Union[FBAMatting], 28 | trimap_generator: Union[TrimapGenerator, CV2TrimapGenerator], 29 | device="cpu", 30 | ): 31 | """ 32 | Initializes Matting Method class. 33 | 34 | Args: 35 | matting_module: Initialized matting neural network class 36 | trimap_generator: Initialized trimap generator class 37 | device: Processing device used for applying mask to image 38 | """ 39 | self.device = device 40 | self.matting_module = matting_module 41 | self.trimap_generator = trimap_generator 42 | 43 | def __call__( 44 | self, 45 | images: List[Union[str, Path, Image.Image]], 46 | masks: List[Union[str, Path, Image.Image]], 47 | ): 48 | """ 49 | Passes data through apply_mask function 50 | 51 | Args: 52 | images: list of images 53 | masks: list pf masks 54 | 55 | Returns: 56 | list of images 57 | """ 58 | if len(images) != len(masks): 59 | raise ValueError("Images and Masks lists should have same length!") 60 | images = thread_pool_processing(lambda x: convert_image(load_image(x)), images) 61 | masks = thread_pool_processing( 62 | lambda x: convert_image(load_image(x), mode="L"), masks 63 | ) 64 | trimaps = thread_pool_processing( 65 | lambda x: self.trimap_generator(original_image=images[x], mask=masks[x]), 66 | range(len(images)), 67 | ) 68 | alpha = self.matting_module(images=images, trimaps=trimaps) 69 | return list( 70 | map( 71 | lambda x: apply_mask( 72 | image=images[x], mask=alpha[x], device=self.device 73 | ), 74 | range(len(images)), 75 | ) 76 | ) 77 | -------------------------------------------------------------------------------- /carvekit/pipelines/preprocessing.py: -------------------------------------------------------------------------------- 1 | """ 2 | Source url: https://github.com/OPHoperHPO/image-background-remove-tool 3 | Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. 4 | License: Apache License 2.0 5 | """ 6 | from pathlib import Path 7 | from typing import Union, List 8 | 9 | from PIL import Image 10 | 11 | __all__ = ["PreprocessingStub"] 12 | 13 | 14 | class PreprocessingStub: 15 | """Stub for future preprocessing methods""" 16 | 17 | def __call__(self, interface, images: List[Union[str, Path, Image.Image]]): 18 | """ 19 | Passes data though interface.segmentation_pipeline() method 20 | 21 | Args: 22 | interface: Interface instance 23 | images: list of images 24 | 25 | Returns: 26 | the result of passing data through segmentation_pipeline method of interface 27 | """ 28 | return interface.segmentation_pipeline(images=images) 29 | -------------------------------------------------------------------------------- /carvekit/trimap/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OPHoperHPO/image-background-remove-tool/f141a311af67fb1da64269c508a6d1f786420801/carvekit/trimap/__init__.py -------------------------------------------------------------------------------- /carvekit/trimap/add_ops.py: -------------------------------------------------------------------------------- 1 | """ 2 | Source url: https://github.com/OPHoperHPO/image-background-remove-tool 3 | Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. 4 | License: Apache License 2.0 5 | """ 6 | import cv2 7 | import numpy as np 8 | from PIL import Image 9 | 10 | 11 | def prob_filter(mask: Image.Image, prob_threshold=231) -> Image.Image: 12 | """ 13 | Applies a filter to the mask by the probability of locating an object in the object area. 14 | 15 | Args: 16 | prob_threshold: Threshold of probability for mark area as background. 17 | mask: Predicted object mask 18 | 19 | Raises: 20 | ValueError if mask or trimap has wrong color mode 21 | 22 | Returns: 23 | Generated trimap for image. 24 | """ 25 | if mask.mode != "L": 26 | raise ValueError("Input mask has wrong color mode.") 27 | # noinspection PyTypeChecker 28 | mask_array = np.array(mask) 29 | mask_array[mask_array > prob_threshold] = 255 # Probability filter for mask 30 | mask_array[mask_array <= prob_threshold] = 0 31 | return Image.fromarray(mask_array).convert("L") 32 | 33 | 34 | def prob_as_unknown_area( 35 | trimap: Image.Image, mask: Image.Image, prob_threshold=255 36 | ) -> Image.Image: 37 | """ 38 | Marks any uncertainty in the seg mask as an unknown region. 39 | 40 | Args: 41 | prob_threshold: Threshold of probability for mark area as unknown. 42 | trimap: Generated trimap. 43 | mask: Predicted object mask 44 | 45 | Raises: 46 | ValueError if mask or trimap has wrong color mode 47 | 48 | Returns: 49 | Generated trimap for image. 50 | """ 51 | if mask.mode != "L" or trimap.mode != "L": 52 | raise ValueError("Input mask has wrong color mode.") 53 | # noinspection PyTypeChecker 54 | mask_array = np.array(mask) 55 | # noinspection PyTypeChecker 56 | trimap_array = np.array(trimap) 57 | trimap_array[np.logical_and(mask_array <= prob_threshold, mask_array > 0)] = 127 58 | return Image.fromarray(trimap_array).convert("L") 59 | 60 | 61 | def post_erosion(trimap: Image.Image, erosion_iters=1) -> Image.Image: 62 | """ 63 | Performs erosion on the mask and marks the resulting area as an unknown region. 64 | 65 | Args: 66 | erosion_iters: The number of iterations of erosion that 67 | the object's mask will be subjected to before forming an unknown area 68 | trimap: Generated trimap. 69 | mask: Predicted object mask 70 | 71 | Returns: 72 | Generated trimap for image. 73 | """ 74 | if trimap.mode != "L": 75 | raise ValueError("Input mask has wrong color mode.") 76 | # noinspection PyTypeChecker 77 | trimap_array = np.array(trimap) 78 | if erosion_iters > 0: 79 | without_unknown_area = trimap_array.copy() 80 | without_unknown_area[without_unknown_area == 127] = 0 81 | 82 | erosion_kernel = np.ones((3, 3), np.uint8) 83 | erode = cv2.erode( 84 | without_unknown_area, erosion_kernel, iterations=erosion_iters 85 | ) 86 | erode = np.where(erode == 0, 0, without_unknown_area) 87 | trimap_array[np.logical_and(erode == 0, without_unknown_area > 0)] = 127 88 | erode = trimap_array.copy() 89 | else: 90 | erode = trimap_array.copy() 91 | return Image.fromarray(erode).convert("L") 92 | -------------------------------------------------------------------------------- /carvekit/trimap/cv_gen.py: -------------------------------------------------------------------------------- 1 | """ 2 | Source url: https://github.com/OPHoperHPO/image-background-remove-tool 3 | Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. 4 | License: Apache License 2.0 5 | """ 6 | import PIL.Image 7 | import cv2 8 | import numpy as np 9 | 10 | 11 | class CV2TrimapGenerator: 12 | def __init__(self, kernel_size: int = 30, erosion_iters: int = 1): 13 | """ 14 | Initialize a new CV2TrimapGenerator instance 15 | 16 | Args: 17 | kernel_size: The size of the offset from the object mask 18 | in pixels when an unknown area is detected in the trimap 19 | erosion_iters: The number of iterations of erosion that 20 | the object's mask will be subjected to before forming an unknown area 21 | """ 22 | self.kernel_size = kernel_size 23 | self.erosion_iters = erosion_iters 24 | 25 | def __call__( 26 | self, original_image: PIL.Image.Image, mask: PIL.Image.Image 27 | ) -> PIL.Image.Image: 28 | """ 29 | Generates trimap based on predicted object mask to refine object mask borders. 30 | Based on cv2 erosion algorithm. 31 | 32 | Args: 33 | original_image: Original image 34 | mask: Predicted object mask 35 | 36 | Returns: 37 | Generated trimap for image. 38 | """ 39 | if mask.mode != "L": 40 | raise ValueError("Input mask has wrong color mode.") 41 | if mask.size != original_image.size: 42 | raise ValueError("Sizes of input image and predicted mask doesn't equal") 43 | # noinspection PyTypeChecker 44 | mask_array = np.array(mask) 45 | pixels = 2 * self.kernel_size + 1 46 | kernel = np.ones((pixels, pixels), np.uint8) 47 | 48 | if self.erosion_iters > 0: 49 | erosion_kernel = np.ones((3, 3), np.uint8) 50 | erode = cv2.erode(mask_array, erosion_kernel, iterations=self.erosion_iters) 51 | erode = np.where(erode == 0, 0, mask_array) 52 | else: 53 | erode = mask_array.copy() 54 | 55 | dilation = cv2.dilate(erode, kernel, iterations=1) 56 | 57 | dilation = np.where(dilation == 255, 127, dilation) # WHITE to GRAY 58 | trimap = np.where(erode > 127, 200, dilation) # mark the tumor inside GRAY 59 | 60 | trimap = np.where(trimap < 127, 0, trimap) # Embelishment 61 | trimap = np.where(trimap > 200, 0, trimap) # Embelishment 62 | trimap = np.where(trimap == 200, 255, trimap) # GRAY to WHITE 63 | 64 | return PIL.Image.fromarray(trimap).convert("L") 65 | -------------------------------------------------------------------------------- /carvekit/trimap/generator.py: -------------------------------------------------------------------------------- 1 | """ 2 | Source url: https://github.com/OPHoperHPO/image-background-remove-tool 3 | Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. 4 | License: Apache License 2.0 5 | """ 6 | from PIL import Image 7 | from carvekit.trimap.cv_gen import CV2TrimapGenerator 8 | from carvekit.trimap.add_ops import prob_filter, prob_as_unknown_area, post_erosion 9 | 10 | 11 | class TrimapGenerator(CV2TrimapGenerator): 12 | def __init__( 13 | self, prob_threshold: int = 231, kernel_size: int = 30, erosion_iters: int = 5 14 | ): 15 | """ 16 | Initialize a TrimapGenerator instance 17 | 18 | Args: 19 | prob_threshold: Probability threshold at which the 20 | prob_filter and prob_as_unknown_area operations will be applied 21 | kernel_size: The size of the offset from the object mask 22 | in pixels when an unknown area is detected in the trimap 23 | erosion_iters: The number of iterations of erosion that 24 | the object's mask will be subjected to before forming an unknown area 25 | """ 26 | super().__init__(kernel_size, erosion_iters=0) 27 | self.prob_threshold = prob_threshold 28 | self.__erosion_iters = erosion_iters 29 | 30 | def __call__(self, original_image: Image.Image, mask: Image.Image) -> Image.Image: 31 | """ 32 | Generates trimap based on predicted object mask to refine object mask borders. 33 | Based on cv2 erosion algorithm and additional prob. filters. 34 | Args: 35 | original_image: Original image 36 | mask: Predicted object mask 37 | 38 | Returns: 39 | Generated trimap for image. 40 | """ 41 | filter_mask = prob_filter(mask=mask, prob_threshold=self.prob_threshold) 42 | trimap = super(TrimapGenerator, self).__call__(original_image, filter_mask) 43 | new_trimap = prob_as_unknown_area( 44 | trimap=trimap, mask=mask, prob_threshold=self.prob_threshold 45 | ) 46 | new_trimap = post_erosion(new_trimap, self.__erosion_iters) 47 | return new_trimap 48 | -------------------------------------------------------------------------------- /carvekit/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OPHoperHPO/image-background-remove-tool/f141a311af67fb1da64269c508a6d1f786420801/carvekit/utils/__init__.py -------------------------------------------------------------------------------- /carvekit/utils/fs_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Source url: https://github.com/OPHoperHPO/image-background-remove-tool 3 | Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. 4 | License: Apache License 2.0 5 | """ 6 | from pathlib import Path 7 | from PIL import Image 8 | import warnings 9 | from typing import Optional 10 | 11 | 12 | def save_file(output: Optional[Path], input_path: Path, image: Image.Image): 13 | """ 14 | Saves an image to the file system 15 | 16 | Args: 17 | output: Output path [dir or end file] 18 | input_path: Input path of the image 19 | image: Image to be saved. 20 | """ 21 | if isinstance(output, Path) and str(output) != "none": 22 | if output.is_dir() and output.exists(): 23 | image.save(output.joinpath(input_path.with_suffix(".png").name)) 24 | elif output.suffix != "": 25 | if output.suffix != ".png": 26 | warnings.warn( 27 | f"Only export with .png extension is supported! Your {output.suffix}" 28 | f" extension will be ignored and replaced with .png!" 29 | ) 30 | image.save(output.with_suffix(".png")) 31 | else: 32 | raise ValueError("Wrong output path!") 33 | elif output is None or str(output) == "none": 34 | image.save( 35 | input_path.with_name( 36 | input_path.stem.split(".")[0] + "_bg_removed" 37 | ).with_suffix(".png") 38 | ) 39 | -------------------------------------------------------------------------------- /carvekit/utils/image_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Source url: https://github.com/OPHoperHPO/image-background-remove-tool 3 | Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. 4 | License: Apache License 2.0 5 | """ 6 | 7 | import pathlib 8 | from typing import Union, Any, Tuple 9 | 10 | import PIL.Image 11 | import numpy as np 12 | import torch 13 | 14 | ALLOWED_SUFFIXES = [".jpg", ".jpeg", ".bmp", ".png", ".webp"] 15 | 16 | 17 | def to_tensor(x: Any) -> torch.Tensor: 18 | """ 19 | Returns a PIL.Image.Image as torch tensor without swap tensor dims. 20 | 21 | Args: 22 | x: PIL.Image.Image instance 23 | 24 | Returns: 25 | torch.Tensor instance 26 | """ 27 | return torch.tensor(np.array(x, copy=True)) 28 | 29 | 30 | def load_image(file: Union[str, pathlib.Path, PIL.Image.Image]) -> PIL.Image.Image: 31 | """Returns a PIL.Image.Image class by string path or pathlib path or PIL.Image.Image instance 32 | 33 | Args: 34 | file: File path or PIL.Image.Image instance 35 | 36 | Returns: 37 | PIL.Image.Image instance 38 | 39 | Raises: 40 | ValueError: If file not exists or file is directory or file isn't an image or file is not correct PIL Image 41 | 42 | """ 43 | if isinstance(file, str) and is_image_valid(pathlib.Path(file)): 44 | return PIL.Image.open(file) 45 | elif isinstance(file, PIL.Image.Image): 46 | return file 47 | elif isinstance(file, pathlib.Path) and is_image_valid(file): 48 | return PIL.Image.open(str(file)) 49 | else: 50 | raise ValueError("Unknown input file type") 51 | 52 | 53 | def convert_image(image: PIL.Image.Image, mode="RGB") -> PIL.Image.Image: 54 | """Performs image conversion to correct color mode 55 | 56 | Args: 57 | image: PIL.Image.Image instance 58 | mode: Colort Mode to convert 59 | 60 | Returns: 61 | PIL.Image.Image instance 62 | 63 | Raises: 64 | ValueError: If image hasn't convertable color mode, or it is too small 65 | """ 66 | if is_image_valid(image): 67 | return image.convert(mode) 68 | 69 | 70 | def is_image_valid(image: Union[pathlib.Path, PIL.Image.Image]) -> bool: 71 | """This function performs image validation. 72 | 73 | Args: 74 | image: Path to the image or PIL.Image.Image instance being checked. 75 | 76 | Returns: 77 | True if image is valid 78 | 79 | Raises: 80 | ValueError: If file not a valid image path or image hasn't convertable color mode, or it is too small 81 | 82 | """ 83 | if isinstance(image, pathlib.Path): 84 | if not image.exists(): 85 | raise ValueError("File is not exists") 86 | elif image.is_dir(): 87 | raise ValueError("File is a directory") 88 | elif image.suffix.lower() not in ALLOWED_SUFFIXES: 89 | raise ValueError( 90 | f"Unsupported image format. Supported file formats: {', '.join(ALLOWED_SUFFIXES)}" 91 | ) 92 | elif isinstance(image, PIL.Image.Image): 93 | if not (image.size[0] > 32 and image.size[1] > 32): 94 | raise ValueError("Image should be bigger then (32x32) pixels.") 95 | elif image.mode not in ["RGB", "RGBA", "L"]: 96 | raise ValueError("Wrong image color mode.") 97 | else: 98 | raise ValueError("Unknown input file type") 99 | return True 100 | 101 | 102 | def transparency_paste( 103 | bg_img: PIL.Image.Image, fg_img: PIL.Image.Image, box=(0, 0) 104 | ) -> PIL.Image.Image: 105 | """ 106 | Inserts an image into another image while maintaining transparency. 107 | 108 | Args: 109 | bg_img: background image 110 | fg_img: foreground image 111 | box: place to paste 112 | 113 | Returns: 114 | Background image with pasted foreground image at point or in the specified box 115 | """ 116 | fg_img_trans = PIL.Image.new("RGBA", bg_img.size) 117 | fg_img_trans.paste(fg_img, box, mask=fg_img) 118 | new_img = PIL.Image.alpha_composite(bg_img, fg_img_trans) 119 | return new_img 120 | 121 | 122 | def add_margin( 123 | pil_img: PIL.Image.Image, 124 | top: int, 125 | right: int, 126 | bottom: int, 127 | left: int, 128 | color: Tuple[int, int, int, int], 129 | ) -> PIL.Image.Image: 130 | """ 131 | Adds margin to the image. 132 | 133 | Args: 134 | pil_img: Image that needed to add margin. 135 | top: pixels count at top side 136 | right: pixels count at right side 137 | bottom: pixels count at bottom side 138 | left: pixels count at left side 139 | color: color of margin 140 | 141 | Returns: 142 | Image with margin. 143 | """ 144 | width, height = pil_img.size 145 | new_width = width + right + left 146 | new_height = height + top + bottom 147 | # noinspection PyTypeChecker 148 | result = PIL.Image.new(pil_img.mode, (new_width, new_height), color) 149 | result.paste(pil_img, (left, top)) 150 | return result 151 | -------------------------------------------------------------------------------- /carvekit/utils/mask_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Source url: https://github.com/OPHoperHPO/image-background-remove-tool 3 | Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. 4 | License: Apache License 2.0 5 | """ 6 | import PIL.Image 7 | import torch 8 | from carvekit.utils.image_utils import to_tensor 9 | 10 | 11 | def composite( 12 | foreground: PIL.Image.Image, 13 | background: PIL.Image.Image, 14 | alpha: PIL.Image.Image, 15 | device="cpu", 16 | ): 17 | """ 18 | Composites foreground with background by following 19 | https://pymatting.github.io/intro.html#alpha-matting math formula. 20 | 21 | Args: 22 | device: Processing device 23 | foreground: Image that will be pasted to background image with following alpha mask. 24 | background: Background image 25 | alpha: Alpha Image 26 | 27 | Returns: 28 | Composited image as PIL.Image instance. 29 | """ 30 | 31 | foreground = foreground.convert("RGBA") 32 | background = background.convert("RGBA") 33 | alpha_rgba = alpha.convert("RGBA") 34 | alpha_l = alpha.convert("L") 35 | 36 | fg = to_tensor(foreground).to(device) 37 | alpha_rgba = to_tensor(alpha_rgba).to(device) 38 | alpha_l = to_tensor(alpha_l).to(device) 39 | bg = to_tensor(background).to(device) 40 | 41 | alpha_l = alpha_l / 255 42 | alpha_rgba = alpha_rgba / 255 43 | 44 | bg = torch.where(torch.logical_not(alpha_rgba >= 1), bg, fg) 45 | bg[:, :, 0] = alpha_l[:, :] * fg[:, :, 0] + (1 - alpha_l[:, :]) * bg[:, :, 0] 46 | bg[:, :, 1] = alpha_l[:, :] * fg[:, :, 1] + (1 - alpha_l[:, :]) * bg[:, :, 1] 47 | bg[:, :, 2] = alpha_l[:, :] * fg[:, :, 2] + (1 - alpha_l[:, :]) * bg[:, :, 2] 48 | bg[:, :, 3] = alpha_l[:, :] * 255 49 | 50 | del alpha_l, alpha_rgba, fg 51 | return PIL.Image.fromarray(bg.cpu().numpy()).convert("RGBA") 52 | 53 | 54 | def apply_mask( 55 | image: PIL.Image.Image, mask: PIL.Image.Image, device="cpu" 56 | ) -> PIL.Image.Image: 57 | """ 58 | Applies mask to foreground. 59 | 60 | Args: 61 | device: Processing device. 62 | image: Image with background. 63 | mask: Alpha Channel mask for this image. 64 | 65 | Returns: 66 | Image without background, where mask was black. 67 | """ 68 | background = PIL.Image.new("RGBA", image.size, color=(130, 130, 130, 0)) 69 | return composite(image, background, mask, device=device).convert("RGBA") 70 | 71 | 72 | def extract_alpha_channel(image: PIL.Image.Image) -> PIL.Image.Image: 73 | """ 74 | Extracts alpha channel from the RGBA image. 75 | 76 | Args: 77 | image: RGBA PIL image 78 | 79 | Returns: 80 | RGBA alpha channel image 81 | """ 82 | alpha = image.split()[-1] 83 | bg = PIL.Image.new("RGBA", image.size, (0, 0, 0, 255)) 84 | bg.paste(alpha, mask=alpha) 85 | return bg.convert("RGBA") 86 | -------------------------------------------------------------------------------- /carvekit/utils/models_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Source url: https://github.com/OPHoperHPO/image-background-remove-tool 3 | Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. 4 | License: Apache License 2.0 5 | """ 6 | 7 | import random 8 | import warnings 9 | from typing import Union, Tuple, Any 10 | 11 | import torch 12 | from torch import autocast 13 | 14 | 15 | class EmptyAutocast(object): 16 | """ 17 | Empty class for disable any autocasting. 18 | """ 19 | 20 | def __enter__(self): 21 | return None 22 | 23 | def __exit__(self, exc_type, exc_val, exc_tb): 24 | return 25 | 26 | def __call__(self, func): 27 | return 28 | 29 | 30 | def get_precision_autocast( 31 | device="cpu", fp16=True, override_dtype=None 32 | ) -> Union[ 33 | Tuple[EmptyAutocast, Union[torch.dtype, Any]], 34 | Tuple[autocast, Union[torch.dtype, Any]], 35 | ]: 36 | """ 37 | Returns precision and autocast settings for given device and fp16 settings. 38 | Args: 39 | device: Device to get precision and autocast settings for. 40 | fp16: Whether to use fp16 precision. 41 | override_dtype: Override dtype for autocast. 42 | 43 | Returns: 44 | Autocast object, dtype 45 | """ 46 | dtype = torch.float32 47 | cache_enabled = None 48 | 49 | if device == "cpu" and fp16: 50 | warnings.warn('FP16 is not supported on CPU. Using FP32 instead.') 51 | dtype = torch.float32 52 | 53 | # TODO: Implement BFP16 on CPU. There are unexpected slowdowns on cpu on a clean environment. 54 | # warnings.warn( 55 | # "Accuracy BFP16 has experimental support on the CPU. " 56 | # "This may result in an unexpected reduction in quality." 57 | # ) 58 | # dtype = ( 59 | # torch.bfloat16 60 | # ) # Using bfloat16 for CPU, since autocast is not supported for float16 61 | 62 | 63 | if "cuda" in device and fp16: 64 | dtype = torch.float16 65 | cache_enabled = True 66 | 67 | if override_dtype is not None: 68 | dtype = override_dtype 69 | 70 | if dtype == torch.float32 and device == "cpu": 71 | return EmptyAutocast(), dtype 72 | 73 | return ( 74 | torch.autocast( 75 | device_type=device, dtype=dtype, enabled=True, cache_enabled=cache_enabled 76 | ), 77 | dtype, 78 | ) 79 | 80 | 81 | def cast_network(network: torch.nn.Module, dtype: torch.dtype): 82 | """Cast network to given dtype 83 | 84 | Args: 85 | network: Network to be casted 86 | dtype: Dtype to cast network to 87 | """ 88 | if dtype == torch.float16: 89 | network.half() 90 | elif dtype == torch.bfloat16: 91 | network.bfloat16() 92 | elif dtype == torch.float32: 93 | network.float() 94 | else: 95 | raise ValueError(f"Unknown dtype {dtype}") 96 | 97 | 98 | def fix_seed(seed=42): 99 | """Sets fixed random seed 100 | 101 | Args: 102 | seed: Random seed to be set 103 | """ 104 | random.seed(seed) 105 | torch.manual_seed(seed) 106 | if torch.cuda.is_available(): 107 | torch.cuda.manual_seed(seed) 108 | torch.cuda.manual_seed_all(seed) 109 | # noinspection PyUnresolvedReferences 110 | torch.backends.cudnn.deterministic = True 111 | # noinspection PyUnresolvedReferences 112 | torch.backends.cudnn.benchmark = False 113 | return True 114 | 115 | 116 | def suppress_warnings(): 117 | # Suppress PyTorch 1.11.0 warning associated with changing order of args in nn.MaxPool2d layer, 118 | # since source code is not affected by this issue and there aren't any other correct way to hide this message. 119 | warnings.filterwarnings( 120 | "ignore", 121 | category=UserWarning, 122 | message="Note that order of the arguments: ceil_mode and " 123 | "return_indices will changeto match the args list " 124 | "in nn.MaxPool2d in a future release.", 125 | module="torch", 126 | ) 127 | -------------------------------------------------------------------------------- /carvekit/utils/pool_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Source url: https://github.com/OPHoperHPO/image-background-remove-tool 3 | Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. 4 | License: Apache License 2.0 5 | """ 6 | from concurrent.futures import ThreadPoolExecutor 7 | from typing import Any, Iterable 8 | 9 | 10 | def thread_pool_processing(func: Any, data: Iterable, workers=18): 11 | """ 12 | Passes all iterator data through the given function 13 | 14 | Args: 15 | workers: Count of workers. 16 | func: function to pass data through 17 | data: input iterator 18 | 19 | Returns: 20 | function return list 21 | 22 | """ 23 | with ThreadPoolExecutor(workers) as p: 24 | return list(p.map(func, data)) 25 | 26 | 27 | def batch_generator(iterable, n=1): 28 | """ 29 | Splits any iterable into n-size packets 30 | 31 | Args: 32 | iterable: iterator 33 | n: size of packets 34 | 35 | Returns: 36 | new n-size packet 37 | """ 38 | it = len(iterable) 39 | for ndx in range(0, it, n): 40 | yield iterable[ndx : min(ndx + n, it)] 41 | -------------------------------------------------------------------------------- /carvekit/web/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OPHoperHPO/image-background-remove-tool/f141a311af67fb1da64269c508a6d1f786420801/carvekit/web/__init__.py -------------------------------------------------------------------------------- /carvekit/web/app.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import uvicorn 4 | from fastapi import FastAPI 5 | from fastapi.middleware.cors import CORSMiddleware 6 | from starlette.staticfiles import StaticFiles 7 | 8 | from carvekit import version 9 | from carvekit.web.deps import config 10 | from carvekit.web.routers.api_router import api_router 11 | 12 | app = FastAPI(title="CarveKit Web API", version=version) 13 | 14 | app.add_middleware( 15 | CORSMiddleware, 16 | allow_origins=["*"], 17 | allow_credentials=True, 18 | allow_methods=["*"], 19 | allow_headers=["*"], 20 | ) 21 | 22 | app.include_router(api_router, prefix="/api") 23 | app.mount( 24 | "/", 25 | StaticFiles(directory=Path(__file__).parent.joinpath("static"), html=True), 26 | name="static", 27 | ) 28 | 29 | if __name__ == "__main__": 30 | uvicorn.run(app, host=config.host, port=config.port) 31 | -------------------------------------------------------------------------------- /carvekit/web/deps.py: -------------------------------------------------------------------------------- 1 | from carvekit.web.schemas.config import WebAPIConfig 2 | from carvekit.web.utils.init_utils import init_config 3 | from carvekit.web.utils.task_queue import MLProcessor 4 | 5 | config: WebAPIConfig = init_config() 6 | ml_processor = MLProcessor(api_config=config) 7 | -------------------------------------------------------------------------------- /carvekit/web/handlers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OPHoperHPO/image-background-remove-tool/f141a311af67fb1da64269c508a6d1f786420801/carvekit/web/handlers/__init__.py -------------------------------------------------------------------------------- /carvekit/web/handlers/response.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | from fastapi import Header 4 | from fastapi.responses import Response, JSONResponse 5 | from carvekit.web.deps import config 6 | 7 | 8 | def Authenticate(x_api_key: Union[str, None] = Header(None)) -> Union[bool, str]: 9 | if x_api_key in config.auth.allowed_tokens: 10 | return "allowed" 11 | elif x_api_key == config.auth.admin_token: 12 | return "admin" 13 | elif config.auth.auth is False: 14 | return "allowed" 15 | else: 16 | return False 17 | 18 | 19 | def handle_response(response, original_image) -> Response: 20 | """ 21 | Response handler from TaskQueue 22 | :param response: TaskQueue response 23 | :param original_image: Original PIL image 24 | :return: Complete flask response 25 | """ 26 | response_object = None 27 | if isinstance(response, dict): 28 | if response["type"] == "jpg": 29 | response_object = Response( 30 | content=response["data"][0].read(), media_type="image/jpeg" 31 | ) 32 | elif response["type"] == "png": 33 | response_object = Response( 34 | content=response["data"][0].read(), media_type="image/png" 35 | ) 36 | elif response["type"] == "zip": 37 | response_object = Response( 38 | content=response["data"][0], media_type="application/zip" 39 | ) 40 | response_object.headers[ 41 | "Content-Disposition" 42 | ] = "attachment; filename='no-bg.zip'" 43 | 44 | # Add headers to output result 45 | response_object.headers["X-Credits-Charged"] = "0" 46 | response_object.headers["X-Type"] = "other" # TODO Make support for this 47 | response_object.headers["X-Max-Width"] = str(original_image.size[0]) 48 | response_object.headers["X-Max-Height"] = str(original_image.size[1]) 49 | response_object.headers[ 50 | "X-Ratelimit-Limit" 51 | ] = "500" # TODO Make ratelimit support 52 | response_object.headers["X-Ratelimit-Remaining"] = "500" 53 | response_object.headers["X-Ratelimit-Reset"] = "1" 54 | response_object.headers["X-Width"] = str(response["data"][1][0]) 55 | response_object.headers["X-Height"] = str(response["data"][1][1]) 56 | 57 | else: 58 | response = JSONResponse(content=response[0]) 59 | response.headers["X-Credits-Charged"] = "0" 60 | 61 | return response_object 62 | -------------------------------------------------------------------------------- /carvekit/web/other/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OPHoperHPO/image-background-remove-tool/f141a311af67fb1da64269c508a6d1f786420801/carvekit/web/other/__init__.py -------------------------------------------------------------------------------- /carvekit/web/responses/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OPHoperHPO/image-background-remove-tool/f141a311af67fb1da64269c508a6d1f786420801/carvekit/web/responses/__init__.py -------------------------------------------------------------------------------- /carvekit/web/responses/api.py: -------------------------------------------------------------------------------- 1 | def error_dict(error_text: str): 2 | """ 3 | Generates a dictionary containing $error_text error 4 | :param error_text: Error text 5 | :return: error dictionary 6 | """ 7 | resp = {"errors": [{"title": error_text}]} 8 | return resp 9 | -------------------------------------------------------------------------------- /carvekit/web/routers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OPHoperHPO/image-background-remove-tool/f141a311af67fb1da64269c508a6d1f786420801/carvekit/web/routers/__init__.py -------------------------------------------------------------------------------- /carvekit/web/schemas/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OPHoperHPO/image-background-remove-tool/f141a311af67fb1da64269c508a6d1f786420801/carvekit/web/schemas/__init__.py -------------------------------------------------------------------------------- /carvekit/web/schemas/config.py: -------------------------------------------------------------------------------- 1 | import secrets 2 | from typing import List 3 | from typing_extensions import Literal 4 | 5 | import torch.cuda 6 | from pydantic import BaseModel, validator 7 | 8 | 9 | class AuthConfig(BaseModel): 10 | """Config for web api token authentication""" 11 | 12 | auth: bool = True 13 | """Enables Token Authentication for API""" 14 | admin_token: str = secrets.token_hex(32) 15 | """Admin Token""" 16 | allowed_tokens: List[str] = [secrets.token_hex(32)] 17 | """All allowed tokens""" 18 | 19 | 20 | class MLConfig(BaseModel): 21 | """Config for ml part of framework""" 22 | 23 | segmentation_network: Literal[ 24 | "u2net", "deeplabv3", "basnet", "tracer_b7" 25 | ] = "tracer_b7" 26 | """Segmentation Network""" 27 | preprocessing_method: Literal["none", "stub"] = "none" 28 | """Pre-processing Method""" 29 | postprocessing_method: Literal["fba", "none"] = "fba" 30 | """Post-Processing Network""" 31 | device: str = "cpu" 32 | """Processing device""" 33 | batch_size_seg: int = 5 34 | """Batch size for segmentation network""" 35 | batch_size_matting: int = 1 36 | """Batch size for matting network""" 37 | seg_mask_size: int = 640 38 | """The size of the input image for the segmentation neural network.""" 39 | matting_mask_size: int = 2048 40 | """The size of the input image for the matting neural network.""" 41 | fp16: bool = False 42 | """Use half precision for inference""" 43 | trimap_dilation: int = 30 44 | """Dilation size for trimap""" 45 | trimap_erosion: int = 5 46 | """Erosion levels for trimap""" 47 | trimap_prob_threshold: int = 231 48 | """Probability threshold for trimap generation""" 49 | 50 | @validator("seg_mask_size") 51 | def seg_mask_size_validator(cls, value: int, values): 52 | if value > 0: 53 | return value 54 | else: 55 | raise ValueError("Incorrect seg_mask_size!") 56 | 57 | @validator("matting_mask_size") 58 | def matting_mask_size_validator(cls, value: int, values): 59 | if value > 0: 60 | return value 61 | else: 62 | raise ValueError("Incorrect matting_mask_size!") 63 | 64 | @validator("batch_size_seg") 65 | def batch_size_seg_validator(cls, value: int, values): 66 | if value > 0: 67 | return value 68 | else: 69 | raise ValueError("Incorrect batch size!") 70 | 71 | @validator("batch_size_matting") 72 | def batch_size_matting_validator(cls, value: int, values): 73 | if value > 0: 74 | return value 75 | else: 76 | raise ValueError("Incorrect batch size!") 77 | 78 | @validator("device") 79 | def device_validator(cls, value): 80 | if torch.cuda.is_available() is False and "cuda" in value: 81 | raise ValueError( 82 | "GPU is not available, but specified as processing device!" 83 | ) 84 | if "cuda" not in value and "cpu" != value: 85 | raise ValueError("Unknown processing device! It should be cpu or cuda!") 86 | return value 87 | 88 | 89 | class WebAPIConfig(BaseModel): 90 | """FastAPI app config""" 91 | 92 | port: int = 5000 93 | """Web API port""" 94 | host: str = "0.0.0.0" 95 | """Web API host""" 96 | ml: MLConfig = MLConfig() 97 | """Config for ml part of framework""" 98 | auth: AuthConfig = AuthConfig() 99 | """Config for web api token authentication """ 100 | -------------------------------------------------------------------------------- /carvekit/web/schemas/request.py: -------------------------------------------------------------------------------- 1 | import re 2 | from typing import Optional 3 | 4 | from pydantic import BaseModel, validator 5 | from typing_extensions import Literal 6 | 7 | 8 | class Parameters(BaseModel): 9 | image_file_b64: Optional[str] = "" 10 | image_url: Optional[str] = "" 11 | size: Optional[Literal["preview", "full", "auto"]] = "preview" 12 | type: Optional[ 13 | Literal["auto", "product", "person", "car"] 14 | ] = "auto" # Not supported at the moment 15 | format: Optional[Literal["auto", "jpg", "png", "zip"]] = "auto" 16 | roi: str = "0% 0% 100% 100%" 17 | crop: bool = False 18 | crop_margin: Optional[str] = "0px" 19 | scale: Optional[str] = "original" 20 | position: Optional[str] = "original" 21 | channels: Optional[Literal["rgba", "alpha"]] = "rgba" 22 | add_shadow: str = "false" # Not supported at the moment 23 | semitransparency: str = "false" # Not supported at the moment 24 | bg_color: Optional[str] = "" 25 | bg_image_url: Optional[str] = "" 26 | 27 | @validator("crop_margin") 28 | def crop_margin_validator(cls, value): 29 | if not re.match(r"[0-9]+(px|%)$", value): 30 | raise ValueError( 31 | "crop_margin paramter is not valid" 32 | ) # TODO: Add support of several values 33 | if "%" in value and (int(value[:-1]) < 0 or int(value[:-1]) > 100): 34 | raise ValueError("crop_margin mast be in range between 0% and 100%") 35 | return value 36 | 37 | @validator("scale") 38 | def scale_validator(cls, value): 39 | if value != "original" and ( 40 | not re.match(r"[0-9]+%$", value) 41 | or not int(value[:-1]) <= 100 42 | or not int(value[:-1]) >= 10 43 | ): 44 | raise ValueError("scale must be original or in between of 10% and 100%") 45 | 46 | if value == "original": 47 | return 100 48 | 49 | return int(value[:-1]) 50 | 51 | @validator("position") 52 | def position_validator(cls, value, values): 53 | if len(value.split(" ")) > 2: 54 | raise ValueError( 55 | "Position must be a value from 0 to 100 " 56 | "for both vertical and horizontal axises or for both axises respectively" 57 | ) 58 | 59 | if value == "original": 60 | return "original" 61 | elif len(value.split(" ")) == 1: 62 | return [int(value[:-1]), int(value[:-1])] 63 | else: 64 | return [int(value.split(" ")[0][:-1]), int(value.split(" ")[1][:-1])] 65 | 66 | @validator("bg_color") 67 | def bg_color_validator(cls, value): 68 | if not re.match(r"(#{0,1}[0-9a-f]{3}){0,2}$", value): 69 | raise ValueError("bg_color is not in hex") 70 | if len(value) and value[0] != "#": 71 | value = "#" + value 72 | return value 73 | -------------------------------------------------------------------------------- /carvekit/web/static/css/fancybox_loading.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OPHoperHPO/image-background-remove-tool/f141a311af67fb1da64269c508a6d1f786420801/carvekit/web/static/css/fancybox_loading.gif -------------------------------------------------------------------------------- /carvekit/web/static/css/fancybox_overlay.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OPHoperHPO/image-background-remove-tool/f141a311af67fb1da64269c508a6d1f786420801/carvekit/web/static/css/fancybox_overlay.png -------------------------------------------------------------------------------- /carvekit/web/static/css/fancybox_sprite.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OPHoperHPO/image-background-remove-tool/f141a311af67fb1da64269c508a6d1f786420801/carvekit/web/static/css/fancybox_sprite.png -------------------------------------------------------------------------------- /carvekit/web/static/css/jquery.fancybox.css: -------------------------------------------------------------------------------- 1 | /*! fancyBox v2.1.5 fancyapps.com | fancyapps.com/fancybox/#license */ 2 | .fancybox-wrap, 3 | .fancybox-skin, 4 | .fancybox-outer, 5 | .fancybox-inner, 6 | .fancybox-image, 7 | .fancybox-wrap iframe, 8 | .fancybox-wrap object, 9 | .fancybox-nav, 10 | .fancybox-nav span, 11 | .fancybox-tmp 12 | { 13 | padding: 0; 14 | margin: 0; 15 | border: 0; 16 | outline: none; 17 | vertical-align: top; 18 | } 19 | 20 | .fancybox-wrap { 21 | position: absolute; 22 | top: 0; 23 | left: 0; 24 | z-index: 8020; 25 | } 26 | 27 | .fancybox-skin { 28 | position: relative; 29 | background: #f9f9f9; 30 | color: #444; 31 | text-shadow: none; 32 | -webkit-border-radius: 4px; 33 | -moz-border-radius: 4px; 34 | border-radius: 4px; 35 | } 36 | 37 | .fancybox-opened { 38 | z-index: 8030; 39 | } 40 | 41 | .fancybox-opened .fancybox-skin { 42 | -webkit-box-shadow: 0 10px 25px rgba(0, 0, 0, 0.5); 43 | -moz-box-shadow: 0 10px 25px rgba(0, 0, 0, 0.5); 44 | box-shadow: 0 10px 25px rgba(0, 0, 0, 0.5); 45 | } 46 | 47 | .fancybox-outer, .fancybox-inner { 48 | position: relative; 49 | } 50 | 51 | .fancybox-inner { 52 | overflow: hidden; 53 | } 54 | 55 | .fancybox-type-iframe .fancybox-inner { 56 | -webkit-overflow-scrolling: touch; 57 | } 58 | 59 | .fancybox-error { 60 | color: #444; 61 | font: 14px/20px "Helvetica Neue",Helvetica,Arial,sans-serif; 62 | margin: 0; 63 | padding: 15px; 64 | white-space: nowrap; 65 | } 66 | 67 | .fancybox-image, .fancybox-iframe { 68 | display: block; 69 | width: 100%; 70 | height: 100%; 71 | } 72 | 73 | .fancybox-image { 74 | max-width: 100%; 75 | max-height: 100%; 76 | } 77 | 78 | #fancybox-loading, .fancybox-close, .fancybox-prev span, .fancybox-next span { 79 | background-image: url('fancybox_sprite.png'); 80 | } 81 | 82 | #fancybox-loading { 83 | position: fixed; 84 | top: 50%; 85 | left: 50%; 86 | margin-top: -22px; 87 | margin-left: -22px; 88 | background-position: 0 -108px; 89 | opacity: 0.8; 90 | cursor: pointer; 91 | z-index: 8060; 92 | } 93 | 94 | #fancybox-loading div { 95 | width: 44px; 96 | height: 44px; 97 | background: url('fancybox_loading.gif') center center no-repeat; 98 | } 99 | 100 | .fancybox-close { 101 | position: absolute; 102 | top: -18px; 103 | right: -18px; 104 | width: 36px; 105 | height: 36px; 106 | cursor: pointer; 107 | z-index: 8040; 108 | } 109 | 110 | .fancybox-nav { 111 | position: absolute; 112 | top: 0; 113 | width: 40%; 114 | height: 100%; 115 | cursor: pointer; 116 | text-decoration: none; 117 | background: transparent url('blank.gif'); /* helps IE */ 118 | -webkit-tap-highlight-color: rgba(0,0,0,0); 119 | z-index: 8040; 120 | } 121 | 122 | .fancybox-prev { 123 | left: 0; 124 | } 125 | 126 | .fancybox-next { 127 | right: 0; 128 | } 129 | 130 | .fancybox-nav span { 131 | position: absolute; 132 | top: 50%; 133 | width: 36px; 134 | height: 34px; 135 | margin-top: -18px; 136 | cursor: pointer; 137 | z-index: 8040; 138 | visibility: hidden; 139 | } 140 | 141 | .fancybox-prev span { 142 | left: 10px; 143 | background-position: 0 -36px; 144 | } 145 | 146 | .fancybox-next span { 147 | right: 10px; 148 | background-position: 0 -72px; 149 | } 150 | 151 | .fancybox-nav:hover span { 152 | visibility: visible; 153 | } 154 | 155 | .fancybox-tmp { 156 | position: absolute; 157 | top: -99999px; 158 | left: -99999px; 159 | visibility: hidden; 160 | max-width: 99999px; 161 | max-height: 99999px; 162 | overflow: visible !important; 163 | } 164 | 165 | /* Overlay helper */ 166 | 167 | .fancybox-lock { 168 | overflow: hidden !important; 169 | width: auto; 170 | } 171 | 172 | .fancybox-lock body { 173 | overflow: hidden !important; 174 | } 175 | 176 | .fancybox-lock-test { 177 | overflow-y: hidden !important; 178 | } 179 | 180 | .fancybox-overlay { 181 | position: absolute; 182 | top: 0; 183 | left: 0; 184 | overflow: hidden; 185 | display: none; 186 | z-index: 8010; 187 | background: url('fancybox_overlay.png'); 188 | } 189 | 190 | .fancybox-overlay-fixed { 191 | position: fixed; 192 | bottom: 0; 193 | right: 0; 194 | } 195 | 196 | .fancybox-lock .fancybox-overlay { 197 | overflow: auto; 198 | overflow-y: scroll; 199 | } 200 | 201 | /* Title helper */ 202 | 203 | .fancybox-title { 204 | visibility: hidden; 205 | font: normal 13px/20px "Helvetica Neue",Helvetica,Arial,sans-serif; 206 | position: relative; 207 | text-shadow: none; 208 | z-index: 8050; 209 | } 210 | 211 | .fancybox-opened .fancybox-title { 212 | visibility: visible; 213 | } 214 | 215 | .fancybox-title-float-wrap { 216 | position: absolute; 217 | bottom: 0; 218 | right: 50%; 219 | margin-bottom: -35px; 220 | z-index: 8050; 221 | text-align: center; 222 | } 223 | 224 | .fancybox-title-float-wrap .child { 225 | display: inline-block; 226 | margin-right: -100%; 227 | padding: 2px 20px; 228 | background: transparent; /* Fallback for web browsers that doesn't support RGBa */ 229 | background: rgba(0, 0, 0, 0.8); 230 | -webkit-border-radius: 15px; 231 | -moz-border-radius: 15px; 232 | border-radius: 15px; 233 | text-shadow: 0 1px 2px #222; 234 | color: #FFF; 235 | font-weight: bold; 236 | line-height: 24px; 237 | white-space: nowrap; 238 | } 239 | 240 | .fancybox-title-outside-wrap { 241 | position: relative; 242 | margin-top: 10px; 243 | color: #fff; 244 | } 245 | 246 | .fancybox-title-inside-wrap { 247 | padding-top: 10px; 248 | } 249 | 250 | .fancybox-title-over-wrap { 251 | position: absolute; 252 | bottom: 0; 253 | left: 0; 254 | color: #fff; 255 | padding: 10px; 256 | background: #000; 257 | background: rgba(0, 0, 0, .8); 258 | } 259 | 260 | /*Retina graphics!*/ 261 | @media only screen and (-webkit-min-device-pixel-ratio: 1.5), 262 | only screen and (min--moz-device-pixel-ratio: 1.5), 263 | only screen and (min-device-pixel-ratio: 1.5){ 264 | 265 | #fancybox-loading, .fancybox-close, .fancybox-prev span, .fancybox-next span { 266 | background-image: url('fancybox_sprite@2x.png'); 267 | background-size: 44px 152px; /*The size of the normal image, half the size of the hi-res image*/ 268 | } 269 | 270 | #fancybox-loading div { 271 | background-image: url('fancybox_loading@2x.gif'); 272 | background-size: 24px 24px; /*The size of the normal image, half the size of the hi-res image*/ 273 | } 274 | } -------------------------------------------------------------------------------- /carvekit/web/static/css/media-queries.css: -------------------------------------------------------------------------------- 1 | /*============================================================ 2 | For Small Desktop 3 | ==============================================================*/ 4 | 5 | @media (min-width: 980px) and (max-width: 1150px) { 6 | 7 | /* slider */ 8 | .carousel-caption h3 { 9 | font-size: 45px; 10 | } 11 | 12 | /* works */ 13 | 14 | 15 | /* team */ 16 | 17 | .member-thumb { 18 | width: auto; 19 | } 20 | 21 | } 22 | 23 | 24 | /*============================================================ 25 | Tablet (Portrait) Design for a width of 768px 26 | ==============================================================*/ 27 | 28 | @media (min-width: 768px) and (max-width: 979px) { 29 | 30 | 31 | /* slider */ 32 | 33 | .carousel-caption h2 { 34 | font-size: 55px; 35 | } 36 | 37 | .carousel-caption h3 { 38 | font-size: 36px; 39 | } 40 | 41 | /* services */ 42 | 43 | .service-item { 44 | margin: 0 auto 30px; 45 | text-align: center; 46 | width: 325px; 47 | } 48 | 49 | .service-icon { 50 | float: none; 51 | margin: 0 auto 15px; 52 | text-align: center; 53 | width: 50px; 54 | } 55 | 56 | .service-desc { 57 | margin-left: 0; 58 | position: relative; 59 | top: 0; 60 | } 61 | 62 | /* works */ 63 | 64 | .work-item { 65 | width: 33%; 66 | } 67 | 68 | /* team */ 69 | 70 | .member-thumb .overlay h5 { 71 | margin: 25px 0; 72 | } 73 | 74 | .member-thumb { 75 | margin: 0 auto; 76 | } 77 | 78 | /* fatcs */ 79 | 80 | #facts { 81 | background-position: center top !important; 82 | } 83 | .counters-item { 84 | margin-bottom: 30px; 85 | } 86 | 87 | .counters-item i { 88 | margin: 0 0 15px; 89 | } 90 | 91 | .counters-item strong { 92 | font-size: 45px; 93 | } 94 | 95 | /* contact */ 96 | 97 | .contact-form .name-email input { 98 | margin-right: 0; 99 | width: 100%; 100 | } 101 | 102 | .footer-social { 103 | margin-top: 45px; 104 | } 105 | 106 | /* footer */ 107 | 108 | .footer-single { 109 | margin-bottom: 30px; 110 | } 111 | 112 | } 113 | 114 | 115 | /*============================================================ 116 | Mobile (Portrait) Design for a width of 320px 117 | ==============================================================*/ 118 | 119 | @media only screen and (max-width: 767px) { 120 | 121 | .sec-sub-title p { 122 | font-size: 14px; 123 | } 124 | 125 | /* slider */ 126 | .carousel-caption h2 { 127 | font-size: 35px; 128 | } 129 | 130 | .carousel-caption h3 { 131 | font-size: 22px; 132 | } 133 | 134 | .carousel-caption p { 135 | font-size: 14px; 136 | } 137 | 138 | .social-links { 139 | margin-top: 20%; 140 | } 141 | 142 | /* services */ 143 | 144 | .service-item { 145 | margin: 0 auto 30px; 146 | text-align: center; 147 | width: 280px; 148 | } 149 | 150 | .service-icon { 151 | float: none; 152 | margin: 0 auto 15px; 153 | text-align: center; 154 | width: 50px; 155 | } 156 | 157 | .service-desc { 158 | margin-left: 0; 159 | position: relative; 160 | top: 0; 161 | } 162 | 163 | /* works */ 164 | 165 | .work-item { 166 | left: 5% !important; 167 | width: 90%; 168 | } 169 | 170 | /* team */ 171 | 172 | .team-member { 173 | margin-bottom: 30px; 174 | } 175 | 176 | .team-member:last-child { 177 | margin-bottom: 0; 178 | } 179 | 180 | .member-thumb { 181 | margin: 0 auto; 182 | } 183 | 184 | /* facts */ 185 | 186 | #facts { 187 | background-position: center top !important; 188 | } 189 | 190 | .counters-item { 191 | margin-bottom: 30px; 192 | } 193 | 194 | /* contact */ 195 | .contact-address { 196 | margin-bottom: 30px; 197 | } 198 | 199 | .footer-social { 200 | margin-top: 20px; 201 | text-align: center; 202 | } 203 | 204 | .footer-social li { 205 | display: inline-block; 206 | } 207 | 208 | .footer-social li a { 209 | margin: 0 10px; 210 | } 211 | 212 | /* footer */ 213 | 214 | .footer-single { 215 | margin-bottom: 30px; 216 | } 217 | 218 | } 219 | 220 | 221 | /*============================================================ 222 | Mobile (Landscape) Design for a width of 480px 223 | ==============================================================*/ 224 | 225 | @media only screen and (min-width: 480px) and (max-width: 767px) { 226 | 227 | 228 | /* services */ 229 | 230 | .service-item { 231 | margin: 0 auto 30px; 232 | text-align: center; 233 | width: 325px; 234 | } 235 | 236 | .service-icon { 237 | float: none; 238 | margin: 0 auto 15px; 239 | text-align: center; 240 | width: 50px; 241 | } 242 | 243 | .service-desc { 244 | margin-left: 0; 245 | position: relative; 246 | top: 0; 247 | } 248 | 249 | /* works */ 250 | 251 | .work-item { 252 | left: inherit !important; 253 | width: 50%; 254 | } 255 | 256 | } -------------------------------------------------------------------------------- /carvekit/web/static/css/normalize.min.css: -------------------------------------------------------------------------------- 1 | /*! normalize.css v1.1.3 | MIT License | git.io/normalize */article,aside,details,figcaption,figure,footer,header,hgroup,main,nav,section,summary{display:block}audio,canvas,video{display:inline-block;*display:inline;*zoom:1}audio:not([controls]){display:none;height:0}[hidden]{display:none}html{font-size:100%;-ms-text-size-adjust:100%;-webkit-text-size-adjust:100%}html,button,input,select,textarea{font-family:sans-serif}body{margin:0}a:focus{outline:thin dotted}a:active,a:hover{outline:0}h1{font-size:2em;margin:.67em 0}h2{font-size:1.5em;margin:.83em 0}h3{font-size:1.17em;margin:1em 0}h4{font-size:1em;margin:1.33em 0}h5{font-size:.83em;margin:1.67em 0}h6{font-size:.67em;margin:2.33em 0}abbr[title]{border-bottom:1px dotted}b,strong{font-weight:bold}blockquote{margin:1em 40px}dfn{font-style:italic}hr{-moz-box-sizing:content-box;box-sizing:content-box;height:0}mark{background:#ff0;color:#000}p,pre{margin:1em 0}code,kbd,pre,samp{font-family:monospace,serif;_font-family:'courier new',monospace;font-size:1em}pre{white-space:pre;white-space:pre-wrap;word-wrap:break-word}q{quotes:none}q:before,q:after{content:'';content:none}small{font-size:80%}sub,sup{font-size:75%;line-height:0;position:relative;vertical-align:baseline}sup{top:-0.5em}sub{bottom:-0.25em}dl,menu,ol,ul{margin:1em 0}dd{margin:0 0 0 40px}menu,ol,ul{padding:0 0 0 40px}nav ul,nav ol{list-style:none;list-style-image:none}img{border:0;-ms-interpolation-mode:bicubic}svg:not(:root){overflow:hidden}figure{margin:0}form{margin:0}fieldset{border:1px solid silver;margin:0 2px;padding:.35em .625em .75em}legend{border:0;padding:0;white-space:normal;*margin-left:-7px}button,input,select,textarea{font-size:100%;margin:0;vertical-align:baseline;*vertical-align:middle}button,input{line-height:normal}button,select{text-transform:none}button,html input[type="button"],input[type="reset"],input[type="submit"]{-webkit-appearance:button;cursor:pointer;*overflow:visible}button[disabled],html input[disabled]{cursor:default}input[type="checkbox"],input[type="radio"]{box-sizing:border-box;padding:0;*height:13px;*width:13px}input[type="search"]{-webkit-appearance:textfield;-moz-box-sizing:content-box;-webkit-box-sizing:content-box;box-sizing:content-box}input[type="search"]::-webkit-search-cancel-button,input[type="search"]::-webkit-search-decoration{-webkit-appearance:none}button::-moz-focus-inner,input::-moz-focus-inner{border:0;padding:0}textarea{overflow:auto;vertical-align:top}table{border-collapse:collapse;border-spacing:0} -------------------------------------------------------------------------------- /carvekit/web/static/css/particles.css: -------------------------------------------------------------------------------- 1 | /* ---- reset ---- */ 2 | 3 | 4 | canvas { 5 | 6 | display: block; 7 | vertical-align: bottom; 8 | } 9 | 10 | /* ---- particles.js container ---- */ 11 | 12 | #particles-js { 13 | 14 | position: absolute; 15 | width: 100%; 16 | height: 100%; 17 | } 18 | 19 | /* ---- stats.js ---- */ 20 | 21 | .count-particles{ 22 | 23 | background: #000022; 24 | position: absolute; 25 | top: 48px; 26 | left: 0; 27 | width: 80px; 28 | color: #13E8E9; 29 | font-size: .8em; 30 | text-align: left; 31 | text-indent: 4px; 32 | line-height: 14px; 33 | padding-bottom: 2px; 34 | font-family: Helvetica, Arial, sans-serif; 35 | font-weight: bold; 36 | } 37 | 38 | .js-count-particles{ 39 | 40 | font-size: 1.1em; 41 | } 42 | 43 | #stats, 44 | .count-particles{ 45 | -webkit-user-select: none; 46 | } 47 | 48 | #stats{ 49 | border-radius: 3px 3px 0 0; 50 | overflow: hidden; 51 | } 52 | 53 | .count-particles{ 54 | border-radius: 0 0 3px 3px; 55 | } -------------------------------------------------------------------------------- /carvekit/web/static/fonts/FontAwesome.otf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OPHoperHPO/image-background-remove-tool/f141a311af67fb1da64269c508a6d1f786420801/carvekit/web/static/fonts/FontAwesome.otf -------------------------------------------------------------------------------- /carvekit/web/static/fonts/fontawesome-webfont.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OPHoperHPO/image-background-remove-tool/f141a311af67fb1da64269c508a6d1f786420801/carvekit/web/static/fonts/fontawesome-webfont.eot -------------------------------------------------------------------------------- /carvekit/web/static/fonts/fontawesome-webfont.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OPHoperHPO/image-background-remove-tool/f141a311af67fb1da64269c508a6d1f786420801/carvekit/web/static/fonts/fontawesome-webfont.ttf -------------------------------------------------------------------------------- /carvekit/web/static/fonts/fontawesome-webfont.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OPHoperHPO/image-background-remove-tool/f141a311af67fb1da64269c508a6d1f786420801/carvekit/web/static/fonts/fontawesome-webfont.woff -------------------------------------------------------------------------------- /carvekit/web/static/img/CarveKit_logo_main.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OPHoperHPO/image-background-remove-tool/f141a311af67fb1da64269c508a6d1f786420801/carvekit/web/static/img/CarveKit_logo_main.png -------------------------------------------------------------------------------- /carvekit/web/static/img/art.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OPHoperHPO/image-background-remove-tool/f141a311af67fb1da64269c508a6d1f786420801/carvekit/web/static/img/art.gif -------------------------------------------------------------------------------- /carvekit/web/static/img/envelop.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OPHoperHPO/image-background-remove-tool/f141a311af67fb1da64269c508a6d1f786420801/carvekit/web/static/img/envelop.png -------------------------------------------------------------------------------- /carvekit/web/static/img/icon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OPHoperHPO/image-background-remove-tool/f141a311af67fb1da64269c508a6d1f786420801/carvekit/web/static/img/icon.png -------------------------------------------------------------------------------- /carvekit/web/static/img/preloader.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OPHoperHPO/image-background-remove-tool/f141a311af67fb1da64269c508a6d1f786420801/carvekit/web/static/img/preloader.gif -------------------------------------------------------------------------------- /carvekit/web/static/js/custom.js: -------------------------------------------------------------------------------- 1 | /* ========================================================================= */ 2 | /* Preloader 3 | /* ========================================================================= */ 4 | 5 | jQuery(window).load(function(){ 6 | 7 | $("#preloader").fadeOut("slow"); 8 | 9 | }); 10 | 11 | 12 | $(document).ready(function(){ 13 | 14 | /* ========================================================================= */ 15 | /* Menu item highlighting 16 | /* ========================================================================= */ 17 | 18 | jQuery('#nav').singlePageNav({ 19 | offset: jQuery('#nav').outerHeight(), 20 | filter: ':not(.external)', 21 | speed: 1200, 22 | currentClass: 'current', 23 | easing: 'easeInOutExpo', 24 | updateHash: true, 25 | beforeStart: function() { 26 | console.log('begin scrolling'); 27 | }, 28 | onComplete: function() { 29 | console.log('done scrolling'); 30 | } 31 | }); 32 | 33 | $(window).scroll(function () { 34 | if ($(window).scrollTop() > 400) { 35 | $("#navigation").css("background-color","#021323"); 36 | } else { 37 | $("#navigation").css("background-color","rgba(2, 19, 35, 0.7)"); 38 | } 39 | }); 40 | 41 | /* ========================================================================= */ 42 | /* Fix Slider Height 43 | /* ========================================================================= */ 44 | 45 | var slideHeight = $(window).height(); 46 | 47 | $('#slider, .carousel.slide, .carousel-inner, .carousel-inner .item').css('height',slideHeight); 48 | 49 | $(window).resize(function(){'use strict', 50 | $('#slider, .carousel.slide, .carousel-inner, .carousel-inner .item').css('height',slideHeight); 51 | }); 52 | 53 | 54 | /* ========================================================================= */ 55 | /* Portfolio Filtering 56 | /* ========================================================================= */ 57 | 58 | 59 | // portfolio filtering 60 | 61 | $(".project-wrapper").mixItUp(); 62 | 63 | 64 | $(".fancybox").fancybox({ 65 | padding: 0, 66 | 67 | openEffect : 'elastic', 68 | openSpeed : 650, 69 | 70 | closeEffect : 'elastic', 71 | closeSpeed : 550, 72 | 73 | closeClick : true, 74 | }); 75 | 76 | /* ========================================================================= */ 77 | /* Parallax 78 | /* ========================================================================= */ 79 | 80 | $('#facts').parallax("50%", 0.3); 81 | 82 | /* ========================================================================= */ 83 | /* Timer count 84 | /* ========================================================================= */ 85 | 86 | "use strict"; 87 | $(".number-counters").appear(function () { 88 | $(".number-counters [data-to]").each(function () { 89 | var e = $(this).attr("data-to"); 90 | $(this).delay(6e3).countTo({ 91 | from: 50, 92 | to: e, 93 | speed: 3e3, 94 | refreshInterval: 50 95 | }) 96 | }) 97 | }); 98 | 99 | /* ========================================================================= */ 100 | /* Back to Top 101 | /* ========================================================================= */ 102 | 103 | 104 | $(window).scroll(function () { 105 | if ($(window).scrollTop() > 400) { 106 | $("#back-top").fadeIn(200) 107 | } else { 108 | $("#back-top").fadeOut(200) 109 | } 110 | }); 111 | $("#back-top").click(function () { 112 | $("html, body").stop().animate({ 113 | scrollTop: 0 114 | }, 1500, "easeInOutExpo") 115 | }); 116 | 117 | }); 118 | 119 | 120 | // ========== START GOOGLE MAP ========== // 121 | function initialize() { 122 | var myLatLng = new google.maps.LatLng(22.402789, 91.822156); 123 | 124 | var mapOptions = { 125 | zoom: 14, 126 | center: myLatLng, 127 | disableDefaultUI: true, 128 | scrollwheel: false, 129 | navigationControl: false, 130 | mapTypeControl: false, 131 | scaleControl: false, 132 | draggable: false, 133 | mapTypeControlOptions: { 134 | mapTypeIds: [google.maps.MapTypeId.ROADMAP, 'roadatlas'] 135 | } 136 | }; 137 | 138 | var map = new google.maps.Map(document.getElementById('map_canvas'), mapOptions); 139 | 140 | 141 | var marker = new google.maps.Marker({ 142 | position: myLatLng, 143 | map: map, 144 | icon: 'img/location-icon.png', 145 | title: '', 146 | }); 147 | 148 | } 149 | 150 | google.maps.event.addDomListener(window, "load", initialize); 151 | // ========== END GOOGLE MAP ========== // 152 | -------------------------------------------------------------------------------- /carvekit/web/static/js/jquery-countTo.js: -------------------------------------------------------------------------------- 1 | /* 2 | Plugin Name: Count To 3 | Written by: Matt Huggins - https://github.com/mhuggins/jquery-countTo 4 | */ 5 | 6 | (function ($) { 7 | $.fn.countTo = function (options) { 8 | options = options || {}; 9 | 10 | return $(this).each(function () { 11 | // set options for current element 12 | var settings = $.extend({}, $.fn.countTo.defaults, { 13 | from: $(this).data('from'), 14 | to: $(this).data('to'), 15 | speed: $(this).data('speed'), 16 | refreshInterval: $(this).data('refresh-interval'), 17 | decimals: $(this).data('decimals') 18 | }, options); 19 | 20 | // how many times to update the value, and how much to increment the value on each update 21 | var loops = Math.ceil(settings.speed / settings.refreshInterval), 22 | increment = (settings.to - settings.from) / loops; 23 | 24 | // references & variables that will change with each update 25 | var self = this, 26 | $self = $(this), 27 | loopCount = 0, 28 | value = settings.from, 29 | data = $self.data('countTo') || {}; 30 | 31 | $self.data('countTo', data); 32 | 33 | // if an existing interval can be found, clear it first 34 | if (data.interval) { 35 | clearInterval(data.interval); 36 | } 37 | data.interval = setInterval(updateTimer, settings.refreshInterval); 38 | 39 | // initialize the element with the starting value 40 | render(value); 41 | 42 | function updateTimer() { 43 | value += increment; 44 | loopCount++; 45 | 46 | render(value); 47 | 48 | if (typeof(settings.onUpdate) == 'function') { 49 | settings.onUpdate.call(self, value); 50 | } 51 | 52 | if (loopCount >= loops) { 53 | // remove the interval 54 | $self.removeData('countTo'); 55 | clearInterval(data.interval); 56 | value = settings.to; 57 | 58 | if (typeof(settings.onComplete) == 'function') { 59 | settings.onComplete.call(self, value); 60 | } 61 | } 62 | } 63 | 64 | function render(value) { 65 | var formattedValue = settings.formatter.call(self, value, settings); 66 | $self.text(formattedValue); 67 | } 68 | }); 69 | }; 70 | 71 | $.fn.countTo.defaults = { 72 | from: 0, // the number the element should start at 73 | to: 0, // the number the element should end at 74 | speed: 1000, // how long it should take to count between the target numbers 75 | refreshInterval: 100, // how often the element should be updated 76 | decimals: 0, // the number of decimal places to show 77 | formatter: formatter, // handler for formatting the value before rendering 78 | onUpdate: null, // callback method for every time the element is updated 79 | onComplete: null // callback method for when the element finishes updating 80 | }; 81 | 82 | function formatter(value, settings) { 83 | return value.toFixed(settings.decimals); 84 | } 85 | }(jQuery)); -------------------------------------------------------------------------------- /carvekit/web/static/js/jquery.appear.js: -------------------------------------------------------------------------------- 1 | /* 2 | * jQuery.appear 3 | * https://github.com/bas2k/jquery.appear/ 4 | * http://code.google.com/p/jquery-appear/ 5 | * 6 | * Copyright (c) 2009 Michael Hixson 7 | * Copyright (c) 2012 Alexander Brovikov 8 | * Licensed under the MIT license (http://www.opensource.org/licenses/mit-license.php) 9 | */ 10 | (function($) { 11 | $.fn.appear = function(fn, options) { 12 | 13 | var settings = $.extend({ 14 | 15 | //arbitrary data to pass to fn 16 | data: undefined, 17 | 18 | //call fn only on the first appear? 19 | one: true, 20 | 21 | // X & Y accuracy 22 | accX: 0, 23 | accY: 0 24 | 25 | }, options); 26 | 27 | return this.each(function() { 28 | 29 | var t = $(this); 30 | 31 | //whether the element is currently visible 32 | t.appeared = false; 33 | 34 | if (!fn) { 35 | 36 | //trigger the custom event 37 | t.trigger('appear', settings.data); 38 | return; 39 | } 40 | 41 | var w = $(window); 42 | 43 | //fires the appear event when appropriate 44 | var check = function() { 45 | 46 | //is the element hidden? 47 | if (!t.is(':visible')) { 48 | 49 | //it became hidden 50 | t.appeared = false; 51 | return; 52 | } 53 | 54 | //is the element inside the visible window? 55 | var a = w.scrollLeft(); 56 | var b = w.scrollTop(); 57 | var o = t.offset(); 58 | var x = o.left; 59 | var y = o.top; 60 | 61 | var ax = settings.accX; 62 | var ay = settings.accY; 63 | var th = t.height(); 64 | var wh = w.height(); 65 | var tw = t.width(); 66 | var ww = w.width(); 67 | 68 | if (y + th + ay >= b && 69 | y <= b + wh + ay && 70 | x + tw + ax >= a && 71 | x <= a + ww + ax) { 72 | 73 | //trigger the custom event 74 | if (!t.appeared) t.trigger('appear', settings.data); 75 | 76 | } else { 77 | 78 | //it scrolled out of view 79 | t.appeared = false; 80 | } 81 | }; 82 | 83 | //create a modified fn with some additional logic 84 | var modifiedFn = function() { 85 | 86 | //mark the element as visible 87 | t.appeared = true; 88 | 89 | //is this supposed to happen only once? 90 | if (settings.one) { 91 | 92 | //remove the check 93 | w.unbind('scroll', check); 94 | var i = $.inArray(check, $.fn.appear.checks); 95 | if (i >= 0) $.fn.appear.checks.splice(i, 1); 96 | } 97 | 98 | //trigger the original fn 99 | fn.apply(this, arguments); 100 | }; 101 | 102 | //bind the modified fn to the element 103 | if (settings.one) t.one('appear', settings.data, modifiedFn); 104 | else t.bind('appear', settings.data, modifiedFn); 105 | 106 | //check whenever the window scrolls 107 | w.scroll(check); 108 | 109 | //check whenever the dom changes 110 | $.fn.appear.checks.push(check); 111 | 112 | //check now 113 | (check)(); 114 | }); 115 | }; 116 | 117 | //keep a queue of appearance checks 118 | $.extend($.fn.appear, { 119 | 120 | checks: [], 121 | timeout: null, 122 | 123 | //process the queue 124 | checkAll: function() { 125 | var length = $.fn.appear.checks.length; 126 | if (length > 0) while (length--) ($.fn.appear.checks[length])(); 127 | }, 128 | 129 | //check the queue asynchronously 130 | run: function() { 131 | if ($.fn.appear.timeout) clearTimeout($.fn.appear.timeout); 132 | $.fn.appear.timeout = setTimeout($.fn.appear.checkAll, 20); 133 | } 134 | }); 135 | 136 | //run checks when these methods are called 137 | $.each(['append', 'prepend', 'after', 'before', 'attr', 138 | 'removeAttr', 'addClass', 'removeClass', 'toggleClass', 139 | 'remove', 'css', 'show', 'hide'], function(i, n) { 140 | var old = $.fn[n]; 141 | if (old) { 142 | $.fn[n] = function() { 143 | var r = old.apply(this, arguments); 144 | $.fn.appear.run(); 145 | return r; 146 | } 147 | } 148 | }); 149 | 150 | })(jQuery); -------------------------------------------------------------------------------- /carvekit/web/static/js/jquery.easing.min.js: -------------------------------------------------------------------------------- 1 | /* 2 | * jQuery Easing v1.3 - http://gsgd.co.uk/sandbox/jquery/easing/ 3 | * 4 | * Uses the built in easing capabilities added In jQuery 1.1 5 | * to offer multiple easing options 6 | * 7 | * TERMS OF USE - EASING EQUATIONS 8 | * 9 | * Open source under the BSD License. 10 | * 11 | * Copyright © 2001 Robert Penner 12 | * All rights reserved. 13 | * 14 | * TERMS OF USE - jQuery Easing 15 | * 16 | * Open source under the BSD License. 17 | * 18 | * Copyright © 2008 George McGinley Smith 19 | * All rights reserved. 20 | * 21 | * Redistribution and use in source and binary forms, with or without modification, 22 | * are permitted provided that the following conditions are met: 23 | * 24 | * Redistributions of source code must retain the above copyright notice, this list of 25 | * conditions and the following disclaimer. 26 | * Redistributions in binary form must reproduce the above copyright notice, this list 27 | * of conditions and the following disclaimer in the documentation and/or other materials 28 | * provided with the distribution. 29 | * 30 | * Neither the name of the author nor the names of contributors may be used to endorse 31 | * or promote products derived from this software without specific prior written permission. 32 | * 33 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY 34 | * EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF 35 | * MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE 36 | * COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 37 | * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE 38 | * GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED 39 | * AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 40 | * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED 41 | * OF THE POSSIBILITY OF SUCH DAMAGE. 42 | * 43 | */ 44 | jQuery.easing.jswing=jQuery.easing.swing;jQuery.extend(jQuery.easing,{def:"easeOutQuad",swing:function(e,f,a,h,g){return jQuery.easing[jQuery.easing.def](e,f,a,h,g)},easeInQuad:function(e,f,a,h,g){return h*(f/=g)*f+a},easeOutQuad:function(e,f,a,h,g){return -h*(f/=g)*(f-2)+a},easeInOutQuad:function(e,f,a,h,g){if((f/=g/2)<1){return h/2*f*f+a}return -h/2*((--f)*(f-2)-1)+a},easeInCubic:function(e,f,a,h,g){return h*(f/=g)*f*f+a},easeOutCubic:function(e,f,a,h,g){return h*((f=f/g-1)*f*f+1)+a},easeInOutCubic:function(e,f,a,h,g){if((f/=g/2)<1){return h/2*f*f*f+a}return h/2*((f-=2)*f*f+2)+a},easeInQuart:function(e,f,a,h,g){return h*(f/=g)*f*f*f+a},easeOutQuart:function(e,f,a,h,g){return -h*((f=f/g-1)*f*f*f-1)+a},easeInOutQuart:function(e,f,a,h,g){if((f/=g/2)<1){return h/2*f*f*f*f+a}return -h/2*((f-=2)*f*f*f-2)+a},easeInQuint:function(e,f,a,h,g){return h*(f/=g)*f*f*f*f+a},easeOutQuint:function(e,f,a,h,g){return h*((f=f/g-1)*f*f*f*f+1)+a},easeInOutQuint:function(e,f,a,h,g){if((f/=g/2)<1){return h/2*f*f*f*f*f+a}return h/2*((f-=2)*f*f*f*f+2)+a},easeInSine:function(e,f,a,h,g){return -h*Math.cos(f/g*(Math.PI/2))+h+a},easeOutSine:function(e,f,a,h,g){return h*Math.sin(f/g*(Math.PI/2))+a},easeInOutSine:function(e,f,a,h,g){return -h/2*(Math.cos(Math.PI*f/g)-1)+a},easeInExpo:function(e,f,a,h,g){return(f==0)?a:h*Math.pow(2,10*(f/g-1))+a},easeOutExpo:function(e,f,a,h,g){return(f==g)?a+h:h*(-Math.pow(2,-10*f/g)+1)+a},easeInOutExpo:function(e,f,a,h,g){if(f==0){return a}if(f==g){return a+h}if((f/=g/2)<1){return h/2*Math.pow(2,10*(f-1))+a}return h/2*(-Math.pow(2,-10*--f)+2)+a},easeInCirc:function(e,f,a,h,g){return -h*(Math.sqrt(1-(f/=g)*f)-1)+a},easeOutCirc:function(e,f,a,h,g){return h*Math.sqrt(1-(f=f/g-1)*f)+a},easeInOutCirc:function(e,f,a,h,g){if((f/=g/2)<1){return -h/2*(Math.sqrt(1-f*f)-1)+a}return h/2*(Math.sqrt(1-(f-=2)*f)+1)+a},easeInElastic:function(f,h,e,l,k){var i=1.70158;var j=0;var g=l;if(h==0){return e}if((h/=k)==1){return e+l}if(!j){j=k*0.3}if(g pos + windowHeight) { 59 | return; 60 | } 61 | 62 | $this.css('backgroundPosition', xpos + " " + Math.round((firstTop - pos) * speedFactor) + "px"); 63 | }); 64 | } 65 | 66 | $window.bind('scroll', update).resize(update); 67 | update(); 68 | }; 69 | })(jQuery); 70 | -------------------------------------------------------------------------------- /carvekit/web/static/js/jquery.singlePageNav.min.js: -------------------------------------------------------------------------------- 1 | /** 2 | * Single Page Nav Plugin 3 | * Copyright (c) 2014 Chris Wojcik 4 | * Dual licensed under MIT and GPL. 5 | * @author Chris Wojcik 6 | * @version 1.2.0 7 | */ 8 | if(typeof Object.create!=="function"){Object.create=function(e){function t(){}t.prototype=e;return new t}}(function(e,t,n,r){"use strict";var i={init:function(n,r){this.options=e.extend({},e.fn.singlePageNav.defaults,n);this.container=r;this.$container=e(r);this.$links=this.$container.find("a");if(this.options.filter!==""){this.$links=this.$links.filter(this.options.filter)}this.$window=e(t);this.$htmlbody=e("html, body");this.$links.on("click.singlePageNav",e.proxy(this.handleClick,this));this.didScroll=false;this.checkPosition();this.setTimer()},handleClick:function(t){var n=this,r=t.currentTarget,i=e(r.hash);t.preventDefault();if(i.length){n.clearTimer();if(typeof n.options.beforeStart==="function"){n.options.beforeStart()}n.setActiveLink(r.hash);n.scrollTo(i,function(){if(n.options.updateHash&&history.pushState){history.pushState(null,null,r.hash)}n.setTimer();if(typeof n.options.onComplete==="function"){n.options.onComplete()}})}},scrollTo:function(e,t){var n=this;var r=n.getCoords(e).top;var i=false;n.$htmlbody.stop().animate({scrollTop:r},{duration:n.options.speed,easing:n.options.easing,complete:function(){if(typeof t==="function"&&!i){t()}i=true}})},setTimer:function(){var e=this;e.$window.on("scroll.singlePageNav",function(){e.didScroll=true});e.timer=setInterval(function(){if(e.didScroll){e.didScroll=false;e.checkPosition()}},250)},clearTimer:function(){clearInterval(this.timer);this.$window.off("scroll.singlePageNav");this.didScroll=false},checkPosition:function(){var e=this.$window.scrollTop();var t=this.getCurrentSection(e);this.setActiveLink(t)},getCoords:function(e){return{top:Math.round(e.offset().top)-this.options.offset}},setActiveLink:function(e){var t=this.$container.find("a[href$='"+e+"']");if(!t.hasClass(this.options.currentClass)){this.$links.removeClass(this.options.currentClass);t.addClass(this.options.currentClass)}},getCurrentSection:function(t){var n,r,i,s;for(n=0;n=i.top-this.options.threshold){s=r}}}return s||this.$links[0].hash}};e.fn.singlePageNav=function(e){return this.each(function(){var t=Object.create(i);t.init(e,this)})};e.fn.singlePageNav.defaults={offset:0,threshold:120,speed:400,currentClass:"current",easing:"swing",updateHash:false,filter:"",onComplete:false,beforeStart:false}})(jQuery,window,document) -------------------------------------------------------------------------------- /carvekit/web/static/js/particles.js: -------------------------------------------------------------------------------- 1 | /* ---- particles.js config ---- */ 2 | 3 | particlesJS("particles-js", { 4 | "particles": { 5 | "number": { 6 | "value": 40, 7 | "density": { 8 | "enable": true, 9 | "value_area": 800 10 | } 11 | }, 12 | "color": { 13 | "value": "#fc900a" 14 | }, 15 | "shape": { 16 | "type": "circle", 17 | "stroke": { 18 | "width": 0, 19 | "color": "#000000" 20 | }, 21 | "polygon": { 22 | "nb_sides": 5 23 | }, 24 | "image": { 25 | "src": "img/icon.png", 26 | "width": 100, 27 | "height": 100 28 | } 29 | }, 30 | "opacity": { 31 | "value": 0.5, 32 | "random": false, 33 | "anim": { 34 | "enable": false, 35 | "speed": 1, 36 | "opacity_min": 0.1, 37 | "sync": false 38 | } 39 | }, 40 | "size": { 41 | "value": 3, 42 | "random": true, 43 | "anim": { 44 | "enable": false, 45 | "speed": 4, 46 | "size_min": 0.1, 47 | "sync": false 48 | } 49 | }, 50 | "line_linked": { 51 | "enable": true, 52 | "distance": 300, 53 | "color": "#ff9100", 54 | "opacity": 0.8, 55 | "width": 1 56 | }, 57 | "move": { 58 | "enable": true, 59 | "speed": 0.2, 60 | "direction": "none", 61 | "random": false, 62 | "straight": false, 63 | "out_mode": "out", 64 | "bounce": false, 65 | "attract": { 66 | "enable": false, 67 | "rotateX": 600, 68 | "rotateY": 1200 69 | } 70 | } 71 | }, 72 | "interactivity": { 73 | "detect_on": "canvas", 74 | "events": { 75 | "onhover": { 76 | "enable": false, 77 | "mode": "grab" 78 | }, 79 | "onclick": { 80 | "enable": false, 81 | "mode": "grab" 82 | }, 83 | "resize": true 84 | }, 85 | "modes": { 86 | "grab": { 87 | "distance": 140, 88 | "line_linked": { 89 | "opacity": 1 90 | } 91 | }, 92 | "bubble": { 93 | "distance": 400, 94 | "size": 40, 95 | "duration": 8, 96 | "opacity": 8, 97 | "speed": 1 98 | }, 99 | "repulse": { 100 | "distance": 200, 101 | "duration": 0.5 102 | }, 103 | "push": { 104 | "particles_nb": 4 105 | }, 106 | "remove": { 107 | "particles_nb": 2 108 | } 109 | } 110 | }, 111 | "retina_detect": true 112 | }); 113 | 114 | 115 | /* ---- stats.js config ---- */ 116 | 117 | var count_particles, stats, update; 118 | stats = new Stats; 119 | stats.setMode(0); 120 | stats.domElement.style.position = 'absolute'; 121 | stats.domElement.style.left = '0px'; 122 | stats.domElement.style.top = '0px'; 123 | document.body.appendChild(stats.domElement); 124 | count_particles = document.querySelector('.js-count-particles'); 125 | update = function() { 126 | stats.begin(); 127 | stats.end(); 128 | if (window.pJSDom[0].pJS.particles && window.pJSDom[0].pJS.particles.array) { 129 | count_particles.innerText = window.pJSDom[0].pJS.particles.array.length; 130 | } 131 | requestAnimationFrame(update); 132 | }; 133 | requestAnimationFrame(update); 134 | -------------------------------------------------------------------------------- /carvekit/web/static/js/wow.min.js: -------------------------------------------------------------------------------- 1 | /*! WOW - v0.1.9 - 2014-05-10 2 | * Copyright (c) 2014 Matthieu Aussaguel; Licensed MIT */(function(){var a,b,c=function(a,b){return function(){return a.apply(b,arguments)}};a=function(){function a(){}return a.prototype.extend=function(a,b){var c,d;for(c in a)d=a[c],null!=d&&(b[c]=d);return b},a.prototype.isMobile=function(a){return/Android|webOS|iPhone|iPad|iPod|BlackBerry|IEMobile|Opera Mini/i.test(a)},a}(),b=this.WeakMap||(b=function(){function a(){this.keys=[],this.values=[]}return a.prototype.get=function(a){var b,c,d,e,f;for(f=this.keys,b=d=0,e=f.length;e>d;b=++d)if(c=f[b],c===a)return this.values[b]},a.prototype.set=function(a,b){var c,d,e,f,g;for(g=this.keys,c=e=0,f=g.length;f>e;c=++e)if(d=g[c],d===a)return void(this.values[c]=b);return this.keys.push(a),this.values.push(b)},a}()),this.WOW=function(){function d(a){null==a&&(a={}),this.scrollCallback=c(this.scrollCallback,this),this.scrollHandler=c(this.scrollHandler,this),this.start=c(this.start,this),this.scrolled=!0,this.config=this.util().extend(a,this.defaults),this.animationNameCache=new b}return d.prototype.defaults={boxClass:"wow",animateClass:"animated",offset:0,mobile:!0},d.prototype.init=function(){var a;return this.element=window.document.documentElement,"interactive"===(a=document.readyState)||"complete"===a?this.start():document.addEventListener("DOMContentLoaded",this.start)},d.prototype.start=function(){var a,b,c,d;if(this.boxes=this.element.getElementsByClassName(this.config.boxClass),this.boxes.length){if(this.disabled())return this.resetStyle();for(d=this.boxes,b=0,c=d.length;c>b;b++)a=d[b],this.applyStyle(a,!0);return window.addEventListener("scroll",this.scrollHandler,!1),window.addEventListener("resize",this.scrollHandler,!1),this.interval=setInterval(this.scrollCallback,50)}},d.prototype.stop=function(){return window.removeEventListener("scroll",this.scrollHandler,!1),window.removeEventListener("resize",this.scrollHandler,!1),null!=this.interval?clearInterval(this.interval):void 0},d.prototype.show=function(a){return this.applyStyle(a),a.className=""+a.className+" "+this.config.animateClass},d.prototype.applyStyle=function(a,b){var c,d,e;return d=a.getAttribute("data-wow-duration"),c=a.getAttribute("data-wow-delay"),e=a.getAttribute("data-wow-iteration"),this.animate(function(f){return function(){return f.customStyle(a,b,d,c,e)}}(this))},d.prototype.animate=function(){return"requestAnimationFrame"in window?function(a){return window.requestAnimationFrame(a)}:function(a){return a()}}(),d.prototype.resetStyle=function(){var a,b,c,d,e;for(d=this.boxes,e=[],b=0,c=d.length;c>b;b++)a=d[b],e.push(a.setAttribute("style","visibility: visible;"));return e},d.prototype.customStyle=function(a,b,c,d,e){return b&&this.cacheAnimationName(a),a.style.visibility=b?"hidden":"visible",c&&this.vendorSet(a.style,{animationDuration:c}),d&&this.vendorSet(a.style,{animationDelay:d}),e&&this.vendorSet(a.style,{animationIterationCount:e}),this.vendorSet(a.style,{animationName:b?"none":this.cachedAnimationName(a)}),a},d.prototype.vendors=["moz","webkit"],d.prototype.vendorSet=function(a,b){var c,d,e,f;f=[];for(c in b)d=b[c],a[""+c]=d,f.push(function(){var b,f,g,h;for(g=this.vendors,h=[],b=0,f=g.length;f>b;b++)e=g[b],h.push(a[""+e+c.charAt(0).toUpperCase()+c.substr(1)]=d);return h}.call(this));return f},d.prototype.vendorCSS=function(a,b){var c,d,e,f,g,h;for(d=window.getComputedStyle(a),c=d.getPropertyCSSValue(b),h=this.vendors,f=0,g=h.length;g>f;f++)e=h[f],c=c||d.getPropertyCSSValue("-"+e+"-"+b);return c},d.prototype.animationName=function(a){var b;try{b=this.vendorCSS(a,"animation-name").cssText}catch(c){b=window.getComputedStyle(a).getPropertyValue("animation-name")}return"none"===b?"":b},d.prototype.cacheAnimationName=function(a){return this.animationNameCache.set(a,this.animationName(a))},d.prototype.cachedAnimationName=function(a){return this.animationNameCache.get(a)},d.prototype.scrollHandler=function(){return this.scrolled=!0},d.prototype.scrollCallback=function(){var a;return this.scrolled&&(this.scrolled=!1,this.boxes=function(){var b,c,d,e;for(d=this.boxes,e=[],b=0,c=d.length;c>b;b++)a=d[b],a&&(this.isVisible(a)?this.show(a):e.push(a));return e}.call(this),!this.boxes.length)?this.stop():void 0},d.prototype.offsetTop=function(a){for(var b;void 0===a.offsetTop;)a=a.parentNode;for(b=a.offsetTop;a=a.offsetParent;)b+=a.offsetTop;return b},d.prototype.isVisible=function(a){var b,c,d,e,f;return c=a.getAttribute("data-wow-offset")||this.config.offset,f=window.pageYOffset,e=f+this.element.clientHeight-c,d=this.offsetTop(a),b=d+a.clientHeight,e>=d&&b>=f},d.prototype.util=function(){return this._util||(this._util=new a)},d.prototype.disabled=function(){return!this.config.mobile&&this.util().isMobile(navigator.userAgent)},d}()}).call(this); -------------------------------------------------------------------------------- /carvekit/web/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OPHoperHPO/image-background-remove-tool/f141a311af67fb1da64269c508a6d1f786420801/carvekit/web/utils/__init__.py -------------------------------------------------------------------------------- /carvekit/web/utils/init_utils.py: -------------------------------------------------------------------------------- 1 | from os import getenv 2 | from typing import Union 3 | 4 | from loguru import logger 5 | 6 | from carvekit.web.schemas.config import WebAPIConfig, MLConfig, AuthConfig 7 | from carvekit.api.interface import Interface 8 | from carvekit.ml.wrap.fba_matting import FBAMatting 9 | from carvekit.ml.wrap.u2net import U2NET 10 | from carvekit.ml.wrap.deeplab_v3 import DeepLabV3 11 | from carvekit.ml.wrap.basnet import BASNET 12 | from carvekit.ml.wrap.tracer_b7 import TracerUniversalB7 13 | 14 | from carvekit.pipelines.postprocessing import MattingMethod 15 | from carvekit.pipelines.preprocessing import PreprocessingStub 16 | from carvekit.trimap.generator import TrimapGenerator 17 | 18 | 19 | def init_config() -> WebAPIConfig: 20 | default_config = WebAPIConfig() 21 | config = WebAPIConfig( 22 | **dict( 23 | port=int(getenv("CARVEKIT_PORT", default_config.port)), 24 | host=getenv("CARVEKIT_HOST", default_config.host), 25 | ml=MLConfig( 26 | segmentation_network=getenv( 27 | "CARVEKIT_SEGMENTATION_NETWORK", 28 | default_config.ml.segmentation_network, 29 | ), 30 | preprocessing_method=getenv( 31 | "CARVEKIT_PREPROCESSING_METHOD", 32 | default_config.ml.preprocessing_method, 33 | ), 34 | postprocessing_method=getenv( 35 | "CARVEKIT_POSTPROCESSING_METHOD", 36 | default_config.ml.postprocessing_method, 37 | ), 38 | device=getenv("CARVEKIT_DEVICE", default_config.ml.device), 39 | batch_size_seg=int( 40 | getenv("CARVEKIT_BATCH_SIZE_SEG", default_config.ml.batch_size_seg) 41 | ), 42 | batch_size_matting=int( 43 | getenv( 44 | "CARVEKIT_BATCH_SIZE_MATTING", 45 | default_config.ml.batch_size_matting, 46 | ) 47 | ), 48 | seg_mask_size=int( 49 | getenv("CARVEKIT_SEG_MASK_SIZE", default_config.ml.seg_mask_size) 50 | ), 51 | matting_mask_size=int( 52 | getenv( 53 | "CARVEKIT_MATTING_MASK_SIZE", 54 | default_config.ml.matting_mask_size, 55 | ) 56 | ), 57 | fp16=bool(int(getenv("CARVEKIT_FP16", default_config.ml.fp16))), 58 | trimap_prob_threshold=int( 59 | getenv( 60 | "CARVEKIT_TRIMAP_PROB_THRESHOLD", 61 | default_config.ml.trimap_prob_threshold, 62 | ) 63 | ), 64 | trimap_dilation=int( 65 | getenv( 66 | "CARVEKIT_TRIMAP_DILATION", default_config.ml.trimap_dilation 67 | ) 68 | ), 69 | trimap_erosion=int( 70 | getenv("CARVEKIT_TRIMAP_EROSION", default_config.ml.trimap_erosion) 71 | ), 72 | ), 73 | auth=AuthConfig( 74 | auth=bool( 75 | int(getenv("CARVEKIT_AUTH_ENABLE", default_config.auth.auth)) 76 | ), 77 | admin_token=getenv( 78 | "CARVEKIT_ADMIN_TOKEN", default_config.auth.admin_token 79 | ), 80 | allowed_tokens=default_config.auth.allowed_tokens 81 | if getenv("CARVEKIT_ALLOWED_TOKENS") is None 82 | else getenv("CARVEKIT_ALLOWED_TOKENS").split(","), 83 | ), 84 | ) 85 | ) 86 | 87 | logger.info(f"Admin token for Web API is {config.auth.admin_token}") 88 | logger.debug(f"Running Web API with this config: {config.json()}") 89 | return config 90 | 91 | 92 | def init_interface(config: Union[WebAPIConfig, MLConfig]) -> Interface: 93 | if isinstance(config, WebAPIConfig): 94 | config = config.ml 95 | if config.segmentation_network == "u2net": 96 | seg_net = U2NET( 97 | device=config.device, 98 | batch_size=config.batch_size_seg, 99 | input_image_size=config.seg_mask_size, 100 | fp16=config.fp16, 101 | ) 102 | elif config.segmentation_network == "deeplabv3": 103 | seg_net = DeepLabV3( 104 | device=config.device, 105 | batch_size=config.batch_size_seg, 106 | input_image_size=config.seg_mask_size, 107 | fp16=config.fp16, 108 | ) 109 | elif config.segmentation_network == "basnet": 110 | seg_net = BASNET( 111 | device=config.device, 112 | batch_size=config.batch_size_seg, 113 | input_image_size=config.seg_mask_size, 114 | fp16=config.fp16, 115 | ) 116 | elif config.segmentation_network == "tracer_b7": 117 | seg_net = TracerUniversalB7( 118 | device=config.device, 119 | batch_size=config.batch_size_seg, 120 | input_image_size=config.seg_mask_size, 121 | fp16=config.fp16, 122 | ) 123 | else: 124 | seg_net = TracerUniversalB7( 125 | device=config.device, 126 | batch_size=config.batch_size_seg, 127 | input_image_size=config.seg_mask_size, 128 | fp16=config.fp16, 129 | ) 130 | 131 | if config.preprocessing_method == "stub": 132 | preprocessing = PreprocessingStub() 133 | elif config.preprocessing_method == "none": 134 | preprocessing = None 135 | else: 136 | preprocessing = None 137 | 138 | if config.postprocessing_method == "fba": 139 | fba = FBAMatting( 140 | device=config.device, 141 | batch_size=config.batch_size_matting, 142 | input_tensor_size=config.matting_mask_size, 143 | fp16=config.fp16, 144 | ) 145 | trimap_generator = TrimapGenerator( 146 | prob_threshold=config.trimap_prob_threshold, 147 | kernel_size=config.trimap_dilation, 148 | erosion_iters=config.trimap_erosion, 149 | ) 150 | postprocessing = MattingMethod( 151 | device=config.device, matting_module=fba, trimap_generator=trimap_generator 152 | ) 153 | 154 | elif config.postprocessing_method == "none": 155 | postprocessing = None 156 | else: 157 | postprocessing = None 158 | 159 | interface = Interface( 160 | pre_pipe=preprocessing, 161 | post_pipe=postprocessing, 162 | seg_pipe=seg_net, 163 | device=config.device, 164 | ) 165 | return interface 166 | -------------------------------------------------------------------------------- /carvekit/web/utils/net_utils.py: -------------------------------------------------------------------------------- 1 | import socket 2 | import struct 3 | from typing import Optional 4 | from urllib.parse import urlparse 5 | 6 | 7 | def is_loopback(address): 8 | host: Optional[str] = None 9 | 10 | try: 11 | parsed_url = urlparse(address) 12 | host = parsed_url.hostname 13 | except ValueError: 14 | return False # url is not even a url 15 | 16 | loopback_checker = { 17 | socket.AF_INET: lambda x: struct.unpack("!I", socket.inet_aton(x))[0] 18 | >> (32 - 8) 19 | == 127, 20 | socket.AF_INET6: lambda x: x == "::1", 21 | } 22 | for family in (socket.AF_INET, socket.AF_INET6): 23 | try: 24 | r = socket.getaddrinfo(host, None, family, socket.SOCK_STREAM) 25 | except socket.gaierror: 26 | continue 27 | for family, _, _, _, sockaddr in r: 28 | if loopback_checker[family](sockaddr[0]): 29 | return True 30 | 31 | if host in ("localhost",): 32 | return True 33 | 34 | return False 35 | -------------------------------------------------------------------------------- /carvekit/web/utils/task_queue.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import threading 3 | import time 4 | import uuid 5 | from typing import Optional 6 | 7 | from loguru import logger 8 | 9 | from carvekit.api.interface import Interface 10 | from carvekit.web.schemas.config import WebAPIConfig 11 | from carvekit.web.utils.init_utils import init_interface 12 | from carvekit.web.other.removebg import process_remove_bg 13 | 14 | 15 | class MLProcessor(threading.Thread): 16 | """Simple ml task queue processor""" 17 | 18 | def __init__(self, api_config: WebAPIConfig): 19 | super().__init__() 20 | self.api_config = api_config 21 | self.interface: Optional[Interface] = None 22 | self.jobs = {} 23 | self.completed_jobs = {} 24 | 25 | def run(self): 26 | """Starts listening for new jobs.""" 27 | unused_completed_jobs_timer = time.time() 28 | if self.interface is None: 29 | self.interface = init_interface(self.api_config) 30 | while True: 31 | # Clear unused completed jobs every hour 32 | if time.time() - unused_completed_jobs_timer > 60: 33 | self.clear_old_completed_jobs() 34 | unused_completed_jobs_timer = time.time() 35 | 36 | if len(self.jobs.keys()) >= 1: 37 | id = list(self.jobs.keys())[0] 38 | data = self.jobs[id] 39 | # TODO add pydantic scheme here 40 | response = process_remove_bg( 41 | self.interface, data[0], data[1], data[2], data[3] 42 | ) 43 | self.completed_jobs[id] = [response, time.time()] 44 | try: 45 | del self.jobs[id] 46 | except KeyError or NameError as e: 47 | logger.error(f"Something went wrong with Task Queue: {str(e)}") 48 | gc.collect() 49 | else: 50 | time.sleep(1) 51 | continue 52 | 53 | def clear_old_completed_jobs(self): 54 | """Clears old completed jobs""" 55 | 56 | if len(self.completed_jobs.keys()) >= 1: 57 | for job_id in self.completed_jobs.keys(): 58 | job_finished_time = self.completed_jobs[job_id][1] 59 | if time.time() - job_finished_time > 3600: 60 | try: 61 | del self.completed_jobs[job_id] 62 | except KeyError or NameError as e: 63 | logger.error(f"Something went wrong with Task Queue: {str(e)}") 64 | gc.collect() 65 | 66 | def job_status(self, id: str) -> str: 67 | """ 68 | Returns current job status 69 | 70 | Args: 71 | id: id of the job 72 | 73 | Returns: 74 | Current job status for specified id. Job status can be [finished, wait, not_found] 75 | """ 76 | if id in self.completed_jobs.keys(): 77 | return "finished" 78 | elif id in self.jobs.keys(): 79 | return "wait" 80 | else: 81 | return "not_found" 82 | 83 | def job_result(self, id: str): 84 | """ 85 | Returns job processing result. 86 | 87 | Args: 88 | id: id of the job 89 | 90 | Returns: 91 | job processing result. 92 | """ 93 | if id in self.completed_jobs.keys(): 94 | data = self.completed_jobs[id][0] 95 | try: 96 | del self.completed_jobs[id] 97 | except KeyError or NameError: 98 | pass 99 | return data 100 | else: 101 | return False 102 | 103 | def job_create(self, data: list): 104 | """ 105 | Send job to ML Processor 106 | 107 | Args: 108 | data: data object 109 | """ 110 | if self.is_alive() is False: 111 | self.start() 112 | id = uuid.uuid4().hex 113 | self.jobs[id] = data 114 | return id 115 | -------------------------------------------------------------------------------- /conftest.py: -------------------------------------------------------------------------------- 1 | """ 2 | Source url: https://github.com/OPHoperHPO/image-background-remove-tool 3 | Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. 4 | License: Apache License 2.0 5 | """ 6 | from pathlib import Path 7 | 8 | import pytest 9 | import torch 10 | from PIL import Image 11 | from typing import Callable, Tuple, List, Union, Optional, Any 12 | 13 | from carvekit.api.high import HiInterface 14 | from carvekit.api.interface import Interface 15 | from carvekit.trimap.cv_gen import CV2TrimapGenerator 16 | from carvekit.trimap.generator import TrimapGenerator 17 | from carvekit.utils.image_utils import convert_image, load_image 18 | from carvekit.pipelines.postprocessing import MattingMethod 19 | from carvekit.pipelines.preprocessing import PreprocessingStub 20 | 21 | from carvekit.ml.wrap.u2net import U2NET 22 | from carvekit.ml.wrap.basnet import BASNET 23 | from carvekit.ml.wrap.fba_matting import FBAMatting 24 | from carvekit.ml.wrap.deeplab_v3 import DeepLabV3 25 | from carvekit.ml.wrap.tracer_b7 import TracerUniversalB7 26 | 27 | 28 | @pytest.fixture() 29 | def u2net_model() -> Callable[[bool], U2NET]: 30 | return lambda fb16: U2NET( 31 | layers_cfg="full", 32 | device="cuda" if torch.cuda.is_available() else "cpu", 33 | input_image_size=320, 34 | batch_size=10, 35 | load_pretrained=True, 36 | fp16=fb16, 37 | ) 38 | 39 | 40 | @pytest.fixture() 41 | def tracer_model() -> Callable[[bool], TracerUniversalB7]: 42 | return lambda fb16: TracerUniversalB7( 43 | device="cuda" if torch.cuda.is_available() else "cpu", 44 | input_image_size=320, 45 | batch_size=10, 46 | load_pretrained=True, 47 | fp16=fb16, 48 | ) 49 | 50 | 51 | @pytest.fixture() 52 | def trimap_instance() -> Callable[[], TrimapGenerator]: 53 | return lambda: TrimapGenerator() 54 | 55 | 56 | @pytest.fixture() 57 | def cv2_trimap_instance() -> Callable[[], CV2TrimapGenerator]: 58 | return lambda: CV2TrimapGenerator(kernel_size=30, erosion_iters=0) 59 | 60 | 61 | @pytest.fixture() 62 | def preprocessing_stub_instance() -> Callable[[], PreprocessingStub]: 63 | return lambda: PreprocessingStub() 64 | 65 | 66 | @pytest.fixture() 67 | def matting_method_instance(fba_model, trimap_instance): 68 | return lambda: MattingMethod( 69 | matting_module=fba_model(False), 70 | trimap_generator=trimap_instance(), 71 | device="cpu", 72 | ) 73 | 74 | 75 | @pytest.fixture() 76 | def high_interface_instance() -> Callable[[], HiInterface]: 77 | return lambda: HiInterface( 78 | batch_size_seg=5, 79 | batch_size_matting=1, 80 | device="cuda" if torch.cuda.is_available() else "cpu", 81 | seg_mask_size=320, 82 | matting_mask_size=2048, 83 | ) 84 | 85 | 86 | @pytest.fixture() 87 | def interface_instance( 88 | u2net_model, preprocessing_stub_instance, matting_method_instance 89 | ) -> Callable[[], Interface]: 90 | return lambda: Interface( 91 | u2net_model(False), 92 | pre_pipe=preprocessing_stub_instance(), 93 | post_pipe=matting_method_instance(), 94 | device="cuda" if torch.cuda.is_available() else "cpu", 95 | ) 96 | 97 | 98 | @pytest.fixture() 99 | def fba_model() -> Callable[[bool], FBAMatting]: 100 | return lambda fp16: FBAMatting( 101 | device="cuda" if torch.cuda.is_available() else "cpu", 102 | input_tensor_size=1024, 103 | batch_size=2, 104 | load_pretrained=True, 105 | fp16=fp16, 106 | ) 107 | 108 | 109 | @pytest.fixture() 110 | def deeplabv3_model() -> Callable[[bool], DeepLabV3]: 111 | return lambda fp16: DeepLabV3( 112 | device="cuda" if torch.cuda.is_available() else "cpu", 113 | batch_size=10, 114 | load_pretrained=True, 115 | fp16=fp16, 116 | ) 117 | 118 | 119 | @pytest.fixture() 120 | def basnet_model() -> Callable[[bool], BASNET]: 121 | return lambda fp16: BASNET( 122 | device="cuda" if torch.cuda.is_available() else "cpu", 123 | input_image_size=320, 124 | batch_size=10, 125 | load_pretrained=True, 126 | fp16=fp16, 127 | ) 128 | 129 | 130 | @pytest.fixture() 131 | def image_str(image_path) -> str: 132 | return str(image_path.absolute()) 133 | 134 | 135 | @pytest.fixture() 136 | def image_path() -> Path: 137 | return Path(__file__).parent.joinpath("tests").joinpath("data", "cat.jpg") 138 | 139 | 140 | @pytest.fixture() 141 | def image_mask(image_path) -> Image.Image: 142 | return Image.open(image_path.with_name("cat_mask").with_suffix(".png")) 143 | 144 | 145 | @pytest.fixture() 146 | def image_trimap(image_path) -> Image.Image: 147 | return Image.open(image_path.with_name("cat_trimap").with_suffix(".png")).convert( 148 | "L" 149 | ) 150 | 151 | 152 | @pytest.fixture() 153 | def image_pil(image_path) -> Image.Image: 154 | return Image.open(image_path) 155 | 156 | 157 | @pytest.fixture() 158 | def black_image_pil() -> Image.Image: 159 | return Image.new("RGB", (512, 512)) 160 | 161 | 162 | @pytest.fixture() 163 | def converted_pil_image(image_pil) -> Image.Image: 164 | return convert_image(load_image(image_pil)) 165 | 166 | 167 | @pytest.fixture() 168 | def available_models( 169 | u2net_model, 170 | deeplabv3_model, 171 | basnet_model, 172 | preprocessing_stub_instance, 173 | matting_method_instance, 174 | ) -> Tuple[ 175 | List[Union[Callable[[], U2NET], Callable[[], DeepLabV3], Callable[[], BASNET]]], 176 | List[Optional[Callable[[], PreprocessingStub]]], 177 | List[Union[Optional[Callable[[], MattingMethod]], Any]], 178 | ]: 179 | models = [u2net_model, deeplabv3_model, basnet_model] 180 | pre_pipes = [None, preprocessing_stub_instance] 181 | post_pipes = [None, matting_method_instance] 182 | return models, pre_pipes, post_pipes 183 | -------------------------------------------------------------------------------- /docker-compose.cpu.yml: -------------------------------------------------------------------------------- 1 | services: 2 | carvekit_api: 3 | image: anodev/carvekit:latest-cpu 4 | ports: 5 | - "5000:5000" # 5000 6 | environment: 7 | - CARVEKIT_PORT=5000 8 | - CARVEKIT_HOST=0.0.0.0 9 | - CARVEKIT_SEGMENTATION_NETWORK=tracer_b7 # can be u2net, tracer_b7, basnet, deeplabv3 10 | - CARVEKIT_PREPROCESSING_METHOD=none # can be none, stub 11 | - CARVEKIT_POSTPROCESSING_METHOD=fba # can be none, fba 12 | - CARVEKIT_DEVICE=cpu # can be cuda (req. cuda docker image), cpu 13 | - CARVEKIT_BATCH_SIZE_SEG=5 # Number of images processed per one segmentation nn call. NOT USED IF WEB API IS USED 14 | - CARVEKIT_BATCH_SIZE_MATTING=1 # Number of images processed per one matting nn call. NOT USED IF WEB API IS USED 15 | - CARVEKIT_SEG_MASK_SIZE=640 # The size of the input image for the segmentation neural network. 16 | - CARVEKIT_MATTING_MASK_SIZE=2048 # The size of the input image for the matting neural network. 17 | - CARVEKIT_FP16=0 # Enables FP16 mode (Only CUDA at the moment) 18 | - CARVEKIT_TRIMAP_PROB_THRESHOLD=231 # Probability threshold at which the prob_filter and prob_as_unknown_area operations will be applied 19 | - CARVEKIT_TRIMAP_DILATION=30 # The size of the offset radius from the object mask in pixels when forming an unknown area 20 | - CARVEKIT_TRIMAP_EROSION=5 # The number of iterations of erosion that the object's mask will be subjected to before forming an unknown area 21 | - CARVEKIT_AUTH_ENABLE=1 # Enables authentication by tokens 22 | # Tokens will be generated automatically every time the container is restarted if these ENV is not set. 23 | #- CARVEKIT_ADMIN_TOKEN=admin 24 | #- CARVEKIT_ALLOWED_TOKENS=test_token1,test_token2 25 | -------------------------------------------------------------------------------- /docker-compose.cuda.yml: -------------------------------------------------------------------------------- 1 | services: 2 | carvekit_api: 3 | image: anodev/carvekit:latest-cuda 4 | ports: 5 | - "5000:5000" # 5000 6 | environment: 7 | - CARVEKIT_PORT=5000 8 | - CARVEKIT_HOST=0.0.0.0 9 | - CARVEKIT_SEGMENTATION_NETWORK=tracer_b7 # can be u2net, tracer_b7, basnet, deeplabv3 10 | - CARVEKIT_PREPROCESSING_METHOD=none # can be none, stub 11 | - CARVEKIT_POSTPROCESSING_METHOD=fba # can be none, fba 12 | - CARVEKIT_DEVICE=cuda # can be cuda (req. cuda docker image), cpu 13 | - CARVEKIT_BATCH_SIZE_SEG=5 # Number of images processed per one segmentation nn call. NOT USED IF WEB API IS USED 14 | - CARVEKIT_BATCH_SIZE_MATTING=1 # Number of images processed per one matting nn call. NOT USED IF WEB API IS USED 15 | - CARVEKIT_SEG_MASK_SIZE=640 # The size of the input image for the segmentation neural network. 16 | - CARVEKIT_MATTING_MASK_SIZE=2048 # The size of the input image for the matting neural network. 17 | - CARVEKIT_FP16=0 # Enables FP16 mode (Only CUDA at the moment) 18 | - CARVEKIT_TRIMAP_PROB_THRESHOLD=231 # Probability threshold at which the prob_filter and prob_as_unknown_area operations will be applied 19 | - CARVEKIT_TRIMAP_DILATION=30 # The size of the offset radius from the object mask in pixels when forming an unknown area 20 | - CARVEKIT_TRIMAP_EROSION=5 # The number of iterations of erosion that the object's mask will be subjected to before forming an unknown area 21 | - CARVEKIT_AUTH_ENABLE=1 # Enables authentication by tokens 22 | # Tokens will be generated automatically every time the container is restarted if these ENV is not set. 23 | #- CARVEKIT_ADMIN_TOKEN=admin 24 | #- CARVEKIT_ALLOWED_TOKENS=test_token1,test_token2 25 | deploy: 26 | resources: 27 | reservations: 28 | devices: 29 | - driver: nvidia 30 | device_ids: ['0'] 31 | capabilities: [gpu] -------------------------------------------------------------------------------- /docs/CREDITS.md: -------------------------------------------------------------------------------- 1 | # Credits: 2 | ## Disclaimer for pretrained models: 3 | All rights to the pretrained models used in this project belong to their authors. \ 4 | I do not vouch for their quality and do not claim to be licensed to use any model. \ 5 | It is your responsibility to determine if you have permission to use the pretrained model under the license for the dataset it was trained on or licensed under. \ 6 | Any use of the pretrained model is strictly regulated by the licenses under which the model is distributed. \ 7 | If you own the model and want to update it (file, segmentation quality information, etc.) or don't want your model to be included in this tool, please get in touch through a GitHub issue. 8 | 9 | ## Photos: 10 | The photos in the `docs/imgs/input/` and `docs/code_examples/python/input/` folders were taken from the Pexels website. \ 11 | The original photos in the `docs/imgs/compare` folder were taken from the Unsplash site. \ 12 | All images are copyrighted by their authors. 13 | 14 | ## References: 15 | 1. https://pytorch.org/hub/pytorch_vision_deeplabv3_resnet101/ 16 | 2. https://github.com/NathanUA/U-2-Net 17 | 3. https://github.com/NathanUA/BASNet 18 | 4. https://github.com/MarcoForte/FBA_Matting 19 | 5. https://arxiv.org/abs/1706.05587 20 | 6. https://arxiv.org/pdf/2005.09007.pdf 21 | 7. http://openaccess.thecvf.com/content_CVPR_2019/html/Qin_BASNet_Boundary-Aware_Salient_Object_Detection_CVPR_2019_paper.html 22 | 8. https://arxiv.org/abs/2003.07711 23 | 9. https://arxiv.org/abs/1506.01497 24 | 10. https://arxiv.org/abs/1703.06870 25 | 11. https://github.com/Karel911/TRACER 26 | 12. https://arxiv.org/abs/2112.07380 27 | -------------------------------------------------------------------------------- /docs/code_examples/python/http_api_lib.py: -------------------------------------------------------------------------------- 1 | """ 2 | Source url: https://github.com/OPHoperHPO/image-background-remove-tool 3 | Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. 4 | License: Apache License 2.0 5 | """ 6 | # Install this library before using this example! 7 | # https://github.com/OPHoperHPO/remove-bg-api 8 | import remove_bg_api 9 | from pathlib import Path 10 | 11 | remove_bg_api.API_URL = "http://localhost:5000/api" # Change the endpoint url 12 | removebg = remove_bg_api.RemoveBg("test") 13 | 14 | settings = { # API settings. See https://www.remove.bg/api for more details. 15 | "size": "preview", # ["preview", "full", "auto", "medium", "hd", "4k", "small", "regular"] 16 | "type": "auto", # ["auto", "person", "product", "car"] 17 | "format": "auto", # ["auto", "png", "jpg", "zip"] 18 | "roi": "", # {}% {}% {}% {}% or {}px {}px {}px {}px 19 | "crop": False, # True or False 20 | "crop_margin": "0px", # {}% or {}px 21 | "scale": "original", # "{}%" or "original" 22 | "position": "original", # "original" "center", or {}% 23 | "channels": "rgba", # "rgba" or "alpha" 24 | "add_shadow": "false", # Not supported at the moment 25 | "semitransparency": "false", # Not supported at the moment 26 | "bg_color": "", # "81d4fa" or "red" or any other color 27 | "bg_image_url": "", # URL 28 | } 29 | 30 | removebg.remove_bg_file( 31 | str(Path("images/4.jpg").absolute()), 32 | raw=False, 33 | out_path=str(Path("./4.png").absolute()), 34 | data=settings, 35 | ) 36 | -------------------------------------------------------------------------------- /docs/code_examples/python/http_api_requests.py: -------------------------------------------------------------------------------- 1 | """ 2 | Source url: https://github.com/OPHoperHPO/image-background-remove-tool 3 | Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. 4 | License: Apache License 2.0 5 | """ 6 | # Requires "requests" to be installed 7 | import requests 8 | from pathlib import Path 9 | 10 | response = requests.post( 11 | "http://localhost:5000/api/removebg", 12 | files={"image_file": Path("images/4.jpg").read_bytes()}, 13 | data={"size": "auto"}, 14 | headers={"X-Api-Key": "test"}, 15 | ) 16 | if response.status_code == 200: 17 | Path("image_without_bg.png").write_bytes(response.content) 18 | else: 19 | print("Error:", response.status_code, response.text) 20 | -------------------------------------------------------------------------------- /docs/code_examples/python/images/4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OPHoperHPO/image-background-remove-tool/f141a311af67fb1da64269c508a6d1f786420801/docs/code_examples/python/images/4.jpg -------------------------------------------------------------------------------- /docs/imgs/compare/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OPHoperHPO/image-background-remove-tool/f141a311af67fb1da64269c508a6d1f786420801/docs/imgs/compare/1.png -------------------------------------------------------------------------------- /docs/imgs/compare/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OPHoperHPO/image-background-remove-tool/f141a311af67fb1da64269c508a6d1f786420801/docs/imgs/compare/2.png -------------------------------------------------------------------------------- /docs/imgs/compare/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OPHoperHPO/image-background-remove-tool/f141a311af67fb1da64269c508a6d1f786420801/docs/imgs/compare/3.png -------------------------------------------------------------------------------- /docs/imgs/compare/readme.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OPHoperHPO/image-background-remove-tool/f141a311af67fb1da64269c508a6d1f786420801/docs/imgs/compare/readme.jpg -------------------------------------------------------------------------------- /docs/imgs/input/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OPHoperHPO/image-background-remove-tool/f141a311af67fb1da64269c508a6d1f786420801/docs/imgs/input/1.jpg -------------------------------------------------------------------------------- /docs/imgs/input/1_bg_removed.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OPHoperHPO/image-background-remove-tool/f141a311af67fb1da64269c508a6d1f786420801/docs/imgs/input/1_bg_removed.png -------------------------------------------------------------------------------- /docs/imgs/input/2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OPHoperHPO/image-background-remove-tool/f141a311af67fb1da64269c508a6d1f786420801/docs/imgs/input/2.jpg -------------------------------------------------------------------------------- /docs/imgs/input/2_bg_removed.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OPHoperHPO/image-background-remove-tool/f141a311af67fb1da64269c508a6d1f786420801/docs/imgs/input/2_bg_removed.png -------------------------------------------------------------------------------- /docs/imgs/input/3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OPHoperHPO/image-background-remove-tool/f141a311af67fb1da64269c508a6d1f786420801/docs/imgs/input/3.jpg -------------------------------------------------------------------------------- /docs/imgs/input/3_bg_removed.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OPHoperHPO/image-background-remove-tool/f141a311af67fb1da64269c508a6d1f786420801/docs/imgs/input/3_bg_removed.png -------------------------------------------------------------------------------- /docs/imgs/input/4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OPHoperHPO/image-background-remove-tool/f141a311af67fb1da64269c508a6d1f786420801/docs/imgs/input/4.jpg -------------------------------------------------------------------------------- /docs/imgs/input/4_bg_removed.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OPHoperHPO/image-background-remove-tool/f141a311af67fb1da64269c508a6d1f786420801/docs/imgs/input/4_bg_removed.png -------------------------------------------------------------------------------- /docs/imgs/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OPHoperHPO/image-background-remove-tool/f141a311af67fb1da64269c508a6d1f786420801/docs/imgs/logo.png -------------------------------------------------------------------------------- /docs/imgs/screenshot/docs_fastapi.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OPHoperHPO/image-background-remove-tool/f141a311af67fb1da64269c508a6d1f786420801/docs/imgs/screenshot/docs_fastapi.png -------------------------------------------------------------------------------- /docs/imgs/screenshot/frontend.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OPHoperHPO/image-background-remove-tool/f141a311af67fb1da64269c508a6d1f786420801/docs/imgs/screenshot/frontend.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | requests>=2.31.0 2 | torch>=2.2.2 3 | Pillow>=10.3.0 4 | typing>=3.7.4.3 5 | torchvision>=0.17.2 6 | opencv-python>=4.9.0.80 7 | numpy>=1.26.4 8 | loguru>=0.7.2 9 | uvicorn>=0.29.0 10 | fastapi>=0.110.1 11 | pydantic>=2.6.4 12 | click>=8.1.7 13 | tqdm>=4.66.2 14 | setuptools>=69.2.0 15 | aiofiles>=23.2.1 16 | python-multipart>=0.0.9 -------------------------------------------------------------------------------- /requirements_dev.txt: -------------------------------------------------------------------------------- 1 | pre-commit>=3.7.0 2 | 3 | -------------------------------------------------------------------------------- /requirements_test.txt: -------------------------------------------------------------------------------- 1 | pytest>=8.1.1 2 | 3 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """ 2 | Source url: https://github.com/OPHoperHPO/image-background-remove-tool 3 | Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. 4 | License: Apache License 2.0 5 | """ 6 | import os 7 | import re 8 | 9 | from setuptools import setup, find_packages 10 | 11 | from carvekit import version 12 | 13 | IS_COLAB_PACKAGE = os.getenv("COLAB_PACKAGE_RELEASE", None) 14 | 15 | 16 | def read(filename: str): 17 | filepath = os.path.join(os.path.dirname(__file__), filename) 18 | file = open(filepath, "r", encoding="utf-8") 19 | return file.read() 20 | 21 | 22 | def req_file(filename: str, folder: str = "."): 23 | with open(os.path.join(folder, filename), encoding="utf-8") as f: 24 | content = f.readlines() 25 | # you may also want to remove whitespace characters 26 | # Example: `\n` at the end of each line 27 | if os.getenv("COLAB_PACKAGE_RELEASE") is not None: 28 | return [re.sub("(>=.*)|(~=.*)|(==.*)|(typing.*)", "", x.strip()) for x in content] 29 | return [x.strip() for x in content] 30 | 31 | 32 | setup( 33 | name="carvekit" if IS_COLAB_PACKAGE is None else "carvekit_colab", 34 | version=version, 35 | author="Nikita Selin (Anodev)", 36 | author_email="farvard34@gmail.com", 37 | description="Open-Source background removal framework", 38 | long_description=read("README.md"), 39 | long_description_content_type="text/markdown", 40 | license="Apache License v2.0", 41 | keywords=[ 42 | "ml", 43 | "carvekit", 44 | "background removal", 45 | "neural networks", 46 | "machine learning", 47 | "remove bg", 48 | ], 49 | url="https://github.com/OPHoperHPO/image-background-remove-tool", 50 | packages=find_packages(), 51 | scripts=[], 52 | install_requires=req_file("requirements.txt"), 53 | include_package_data=True, 54 | zip_safe=False, 55 | entry_points={ 56 | "console_scripts": [ 57 | "carvekit=carvekit:__main__.removebg", 58 | ], 59 | }, 60 | python_requires=">=3.8" if IS_COLAB_PACKAGE is None else ">=3.6", 61 | classifiers=[ 62 | "Development Status :: 5 - Production/Stable", 63 | "Environment :: Web Environment", 64 | "Intended Audience :: Developers", 65 | "License :: OSI Approved :: Apache Software License", 66 | "Natural Language :: English", 67 | "Operating System :: OS Independent", 68 | "Programming Language :: Python :: 3.10", 69 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 70 | ], 71 | ) 72 | -------------------------------------------------------------------------------- /tests/data/cat.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OPHoperHPO/image-background-remove-tool/f141a311af67fb1da64269c508a6d1f786420801/tests/data/cat.JPG -------------------------------------------------------------------------------- /tests/data/cat.MP3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OPHoperHPO/image-background-remove-tool/f141a311af67fb1da64269c508a6d1f786420801/tests/data/cat.MP3 -------------------------------------------------------------------------------- /tests/data/cat.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OPHoperHPO/image-background-remove-tool/f141a311af67fb1da64269c508a6d1f786420801/tests/data/cat.jpg -------------------------------------------------------------------------------- /tests/data/cat.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OPHoperHPO/image-background-remove-tool/f141a311af67fb1da64269c508a6d1f786420801/tests/data/cat.mp3 -------------------------------------------------------------------------------- /tests/data/cat_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OPHoperHPO/image-background-remove-tool/f141a311af67fb1da64269c508a6d1f786420801/tests/data/cat_mask.png -------------------------------------------------------------------------------- /tests/data/cat_trimap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OPHoperHPO/image-background-remove-tool/f141a311af67fb1da64269c508a6d1f786420801/tests/data/cat_trimap.png -------------------------------------------------------------------------------- /tests/test_basnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Source url: https://github.com/OPHoperHPO/image-background-remove-tool 3 | Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. 4 | License: Apache License 2.0 5 | """ 6 | 7 | import torch 8 | from PIL import Image 9 | 10 | from carvekit.ml.wrap.basnet import BASNET 11 | 12 | 13 | def test_init(): 14 | BASNET(input_image_size=[320, 320], load_pretrained=True) 15 | BASNET(load_pretrained=False) 16 | 17 | 18 | def test_preprocessing(basnet_model, converted_pil_image, black_image_pil): 19 | basnet_model = basnet_model(False) 20 | assert ( 21 | isinstance( 22 | basnet_model.data_preprocessing(converted_pil_image), torch.FloatTensor 23 | ) 24 | is True 25 | ) 26 | assert ( 27 | isinstance(basnet_model.data_preprocessing(black_image_pil), torch.FloatTensor) 28 | is True 29 | ) 30 | 31 | 32 | def test_postprocessing(basnet_model, converted_pil_image, black_image_pil): 33 | basnet_model = basnet_model(False) 34 | assert isinstance( 35 | basnet_model.data_postprocessing( 36 | torch.ones((1, 320, 320), dtype=torch.float64), converted_pil_image 37 | ), 38 | Image.Image, 39 | ) 40 | 41 | 42 | def test_seg(basnet_model, image_pil, image_str, image_path, black_image_pil): 43 | basnet_model = basnet_model(False) 44 | basnet_model([image_pil]) 45 | basnet_model([image_pil, image_str, image_path, black_image_pil]) 46 | 47 | 48 | def test_seg_fp12(basnet_model, image_pil, image_str, image_path, black_image_pil): 49 | basnet_model = basnet_model(True) 50 | basnet_model([image_pil]) 51 | basnet_model([image_pil, image_str, image_path, black_image_pil]) 52 | -------------------------------------------------------------------------------- /tests/test_deeplabv3.py: -------------------------------------------------------------------------------- 1 | """ 2 | Source url: https://github.com/OPHoperHPO/image-background-remove-tool 3 | Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. 4 | License: Apache License 2.0 5 | """ 6 | 7 | import torch 8 | from PIL import Image 9 | 10 | from carvekit.ml.wrap.deeplab_v3 import DeepLabV3 11 | 12 | 13 | def test_init(): 14 | DeepLabV3(load_pretrained=True) 15 | DeepLabV3(load_pretrained=False).to("cpu") 16 | DeepLabV3(input_image_size=[128, 256]) 17 | 18 | 19 | def test_preprocessing(deeplabv3_model, converted_pil_image, black_image_pil): 20 | deeplabv3_model = deeplabv3_model(False) 21 | assert ( 22 | isinstance( 23 | deeplabv3_model.data_preprocessing(converted_pil_image), torch.FloatTensor 24 | ) 25 | is True 26 | ) 27 | assert ( 28 | isinstance( 29 | deeplabv3_model.data_preprocessing(black_image_pil), torch.FloatTensor 30 | ) 31 | is True 32 | ) 33 | 34 | 35 | def test_postprocessing(deeplabv3_model, converted_pil_image, black_image_pil): 36 | deeplabv3_model = deeplabv3_model(False) 37 | assert isinstance( 38 | deeplabv3_model.data_postprocessing( 39 | torch.ones((320, 320), dtype=torch.float64), converted_pil_image 40 | ), 41 | Image.Image, 42 | ) 43 | 44 | 45 | def test_seg(deeplabv3_model, image_pil, image_str, image_path, black_image_pil): 46 | deeplabv3_model = deeplabv3_model(False) 47 | deeplabv3_model([image_pil]) 48 | deeplabv3_model([image_pil, image_str, image_path, black_image_pil]) 49 | 50 | 51 | def test_seg_with_fp12( 52 | deeplabv3_model, image_pil, image_str, image_path, black_image_pil 53 | ): 54 | deeplabv3_model = deeplabv3_model(True) 55 | deeplabv3_model([image_pil]) 56 | deeplabv3_model([image_pil, image_str, image_path, black_image_pil]) 57 | -------------------------------------------------------------------------------- /tests/test_fba.py: -------------------------------------------------------------------------------- 1 | """ 2 | Source url: https://github.com/OPHoperHPO/image-background-remove-tool 3 | Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. 4 | License: Apache License 2.0 5 | """ 6 | 7 | import pytest 8 | import torch 9 | from PIL import Image 10 | 11 | from carvekit.ml.wrap.fba_matting import FBAMatting 12 | 13 | 14 | def test_init(): 15 | FBAMatting(load_pretrained=True) 16 | FBAMatting(load_pretrained=False) 17 | FBAMatting(input_tensor_size=[128, 256]) 18 | 19 | 20 | def test_preprocessing(fba_model, converted_pil_image, black_image_pil, image_mask): 21 | fba_model = fba_model(False) 22 | assert ( 23 | isinstance( 24 | fba_model.data_preprocessing(converted_pil_image)[0], torch.FloatTensor 25 | ) 26 | is True 27 | ) 28 | assert ( 29 | isinstance(fba_model.data_preprocessing(black_image_pil)[0], torch.FloatTensor) 30 | is True 31 | ) 32 | assert ( 33 | isinstance(fba_model.data_preprocessing(image_mask)[0], torch.FloatTensor) 34 | is True 35 | ) 36 | with pytest.raises(ValueError): 37 | assert ( 38 | isinstance( 39 | fba_model.data_preprocessing(Image.new("P", (512, 512)))[0], 40 | torch.FloatTensor, 41 | ) 42 | is True 43 | ) 44 | fba_model = FBAMatting( 45 | device="cuda" if torch.cuda.is_available() else "cpu", 46 | input_tensor_size=1024, 47 | batch_size=1, 48 | load_pretrained=True, 49 | ) 50 | assert ( 51 | isinstance( 52 | fba_model.data_preprocessing(converted_pil_image)[0], torch.FloatTensor 53 | ) 54 | is True 55 | ) 56 | assert ( 57 | isinstance(fba_model.data_preprocessing(black_image_pil)[0], torch.FloatTensor) 58 | is True 59 | ) 60 | assert ( 61 | isinstance(fba_model.data_preprocessing(image_mask)[0], torch.FloatTensor) 62 | is True 63 | ) 64 | with pytest.raises(ValueError): 65 | assert ( 66 | isinstance( 67 | fba_model.data_preprocessing(Image.new("P", (512, 512)))[0], 68 | torch.FloatTensor, 69 | ) 70 | is True 71 | ) 72 | 73 | 74 | def test_postprocessing(fba_model, converted_pil_image, black_image_pil): 75 | fba_model = fba_model(False) 76 | assert isinstance( 77 | fba_model.data_postprocessing( 78 | torch.ones((7, 320, 320), dtype=torch.float64), black_image_pil.convert("L") 79 | ), 80 | Image.Image, 81 | ) 82 | with pytest.raises(ValueError): 83 | assert isinstance( 84 | fba_model.data_postprocessing( 85 | torch.ones((7, 320, 320), dtype=torch.float64), 86 | black_image_pil.convert("RGBA"), 87 | ), 88 | Image.Image, 89 | ) 90 | 91 | 92 | def test_seg( 93 | fba_model, image_pil, image_str, image_path, black_image_pil, image_trimap 94 | ): 95 | fba_model = fba_model(False) 96 | fba_model([image_pil], [image_trimap]) 97 | fba_model( 98 | [image_pil, image_str, image_path], [image_trimap, image_trimap, image_trimap] 99 | ) 100 | fba_model( 101 | [Image.new("RGB", (512, 512)), Image.new("RGB", (512, 512))], 102 | [Image.new("L", (512, 512)), Image.new("L", (512, 512))], 103 | ) 104 | with pytest.raises(ValueError): 105 | fba_model([image_pil], [image_trimap, image_trimap]) 106 | 107 | 108 | def test_seg_with_fp12( 109 | fba_model, image_pil, image_str, image_path, black_image_pil, image_trimap 110 | ): 111 | fba_model = fba_model(True) 112 | fba_model([image_pil], [image_trimap]) 113 | fba_model( 114 | [image_pil, image_str, image_path], [image_trimap, image_trimap, image_trimap] 115 | ) 116 | fba_model( 117 | [Image.new("RGB", (512, 512)), Image.new("RGB", (512, 512))], 118 | [Image.new("L", (512, 512)), Image.new("L", (512, 512))], 119 | ) 120 | with pytest.raises(ValueError): 121 | fba_model([image_pil], [image_trimap, image_trimap]) 122 | -------------------------------------------------------------------------------- /tests/test_fs_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Source url: https://github.com/OPHoperHPO/image-background-remove-tool 3 | Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. 4 | License: Apache License 2.0 5 | """ 6 | import pytest 7 | from carvekit.utils.fs_utils import save_file 8 | from pathlib import Path 9 | import PIL.Image 10 | import os 11 | 12 | 13 | def test_save_file(): 14 | save_file(Path("output.png"), Path("input.png"), PIL.Image.new("RGB", (512, 512))) 15 | os.remove(Path("output.png")) 16 | save_file( 17 | Path(__file__).parent.joinpath("data"), 18 | Path("input.png"), 19 | PIL.Image.new("RGB", (512, 512)), 20 | ) 21 | os.remove(Path(__file__).parent.joinpath("data").joinpath("input.png")) 22 | save_file(Path("output.jpg"), Path("input.jpg"), PIL.Image.new("RGB", (512, 512))) 23 | os.remove(Path("output.png")) 24 | with pytest.raises(ValueError): 25 | save_file( 26 | Path("NotExistedPath"), Path("input.png"), PIL.Image.new("RGB", (512, 512)) 27 | ) 28 | save_file( 29 | output=None, 30 | input_path=Path("input.png"), 31 | image=PIL.Image.new("RGB", (512, 512)), 32 | ) 33 | os.remove(Path("input_bg_removed.png")) 34 | -------------------------------------------------------------------------------- /tests/test_high.py: -------------------------------------------------------------------------------- 1 | """ 2 | Source url: https://github.com/OPHoperHPO/image-background-remove-tool 3 | Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. 4 | License: Apache License 2.0 5 | """ 6 | 7 | from carvekit.api.high import HiInterface 8 | 9 | 10 | def test_init(): 11 | HiInterface( 12 | batch_size_seg=1, 13 | batch_size_matting=4, 14 | device="cpu", 15 | seg_mask_size=160, 16 | matting_mask_size=1024, 17 | trimap_prob_threshold=1, 18 | trimap_dilation=2, 19 | trimap_erosion_iters=3, 20 | fp16=False, 21 | ) 22 | HiInterface( 23 | batch_size_seg=0, 24 | batch_size_matting=0, 25 | device="cpu", 26 | seg_mask_size=0, 27 | matting_mask_size=0, 28 | trimap_prob_threshold=0, 29 | trimap_dilation=0, 30 | trimap_erosion_iters=0, 31 | fp16=True, 32 | ) 33 | -------------------------------------------------------------------------------- /tests/test_image_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Source url: https://github.com/OPHoperHPO/image-background-remove-tool 3 | Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. 4 | License: Apache License 2.0 5 | """ 6 | import uuid 7 | from pathlib import Path 8 | 9 | import PIL.Image 10 | import pytest 11 | import torch 12 | from PIL import Image 13 | from carvekit.utils.image_utils import ( 14 | load_image, 15 | convert_image, 16 | is_image_valid, 17 | to_tensor, 18 | transparency_paste, 19 | add_margin, 20 | ) 21 | 22 | 23 | def test_load_image(image_path, image_pil, image_str): 24 | assert isinstance(load_image(image_path), Image.Image) is True 25 | assert isinstance(load_image(image_pil), Image.Image) is True 26 | assert isinstance(load_image(image_str), Image.Image) is True 27 | 28 | with pytest.raises(ValueError): 29 | load_image(23) 30 | 31 | 32 | def test_is_image_valid(image_path, image_pil, image_str): 33 | assert is_image_valid(image_path) is True 34 | assert is_image_valid(image_path.with_suffix(".JPG")) is True 35 | 36 | with pytest.raises(ValueError): 37 | is_image_valid(Path(uuid.uuid1().hex).with_suffix(".jpg")) 38 | with pytest.raises(ValueError): 39 | is_image_valid(Path(__file__).parent) 40 | with pytest.raises(ValueError): 41 | is_image_valid(image_path.with_suffix(".mp3")) 42 | with pytest.raises(ValueError): 43 | is_image_valid(image_path.with_suffix(".MP3")) 44 | with pytest.raises(ValueError): 45 | is_image_valid(23) 46 | 47 | assert is_image_valid(image_pil) is True 48 | assert is_image_valid(Image.new("RGB", (512, 512))) is True 49 | assert is_image_valid(Image.new("L", (512, 512))) is True 50 | assert is_image_valid(Image.new("RGBA", (512, 512))) is True 51 | 52 | with pytest.raises(ValueError): 53 | is_image_valid(Image.new("P", (512, 512))) 54 | with pytest.raises(ValueError): 55 | is_image_valid(Image.new("RGB", (32, 10))) 56 | 57 | 58 | def test_convert_image(image_pil): 59 | with pytest.raises(ValueError): 60 | convert_image(Image.new("L", (10, 10))) 61 | assert convert_image(image_pil.convert("RGBA")).mode == "RGB" 62 | 63 | 64 | def test_to_tensor(image_pil): 65 | assert isinstance(to_tensor(image_pil), torch.Tensor) 66 | 67 | 68 | def test_transparency_paste(): 69 | assert isinstance( 70 | transparency_paste( 71 | PIL.Image.new("RGBA", (1024, 1024)), PIL.Image.new("RGBA", (1024, 1024)) 72 | ), 73 | PIL.Image.Image, 74 | ) 75 | assert isinstance( 76 | transparency_paste( 77 | PIL.Image.new("RGBA", (512, 512)), PIL.Image.new("RGBA", (512, 512)) 78 | ), 79 | PIL.Image.Image, 80 | ) 81 | 82 | 83 | def test_add_margin(): 84 | assert ( 85 | isinstance( 86 | add_margin( 87 | PIL.Image.new("RGB", (512, 512)), 10, 10, 10, 10, (10, 10, 10, 10) 88 | ), 89 | PIL.Image.Image, 90 | ) 91 | is True 92 | ) 93 | -------------------------------------------------------------------------------- /tests/test_interface.py: -------------------------------------------------------------------------------- 1 | """ 2 | Source url: https://github.com/OPHoperHPO/image-background-remove-tool 3 | Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. 4 | License: Apache License 2.0 5 | """ 6 | import warnings 7 | 8 | import torch 9 | 10 | from carvekit.api.interface import Interface 11 | 12 | 13 | def test_init(available_models): 14 | models, pre_pipes, post_pipes = available_models 15 | devices = ["cpu", "cuda"] 16 | for model in models: 17 | mdl = model(False) 18 | for pre_pipe in pre_pipes: 19 | pre = pre_pipe() if pre_pipe is not None else pre_pipe 20 | for post_pipe in post_pipes: 21 | post = post_pipe() if post_pipe is not None else post_pipe 22 | for device in devices: 23 | if device == "cuda" and torch.cuda.is_available() is False: 24 | warnings.warn( 25 | "Cuda GPU is not available! Testing on cuda skipped!" 26 | ) 27 | continue 28 | inf = Interface( 29 | seg_pipe=mdl, post_pipe=post, pre_pipe=pre, device=device 30 | ) 31 | del inf 32 | del post 33 | del pre 34 | del mdl 35 | 36 | 37 | def test_seg(image_pil, image_str, image_path, available_models): 38 | models, pre_pipes, post_pipes = available_models 39 | for model in models: 40 | mdl = model(False) 41 | for pre_pipe in pre_pipes: 42 | pre = pre_pipe() if pre_pipe is not None else pre_pipe 43 | for post_pipe in post_pipes: 44 | post = post_pipe() if post_pipe is not None else post_pipe 45 | interface = Interface( 46 | seg_pipe=mdl, 47 | post_pipe=post, 48 | pre_pipe=pre, 49 | device="cuda" if torch.cuda.is_available() else "cpu", 50 | ) 51 | interface([image_pil, image_str, image_path]) 52 | del post, interface 53 | del pre 54 | del mdl 55 | -------------------------------------------------------------------------------- /tests/test_mask_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Source url: https://github.com/OPHoperHPO/image-background-remove-tool 3 | Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. 4 | License: Apache License 2.0 5 | """ 6 | import pytest 7 | import PIL.Image 8 | from carvekit.utils.mask_utils import composite, apply_mask, extract_alpha_channel 9 | 10 | 11 | def test_composite(): 12 | assert ( 13 | isinstance( 14 | composite( 15 | PIL.Image.new("RGB", (512, 512)), 16 | PIL.Image.new("RGB", (512, 512)), 17 | PIL.Image.new("RGB", (512, 512)), 18 | device="cpu", 19 | ), 20 | PIL.Image.Image, 21 | ) 22 | is True 23 | ) 24 | 25 | 26 | def test_apply_mask(): 27 | assert ( 28 | isinstance( 29 | apply_mask( 30 | image=PIL.Image.new("RGB", (512, 512)), 31 | mask=PIL.Image.new("RGB", (512, 512)), 32 | device="cpu", 33 | ), 34 | PIL.Image.Image, 35 | ) 36 | is True 37 | ) 38 | 39 | 40 | def test_extract_alpha_channel(): 41 | assert ( 42 | isinstance( 43 | extract_alpha_channel(PIL.Image.new("RGB", (512, 512))), PIL.Image.Image 44 | ) 45 | is True 46 | ) 47 | -------------------------------------------------------------------------------- /tests/test_models_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Source url: https://github.com/OPHoperHPO/image-background-remove-tool 3 | Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. 4 | License: Apache License 2.0 5 | """ 6 | import os 7 | import pytest 8 | from pathlib import Path 9 | from carvekit.utils.download_models import sha512_checksum_calc 10 | from carvekit.ml.files.models_loc import ( 11 | u2net_full_pretrained, 12 | fba_pretrained, 13 | deeplab_pretrained, 14 | basnet_pretrained, 15 | download_all, 16 | checkpoints_dir, 17 | downloader, 18 | tracer_b7_pretrained, 19 | ) 20 | from carvekit.utils.models_utils import fix_seed, suppress_warnings 21 | 22 | 23 | def test_fix_seed(): 24 | fix_seed(seed=42) 25 | 26 | 27 | def test_suppress_warnings(): 28 | suppress_warnings() 29 | 30 | 31 | def test_download_all(): 32 | download_all() 33 | 34 | 35 | def test_download_model(): 36 | hh = checkpoints_dir / "u2net-universal" / "u2net.pth" 37 | hh.write_text("1234") 38 | assert downloader("u2net.pth") == hh 39 | os.remove(hh) 40 | with pytest.raises(FileNotFoundError): 41 | downloader("NotExistedPath/2.dl") 42 | 43 | 44 | def test_sha512(): 45 | hh = checkpoints_dir / "basnet-universal" / "basnet.pth" 46 | hh.write_text("1234") 47 | assert ( 48 | sha512_checksum_calc(hh) 49 | == "d404559f602eab6fd602ac7680dacbfaadd13630335e951f097a" 50 | "f3900e9de176b6db28512f2e000" 51 | "b9d04fba5133e8b1c6e8df59db3a8ab9d60be4b97cc9e81db" 52 | ) 53 | 54 | 55 | def test_check_model(): 56 | invalid_hash_file = checkpoints_dir / "basnet-universal" / "basnet.pth" 57 | invalid_hash_file.write_text("1234") 58 | downloader("basnet.pth") 59 | assert ( 60 | sha512_checksum_calc(invalid_hash_file) 61 | != "d404559f602eab6fd602ac7680dacbfaadd13630335e951f097a" 62 | "f3900e9de176b6db28512f2e000" 63 | "b9d04fba5133e8b1c6e8df59db3a8ab9d60be4b97cc9e81db" 64 | ) 65 | 66 | 67 | def test_check_for_exists(): 68 | assert u2net_full_pretrained().exists() 69 | assert fba_pretrained().exists() 70 | assert deeplab_pretrained().exists() 71 | assert basnet_pretrained().exists() 72 | assert tracer_b7_pretrained().exists() 73 | -------------------------------------------------------------------------------- /tests/test_pool_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Source url: https://github.com/OPHoperHPO/image-background-remove-tool 3 | Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. 4 | License: Apache License 2.0 5 | """ 6 | from carvekit.utils.pool_utils import batch_generator, thread_pool_processing 7 | 8 | 9 | def test_thread_pool_processing(): 10 | assert thread_pool_processing(int, ["1", "2", "3"]) == [1, 2, 3] 11 | assert thread_pool_processing(int, ["1", "2", "3"], workers=1) == [1, 2, 3] 12 | 13 | 14 | def test_batch_generator(): 15 | assert list(batch_generator([1, 2, 3], n=1)) == [[1], [2], [3]] 16 | assert list(batch_generator([1, 2, 3, 4], n=2)) == [[1, 2], [3, 4]] 17 | -------------------------------------------------------------------------------- /tests/test_postprocessing.py: -------------------------------------------------------------------------------- 1 | """ 2 | Source url: https://github.com/OPHoperHPO/image-background-remove-tool 3 | Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. 4 | License: Apache License 2.0 5 | """ 6 | import pytest 7 | from carvekit.pipelines.postprocessing import MattingMethod 8 | 9 | 10 | def test_init(fba_model, trimap_instance): 11 | fba_model = fba_model(False) 12 | trimap_instance = trimap_instance() 13 | MattingMethod(fba_model, trimap_instance, "cpu") 14 | MattingMethod(fba_model, trimap_instance, device="cuda") 15 | 16 | 17 | def test_seg(matting_method_instance, image_str, image_path, image_pil): 18 | matting_method_instance = matting_method_instance() 19 | matting_method_instance( 20 | images=[image_str, image_path], masks=[image_pil, image_path] 21 | ) 22 | with pytest.raises(ValueError): 23 | matting_method_instance(images=[image_str], masks=[image_pil, image_path]) 24 | -------------------------------------------------------------------------------- /tests/test_preprocessing.py: -------------------------------------------------------------------------------- 1 | """ 2 | Source url: https://github.com/OPHoperHPO/image-background-remove-tool 3 | Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. 4 | License: Apache License 2.0 5 | """ 6 | 7 | 8 | def test_seg( 9 | preprocessing_stub_instance, image_str, image_path, image_pil, interface_instance 10 | ): 11 | preprocessing_stub_instance = preprocessing_stub_instance() 12 | interface_instance = interface_instance() 13 | preprocessing_stub_instance(interface_instance, [image_str, image_path]) 14 | preprocessing_stub_instance(interface_instance, [image_pil, image_path]) 15 | -------------------------------------------------------------------------------- /tests/test_tracer.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from PIL import Image 4 | 5 | from carvekit.ml.wrap.tracer_b7 import TracerUniversalB7 6 | 7 | 8 | def test_init(): 9 | TracerUniversalB7(input_image_size=[640, 640], load_pretrained=True) 10 | TracerUniversalB7(input_image_size=640, load_pretrained=True) 11 | TracerUniversalB7(load_pretrained=False) 12 | TracerUniversalB7(fp16=True) 13 | 14 | 15 | def test_preprocessing(tracer_model, converted_pil_image, black_image_pil): 16 | tracer_model = tracer_model(False) 17 | assert ( 18 | isinstance( 19 | tracer_model.data_preprocessing(converted_pil_image), torch.FloatTensor 20 | ) 21 | is True 22 | ) 23 | assert ( 24 | isinstance(tracer_model.data_preprocessing(black_image_pil), torch.FloatTensor) 25 | is True 26 | ) 27 | 28 | 29 | def test_postprocessing(tracer_model, converted_pil_image, black_image_pil): 30 | tracer_model = tracer_model(False) 31 | assert isinstance( 32 | tracer_model.data_postprocessing( 33 | torch.ones((1, 640, 640), dtype=torch.float64), converted_pil_image 34 | ), 35 | Image.Image, 36 | ) 37 | 38 | 39 | def test_seg(tracer_model, image_pil, image_str, image_path, black_image_pil): 40 | tracer_model = tracer_model(False) 41 | tracer_model([image_pil]) 42 | tracer_model([image_pil, image_str, image_path, black_image_pil]) 43 | 44 | 45 | def test_seg_with_fp12(tracer_model, image_pil, image_str, image_path, black_image_pil): 46 | tracer_model = tracer_model(True) 47 | tracer_model([image_pil]) 48 | tracer_model([image_pil, image_str, image_path, black_image_pil]) 49 | -------------------------------------------------------------------------------- /tests/test_trimap.py: -------------------------------------------------------------------------------- 1 | """ 2 | Source url: https://github.com/OPHoperHPO/image-background-remove-tool 3 | Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. 4 | License: Apache License 2.0 5 | """ 6 | import PIL.Image 7 | import pytest 8 | 9 | from carvekit.trimap.add_ops import prob_as_unknown_area 10 | 11 | 12 | def test_trimap_generator(trimap_instance, image_mask, image_pil): 13 | te = trimap_instance() 14 | assert isinstance(te(image_pil, image_mask), PIL.Image.Image) 15 | assert isinstance( 16 | te(PIL.Image.new("RGB", (512, 512)), PIL.Image.new("L", (512, 512))), 17 | PIL.Image.Image, 18 | ) 19 | assert isinstance( 20 | te( 21 | PIL.Image.new("RGB", (512, 512), color=(255, 255, 255)), 22 | PIL.Image.new("L", (512, 512), color=255), 23 | ), 24 | PIL.Image.Image, 25 | ) 26 | with pytest.raises(ValueError): 27 | te(PIL.Image.new("RGB", (512, 512)), PIL.Image.new("RGB", (512, 512))) 28 | with pytest.raises(ValueError): 29 | te(PIL.Image.new("RGB", (512, 512)), PIL.Image.new("RGB", (512, 512))) 30 | 31 | 32 | def test_cv2_generator(cv2_trimap_instance, image_pil, image_mask): 33 | cv2trimapgen = cv2_trimap_instance() 34 | assert isinstance(cv2trimapgen(image_pil, image_mask), PIL.Image.Image) 35 | with pytest.raises(ValueError): 36 | cv2trimapgen(PIL.Image.new("RGB", (512, 512)), PIL.Image.new("RGB", (512, 512))) 37 | with pytest.raises(ValueError): 38 | cv2trimapgen(PIL.Image.new("L", (256, 256)), PIL.Image.new("L", (512, 512))) 39 | 40 | 41 | def test_prob_as_unknown_area(image_pil, image_mask): 42 | with pytest.raises(ValueError): 43 | prob_as_unknown_area(image_pil, image_mask) 44 | -------------------------------------------------------------------------------- /tests/test_u2net.py: -------------------------------------------------------------------------------- 1 | """ 2 | Source url: https://github.com/OPHoperHPO/image-background-remove-tool 3 | Author: Nikita Selin (OPHoperHPO)[https://github.com/OPHoperHPO]. 4 | License: Apache License 2.0 5 | """ 6 | 7 | import pytest 8 | import torch 9 | from PIL import Image 10 | 11 | from carvekit.ml.wrap.u2net import U2NET 12 | 13 | 14 | def test_init(): 15 | U2NET(layers_cfg="full", input_image_size=[320, 320], load_pretrained=True) 16 | U2NET(layers_cfg="full", load_pretrained=False) 17 | U2NET( 18 | layers_cfg={ 19 | "stage1": ["En_1", (7, 3, 32, 64), -1], 20 | "stage2": ["En_2", (6, 64, 32, 128), -1], 21 | "stage3": ["En_3", (5, 128, 64, 256), -1], 22 | "stage4": ["En_4", (4, 256, 128, 512), -1], 23 | "stage5": ["En_5", (4, 512, 256, 512, True), -1], 24 | "stage6": ["En_6", (4, 512, 256, 512, True), 512], 25 | "stage5d": ["De_5", (4, 1024, 256, 512, True), 512], 26 | "stage4d": ["De_4", (4, 1024, 128, 256), 256], 27 | "stage3d": ["De_3", (5, 512, 64, 128), 128], 28 | "stage2d": ["De_2", (6, 256, 32, 64), 64], 29 | "stage1d": ["De_1", (7, 128, 16, 64), 64], 30 | } 31 | ) 32 | with pytest.raises(ValueError): 33 | U2NET(layers_cfg="nan") 34 | with pytest.raises(ValueError): 35 | U2NET(layers_cfg=[]) 36 | 37 | 38 | def test_preprocessing(u2net_model, converted_pil_image, black_image_pil): 39 | u2net_model = u2net_model(False) 40 | assert ( 41 | isinstance( 42 | u2net_model.data_preprocessing(converted_pil_image), torch.FloatTensor 43 | ) 44 | is True 45 | ) 46 | assert ( 47 | isinstance(u2net_model.data_preprocessing(black_image_pil), torch.FloatTensor) 48 | is True 49 | ) 50 | 51 | 52 | def test_postprocessing(u2net_model, converted_pil_image, black_image_pil): 53 | u2net_model = u2net_model(False) 54 | assert isinstance( 55 | u2net_model.data_postprocessing( 56 | torch.ones((1, 320, 320), dtype=torch.float64), converted_pil_image 57 | ), 58 | Image.Image, 59 | ) 60 | 61 | 62 | def test_seg(u2net_model, image_pil, image_str, image_path, black_image_pil): 63 | u2net_model = u2net_model(False) 64 | u2net_model([image_pil]) 65 | u2net_model([image_pil, image_str, image_path, black_image_pil]) 66 | 67 | 68 | def test_seg_with_fp12(u2net_model, image_pil, image_str, image_path, black_image_pil): 69 | u2net_model = u2net_model(True) 70 | u2net_model([image_pil]) 71 | u2net_model([image_pil, image_str, image_path, black_image_pil]) 72 | --------------------------------------------------------------------------------