├── .github └── workflows │ └── publish.yml ├── .gitignore ├── LICENSE ├── README.md ├── README_CN.md ├── __init__.py ├── birefnet ├── __init__.py ├── config.py ├── dataset.py ├── image_proc.py ├── models │ ├── __init__.py │ ├── backbones │ │ ├── __init__.py │ │ ├── build_backbone.py │ │ ├── pvt_v2.py │ │ └── swin_v1.py │ ├── birefnet.py │ ├── modules │ │ ├── __init__.py │ │ ├── aspp.py │ │ ├── decoder_blocks.py │ │ ├── deform_conv.py │ │ ├── lateral_blocks.py │ │ ├── prompt_encoder.py │ │ └── utils.py │ └── refinement │ │ ├── __init__.py │ │ ├── refiner.py │ │ └── stem_layer.py └── utils.py ├── birefnetNode.py ├── birefnet_old ├── __init__.py ├── config.py ├── dataset.py ├── models │ ├── backbones │ │ ├── __init__.py │ │ ├── build_backbone.py │ │ ├── pvt_v2.py │ │ └── swin_v1.py │ ├── birefnet.py │ ├── modules │ │ ├── __init__.py │ │ ├── aspp.py │ │ ├── attentions.py │ │ ├── decoder_blocks.py │ │ ├── deform_conv.py │ │ ├── ing.py │ │ ├── lateral_blocks.py │ │ ├── mlp.py │ │ └── utils.py │ └── refinement │ │ ├── __init__.py │ │ ├── refiner.py │ │ └── stem_layer.py ├── preproc.py └── utils.py ├── doc ├── base.png └── video.gif ├── example └── workflow_base.png ├── pyproject.toml ├── requirements.txt └── util.py /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish to Comfy registry 2 | on: 3 | workflow_dispatch: 4 | push: 5 | branches: 6 | - main 7 | paths: 8 | - "pyproject.toml" 9 | 10 | permissions: 11 | issues: write 12 | 13 | jobs: 14 | publish-node: 15 | name: Publish Custom Node to registry 16 | runs-on: ubuntu-latest 17 | if: ${{ github.repository_owner == 'lldacing' }} 18 | steps: 19 | - name: Check out code 20 | uses: actions/checkout@v4 21 | - name: Publish Custom Node 22 | uses: Comfy-Org/publish-node-action@v1 23 | with: 24 | ## Add your own personal access token to your Github Repository secrets and reference it here. 25 | personal_access_token: ${{ secrets.REGISTRY_ACCESS_TOKEN }} 26 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 lldacing 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | --- 24 | 25 | The code and models of BiRefNet are released under the MIT License. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [中文文档](README_CN.md) 2 | 3 | Support the use of new and old versions of BiRefNet models 4 | 5 | ## Preview 6 | ![save api extended](doc/base.png) 7 | ![save api extended](doc/video.gif) 8 | 9 | ## Install 10 | 11 | - Manual 12 | ```shell 13 | cd custom_nodes 14 | git clone https://github.com/lldacing/ComfyUI_BiRefNet_ll.git 15 | cd ComfyUI_BiRefNet_ll 16 | pip install -r requirements.txt 17 | # restart ComfyUI 18 | ``` 19 | - Via ComfyUI Manager 20 | 21 | 22 | ## Models 23 | 24 | ### The available newest models are: 25 | 26 | - General: A pre-trained model for general use cases. 27 | - General-HR: A pre-trained model for general use cases which shows great performance on higher resolution images (2048x2048). 28 | - General-Lite: A light pre-trained model for general use cases. 29 | - General-Lite-2K: A light pre-trained model for general use cases in high resolution (2560x1440). 30 | - General-dynamic: A pre-trained model for dynamic resolution, trained with images in range from 256x256 to 2304x2304. 31 | - General-reso_512: A pre-trained model for faster and more accurate lower resolution, trained with images in 512x512. 32 | - General-legacy: A pre-trained model for general use trained on DIS5K-TR,DIS-TEs, DUTS-TR_TE,HRSOD-TR_TE,UHRSD-TR_TE, HRS10K-TR_TE (w/o portrait seg data). 33 | - Portrait: A pre-trained model for human portraits. 34 | - Matting: A pre-trained model for general trimap-free matting use. 35 | - Matting-HR: A pre-trained model for general matting use which shows great matting performance on higher resolution images (2048x2048). 36 | - Matting-Lite: A light pre-trained model for general trimap-free matting use. 37 | - DIS: A pre-trained model for dichotomous image segmentation (DIS). 38 | - HRSOD: A pre-trained model for high-resolution salient object detection (HRSOD). 39 | - COD: A pre-trained model for concealed object detection (COD). 40 | - DIS-TR_TEs: A pre-trained model with massive dataset. 41 | 42 | Model files go here (when use AutoDownloadBiRefNetModel automatically downloaded if the folder is not present during first run): `${comfyui_rootpath}/models/BiRefNet`. 43 | 44 | If necessary, they can be downloaded from: 45 | - [General](https://huggingface.co/ZhengPeng7/BiRefNet/resolve/main/model.safetensors) ➔ `model.safetensors` must be renamed `General.safetensors` 46 | - [General-HR](https://huggingface.co/ZhengPeng7/BiRefNet_HR/resolve/main/model.safetensors) ➔ `model.safetensors` must be renamed `General-HR.safetensors` 47 | - [General-Lite](https://huggingface.co/ZhengPeng7/BiRefNet_T/resolve/main/model.safetensors) ➔ `model.safetensors` must be renamed `General-Lite.safetensors` 48 | - [General-Lite-2K](https://huggingface.co/ZhengPeng7/BiRefNet_lite-2K/resolve/main/model.safetensors) ➔ `model.safetensors` must be renamed `General-Lite-2K.safetensors` 49 | - [General-dynamic](https://huggingface.co/ZhengPeng7/BiRefNet_dynamic/resolve/main/model.safetensors) ➔ `model.safetensors` must be renamed `General-dynamic.safetensors` 50 | - [General-legacy](https://huggingface.co/ZhengPeng7/BiRefNet-legacy/resolve/main/model.safetensors) ➔ `model.safetensors` must be renamed `General-legacy.safetensors` 51 | - [General-reso_512](https://huggingface.co/ZhengPeng7/BiRefNet_512x512/resolve/main/model.safetensors) ➔ `model.safetensors` must be renamed `General-reso_512.safetensors` 52 | - [Portrait](https://huggingface.co/ZhengPeng7/BiRefNet-portrait/resolve/main/model.safetensors) ➔ `model.safetensors` must be renamed `Portrait.safetensors` 53 | - [Matting](https://huggingface.co/ZhengPeng7/BiRefNet-matting/resolve/main/model.safetensors) ➔ `model.safetensors` must be renamed `Matting.safetensors` 54 | - [Matting-HR](https://huggingface.co/ZhengPeng7/BiRefNet_HR-matting/resolve/main/model.safetensors) ➔ `model.safetensors` must be renamed `Matting-HR.safetensors` 55 | - [Matting-Lite](https://huggingface.co/ZhengPeng7/BiRefNet_lite-matting/resolve/main/model.safetensors) ➔ `model.safetensors` must be renamed `Matting-Lite.safetensors` 56 | - [DIS](https://huggingface.co/ZhengPeng7/BiRefNet-DIS5K/resolve/main/model.safetensors) ➔ `model.safetensors` must be renamed `DIS.safetensors` 57 | - [HRSOD](https://huggingface.co/ZhengPeng7/BiRefNet-HRSOD/resolve/main/model.safetensors) ➔ `model.safetensors` must be renamed `HRSOD.safetensors` 58 | - [COD](https://huggingface.co/ZhengPeng7/BiRefNet-COD/resolve/main/model.safetensors) ➔ `model.safetensors` must be renamed `COD.safetensors` 59 | - [DIS-TR_TEs](https://huggingface.co/ZhengPeng7/BiRefNet-DIS5K-TR_TEs/resolve/main/model.safetensors) ➔ `model.safetensors` must be renamed `DIS-TR_TEs.safetensors` 60 | 61 | Some models on GitHub: 62 | [BiRefNet Releases](https://github.com/ZhengPeng7/BiRefNet/releases) 63 | 64 | ### Old models: 65 | - [BiRefNet-DIS_ep580.pth](https://huggingface.co/ViperYX/BiRefNet/resolve/main/BiRefNet-DIS_ep580.pth) 66 | - [BiRefNet-ep480.pth](https://huggingface.co/ViperYX/BiRefNet/resolve/main/BiRefNet-ep480.pth) 67 | 68 | ## Weight Models (Optional) 69 | - [swin_large_patch4_window12_384_22kto1k.pth](https://huggingface.co/ViperYX/BiRefNet/resolve/main/swin_large_patch4_window12_384_22kto1k.pth)(not General-Lite, General-Lite-2K and Matting-Lite model) 70 | - [swin_tiny_patch4_window7_224_22kto1k_finetune.pth](https://drive.google.com/drive/folders/1cmce_emsS8A5ha5XT2c_CZiJzlLM81ms)(just General-Lite, General-Lite-2K and Matting-Lite model) 71 | 72 | 73 | ## Nodes 74 | - AutoDownloadBiRefNetModel 75 | - Automatically download the model into `${comfyui_rootpath}/models/BiRefNet`, do not support weight model 76 | - LoadRembgByBiRefNetModel 77 | - Can select model from `${comfyui_rootpath}/models/BiRefNet` or the path of `birefnet` configured in the extra YAML file 78 | - You can download latest models from [BiRefNet Releases](https://github.com/ZhengPeng7/BiRefNet/releases) or old models [BiRefNet-DIS_ep580.pth](https://huggingface.co/ViperYX/BiRefNet/resolve/main/BiRefNet-DIS_ep580.pth) and [BiRefNet-ep480.pth](https://huggingface.co/ViperYX/BiRefNet/resolve/main/BiRefNet-ep480.pth) 79 | - When param use_weight is True, need download weight model [swin_large_patch4_window12_384_22kto1k.pth](https://huggingface.co/ViperYX/BiRefNet/resolve/main/swin_large_patch4_window12_384_22kto1k.pth) 80 | model General-Lite, General-Lite-2K and Matting-Lite must use weight model [swin_tiny_patch4_window7_224_22kto1k_finetune.pth](https://drive.google.com/drive/folders/1cmce_emsS8A5ha5XT2c_CZiJzlLM81ms) 81 | - RembgByBiRefNet 82 | - Output transparent foreground image and mask 83 | - RembgByBiRefNetAdvanced 84 | - Output foreground image and mask, provide some fine-tuning parameters 85 | - GetMaskByBiRefNet 86 | - Only output mask 87 | - BlurFusionForegroundEstimation 88 | - Use the [fast-foreground-estimation method](https://github.com/Photoroom/fast-foreground-estimation) to estimate the foreground image 89 | 90 | ## Thanks 91 | 92 | [ZhengPeng7/BiRefNet](https://github.com/zhengpeng7/birefnet) 93 | 94 | [dimitribarbot/sd-webui-birefnet](https://github.com/dimitribarbot/sd-webui-birefnet) 95 | 96 | -------------------------------------------------------------------------------- /README_CN.md: -------------------------------------------------------------------------------- 1 | [English](README.md) 2 | 3 | 支持使用新老版本BiRefNet模型进行抠图 4 | 5 | ## 预览 6 | ![save api extended](doc/base.png) 7 | ![save api extended](doc/video.gif) 8 | 9 | ## 安装 10 | 11 | - 手动安装 12 | ```shell 13 | cd custom_nodes 14 | git clone https://github.com/lldacing/ComfyUI_BiRefNet_ll.git 15 | cd ComfyUI_BiRefNet_ll 16 | pip install -r requirements.txt 17 | # restart ComfyUI 18 | ``` 19 | - ComfyUI管理器搜索安装 20 | 21 | 22 | ## 模型 23 | 24 | ### 最新的模型: 25 | 26 | - General: 用于一般用例的预训练模型。 27 | - General-HR: 用于一般用例的预训练模型,在更高分辨率的图像上表现出色(训练分辨率2048x2048)。 28 | - General-Lite: 用于一般用例的轻量级预训练模型。 29 | - General-Lite-2K: 用于一般用例的轻量级预训练模型,适用于高分辨率图像(最佳分辨率2560x1440)。 30 | - General-dynamic: 用于动态分辨率的预训练模型,基于256x256到2304x2304的图片分辨率进行训练。 31 | - General-reso_512: 一个更快、更准确的低分辨率预训练模型,基于512x512的图像训练。 32 | - General-legacy: 一般用例的预训练模型,基于DIS5K-TR,DIS-TEs, DUTS-TR_TE,HRSOD-TR_TE,UHRSD-TR_TE, HRS10K-TR_TE (w/o portrait seg data)。 33 | - Portrait: 人物肖像预训练模型。 34 | - Matting: 一种使用无过渡遮罩抠图的预训练模型。 35 | - Matting-HR: 在更高分辨率的图像上显示出出色的性能的预训练模型 (2048x2048)。 36 | - Matting-Lite: 用一般用例的无过渡遮罩抠图的轻量级预训练模型。 37 | - DIS: 一种用于二分图像分割(DIS)的预训练模型。 38 | - HRSOD: 一种用于高分辨率显著目标检测(HRSOD)的预训练模型。 39 | - COD: 一种用于隐蔽目标检测(COD)的预训练模型。 40 | - DIS-TR_TEs: 具有大量数据集的预训练模型。 41 | 42 | 模型文件放在`${comfyui_rootpath}/models/BiRefNet`(当使用AutoDownloadBiRefNetModel时,则会自动下载模型)。 43 | 44 | 也可以手动下载模型: 45 | - [General](https://huggingface.co/ZhengPeng7/BiRefNet/resolve/main/model.safetensors) ➔ `model.safetensors` 重命名为 `General.safetensors` 46 | - [General-HR](https://huggingface.co/ZhengPeng7/BiRefNet_HR/resolve/main/model.safetensors) ➔ `model.safetensors` 重命名为 `General-HR.safetensors` 47 | - [General-Lite](https://huggingface.co/ZhengPeng7/BiRefNet_T/resolve/main/model.safetensors) ➔ `model.safetensors` 重命名为 `General-Lite.safetensors` 48 | - [General-Lite-2K](https://huggingface.co/ZhengPeng7/BiRefNet_lite-2K/resolve/main/model.safetensors) ➔ `model.safetensors` 重命名为 `General-Lite-2K.safetensors` 49 | - [General-dynamic](https://huggingface.co/ZhengPeng7/BiRefNet_dynamic/resolve/main/model.safetensors) ➔ `model.safetensors` 重命名为 `General-dynamic.safetensors` 50 | - [General-legacy](https://huggingface.co/ZhengPeng7/BiRefNet-legacy/resolve/main/model.safetensors) ➔ `model.safetensors` 重命名为 `General-legacy.safetensors` 51 | - [General-reso_512](https://huggingface.co/ZhengPeng7/BiRefNet_512x512/resolve/main/model.safetensors) ➔ `model.safetensors` 重命名为 `General-reso_512.safetensors` 52 | - [Portrait](https://huggingface.co/ZhengPeng7/BiRefNet-portrait/resolve/main/model.safetensors) ➔ `model.safetensors` 重命名为 `Portrait.safetensors` 53 | - [Matting](https://huggingface.co/ZhengPeng7/BiRefNet-matting/resolve/main/model.safetensors) ➔ `model.safetensors` 重命名为 `Matting.safetensors` 54 | - [Matting-HR](https://huggingface.co/ZhengPeng7/BiRefNet_HR-matting/resolve/main/model.safetensors) ➔ `model.safetensors` 重命名为 `Matting-HR.safetensors` 55 | - [Matting-Lite](https://huggingface.co/ZhengPeng7/BiRefNet_lite-matting/resolve/main/model.safetensors) ➔ `model.safetensors` 重命名为 `Matting-Lite.safetensors` 56 | - [DIS](https://huggingface.co/ZhengPeng7/BiRefNet-DIS5K/resolve/main/model.safetensors) ➔ `model.safetensors` 重命名为 `DIS.safetensors` 57 | - [HRSOD](https://huggingface.co/ZhengPeng7/BiRefNet-HRSOD/resolve/main/model.safetensors) ➔ `model.safetensors` 重命名为 `HRSOD.safetensors` 58 | - [COD](https://huggingface.co/ZhengPeng7/BiRefNet-COD/resolve/main/model.safetensors) ➔ `model.safetensors` 重命名为 `COD.safetensors` 59 | - [DIS-TR_TEs](https://huggingface.co/ZhengPeng7/BiRefNet-DIS5K-TR_TEs/resolve/main/model.safetensors) ➔ `model.safetensors` 重命名为 `DIS-TR_TEs.safetensors` 60 | 61 | 62 | GitHub上的模型: 63 | [BiRefNet Releases](https://github.com/ZhengPeng7/BiRefNet/releases) 64 | 65 | ### 旧模型: 66 | - [BiRefNet-DIS_ep580.pth](https://huggingface.co/ViperYX/BiRefNet/resolve/main/BiRefNet-DIS_ep580.pth) 67 | - [BiRefNet-ep480.pth](https://huggingface.co/ViperYX/BiRefNet/resolve/main/BiRefNet-ep480.pth) 68 | 69 | ## 权重模型(非必须) 70 | 下载放在`models/BiRefNet` 71 | - [swin_large_patch4_window12_384_22kto1k.pth](https://huggingface.co/ViperYX/BiRefNet/resolve/main/swin_large_patch4_window12_384_22kto1k.pth)(非General-Lite、General-Lite-2K和Matting-Lite模型) 72 | - [swin_tiny_patch4_window7_224_22kto1k_finetune.pth](https://drive.google.com/drive/folders/1cmce_emsS8A5ha5XT2c_CZiJzlLM81ms)(仅General-Lite、General-Lite-2K和Matting-Lite模型) 73 | 74 | 75 | ## 节点 76 | - AutoDownloadBiRefNetModel 77 | - 自动下载模型到 `${comfyui_rootpath}/models/BiRefNet`,不支持权重 78 | - LoadRembgByBiRefNetModel 79 | - 从 `${comfyui_rootpath}/models/BiRefNet` 和 在extra YAML 文件中通过`birefnet`配置的路径中选择模型 80 | - 支持 [BiRefNet Releases](https://github.com/ZhengPeng7/BiRefNet/releases) 中的新模型 和 老的模型[BiRefNet-DIS_ep580.pth](https://huggingface.co/ViperYX/BiRefNet/resolve/main/BiRefNet-DIS_ep580.pth) 与 [BiRefNet-ep480.pth](https://huggingface.co/ViperYX/BiRefNet/resolve/main/BiRefNet-ep480.pth) 81 | - 参数use_weight设为True时, 需要下载权重模型,General-Lite、General-Lite-2K和Matting-Lite模型使用[swin_tiny_patch4_window7_224_22kto1k_finetune.pth](https://drive.google.com/drive/folders/1cmce_emsS8A5ha5XT2c_CZiJzlLM81ms),其它模型使用 [swin_large_patch4_window12_384_22kto1k.pth](https://huggingface.co/ViperYX/BiRefNet/resolve/main/swin_large_patch4_window12_384_22kto1k.pth) 82 | - RembgByBiRefNet 83 | - 输出透明前景图和遮罩 84 | - RembgByBiRefNetAdvanced 85 | - 输出前景图和遮罩,提供一些微调参数 86 | - GetMaskByBiRefNet 87 | - 仅输出遮罩 88 | - BlurFusionForegroundEstimation 89 | - 使用[fast-foreground-estimation](https://github.com/Photoroom/fast-foreground-estimation)方法预估前景图 90 | 91 | ## 感谢 92 | 93 | [ZhengPeng7/BiRefNet](https://github.com/zhengpeng7/birefnet) 94 | 95 | [dimitribarbot/sd-webui-birefnet](https://github.com/dimitribarbot/sd-webui-birefnet) 96 | 97 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import folder_paths 5 | 6 | # 获取当前目录的父目录的父目录 7 | parent_dir = os.path.dirname(os.path.abspath(__file__)) 8 | 9 | # 添加父目录的父目录到系统路径 10 | sys.path.insert(0, parent_dir) 11 | 12 | models_dir_key = "birefnet" 13 | models_dir_default = os.path.join(folder_paths.models_dir, "BiRefNet") 14 | if models_dir_key not in folder_paths.folder_names_and_paths: 15 | folder_paths.folder_names_and_paths[models_dir_key] = ( 16 | [os.path.join(folder_paths.models_dir, "BiRefNet")], folder_paths.supported_pt_extensions) 17 | else: 18 | if not os.path.exists(models_dir_default): 19 | os.makedirs(models_dir_default, exist_ok=True) 20 | folder_paths.add_model_folder_path(models_dir_key, models_dir_default) 21 | 22 | from . import birefnetNode 23 | 24 | NODE_CLASS_MAPPINGS = {**birefnetNode.NODE_CLASS_MAPPINGS} 25 | NODE_DISPLAY_NAME_MAPPINGS = {**birefnetNode.NODE_DISPLAY_NAME_MAPPINGS} 26 | -------------------------------------------------------------------------------- /birefnet/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lldacing/ComfyUI_BiRefNet_ll/5443a2aa16cfbd98bb2f7dcc8bdcb70439e08529/birefnet/__init__.py -------------------------------------------------------------------------------- /birefnet/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | 4 | import folder_paths 5 | 6 | 7 | class Config: 8 | def __init__(self, bb_index: int = 6) -> None: 9 | # PATH settings 10 | # Make up your file system as: SYS_HOME_DIR/codes/dis/BiRefNet, SYS_HOME_DIR/datasets/dis/xx, SYS_HOME_DIR/weights/xx 11 | # self.sys_home_dir = [os.path.expanduser('~'), '/mnt/data'][0] # Default, custom 12 | # self.data_root_dir = os.path.join(self.sys_home_dir, 'datasets/dis') 13 | 14 | # TASK settings 15 | self.task = ['DIS5K', 'COD', 'HRSOD', 'General', 'General-2K', 'Matting'][0] 16 | self.testsets = { 17 | # Benchmarks 18 | 'DIS5K': ','.join(['DIS-VD', 'DIS-TE1', 'DIS-TE2', 'DIS-TE3', 'DIS-TE4']), 19 | 'COD': ','.join(['CHAMELEON', 'NC4K', 'TE-CAMO', 'TE-COD10K']), 20 | 'HRSOD': ','.join(['DAVIS-S', 'TE-HRSOD', 'TE-UHRSD', 'DUT-OMRON', 'TE-DUTS']), 21 | # Practical use 22 | 'General': ','.join(['DIS-VD', 'TE-P3M-500-NP']), 23 | 'General-2K': ','.join(['DIS-VD', 'TE-P3M-500-NP']), 24 | 'Matting': ','.join(['TE-P3M-500-NP', 'TE-AM-2k']), 25 | }[self.task] 26 | # datasets_all = '+'.join([ds for ds in (os.listdir(os.path.join(self.data_root_dir, self.task)) if os.path.isdir(os.path.join(self.data_root_dir, self.task)) else []) if ds not in self.testsets.split(',')]) 27 | self.training_set = { 28 | 'DIS5K': ['DIS-TR', 'DIS-TR+DIS-TE1+DIS-TE2+DIS-TE3+DIS-TE4'][0], 29 | 'COD': 'TR-COD10K+TR-CAMO', 30 | 'HRSOD': ['TR-DUTS', 'TR-HRSOD', 'TR-UHRSD', 'TR-DUTS+TR-HRSOD', 'TR-DUTS+TR-UHRSD', 'TR-HRSOD+TR-UHRSD', 'TR-DUTS+TR-HRSOD+TR-UHRSD'][5], 31 | 'General': 'DIS-TE1+DIS-TE2+DIS-TE3+DIS-TE4+DIS-TR+TR-HRSOD+TE-HRSOD+TR-HRS10K+TE-HRS10K+TR-UHRSD+TE-UHRSD+TR-P3M-10k+TE-P3M-500-P+TR-humans+DIS-VD-ori', # datasets_all 32 | 'General-2K': 'DIS-TE1+DIS-TE2+DIS-TE3+DIS-TE4+DIS-TR+TR-HRSOD+TE-HRSOD+TR-HRS10K+TE-HRS10K+TR-UHRSD+TE-UHRSD+TR-P3M-10k+TE-P3M-500-P+TR-humans+DIS-VD-ori', # datasets_all 33 | 'Matting': 'TR-P3M-10k+TE-P3M-500-NP+TR-humans+TR-Distrinctions-646', # datasets_all 34 | }[self.task] 35 | 36 | # Data settings 37 | self.size = (1024, 1024) if self.task not in ['General-2K'] else (2560, 1440) # wid, hei. Can be overwritten by dynamic_size in training. 38 | self.dynamic_size = [None, ((512-256, 2048+256), (512-256, 2048+256))][0] # wid, hei. It might cause errors in using compile. 39 | self.background_color_synthesis = False # whether to use pure bg color to replace the original backgrounds. 40 | 41 | # Faster-Training settings 42 | self.load_all = False and self.dynamic_size is None # Turn it on/off by your case. It may consume a lot of CPU memory. And for multi-GPU (N), it would cost N times the CPU memory to load the data. 43 | self.compile = True # 1. Trigger CPU memory leak in some extend, which is an inherent problem of PyTorch. 44 | # Machines with > 70GB CPU memory can run the whole training on DIS5K with default setting. 45 | # 2. Higher PyTorch version may fix it: https://github.com/pytorch/pytorch/issues/119607. 46 | # 3. But compile in 2.0.1 < Pytorch < 2.5.0 seems to bring no acceleration for training. 47 | self.precisionHigh = True 48 | 49 | # MODEL settings 50 | self.ms_supervision = True 51 | self.out_ref = self.ms_supervision and True 52 | self.dec_ipt = True 53 | self.dec_ipt_split = True 54 | self.cxt_num = [0, 3][1] # multi-scale skip connections from encoder 55 | self.mul_scl_ipt = ['', 'add', 'cat'][2] 56 | self.dec_att = ['', 'ASPP', 'ASPPDeformable'][2] 57 | self.squeeze_block = ['', 'BasicDecBlk_x1', 'ResBlk_x4', 'ASPP_x3', 'ASPPDeformable_x3'][1] 58 | self.dec_blk = ['BasicDecBlk', 'ResBlk'][0] 59 | 60 | # TRAINING settings 61 | self.batch_size = 4 62 | self.finetune_last_epochs = [ 63 | 0, 64 | { 65 | 'DIS5K': -40, 66 | 'COD': -20, 67 | 'HRSOD': -20, 68 | 'General': -20, 69 | 'General-2K': -20, 70 | 'Matting': -20, 71 | }[self.task] 72 | ][1] # choose 0 to skip 73 | self.lr = (1e-4 if 'DIS5K' in self.task else 1e-5) * math.sqrt(self.batch_size / 4) # DIS needs high lr to converge faster. Adapt the lr linearly 74 | self.num_workers = max(4, self.batch_size) # will be decrease to min(it, batch_size) at the initialization of the data_loader 75 | 76 | # Backbone settings 77 | self.bb = [ 78 | 'vgg16', 'vgg16bn', 'resnet50', # 0, 1, 2 79 | 'swin_v1_t', 'swin_v1_s', # 3, 4 80 | 'swin_v1_b', 'swin_v1_l', # 5-bs9, 6-bs4 81 | 'pvt_v2_b0', 'pvt_v2_b1', # 7, 8 82 | 'pvt_v2_b2', 'pvt_v2_b5', # 9-bs10, 10-bs5 83 | ][bb_index] 84 | self.lateral_channels_in_collection = { 85 | 'vgg16': [512, 256, 128, 64], 'vgg16bn': [512, 256, 128, 64], 'resnet50': [1024, 512, 256, 64], 86 | 'pvt_v2_b2': [512, 320, 128, 64], 'pvt_v2_b5': [512, 320, 128, 64], 87 | 'swin_v1_b': [1024, 512, 256, 128], 'swin_v1_l': [1536, 768, 384, 192], 88 | 'swin_v1_t': [768, 384, 192, 96], 'swin_v1_s': [768, 384, 192, 96], 89 | 'pvt_v2_b0': [256, 160, 64, 32], 'pvt_v2_b1': [512, 320, 128, 64], 90 | }[self.bb] 91 | if self.mul_scl_ipt == 'cat': 92 | self.lateral_channels_in_collection = [channel * 2 for channel in self.lateral_channels_in_collection] 93 | self.cxt = self.lateral_channels_in_collection[1:][::-1][-self.cxt_num:] if self.cxt_num else [] 94 | 95 | # MODEL settings - inactive 96 | self.lat_blk = ['BasicLatBlk'][0] 97 | self.dec_channels_inter = ['fixed', 'adap'][0] 98 | self.refine = ['', 'itself', 'RefUNet', 'Refiner', 'RefinerPVTInChannels4'][0] 99 | self.progressive_ref = self.refine and True 100 | self.ender = self.progressive_ref and False 101 | self.scale = self.progressive_ref and 2 102 | self.auxiliary_classification = False # Only for DIS5K, where class labels are saved in `dataset.py`. 103 | self.refine_iteration = 1 104 | self.freeze_bb = False 105 | self.model = [ 106 | 'BiRefNet', 107 | 'BiRefNetC2F', 108 | ][0] 109 | 110 | # TRAINING settings - inactive 111 | self.preproc_methods = ['flip', 'enhance', 'rotate', 'pepper', 'crop'][:4 if not self.background_color_synthesis else 1] 112 | self.optimizer = ['Adam', 'AdamW'][1] 113 | self.lr_decay_epochs = [1e5] # Set to negative N to decay the lr in the last N-th epoch. 114 | self.lr_decay_rate = 0.5 115 | # Loss 116 | if self.task in ['Matting']: 117 | self.lambdas_pix_last = { 118 | 'bce': 30 * 1, 119 | 'iou': 0.5 * 0, 120 | 'iou_patch': 0.5 * 0, 121 | 'mae': 100 * 1, 122 | 'mse': 30 * 0, 123 | 'triplet': 3 * 0, 124 | 'reg': 100 * 0, 125 | 'ssim': 10 * 1, 126 | 'cnt': 5 * 0, 127 | 'structure': 5 * 0, 128 | } 129 | elif self.task in ['General', 'General-2K']: 130 | self.lambdas_pix_last = { 131 | 'bce': 30 * 1, 132 | 'iou': 0.5 * 1, 133 | 'iou_patch': 0.5 * 0, 134 | 'mae': 100 * 1, 135 | 'mse': 30 * 0, 136 | 'triplet': 3 * 0, 137 | 'reg': 100 * 0, 138 | 'ssim': 10 * 1, 139 | 'cnt': 5 * 0, 140 | 'structure': 5 * 0, 141 | } 142 | else: 143 | self.lambdas_pix_last = { 144 | # not 0 means opening this loss 145 | # original rate -- 1 : 30 : 1.5 : 0.2, bce x 30 146 | 'bce': 30 * 1, # high performance 147 | 'iou': 0.5 * 1, # 0 / 255 148 | 'iou_patch': 0.5 * 0, # 0 / 255, win_size = (64, 64) 149 | 'mae': 30 * 0, 150 | 'mse': 30 * 0, # can smooth the saliency map 151 | 'triplet': 3 * 0, 152 | 'reg': 100 * 0, 153 | 'ssim': 10 * 1, # help contours, 154 | 'cnt': 5 * 0, # help contours 155 | 'structure': 5 * 0, # structure loss from codes of MVANet. A little improvement on DIS-TE[1,2,3], a bit more decrease on DIS-TE4. 156 | } 157 | self.lambdas_cls = { 158 | 'ce': 5.0 159 | } 160 | 161 | # PATH settings - inactive 162 | # https://drive.google.com/drive/folders/1cmce_emsS8A5ha5XT2c_CZiJzlLM81ms 163 | # self.weights_root_dir = os.path.join(self.sys_home_dir, 'weights/cv') 164 | # self.weights = { 165 | # 'pvt_v2_b2': os.path.join(self.weights_root_dir, 'pvt_v2_b2.pth'), 166 | # 'pvt_v2_b5': os.path.join(self.weights_root_dir, ['pvt_v2_b5.pth', 'pvt_v2_b5_22k.pth'][0]), 167 | # 'swin_v1_b': os.path.join(self.weights_root_dir, ['swin_base_patch4_window12_384_22kto1k.pth', 'swin_base_patch4_window12_384_22k.pth'][0]), 168 | # 'swin_v1_l': os.path.join(self.weights_root_dir, ['swin_large_patch4_window12_384_22kto1k.pth', 'swin_large_patch4_window12_384_22k.pth'][0]), 169 | # 'swin_v1_t': os.path.join(self.weights_root_dir, ['swin_tiny_patch4_window7_224_22kto1k_finetune.pth'][0]), 170 | # 'swin_v1_s': os.path.join(self.weights_root_dir, ['swin_small_patch4_window7_224_22kto1k_finetune.pth'][0]), 171 | # 'pvt_v2_b0': os.path.join(self.weights_root_dir, ['pvt_v2_b0.pth'][0]), 172 | # 'pvt_v2_b1': os.path.join(self.weights_root_dir, ['pvt_v2_b1.pth'][0]), 173 | # } 174 | weight_paths_name = "birefnet" 175 | self.weights = { 176 | 'pvt_v2_b2': folder_paths.get_full_path(weight_paths_name, 'pvt_v2_b2.pth'), 177 | 'pvt_v2_b5': folder_paths.get_full_path(weight_paths_name, ['pvt_v2_b5.pth', 'pvt_v2_b5_22k.pth'][0]), 178 | 'swin_v1_b': folder_paths.get_full_path(weight_paths_name, ['swin_base_patch4_window12_384_22kto1k.pth', 'swin_base_patch4_window12_384_22k.pth'][0]), 179 | 'swin_v1_l': folder_paths.get_full_path(weight_paths_name, ['swin_large_patch4_window12_384_22kto1k.pth', 'swin_large_patch4_window12_384_22k.pth'][0]), 180 | 'swin_v1_t': folder_paths.get_full_path(weight_paths_name, ['swin_tiny_patch4_window7_224_22kto1k_finetune.pth'][0]), 181 | 'swin_v1_s': folder_paths.get_full_path(weight_paths_name, ['swin_small_patch4_window7_224_22kto1k_finetune.pth'][0]), 182 | 'pvt_v2_b0': folder_paths.get_full_path(weight_paths_name, ['pvt_v2_b0.pth'][0]), 183 | 'pvt_v2_b1': folder_paths.get_full_path(weight_paths_name, ['pvt_v2_b1.pth'][0]), 184 | } 185 | 186 | # Callbacks - inactive 187 | self.verbose_eval = True 188 | self.only_S_MAE = False 189 | self.SDPA_enabled = False # Bugs. Slower and errors occur in multi-GPUs 190 | 191 | # others 192 | self.device = [0, 'cpu'][0] # .to(0) == .to('cuda:0') 193 | 194 | self.batch_size_valid = 1 195 | self.rand_seed = 7 196 | # run_sh_file = [f for f in os.listdir('.') if 'train.sh' == f] + [os.path.join('..', f) for f in os.listdir('..') if 'train.sh' == f] 197 | # if run_sh_file: 198 | # with open(run_sh_file[0], 'r') as f: 199 | # lines = f.readlines() 200 | # self.save_last = int([l.strip() for l in lines if '"{}")'.format(self.task) in l and 'val_last=' in l][0].split('val_last=')[-1].split()[0]) 201 | # self.save_step = int([l.strip() for l in lines if '"{}")'.format(self.task) in l and 'step=' in l][0].split('step=')[-1].split()[0]) 202 | 203 | 204 | -------------------------------------------------------------------------------- /birefnet/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import numpy as np 4 | import cv2 5 | from tqdm import tqdm 6 | from PIL import Image 7 | from torch.utils import data 8 | from torchvision import transforms 9 | 10 | from .image_proc import preproc 11 | from .config import Config 12 | from .utils import path_to_image 13 | 14 | 15 | Image.MAX_IMAGE_PIXELS = None # remove DecompressionBombWarning 16 | config = Config() 17 | _class_labels_TR_sorted = ( 18 | 'Airplane, Ant, Antenna, Archery, Axe, BabyCarriage, Bag, BalanceBeam, Balcony, Balloon, Basket, BasketballHoop, Beatle, Bed, Bee, Bench, Bicycle, ' 19 | 'BicycleFrame, BicycleStand, Boat, Bonsai, BoomLift, Bridge, BunkBed, Butterfly, Button, Cable, CableLift, Cage, Camcorder, Cannon, Canoe, Car, ' 20 | 'CarParkDropArm, Carriage, Cart, Caterpillar, CeilingLamp, Centipede, Chair, Clip, Clock, Clothes, CoatHanger, Comb, ConcretePumpTruck, Crack, Crane, ' 21 | 'Cup, DentalChair, Desk, DeskChair, Diagram, DishRack, DoorHandle, Dragonfish, Dragonfly, Drum, Earphone, Easel, ElectricIron, Excavator, Eyeglasses, ' 22 | 'Fan, Fence, Fencing, FerrisWheel, FireExtinguisher, Fishing, Flag, FloorLamp, Forklift, GasStation, Gate, Gear, Goal, Golf, GymEquipment, Hammock, ' 23 | 'Handcart, Handcraft, Handrail, HangGlider, Harp, Harvester, Headset, Helicopter, Helmet, Hook, HorizontalBar, Hydrovalve, IroningTable, Jewelry, Key, ' 24 | 'KidsPlayground, Kitchenware, Kite, Knife, Ladder, LaundryRack, Lightning, Lobster, Locust, Machine, MachineGun, MagazineRack, Mantis, Medal, MemorialArchway, ' 25 | 'Microphone, Missile, MobileHolder, Monitor, Mosquito, Motorcycle, MovingTrolley, Mower, MusicPlayer, MusicStand, ObservationTower, Octopus, OilWell, ' 26 | 'OlympicLogo, OperatingTable, OutdoorFitnessEquipment, Parachute, Pavilion, Piano, Pipe, PlowHarrow, PoleVault, Punchbag, Rack, Racket, Rifle, Ring, Robot, ' 27 | 'RockClimbing, Rope, Sailboat, Satellite, Scaffold, Scale, Scissor, Scooter, Sculpture, Seadragon, Seahorse, Seal, SewingMachine, Ship, Shoe, ShoppingCart, ' 28 | 'ShoppingTrolley, Shower, Shrimp, Signboard, Skateboarding, Skeleton, Skiing, Spade, SpeedBoat, Spider, Spoon, Stair, Stand, Stationary, SteeringWheel, ' 29 | 'Stethoscope, Stool, Stove, StreetLamp, SweetStand, Swing, Sword, TV, Table, TableChair, TableLamp, TableTennis, Tank, Tapeline, Teapot, Telescope, Tent, ' 30 | 'TobaccoPipe, Toy, Tractor, TrafficLight, TrafficSign, Trampoline, TransmissionTower, Tree, Tricycle, TrimmerCover, Tripod, Trombone, Truck, Trumpet, Tuba, ' 31 | 'UAV, Umbrella, UnevenBars, UtilityPole, VacuumCleaner, Violin, Wakesurfing, Watch, WaterTower, WateringPot, Well, WellLid, Wheel, Wheelchair, WindTurbine, Windmill, WineGlass, WireWhisk, Yacht' 32 | ) 33 | class_labels_TR_sorted = _class_labels_TR_sorted.split(', ') 34 | 35 | 36 | class MyData(data.Dataset): 37 | def __init__(self, datasets, data_size, is_train=True): 38 | # data_size is None when using dynamic_size or data_size is manually set to None (for inference in the original size). 39 | self.is_train = is_train 40 | self.data_size = data_size 41 | self.load_all = config.load_all 42 | self.device = config.device 43 | valid_extensions = ['.png', '.jpg', '.PNG', '.JPG', '.JPEG'] 44 | 45 | if self.is_train and config.auxiliary_classification: 46 | self.cls_name2id = {_name: _id for _id, _name in enumerate(class_labels_TR_sorted)} 47 | self.transform_image = transforms.Compose([ 48 | transforms.ToTensor(), 49 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 50 | ]) 51 | self.transform_label = transforms.Compose([ 52 | transforms.ToTensor(), 53 | ]) 54 | dataset_root = os.path.join(config.data_root_dir, config.task) 55 | # datasets can be a list of different datasets for training on combined sets. 56 | self.image_paths = [] 57 | for dataset in datasets.split('+'): 58 | image_root = os.path.join(dataset_root, dataset, 'im') 59 | self.image_paths += [os.path.join(image_root, p) for p in os.listdir(image_root) if any(p.endswith(ext) for ext in valid_extensions)] 60 | self.label_paths = [] 61 | for p in self.image_paths: 62 | for ext in valid_extensions: 63 | ## 'im' and 'gt' may need modifying 64 | p_gt = p.replace('/im/', '/gt/')[:-(len(p.split('.')[-1])+1)] + ext 65 | file_exists = False 66 | if os.path.exists(p_gt): 67 | self.label_paths.append(p_gt) 68 | file_exists = True 69 | break 70 | if not file_exists: 71 | print('Not exists:', p_gt) 72 | 73 | if len(self.label_paths) != len(self.image_paths): 74 | set_image_paths = set([os.path.splitext(p.split(os.sep)[-1])[0] for p in self.image_paths]) 75 | set_label_paths = set([os.path.splitext(p.split(os.sep)[-1])[0] for p in self.label_paths]) 76 | print('Path diff:', set_image_paths - set_label_paths) 77 | raise ValueError(f"There are different numbers of images ({len(self.label_paths)}) and labels ({len(self.image_paths)})") 78 | 79 | if self.load_all: 80 | self.images_loaded, self.labels_loaded = [], [] 81 | self.class_labels_loaded = [] 82 | # for image_path, label_path in zip(self.image_paths, self.label_paths): 83 | for image_path, label_path in tqdm(zip(self.image_paths, self.label_paths), total=len(self.image_paths)): 84 | _image = path_to_image(image_path, size=self.data_size, color_type='rgb') 85 | _label = path_to_image(label_path, size=self.data_size, color_type='gray') 86 | self.images_loaded.append(_image) 87 | self.labels_loaded.append(_label) 88 | self.class_labels_loaded.append( 89 | self.cls_name2id[label_path.split('/')[-1].split('#')[3]] if self.is_train and config.auxiliary_classification else -1 90 | ) 91 | 92 | def __getitem__(self, index): 93 | if self.load_all: 94 | image = self.images_loaded[index] 95 | label = self.labels_loaded[index] 96 | class_label = self.class_labels_loaded[index] if self.is_train and config.auxiliary_classification else -1 97 | else: 98 | image = path_to_image(self.image_paths[index], size=self.data_size, color_type='rgb') 99 | label = path_to_image(self.label_paths[index], size=self.data_size, color_type='gray') 100 | class_label = self.cls_name2id[self.label_paths[index].split('/')[-1].split('#')[3]] if self.is_train and config.auxiliary_classification else -1 101 | 102 | # loading image and label 103 | if self.is_train: 104 | if config.background_color_synthesis: 105 | image.putalpha(label) 106 | array_image = np.array(image) 107 | array_foreground = array_image[:, :, :3].astype(np.float32) 108 | array_mask = (array_image[:, :, 3:] / 255).astype(np.float32) 109 | array_background = np.zeros_like(array_foreground) 110 | choice = random.random() 111 | if choice < 0.4: 112 | # Black/Gray/White backgrounds 113 | array_background[:, :, :] = random.randint(0, 255) 114 | elif choice < 0.8: 115 | # Background color that similar to the foreground object. Hard negative samples. 116 | foreground_pixel_number = np.sum(array_mask > 0) 117 | color_foreground_mean = np.mean(array_foreground * array_mask, axis=(0, 1)) * (np.prod(array_foreground.shape[:2]) / foreground_pixel_number) 118 | color_up_or_down = random.choice((-1, 1)) 119 | # Up or down for 20% range from 255 or 0, respectively. 120 | color_foreground_mean += (255 - color_foreground_mean if color_up_or_down == 1 else color_foreground_mean) * (random.random() * 0.2) * color_up_or_down 121 | array_background[:, :, :] = color_foreground_mean 122 | else: 123 | # Any color 124 | for idx_channel in range(3): 125 | array_background[:, :, idx_channel] = random.randint(0, 255) 126 | array_foreground_background = array_foreground * array_mask + array_background * (1 - array_mask) 127 | image = Image.fromarray(array_foreground_background.astype(np.uint8)) 128 | image, label = preproc(image, label, preproc_methods=config.preproc_methods) 129 | # else: 130 | # if _label.shape[0] > 2048 or _label.shape[1] > 2048: 131 | # _image = cv2.resize(_image, (2048, 2048), interpolation=cv2.INTER_LINEAR) 132 | # _label = cv2.resize(_label, (2048, 2048), interpolation=cv2.INTER_LINEAR) 133 | 134 | # At present, we use fixed sizes in inference, instead of consistent dynamic size with training. 135 | if self.is_train: 136 | if config.dynamic_size is None: 137 | image, label = self.transform_image(image), self.transform_label(label) 138 | else: 139 | size_div_32 = (int(image.size[0] // 32 * 32), int(image.size[1] // 32 * 32)) 140 | if image.size != size_div_32: 141 | image = image.resize(size_div_32) 142 | label = label.resize(size_div_32) 143 | image, label = self.transform_image(image), self.transform_label(label) 144 | 145 | if self.is_train: 146 | return image, label, class_label 147 | else: 148 | return image, label, self.label_paths[index] 149 | 150 | def __len__(self): 151 | return len(self.image_paths) 152 | 153 | 154 | def custom_collate_fn(batch): 155 | if config.dynamic_size: 156 | dynamic_size = tuple(sorted(config.dynamic_size)) 157 | dynamic_size_batch = (random.randint(dynamic_size[0][0], dynamic_size[0][1]) // 32 * 32, random.randint(dynamic_size[1][0], dynamic_size[1][1]) // 32 * 32) # select a value randomly in the range of [dynamic_size[0/1][0], dynamic_size[0/1][1]]. 158 | data_size = dynamic_size_batch 159 | else: 160 | data_size = config.size 161 | new_batch = [] 162 | transform_image = transforms.Compose([ 163 | transforms.Resize(data_size[::-1]), 164 | transforms.ToTensor(), 165 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 166 | ]) 167 | transform_label = transforms.Compose([ 168 | transforms.Resize(data_size[::-1]), 169 | transforms.ToTensor(), 170 | ]) 171 | for image, label, class_label in batch: 172 | new_batch.append((transform_image(image), transform_label(label), class_label)) 173 | return data._utils.collate.default_collate(new_batch) 174 | -------------------------------------------------------------------------------- /birefnet/image_proc.py: -------------------------------------------------------------------------------- 1 | import random 2 | from PIL import Image, ImageEnhance 3 | import numpy as np 4 | import cv2 5 | 6 | 7 | def refine_foreground(image, mask, r=90): 8 | if mask.size != image.size: 9 | mask = mask.resize(image.size) 10 | image = np.array(image) / 255.0 11 | mask = np.array(mask) / 255.0 12 | estimated_foreground = FB_blur_fusion_foreground_estimator_2(image, mask, r=r) 13 | image_masked = Image.fromarray((estimated_foreground * 255.0).astype(np.uint8)) 14 | return image_masked 15 | 16 | 17 | def FB_blur_fusion_foreground_estimator_2(image, alpha, r=90): 18 | # Thanks to the source: https://github.com/Photoroom/fast-foreground-estimation 19 | alpha = alpha[:, :, None] 20 | F, blur_B = FB_blur_fusion_foreground_estimator( 21 | image, image, image, alpha, r) 22 | return FB_blur_fusion_foreground_estimator(image, F, blur_B, alpha, r=6)[0] 23 | 24 | 25 | def FB_blur_fusion_foreground_estimator(image, F, B, alpha, r=90): 26 | if isinstance(image, Image.Image): 27 | image = np.array(image) / 255.0 28 | blurred_alpha = cv2.blur(alpha, (r, r))[:, :, None] 29 | 30 | blurred_FA = cv2.blur(F * alpha, (r, r)) 31 | blurred_F = blurred_FA / (blurred_alpha + 1e-5) 32 | 33 | blurred_B1A = cv2.blur(B * (1 - alpha), (r, r)) 34 | blurred_B = blurred_B1A / ((1 - blurred_alpha) + 1e-5) 35 | F = blurred_F + alpha * \ 36 | (image - alpha * blurred_F - (1 - alpha) * blurred_B) 37 | F = np.clip(F, 0, 1) 38 | return F, blurred_B 39 | 40 | 41 | def preproc(image, label, preproc_methods=['flip']): 42 | if 'flip' in preproc_methods: 43 | image, label = cv_random_flip(image, label) 44 | if 'crop' in preproc_methods: 45 | image, label = random_crop(image, label) 46 | if 'rotate' in preproc_methods: 47 | image, label = random_rotate(image, label) 48 | if 'enhance' in preproc_methods: 49 | image = color_enhance(image) 50 | if 'pepper' in preproc_methods: 51 | image = random_pepper(image) 52 | return image, label 53 | 54 | 55 | def cv_random_flip(img, label): 56 | if random.random() > 0.5: 57 | img = img.transpose(Image.FLIP_LEFT_RIGHT) 58 | label = label.transpose(Image.FLIP_LEFT_RIGHT) 59 | return img, label 60 | 61 | 62 | def random_crop(image, label): 63 | border = 30 64 | image_width = image.size[0] 65 | image_height = image.size[1] 66 | border = int(min(image_width, image_height) * 0.1) 67 | crop_win_width = np.random.randint(image_width - border, image_width) 68 | crop_win_height = np.random.randint(image_height - border, image_height) 69 | random_region = ( 70 | (image_width - crop_win_width) >> 1, (image_height - crop_win_height) >> 1, (image_width + crop_win_width) >> 1, 71 | (image_height + crop_win_height) >> 1) 72 | return image.crop(random_region), label.crop(random_region) 73 | 74 | 75 | def random_rotate(image, label, angle=15): 76 | mode = Image.BICUBIC 77 | if random.random() > 0.8: 78 | random_angle = np.random.randint(-angle, angle) 79 | image = image.rotate(random_angle, mode) 80 | label = label.rotate(random_angle, mode) 81 | return image, label 82 | 83 | 84 | def color_enhance(image): 85 | bright_intensity = random.randint(5, 15) / 10.0 86 | image = ImageEnhance.Brightness(image).enhance(bright_intensity) 87 | contrast_intensity = random.randint(5, 15) / 10.0 88 | image = ImageEnhance.Contrast(image).enhance(contrast_intensity) 89 | color_intensity = random.randint(0, 20) / 10.0 90 | image = ImageEnhance.Color(image).enhance(color_intensity) 91 | sharp_intensity = random.randint(0, 30) / 10.0 92 | image = ImageEnhance.Sharpness(image).enhance(sharp_intensity) 93 | return image 94 | 95 | 96 | def random_gaussian(image, mean=0.1, sigma=0.35): 97 | def gaussianNoisy(im, mean=mean, sigma=sigma): 98 | for _i in range(len(im)): 99 | im[_i] += random.gauss(mean, sigma) 100 | return im 101 | 102 | img = np.asarray(image) 103 | width, height = img.shape 104 | img = gaussianNoisy(img[:].flatten(), mean, sigma) 105 | img = img.reshape([width, height]) 106 | return Image.fromarray(np.uint8(img)) 107 | 108 | 109 | def random_pepper(img, N=0.0015): 110 | img = np.array(img) 111 | noiseNum = int(N * img.shape[0] * img.shape[1]) 112 | for i in range(noiseNum): 113 | randX = random.randint(0, img.shape[0] - 1) 114 | randY = random.randint(0, img.shape[1] - 1) 115 | img[randX, randY] = random.randint(0, 1) * 255 116 | return Image.fromarray(img) -------------------------------------------------------------------------------- /birefnet/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lldacing/ComfyUI_BiRefNet_ll/5443a2aa16cfbd98bb2f7dcc8bdcb70439e08529/birefnet/models/__init__.py -------------------------------------------------------------------------------- /birefnet/models/backbones/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lldacing/ComfyUI_BiRefNet_ll/5443a2aa16cfbd98bb2f7dcc8bdcb70439e08529/birefnet/models/backbones/__init__.py -------------------------------------------------------------------------------- /birefnet/models/backbones/build_backbone.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from collections import OrderedDict 4 | from torchvision.models import vgg16, vgg16_bn, VGG16_Weights, VGG16_BN_Weights, resnet50, ResNet50_Weights 5 | from ..backbones.pvt_v2 import pvt_v2_b0, pvt_v2_b1, pvt_v2_b2, pvt_v2_b5 6 | from ..backbones.swin_v1 import swin_v1_t, swin_v1_s, swin_v1_b, swin_v1_l 7 | from ...config import Config 8 | 9 | 10 | config = Config() 11 | 12 | def build_backbone(bb_name, pretrained=True, params_settings=''): 13 | if bb_name == 'vgg16': 14 | bb_net = list(vgg16(pretrained=VGG16_Weights.DEFAULT if pretrained else None).children())[0] 15 | bb = nn.Sequential(OrderedDict({'conv1': bb_net[:4], 'conv2': bb_net[4:9], 'conv3': bb_net[9:16], 'conv4': bb_net[16:23]})) 16 | elif bb_name == 'vgg16bn': 17 | bb_net = list(vgg16_bn(pretrained=VGG16_BN_Weights.DEFAULT if pretrained else None).children())[0] 18 | bb = nn.Sequential(OrderedDict({'conv1': bb_net[:6], 'conv2': bb_net[6:13], 'conv3': bb_net[13:23], 'conv4': bb_net[23:33]})) 19 | elif bb_name == 'resnet50': 20 | bb_net = list(resnet50(pretrained=ResNet50_Weights.DEFAULT if pretrained else None).children()) 21 | bb = nn.Sequential(OrderedDict({'conv1': nn.Sequential(*bb_net[0:3]), 'conv2': bb_net[4], 'conv3': bb_net[5], 'conv4': bb_net[6]})) 22 | else: 23 | bb = eval('{}({})'.format(bb_name, params_settings)) 24 | if pretrained: 25 | bb = load_weights(bb, bb_name) 26 | return bb 27 | 28 | def load_weights(model, model_name): 29 | save_model = torch.load(config.weights[model_name], map_location='cpu', weights_only=True) 30 | model_dict = model.state_dict() 31 | state_dict = {k: v if v.size() == model_dict[k].size() else model_dict[k] for k, v in save_model.items() if k in model_dict.keys()} 32 | # to ignore the weights with mismatched size when I modify the backbone itself. 33 | if not state_dict: 34 | save_model_keys = list(save_model.keys()) 35 | sub_item = save_model_keys[0] if len(save_model_keys) == 1 else None 36 | state_dict = {k: v if v.size() == model_dict[k].size() else model_dict[k] for k, v in save_model[sub_item].items() if k in model_dict.keys()} 37 | if not state_dict or not sub_item: 38 | print('Weights are not successfully loaded. Check the state dict of weights file.') 39 | return None 40 | else: 41 | print('Found correct weights in the "{}" item of loaded state_dict.'.format(sub_item)) 42 | model_dict.update(state_dict) 43 | model.load_state_dict(model_dict) 44 | return model 45 | -------------------------------------------------------------------------------- /birefnet/models/backbones/pvt_v2.py: -------------------------------------------------------------------------------- 1 | import math 2 | from functools import partial 3 | import torch 4 | import torch.nn as nn 5 | 6 | try: 7 | # version > 0.6.13 8 | from timm.layers import DropPath, to_2tuple, trunc_normal_ 9 | except Exception: 10 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 11 | 12 | from ...config import Config 13 | 14 | config = Config() 15 | 16 | class Mlp(nn.Module): 17 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 18 | super().__init__() 19 | out_features = out_features or in_features 20 | hidden_features = hidden_features or in_features 21 | self.fc1 = nn.Linear(in_features, hidden_features) 22 | self.dwconv = DWConv(hidden_features) 23 | self.act = act_layer() 24 | self.fc2 = nn.Linear(hidden_features, out_features) 25 | self.drop = nn.Dropout(drop) 26 | 27 | self.apply(self._init_weights) 28 | 29 | def _init_weights(self, m): 30 | if isinstance(m, nn.Linear): 31 | trunc_normal_(m.weight, std=.02) 32 | if isinstance(m, nn.Linear) and m.bias is not None: 33 | nn.init.constant_(m.bias, 0) 34 | elif isinstance(m, nn.LayerNorm): 35 | nn.init.constant_(m.bias, 0) 36 | nn.init.constant_(m.weight, 1.0) 37 | elif isinstance(m, nn.Conv2d): 38 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 39 | fan_out //= m.groups 40 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 41 | if m.bias is not None: 42 | m.bias.data.zero_() 43 | 44 | def forward(self, x, H, W): 45 | x = self.fc1(x) 46 | x = self.dwconv(x, H, W) 47 | x = self.act(x) 48 | x = self.drop(x) 49 | x = self.fc2(x) 50 | x = self.drop(x) 51 | return x 52 | 53 | 54 | class Attention(nn.Module): 55 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1): 56 | super().__init__() 57 | assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." 58 | 59 | self.dim = dim 60 | self.num_heads = num_heads 61 | head_dim = dim // num_heads 62 | self.scale = qk_scale or head_dim ** -0.5 63 | 64 | self.q = nn.Linear(dim, dim, bias=qkv_bias) 65 | self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) 66 | self.attn_drop_prob = attn_drop 67 | self.attn_drop = nn.Dropout(attn_drop) 68 | self.proj = nn.Linear(dim, dim) 69 | self.proj_drop = nn.Dropout(proj_drop) 70 | 71 | self.sr_ratio = sr_ratio 72 | if sr_ratio > 1: 73 | self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) 74 | self.norm = nn.LayerNorm(dim) 75 | 76 | self.apply(self._init_weights) 77 | 78 | def _init_weights(self, m): 79 | if isinstance(m, nn.Linear): 80 | trunc_normal_(m.weight, std=.02) 81 | if isinstance(m, nn.Linear) and m.bias is not None: 82 | nn.init.constant_(m.bias, 0) 83 | elif isinstance(m, nn.LayerNorm): 84 | nn.init.constant_(m.bias, 0) 85 | nn.init.constant_(m.weight, 1.0) 86 | elif isinstance(m, nn.Conv2d): 87 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 88 | fan_out //= m.groups 89 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 90 | if m.bias is not None: 91 | m.bias.data.zero_() 92 | 93 | def forward(self, x, H, W): 94 | B, N, C = x.shape 95 | q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 96 | 97 | if self.sr_ratio > 1: 98 | x_ = x.permute(0, 2, 1).reshape(B, C, H, W) 99 | x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1) 100 | x_ = self.norm(x_) 101 | kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 102 | else: 103 | kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 104 | k, v = kv[0], kv[1] 105 | 106 | if config.SDPA_enabled: 107 | x = torch.nn.functional.scaled_dot_product_attention( 108 | q, k, v, 109 | attn_mask=None, dropout_p=self.attn_drop_prob, is_causal=False 110 | ).transpose(1, 2).reshape(B, N, C) 111 | else: 112 | attn = (q @ k.transpose(-2, -1)) * self.scale 113 | attn = attn.softmax(dim=-1) 114 | attn = self.attn_drop(attn) 115 | 116 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 117 | x = self.proj(x) 118 | x = self.proj_drop(x) 119 | 120 | return x 121 | 122 | 123 | class Block(nn.Module): 124 | 125 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 126 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1): 127 | super().__init__() 128 | self.norm1 = norm_layer(dim) 129 | self.attn = Attention( 130 | dim, 131 | num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, 132 | attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio) 133 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 134 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 135 | self.norm2 = norm_layer(dim) 136 | mlp_hidden_dim = int(dim * mlp_ratio) 137 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 138 | 139 | self.apply(self._init_weights) 140 | 141 | def _init_weights(self, m): 142 | if isinstance(m, nn.Linear): 143 | trunc_normal_(m.weight, std=.02) 144 | if isinstance(m, nn.Linear) and m.bias is not None: 145 | nn.init.constant_(m.bias, 0) 146 | elif isinstance(m, nn.LayerNorm): 147 | nn.init.constant_(m.bias, 0) 148 | nn.init.constant_(m.weight, 1.0) 149 | elif isinstance(m, nn.Conv2d): 150 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 151 | fan_out //= m.groups 152 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 153 | if m.bias is not None: 154 | m.bias.data.zero_() 155 | 156 | def forward(self, x, H, W): 157 | x = x + self.drop_path(self.attn(self.norm1(x), H, W)) 158 | x = x + self.drop_path(self.mlp(self.norm2(x), H, W)) 159 | 160 | return x 161 | 162 | 163 | class OverlapPatchEmbed(nn.Module): 164 | """ Image to Patch Embedding 165 | """ 166 | 167 | def __init__(self, img_size=224, patch_size=7, stride=4, in_channels=3, embed_dim=768): 168 | super().__init__() 169 | img_size = to_2tuple(img_size) 170 | patch_size = to_2tuple(patch_size) 171 | 172 | self.img_size = img_size 173 | self.patch_size = patch_size 174 | self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1] 175 | self.num_patches = self.H * self.W 176 | self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=stride, 177 | padding=(patch_size[0] // 2, patch_size[1] // 2)) 178 | self.norm = nn.LayerNorm(embed_dim) 179 | 180 | self.apply(self._init_weights) 181 | 182 | def _init_weights(self, m): 183 | if isinstance(m, nn.Linear): 184 | trunc_normal_(m.weight, std=.02) 185 | if isinstance(m, nn.Linear) and m.bias is not None: 186 | nn.init.constant_(m.bias, 0) 187 | elif isinstance(m, nn.LayerNorm): 188 | nn.init.constant_(m.bias, 0) 189 | nn.init.constant_(m.weight, 1.0) 190 | elif isinstance(m, nn.Conv2d): 191 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 192 | fan_out //= m.groups 193 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 194 | if m.bias is not None: 195 | m.bias.data.zero_() 196 | 197 | def forward(self, x): 198 | x = self.proj(x) 199 | _, _, H, W = x.shape 200 | x = x.flatten(2).transpose(1, 2) 201 | x = self.norm(x) 202 | 203 | return x, H, W 204 | 205 | 206 | class PyramidVisionTransformerImpr(nn.Module): 207 | def __init__(self, img_size=224, patch_size=16, in_channels=3, num_classes=1000, embed_dims=[64, 128, 256, 512], 208 | num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0., 209 | attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, 210 | depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1]): 211 | super().__init__() 212 | self.num_classes = num_classes 213 | self.depths = depths 214 | 215 | # patch_embed 216 | self.patch_embed1 = OverlapPatchEmbed(img_size=img_size, patch_size=7, stride=4, in_channels=in_channels, 217 | embed_dim=embed_dims[0]) 218 | self.patch_embed2 = OverlapPatchEmbed(img_size=img_size // 4, patch_size=3, stride=2, in_channels=embed_dims[0], 219 | embed_dim=embed_dims[1]) 220 | self.patch_embed3 = OverlapPatchEmbed(img_size=img_size // 8, patch_size=3, stride=2, in_channels=embed_dims[1], 221 | embed_dim=embed_dims[2]) 222 | self.patch_embed4 = OverlapPatchEmbed(img_size=img_size // 16, patch_size=3, stride=2, in_channels=embed_dims[2], 223 | embed_dim=embed_dims[3]) 224 | 225 | # transformer encoder 226 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule 227 | cur = 0 228 | self.block1 = nn.ModuleList([Block( 229 | dim=embed_dims[0], num_heads=num_heads[0], mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale, 230 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, 231 | sr_ratio=sr_ratios[0]) 232 | for i in range(depths[0])]) 233 | self.norm1 = norm_layer(embed_dims[0]) 234 | 235 | cur += depths[0] 236 | self.block2 = nn.ModuleList([Block( 237 | dim=embed_dims[1], num_heads=num_heads[1], mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, qk_scale=qk_scale, 238 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, 239 | sr_ratio=sr_ratios[1]) 240 | for i in range(depths[1])]) 241 | self.norm2 = norm_layer(embed_dims[1]) 242 | 243 | cur += depths[1] 244 | self.block3 = nn.ModuleList([Block( 245 | dim=embed_dims[2], num_heads=num_heads[2], mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, qk_scale=qk_scale, 246 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, 247 | sr_ratio=sr_ratios[2]) 248 | for i in range(depths[2])]) 249 | self.norm3 = norm_layer(embed_dims[2]) 250 | 251 | cur += depths[2] 252 | self.block4 = nn.ModuleList([Block( 253 | dim=embed_dims[3], num_heads=num_heads[3], mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, qk_scale=qk_scale, 254 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, 255 | sr_ratio=sr_ratios[3]) 256 | for i in range(depths[3])]) 257 | self.norm4 = norm_layer(embed_dims[3]) 258 | 259 | # classification head 260 | # self.head = nn.Linear(embed_dims[3], num_classes) if num_classes > 0 else nn.Identity() 261 | 262 | self.apply(self._init_weights) 263 | 264 | def _init_weights(self, m): 265 | if isinstance(m, nn.Linear): 266 | trunc_normal_(m.weight, std=.02) 267 | if isinstance(m, nn.Linear) and m.bias is not None: 268 | nn.init.constant_(m.bias, 0) 269 | elif isinstance(m, nn.LayerNorm): 270 | nn.init.constant_(m.bias, 0) 271 | nn.init.constant_(m.weight, 1.0) 272 | elif isinstance(m, nn.Conv2d): 273 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 274 | fan_out //= m.groups 275 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 276 | if m.bias is not None: 277 | m.bias.data.zero_() 278 | 279 | def init_weights(self, pretrained=None): 280 | if isinstance(pretrained, str): 281 | logger = 1 282 | #load_checkpoint(self, pretrained, map_location='cpu', strict=False, logger=logger) 283 | 284 | def reset_drop_path(self, drop_path_rate): 285 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))] 286 | cur = 0 287 | for i in range(self.depths[0]): 288 | self.block1[i].drop_path.drop_prob = dpr[cur + i] 289 | 290 | cur += self.depths[0] 291 | for i in range(self.depths[1]): 292 | self.block2[i].drop_path.drop_prob = dpr[cur + i] 293 | 294 | cur += self.depths[1] 295 | for i in range(self.depths[2]): 296 | self.block3[i].drop_path.drop_prob = dpr[cur + i] 297 | 298 | cur += self.depths[2] 299 | for i in range(self.depths[3]): 300 | self.block4[i].drop_path.drop_prob = dpr[cur + i] 301 | 302 | def freeze_patch_emb(self): 303 | self.patch_embed1.requires_grad = False 304 | 305 | @torch.jit.ignore 306 | def no_weight_decay(self): 307 | return {'pos_embed1', 'pos_embed2', 'pos_embed3', 'pos_embed4', 'cls_token'} # has pos_embed may be better 308 | 309 | def get_classifier(self): 310 | return self.head 311 | 312 | def reset_classifier(self, num_classes, global_pool=''): 313 | self.num_classes = num_classes 314 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 315 | 316 | def forward_features(self, x): 317 | B = x.shape[0] 318 | outs = [] 319 | 320 | # stage 1 321 | x, H, W = self.patch_embed1(x) 322 | for i, blk in enumerate(self.block1): 323 | x = blk(x, H, W) 324 | x = self.norm1(x) 325 | x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 326 | outs.append(x) 327 | 328 | # stage 2 329 | x, H, W = self.patch_embed2(x) 330 | for i, blk in enumerate(self.block2): 331 | x = blk(x, H, W) 332 | x = self.norm2(x) 333 | x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 334 | outs.append(x) 335 | 336 | # stage 3 337 | x, H, W = self.patch_embed3(x) 338 | for i, blk in enumerate(self.block3): 339 | x = blk(x, H, W) 340 | x = self.norm3(x) 341 | x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 342 | outs.append(x) 343 | 344 | # stage 4 345 | x, H, W = self.patch_embed4(x) 346 | for i, blk in enumerate(self.block4): 347 | x = blk(x, H, W) 348 | x = self.norm4(x) 349 | x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 350 | outs.append(x) 351 | 352 | return outs 353 | 354 | # return x.mean(dim=1) 355 | 356 | def forward(self, x): 357 | x = self.forward_features(x) 358 | # x = self.head(x) 359 | 360 | return x 361 | 362 | 363 | class DWConv(nn.Module): 364 | def __init__(self, dim=768): 365 | super(DWConv, self).__init__() 366 | self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) 367 | 368 | def forward(self, x, H, W): 369 | B, N, C = x.shape 370 | x = x.transpose(1, 2).view(B, C, H, W).contiguous() 371 | x = self.dwconv(x) 372 | x = x.flatten(2).transpose(1, 2) 373 | 374 | return x 375 | 376 | 377 | def _conv_filter(state_dict, patch_size=16): 378 | """ convert patch embedding weight from manual patchify + linear proj to conv""" 379 | out_dict = {} 380 | for k, v in state_dict.items(): 381 | if 'patch_embed.proj.weight' in k: 382 | v = v.reshape((v.shape[0], 3, patch_size, patch_size)) 383 | out_dict[k] = v 384 | 385 | return out_dict 386 | 387 | 388 | class pvt_v2_b0(PyramidVisionTransformerImpr): 389 | def __init__(self, **kwargs): 390 | super(pvt_v2_b0, self).__init__( 391 | patch_size=4, embed_dims=[32, 64, 160, 256], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], 392 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], 393 | drop_rate=0.0, drop_path_rate=0.1) 394 | 395 | 396 | class pvt_v2_b1(PyramidVisionTransformerImpr): 397 | def __init__(self, **kwargs): 398 | super(pvt_v2_b1, self).__init__( 399 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], 400 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], 401 | drop_rate=0.0, drop_path_rate=0.1) 402 | 403 | 404 | class pvt_v2_b2(PyramidVisionTransformerImpr): 405 | def __init__(self, in_channels=3, **kwargs): 406 | super(pvt_v2_b2, self).__init__( 407 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], 408 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], 409 | drop_rate=0.0, drop_path_rate=0.1, in_channels=in_channels) 410 | 411 | 412 | class pvt_v2_b3(PyramidVisionTransformerImpr): 413 | def __init__(self, **kwargs): 414 | super(pvt_v2_b3, self).__init__( 415 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], 416 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1], 417 | drop_rate=0.0, drop_path_rate=0.1) 418 | 419 | 420 | class pvt_v2_b4(PyramidVisionTransformerImpr): 421 | def __init__(self, **kwargs): 422 | super(pvt_v2_b4, self).__init__( 423 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], 424 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 8, 27, 3], sr_ratios=[8, 4, 2, 1], 425 | drop_rate=0.0, drop_path_rate=0.1) 426 | 427 | 428 | class pvt_v2_b5(PyramidVisionTransformerImpr): 429 | def __init__(self, **kwargs): 430 | super(pvt_v2_b5, self).__init__( 431 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], 432 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 6, 40, 3], sr_ratios=[8, 4, 2, 1], 433 | drop_rate=0.0, drop_path_rate=0.1) 434 | -------------------------------------------------------------------------------- /birefnet/models/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lldacing/ComfyUI_BiRefNet_ll/5443a2aa16cfbd98bb2f7dcc8bdcb70439e08529/birefnet/models/modules/__init__.py -------------------------------------------------------------------------------- /birefnet/models/modules/aspp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from ..modules.deform_conv import DeformableConv2d 5 | from ...config import Config 6 | 7 | 8 | config = Config() 9 | 10 | 11 | class _ASPPModule(nn.Module): 12 | def __init__(self, in_channels, planes, kernel_size, padding, dilation): 13 | super(_ASPPModule, self).__init__() 14 | self.atrous_conv = nn.Conv2d(in_channels, planes, kernel_size=kernel_size, 15 | stride=1, padding=padding, dilation=dilation, bias=False) 16 | self.bn = nn.BatchNorm2d(planes) if config.batch_size > 1 else nn.Identity() 17 | self.relu = nn.ReLU(inplace=True) 18 | 19 | def forward(self, x): 20 | x = self.atrous_conv(x) 21 | x = self.bn(x) 22 | 23 | return self.relu(x) 24 | 25 | 26 | class ASPP(nn.Module): 27 | def __init__(self, in_channels=64, out_channels=None, output_stride=16): 28 | super(ASPP, self).__init__() 29 | self.down_scale = 1 30 | if out_channels is None: 31 | out_channels = in_channels 32 | self.in_channelster = 256 // self.down_scale 33 | if output_stride == 16: 34 | dilations = [1, 6, 12, 18] 35 | elif output_stride == 8: 36 | dilations = [1, 12, 24, 36] 37 | else: 38 | raise NotImplementedError 39 | 40 | self.aspp1 = _ASPPModule(in_channels, self.in_channelster, 1, padding=0, dilation=dilations[0]) 41 | self.aspp2 = _ASPPModule(in_channels, self.in_channelster, 3, padding=dilations[1], dilation=dilations[1]) 42 | self.aspp3 = _ASPPModule(in_channels, self.in_channelster, 3, padding=dilations[2], dilation=dilations[2]) 43 | self.aspp4 = _ASPPModule(in_channels, self.in_channelster, 3, padding=dilations[3], dilation=dilations[3]) 44 | 45 | self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)), 46 | nn.Conv2d(in_channels, self.in_channelster, 1, stride=1, bias=False), 47 | nn.BatchNorm2d(self.in_channelster) if config.batch_size > 1 else nn.Identity(), 48 | nn.ReLU(inplace=True)) 49 | self.conv1 = nn.Conv2d(self.in_channelster * 5, out_channels, 1, bias=False) 50 | self.bn1 = nn.BatchNorm2d(out_channels) if config.batch_size > 1 else nn.Identity() 51 | self.relu = nn.ReLU(inplace=True) 52 | self.dropout = nn.Dropout(0.5) 53 | 54 | def forward(self, x): 55 | x1 = self.aspp1(x) 56 | x2 = self.aspp2(x) 57 | x3 = self.aspp3(x) 58 | x4 = self.aspp4(x) 59 | x5 = self.global_avg_pool(x) 60 | x5 = F.interpolate(x5, size=x1.size()[2:], mode='bilinear', align_corners=True) 61 | x = torch.cat((x1, x2, x3, x4, x5), dim=1) 62 | 63 | x = self.conv1(x) 64 | x = self.bn1(x) 65 | x = self.relu(x) 66 | 67 | return self.dropout(x) 68 | 69 | 70 | ##################### Deformable 71 | class _ASPPModuleDeformable(nn.Module): 72 | def __init__(self, in_channels, planes, kernel_size, padding): 73 | super(_ASPPModuleDeformable, self).__init__() 74 | self.atrous_conv = DeformableConv2d(in_channels, planes, kernel_size=kernel_size, 75 | stride=1, padding=padding, bias=False) 76 | self.bn = nn.BatchNorm2d(planes) if config.batch_size > 1 else nn.Identity() 77 | self.relu = nn.ReLU(inplace=True) 78 | 79 | def forward(self, x): 80 | x = self.atrous_conv(x) 81 | x = self.bn(x) 82 | 83 | return self.relu(x) 84 | 85 | 86 | class ASPPDeformable(nn.Module): 87 | def __init__(self, in_channels, out_channels=None, parallel_block_sizes=[1, 3, 7]): 88 | super(ASPPDeformable, self).__init__() 89 | self.down_scale = 1 90 | if out_channels is None: 91 | out_channels = in_channels 92 | self.in_channelster = 256 // self.down_scale 93 | 94 | self.aspp1 = _ASPPModuleDeformable(in_channels, self.in_channelster, 1, padding=0) 95 | self.aspp_deforms = nn.ModuleList([ 96 | _ASPPModuleDeformable(in_channels, self.in_channelster, conv_size, padding=int(conv_size//2)) for conv_size in parallel_block_sizes 97 | ]) 98 | 99 | self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)), 100 | nn.Conv2d(in_channels, self.in_channelster, 1, stride=1, bias=False), 101 | nn.BatchNorm2d(self.in_channelster) if config.batch_size > 1 else nn.Identity(), 102 | nn.ReLU(inplace=True)) 103 | self.conv1 = nn.Conv2d(self.in_channelster * (2 + len(self.aspp_deforms)), out_channels, 1, bias=False) 104 | self.bn1 = nn.BatchNorm2d(out_channels) if config.batch_size > 1 else nn.Identity() 105 | self.relu = nn.ReLU(inplace=True) 106 | self.dropout = nn.Dropout(0.5) 107 | 108 | def forward(self, x): 109 | x1 = self.aspp1(x) 110 | x_aspp_deforms = [aspp_deform(x) for aspp_deform in self.aspp_deforms] 111 | x5 = self.global_avg_pool(x) 112 | x5 = F.interpolate(x5, size=x1.size()[2:], mode='bilinear', align_corners=True) 113 | x = torch.cat((x1, *x_aspp_deforms, x5), dim=1) 114 | 115 | x = self.conv1(x) 116 | x = self.bn1(x) 117 | x = self.relu(x) 118 | 119 | return self.dropout(x) 120 | -------------------------------------------------------------------------------- /birefnet/models/modules/decoder_blocks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from ..modules.aspp import ASPP, ASPPDeformable 4 | from ...config import Config 5 | 6 | 7 | config = Config() 8 | 9 | 10 | class BasicDecBlk(nn.Module): 11 | def __init__(self, in_channels=64, out_channels=64, inter_channels=64): 12 | super(BasicDecBlk, self).__init__() 13 | inter_channels = in_channels // 4 if config.dec_channels_inter == 'adap' else 64 14 | self.conv_in = nn.Conv2d(in_channels, inter_channels, 3, 1, padding=1) 15 | self.relu_in = nn.ReLU(inplace=True) 16 | if config.dec_att == 'ASPP': 17 | self.dec_att = ASPP(in_channels=inter_channels) 18 | elif config.dec_att == 'ASPPDeformable': 19 | self.dec_att = ASPPDeformable(in_channels=inter_channels) 20 | self.conv_out = nn.Conv2d(inter_channels, out_channels, 3, 1, padding=1) 21 | self.bn_in = nn.BatchNorm2d(inter_channels) if config.batch_size > 1 else nn.Identity() 22 | self.bn_out = nn.BatchNorm2d(out_channels) if config.batch_size > 1 else nn.Identity() 23 | 24 | def forward(self, x): 25 | x = self.conv_in(x) 26 | x = self.bn_in(x) 27 | x = self.relu_in(x) 28 | if hasattr(self, 'dec_att'): 29 | x = self.dec_att(x) 30 | x = self.conv_out(x) 31 | x = self.bn_out(x) 32 | return x 33 | 34 | 35 | class ResBlk(nn.Module): 36 | def __init__(self, in_channels=64, out_channels=None, inter_channels=64): 37 | super(ResBlk, self).__init__() 38 | if out_channels is None: 39 | out_channels = in_channels 40 | inter_channels = in_channels // 4 if config.dec_channels_inter == 'adap' else 64 41 | 42 | self.conv_in = nn.Conv2d(in_channels, inter_channels, 3, 1, padding=1) 43 | self.bn_in = nn.BatchNorm2d(inter_channels) if config.batch_size > 1 else nn.Identity() 44 | self.relu_in = nn.ReLU(inplace=True) 45 | 46 | if config.dec_att == 'ASPP': 47 | self.dec_att = ASPP(in_channels=inter_channels) 48 | elif config.dec_att == 'ASPPDeformable': 49 | self.dec_att = ASPPDeformable(in_channels=inter_channels) 50 | 51 | self.conv_out = nn.Conv2d(inter_channels, out_channels, 3, 1, padding=1) 52 | self.bn_out = nn.BatchNorm2d(out_channels) if config.batch_size > 1 else nn.Identity() 53 | 54 | self.conv_resi = nn.Conv2d(in_channels, out_channels, 1, 1, 0) 55 | 56 | def forward(self, x): 57 | _x = self.conv_resi(x) 58 | x = self.conv_in(x) 59 | x = self.bn_in(x) 60 | x = self.relu_in(x) 61 | if hasattr(self, 'dec_att'): 62 | x = self.dec_att(x) 63 | x = self.conv_out(x) 64 | x = self.bn_out(x) 65 | return x + _x -------------------------------------------------------------------------------- /birefnet/models/modules/deform_conv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torchvision.ops import deform_conv2d 4 | 5 | 6 | class DeformableConv2d(nn.Module): 7 | def __init__(self, 8 | in_channels, 9 | out_channels, 10 | kernel_size=3, 11 | stride=1, 12 | padding=1, 13 | bias=False): 14 | 15 | super(DeformableConv2d, self).__init__() 16 | 17 | assert type(kernel_size) == tuple or type(kernel_size) == int 18 | 19 | kernel_size = kernel_size if type(kernel_size) == tuple else (kernel_size, kernel_size) 20 | self.stride = stride if type(stride) == tuple else (stride, stride) 21 | self.padding = padding 22 | 23 | self.offset_conv = nn.Conv2d(in_channels, 24 | 2 * kernel_size[0] * kernel_size[1], 25 | kernel_size=kernel_size, 26 | stride=stride, 27 | padding=self.padding, 28 | bias=True) 29 | 30 | nn.init.constant_(self.offset_conv.weight, 0.) 31 | nn.init.constant_(self.offset_conv.bias, 0.) 32 | 33 | self.modulator_conv = nn.Conv2d(in_channels, 34 | 1 * kernel_size[0] * kernel_size[1], 35 | kernel_size=kernel_size, 36 | stride=stride, 37 | padding=self.padding, 38 | bias=True) 39 | 40 | nn.init.constant_(self.modulator_conv.weight, 0.) 41 | nn.init.constant_(self.modulator_conv.bias, 0.) 42 | 43 | self.regular_conv = nn.Conv2d(in_channels, 44 | out_channels=out_channels, 45 | kernel_size=kernel_size, 46 | stride=stride, 47 | padding=self.padding, 48 | bias=bias) 49 | 50 | def forward(self, x): 51 | #h, w = x.shape[2:] 52 | #max_offset = max(h, w)/4. 53 | 54 | offset = self.offset_conv(x)#.clamp(-max_offset, max_offset) 55 | modulator = 2. * torch.sigmoid(self.modulator_conv(x)) 56 | 57 | x = deform_conv2d( 58 | input=x, 59 | offset=offset, 60 | weight=self.regular_conv.weight, 61 | bias=self.regular_conv.bias, 62 | padding=self.padding, 63 | mask=modulator, 64 | stride=self.stride, 65 | ) 66 | return x 67 | -------------------------------------------------------------------------------- /birefnet/models/modules/lateral_blocks.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from functools import partial 6 | 7 | from ...config import Config 8 | 9 | 10 | config = Config() 11 | 12 | 13 | class BasicLatBlk(nn.Module): 14 | def __init__(self, in_channels=64, out_channels=64, inter_channels=64): 15 | super(BasicLatBlk, self).__init__() 16 | inter_channels = in_channels // 4 if config.dec_channels_inter == 'adap' else 64 17 | self.conv = nn.Conv2d(in_channels, out_channels, 1, 1, 0) 18 | 19 | def forward(self, x): 20 | x = self.conv(x) 21 | return x 22 | -------------------------------------------------------------------------------- /birefnet/models/modules/prompt_encoder.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from typing import Any, Optional, Tuple, Type 5 | 6 | 7 | class PromptEncoder(nn.Module): 8 | def __init__( 9 | self, 10 | embed_dim=256, 11 | image_embedding_size=1024, 12 | input_image_size=(1024, 1024), 13 | mask_in_chans=16, 14 | activation=nn.GELU 15 | ) -> None: 16 | super().__init__() 17 | """ 18 | Codes are partially from SAM: https://github.com/facebookresearch/segment-anything/blob/6fdee8f2727f4506cfbbe553e23b895e27956588/segment_anything/modeling/prompt_encoder.py. 19 | 20 | Arguments: 21 | embed_dim (int): The prompts' embedding dimension 22 | image_embedding_size (tuple(int, int)): The spatial size of the 23 | image embedding, as (H, W). 24 | input_image_size (int): The padded size of the image as input 25 | to the image encoder, as (H, W). 26 | mask_in_chans (int): The number of hidden channels used for 27 | encoding input masks. 28 | activation (nn.Module): The activation to use when encoding 29 | input masks. 30 | """ 31 | super().__init__() 32 | self.embed_dim = embed_dim 33 | self.input_image_size = input_image_size 34 | self.image_embedding_size = image_embedding_size 35 | self.pe_layer = PositionEmbeddingRandom(embed_dim // 2) 36 | 37 | self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners 38 | point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)] 39 | self.point_embeddings = nn.ModuleList(point_embeddings) 40 | self.not_a_point_embed = nn.Embedding(1, embed_dim) 41 | 42 | self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1]) 43 | self.mask_downscaling = nn.Sequential( 44 | nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2), 45 | LayerNorm2d(mask_in_chans // 4), 46 | activation(), 47 | nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2), 48 | LayerNorm2d(mask_in_chans), 49 | activation(), 50 | nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1), 51 | ) 52 | self.no_mask_embed = nn.Embedding(1, embed_dim) 53 | 54 | def get_dense_pe(self) -> torch.Tensor: 55 | """ 56 | Returns the positional encoding used to encode point prompts, 57 | applied to a dense set of points the shape of the image encoding. 58 | 59 | Returns: 60 | torch.Tensor: Positional encoding with shape 61 | 1x(embed_dim)x(embedding_h)x(embedding_w) 62 | """ 63 | return self.pe_layer(self.image_embedding_size).unsqueeze(0) 64 | 65 | def _embed_points( 66 | self, 67 | points: torch.Tensor, 68 | labels: torch.Tensor, 69 | pad: bool, 70 | ) -> torch.Tensor: 71 | """Embeds point prompts.""" 72 | points = points + 0.5 # Shift to center of pixel 73 | if pad: 74 | padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device) 75 | padding_label = -torch.ones((labels.shape[0], 1), device=labels.device) 76 | points = torch.cat([points, padding_point], dim=1) 77 | labels = torch.cat([labels, padding_label], dim=1) 78 | point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size) 79 | point_embedding[labels == -1] = 0.0 80 | point_embedding[labels == -1] += self.not_a_point_embed.weight 81 | point_embedding[labels == 0] += self.point_embeddings[0].weight 82 | point_embedding[labels == 1] += self.point_embeddings[1].weight 83 | return point_embedding 84 | 85 | def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: 86 | """Embeds box prompts.""" 87 | boxes = boxes + 0.5 # Shift to center of pixel 88 | coords = boxes.reshape(-1, 2, 2) 89 | corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size) 90 | corner_embedding[:, 0, :] += self.point_embeddings[2].weight 91 | corner_embedding[:, 1, :] += self.point_embeddings[3].weight 92 | return corner_embedding 93 | 94 | def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor: 95 | """Embeds mask inputs.""" 96 | mask_embedding = self.mask_downscaling(masks) 97 | return mask_embedding 98 | 99 | def _get_batch_size( 100 | self, 101 | points: Optional[Tuple[torch.Tensor, torch.Tensor]], 102 | boxes: Optional[torch.Tensor], 103 | masks: Optional[torch.Tensor], 104 | ) -> int: 105 | """ 106 | Gets the batch size of the output given the batch size of the input prompts. 107 | """ 108 | if points is not None: 109 | return points[0].shape[0] 110 | elif boxes is not None: 111 | return boxes.shape[0] 112 | elif masks is not None: 113 | return masks.shape[0] 114 | else: 115 | return 1 116 | 117 | def _get_device(self) -> torch.device: 118 | return self.point_embeddings[0].weight.device 119 | 120 | def forward( 121 | self, 122 | points: Optional[Tuple[torch.Tensor, torch.Tensor]], 123 | boxes: Optional[torch.Tensor], 124 | masks: Optional[torch.Tensor], 125 | ) -> Tuple[torch.Tensor, torch.Tensor]: 126 | """ 127 | Embeds different types of prompts, returning both sparse and dense 128 | embeddings. 129 | 130 | Arguments: 131 | points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates 132 | and labels to embed. 133 | boxes (torch.Tensor or none): boxes to embed 134 | masks (torch.Tensor or none): masks to embed 135 | 136 | Returns: 137 | torch.Tensor: sparse embeddings for the points and boxes, with shape 138 | BxNx(embed_dim), where N is determined by the number of input points 139 | and boxes. 140 | torch.Tensor: dense embeddings for the masks, in the shape 141 | Bx(embed_dim)x(embed_H)x(embed_W) 142 | """ 143 | bs = self._get_batch_size(points, boxes, masks) 144 | sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device()) 145 | if points is not None: 146 | coords, labels = points 147 | point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)) 148 | sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1) 149 | if boxes is not None: 150 | box_embeddings = self._embed_boxes(boxes) 151 | sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1) 152 | 153 | if masks is not None: 154 | dense_embeddings = self._embed_masks(masks) 155 | else: 156 | dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( 157 | bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] 158 | ) 159 | 160 | return sparse_embeddings, dense_embeddings 161 | 162 | 163 | class PositionEmbeddingRandom(nn.Module): 164 | """ 165 | Positional encoding using random spatial frequencies. 166 | """ 167 | 168 | def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: 169 | super().__init__() 170 | if scale is None or scale <= 0.0: 171 | scale = 1.0 172 | self.register_buffer( 173 | "positional_encoding_gaussian_matrix", 174 | scale * torch.randn((2, num_pos_feats)), 175 | ) 176 | 177 | def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: 178 | """Positionally encode points that are normalized to [0,1].""" 179 | # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape 180 | coords = 2 * coords - 1 181 | coords = coords @ self.positional_encoding_gaussian_matrix 182 | coords = 2 * np.pi * coords 183 | # outputs d_1 x ... x d_n x C shape 184 | return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) 185 | 186 | def forward(self, size: Tuple[int, int]) -> torch.Tensor: 187 | """Generate positional encoding for a grid of the specified size.""" 188 | h, w = size 189 | device: Any = self.positional_encoding_gaussian_matrix.device 190 | grid = torch.ones((h, w), device=device, dtype=torch.float32) 191 | y_embed = grid.cumsum(dim=0) - 0.5 192 | x_embed = grid.cumsum(dim=1) - 0.5 193 | y_embed = y_embed / h 194 | x_embed = x_embed / w 195 | 196 | pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1)) 197 | return pe.permute(2, 0, 1) # C x H x W 198 | 199 | def forward_with_coords( 200 | self, coords_input: torch.Tensor, image_size: Tuple[int, int] 201 | ) -> torch.Tensor: 202 | """Positionally encode points that are not normalized to [0,1].""" 203 | coords = coords_input.clone() 204 | coords[:, :, 0] = coords[:, :, 0] / image_size[1] 205 | coords[:, :, 1] = coords[:, :, 1] / image_size[0] 206 | return self._pe_encoding(coords.to(torch.float)) # B x N x C 207 | 208 | 209 | class LayerNorm2d(nn.Module): 210 | def __init__(self, num_channels: int, eps: float = 1e-6) -> None: 211 | super().__init__() 212 | self.weight = nn.Parameter(torch.ones(num_channels)) 213 | self.bias = nn.Parameter(torch.zeros(num_channels)) 214 | self.eps = eps 215 | 216 | def forward(self, x: torch.Tensor) -> torch.Tensor: 217 | u = x.mean(1, keepdim=True) 218 | s = (x - u).pow(2).mean(1, keepdim=True) 219 | x = (x - u) / torch.sqrt(s + self.eps) 220 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 221 | return x 222 | 223 | -------------------------------------------------------------------------------- /birefnet/models/modules/utils.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | def build_act_layer(act_layer): 5 | if act_layer == 'ReLU': 6 | return nn.ReLU(inplace=True) 7 | elif act_layer == 'SiLU': 8 | return nn.SiLU(inplace=True) 9 | elif act_layer == 'GELU': 10 | return nn.GELU() 11 | 12 | raise NotImplementedError(f'build_act_layer does not support {act_layer}') 13 | 14 | 15 | def build_norm_layer(dim, 16 | norm_layer, 17 | in_format='channels_last', 18 | out_format='channels_last', 19 | eps=1e-6): 20 | layers = [] 21 | if norm_layer == 'BN': 22 | if in_format == 'channels_last': 23 | layers.append(to_channels_first()) 24 | layers.append(nn.BatchNorm2d(dim)) 25 | if out_format == 'channels_last': 26 | layers.append(to_channels_last()) 27 | elif norm_layer == 'LN': 28 | if in_format == 'channels_first': 29 | layers.append(to_channels_last()) 30 | layers.append(nn.LayerNorm(dim, eps=eps)) 31 | if out_format == 'channels_first': 32 | layers.append(to_channels_first()) 33 | else: 34 | raise NotImplementedError( 35 | f'build_norm_layer does not support {norm_layer}') 36 | return nn.Sequential(*layers) 37 | 38 | 39 | class to_channels_first(nn.Module): 40 | 41 | def __init__(self): 42 | super().__init__() 43 | 44 | def forward(self, x): 45 | return x.permute(0, 3, 1, 2) 46 | 47 | 48 | class to_channels_last(nn.Module): 49 | 50 | def __init__(self): 51 | super().__init__() 52 | 53 | def forward(self, x): 54 | return x.permute(0, 2, 3, 1) 55 | -------------------------------------------------------------------------------- /birefnet/models/refinement/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lldacing/ComfyUI_BiRefNet_ll/5443a2aa16cfbd98bb2f7dcc8bdcb70439e08529/birefnet/models/refinement/__init__.py -------------------------------------------------------------------------------- /birefnet/models/refinement/refiner.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from collections import OrderedDict 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torchvision.models import vgg16, vgg16_bn 8 | from torchvision.models import resnet50 9 | 10 | from ...config import Config 11 | from ...dataset import class_labels_TR_sorted 12 | from ..backbones.build_backbone import build_backbone 13 | from ..modules.decoder_blocks import BasicDecBlk 14 | from ..modules.lateral_blocks import BasicLatBlk 15 | from ..refinement.stem_layer import StemLayer 16 | 17 | 18 | class RefinerPVTInChannels4(nn.Module): 19 | def __init__(self, in_channels=3+1): 20 | super(RefinerPVTInChannels4, self).__init__() 21 | self.config = Config() 22 | self.epoch = 1 23 | self.bb = build_backbone(self.config.bb, params_settings='in_channels=4') 24 | 25 | lateral_channels_in_collection = { 26 | 'vgg16': [512, 256, 128, 64], 'vgg16bn': [512, 256, 128, 64], 'resnet50': [1024, 512, 256, 64], 27 | 'pvt_v2_b2': [512, 320, 128, 64], 'pvt_v2_b5': [512, 320, 128, 64], 28 | 'swin_v1_b': [1024, 512, 256, 128], 'swin_v1_l': [1536, 768, 384, 192], 29 | } 30 | channels = lateral_channels_in_collection[self.config.bb] 31 | self.squeeze_module = BasicDecBlk(channels[0], channels[0]) 32 | 33 | self.decoder = Decoder(channels) 34 | 35 | if 0: 36 | for key, value in self.named_parameters(): 37 | if 'bb.' in key: 38 | value.requires_grad = False 39 | 40 | def forward(self, x): 41 | if isinstance(x, list): 42 | x = torch.cat(x, dim=1) 43 | ########## Encoder ########## 44 | if self.config.bb in ['vgg16', 'vgg16bn', 'resnet50']: 45 | x1 = self.bb.conv1(x) 46 | x2 = self.bb.conv2(x1) 47 | x3 = self.bb.conv3(x2) 48 | x4 = self.bb.conv4(x3) 49 | else: 50 | x1, x2, x3, x4 = self.bb(x) 51 | 52 | x4 = self.squeeze_module(x4) 53 | 54 | ########## Decoder ########## 55 | 56 | features = [x, x1, x2, x3, x4] 57 | scaled_preds = self.decoder(features) 58 | 59 | return scaled_preds 60 | 61 | 62 | class Refiner(nn.Module): 63 | def __init__(self, in_channels=3+1): 64 | super(Refiner, self).__init__() 65 | self.config = Config() 66 | self.epoch = 1 67 | self.stem_layer = StemLayer(in_channels=in_channels, inter_channels=48, out_channels=3, norm_layer='BN' if self.config.batch_size > 1 else 'LN') 68 | self.bb = build_backbone(self.config.bb) 69 | 70 | lateral_channels_in_collection = { 71 | 'vgg16': [512, 256, 128, 64], 'vgg16bn': [512, 256, 128, 64], 'resnet50': [1024, 512, 256, 64], 72 | 'pvt_v2_b2': [512, 320, 128, 64], 'pvt_v2_b5': [512, 320, 128, 64], 73 | 'swin_v1_b': [1024, 512, 256, 128], 'swin_v1_l': [1536, 768, 384, 192], 74 | } 75 | channels = lateral_channels_in_collection[self.config.bb] 76 | self.squeeze_module = BasicDecBlk(channels[0], channels[0]) 77 | 78 | self.decoder = Decoder(channels) 79 | 80 | if 0: 81 | for key, value in self.named_parameters(): 82 | if 'bb.' in key: 83 | value.requires_grad = False 84 | 85 | def forward(self, x): 86 | if isinstance(x, list): 87 | x = torch.cat(x, dim=1) 88 | x = self.stem_layer(x) 89 | ########## Encoder ########## 90 | if self.config.bb in ['vgg16', 'vgg16bn', 'resnet50']: 91 | x1 = self.bb.conv1(x) 92 | x2 = self.bb.conv2(x1) 93 | x3 = self.bb.conv3(x2) 94 | x4 = self.bb.conv4(x3) 95 | else: 96 | x1, x2, x3, x4 = self.bb(x) 97 | 98 | x4 = self.squeeze_module(x4) 99 | 100 | ########## Decoder ########## 101 | 102 | features = [x, x1, x2, x3, x4] 103 | scaled_preds = self.decoder(features) 104 | 105 | return scaled_preds 106 | 107 | 108 | class Decoder(nn.Module): 109 | def __init__(self, channels): 110 | super(Decoder, self).__init__() 111 | self.config = Config() 112 | DecoderBlock = eval('BasicDecBlk') 113 | LateralBlock = eval('BasicLatBlk') 114 | 115 | self.decoder_block4 = DecoderBlock(channels[0], channels[1]) 116 | self.decoder_block3 = DecoderBlock(channels[1], channels[2]) 117 | self.decoder_block2 = DecoderBlock(channels[2], channels[3]) 118 | self.decoder_block1 = DecoderBlock(channels[3], channels[3]//2) 119 | 120 | self.lateral_block4 = LateralBlock(channels[1], channels[1]) 121 | self.lateral_block3 = LateralBlock(channels[2], channels[2]) 122 | self.lateral_block2 = LateralBlock(channels[3], channels[3]) 123 | 124 | if self.config.ms_supervision: 125 | self.conv_ms_spvn_4 = nn.Conv2d(channels[1], 1, 1, 1, 0) 126 | self.conv_ms_spvn_3 = nn.Conv2d(channels[2], 1, 1, 1, 0) 127 | self.conv_ms_spvn_2 = nn.Conv2d(channels[3], 1, 1, 1, 0) 128 | self.conv_out1 = nn.Sequential(nn.Conv2d(channels[3]//2, 1, 1, 1, 0)) 129 | 130 | def forward(self, features): 131 | x, x1, x2, x3, x4 = features 132 | outs = [] 133 | p4 = self.decoder_block4(x4) 134 | _p4 = F.interpolate(p4, size=x3.shape[2:], mode='bilinear', align_corners=True) 135 | _p3 = _p4 + self.lateral_block4(x3) 136 | 137 | p3 = self.decoder_block3(_p3) 138 | _p3 = F.interpolate(p3, size=x2.shape[2:], mode='bilinear', align_corners=True) 139 | _p2 = _p3 + self.lateral_block3(x2) 140 | 141 | p2 = self.decoder_block2(_p2) 142 | _p2 = F.interpolate(p2, size=x1.shape[2:], mode='bilinear', align_corners=True) 143 | _p1 = _p2 + self.lateral_block2(x1) 144 | 145 | _p1 = self.decoder_block1(_p1) 146 | _p1 = F.interpolate(_p1, size=x.shape[2:], mode='bilinear', align_corners=True) 147 | p1_out = self.conv_out1(_p1) 148 | 149 | if self.config.ms_supervision: 150 | outs.append(self.conv_ms_spvn_4(p4)) 151 | outs.append(self.conv_ms_spvn_3(p3)) 152 | outs.append(self.conv_ms_spvn_2(p2)) 153 | outs.append(p1_out) 154 | return outs 155 | 156 | 157 | class RefUNet(nn.Module): 158 | # Refinement 159 | def __init__(self, in_channels=3+1): 160 | super(RefUNet, self).__init__() 161 | self.encoder_1 = nn.Sequential( 162 | nn.Conv2d(in_channels, 64, 3, 1, 1), 163 | nn.Conv2d(64, 64, 3, 1, 1), 164 | nn.BatchNorm2d(64), 165 | nn.ReLU(inplace=True) 166 | ) 167 | 168 | self.encoder_2 = nn.Sequential( 169 | nn.MaxPool2d(2, 2, ceil_mode=True), 170 | nn.Conv2d(64, 64, 3, 1, 1), 171 | nn.BatchNorm2d(64), 172 | nn.ReLU(inplace=True) 173 | ) 174 | 175 | self.encoder_3 = nn.Sequential( 176 | nn.MaxPool2d(2, 2, ceil_mode=True), 177 | nn.Conv2d(64, 64, 3, 1, 1), 178 | nn.BatchNorm2d(64), 179 | nn.ReLU(inplace=True) 180 | ) 181 | 182 | self.encoder_4 = nn.Sequential( 183 | nn.MaxPool2d(2, 2, ceil_mode=True), 184 | nn.Conv2d(64, 64, 3, 1, 1), 185 | nn.BatchNorm2d(64), 186 | nn.ReLU(inplace=True) 187 | ) 188 | 189 | self.pool4 = nn.MaxPool2d(2, 2, ceil_mode=True) 190 | ##### 191 | self.decoder_5 = nn.Sequential( 192 | nn.Conv2d(64, 64, 3, 1, 1), 193 | nn.BatchNorm2d(64), 194 | nn.ReLU(inplace=True) 195 | ) 196 | ##### 197 | self.decoder_4 = nn.Sequential( 198 | nn.Conv2d(128, 64, 3, 1, 1), 199 | nn.BatchNorm2d(64), 200 | nn.ReLU(inplace=True) 201 | ) 202 | 203 | self.decoder_3 = nn.Sequential( 204 | nn.Conv2d(128, 64, 3, 1, 1), 205 | nn.BatchNorm2d(64), 206 | nn.ReLU(inplace=True) 207 | ) 208 | 209 | self.decoder_2 = nn.Sequential( 210 | nn.Conv2d(128, 64, 3, 1, 1), 211 | nn.BatchNorm2d(64), 212 | nn.ReLU(inplace=True) 213 | ) 214 | 215 | self.decoder_1 = nn.Sequential( 216 | nn.Conv2d(128, 64, 3, 1, 1), 217 | nn.BatchNorm2d(64), 218 | nn.ReLU(inplace=True) 219 | ) 220 | 221 | self.conv_d0 = nn.Conv2d(64, 1, 3, 1, 1) 222 | 223 | self.upscore2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 224 | 225 | def forward(self, x): 226 | outs = [] 227 | if isinstance(x, list): 228 | x = torch.cat(x, dim=1) 229 | hx = x 230 | 231 | hx1 = self.encoder_1(hx) 232 | hx2 = self.encoder_2(hx1) 233 | hx3 = self.encoder_3(hx2) 234 | hx4 = self.encoder_4(hx3) 235 | 236 | hx = self.decoder_5(self.pool4(hx4)) 237 | hx = torch.cat((self.upscore2(hx), hx4), 1) 238 | 239 | d4 = self.decoder_4(hx) 240 | hx = torch.cat((self.upscore2(d4), hx3), 1) 241 | 242 | d3 = self.decoder_3(hx) 243 | hx = torch.cat((self.upscore2(d3), hx2), 1) 244 | 245 | d2 = self.decoder_2(hx) 246 | hx = torch.cat((self.upscore2(d2), hx1), 1) 247 | 248 | d1 = self.decoder_1(hx) 249 | 250 | x = self.conv_d0(d1) 251 | outs.append(x) 252 | return outs 253 | -------------------------------------------------------------------------------- /birefnet/models/refinement/stem_layer.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from ..modules.utils import build_act_layer, build_norm_layer 3 | 4 | 5 | class StemLayer(nn.Module): 6 | r""" Stem layer of InternImage 7 | Args: 8 | in_channels (int): number of input channels 9 | out_channels (int): number of output channels 10 | act_layer (str): activation layer 11 | norm_layer (str): normalization layer 12 | """ 13 | 14 | def __init__(self, 15 | in_channels=3+1, 16 | inter_channels=48, 17 | out_channels=96, 18 | act_layer='GELU', 19 | norm_layer='BN'): 20 | super().__init__() 21 | self.conv1 = nn.Conv2d(in_channels, 22 | inter_channels, 23 | kernel_size=3, 24 | stride=1, 25 | padding=1) 26 | self.norm1 = build_norm_layer( 27 | inter_channels, norm_layer, 'channels_first', 'channels_first' 28 | ) 29 | self.act = build_act_layer(act_layer) 30 | self.conv2 = nn.Conv2d(inter_channels, 31 | out_channels, 32 | kernel_size=3, 33 | stride=1, 34 | padding=1) 35 | self.norm2 = build_norm_layer( 36 | out_channels, norm_layer, 'channels_first', 'channels_first' 37 | ) 38 | 39 | def forward(self, x): 40 | x = self.conv1(x) 41 | x = self.norm1(x) 42 | x = self.act(x) 43 | x = self.conv2(x) 44 | x = self.norm2(x) 45 | return x 46 | -------------------------------------------------------------------------------- /birefnet/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import torch 4 | from torchvision import transforms 5 | import numpy as np 6 | import random 7 | import cv2 8 | from PIL import Image 9 | 10 | 11 | def path_to_image(path, size=(1024, 1024), color_type=['rgb', 'gray'][0]): 12 | if color_type.lower() == 'rgb': 13 | image = cv2.imread(path) 14 | elif color_type.lower() == 'gray': 15 | image = cv2.imread(path, cv2.IMREAD_GRAYSCALE) 16 | else: 17 | print('Select the color_type to return, either to RGB or gray image.') 18 | return 19 | if size: 20 | image = cv2.resize(image, size, interpolation=cv2.INTER_LINEAR) 21 | if color_type.lower() == 'rgb': 22 | image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)).convert('RGB') 23 | else: 24 | image = Image.fromarray(image).convert('L') 25 | return image 26 | 27 | 28 | 29 | def check_state_dict(state_dict, unwanted_prefixes=['module.', '_orig_mod.']): 30 | for k, v in list(state_dict.items()): 31 | prefix_length = 0 32 | for unwanted_prefix in unwanted_prefixes: 33 | if k[prefix_length:].startswith(unwanted_prefix): 34 | prefix_length += len(unwanted_prefix) 35 | state_dict[k[prefix_length:]] = state_dict.pop(k) 36 | return state_dict 37 | 38 | 39 | def generate_smoothed_gt(gts): 40 | epsilon = 0.001 41 | new_gts = (1-epsilon)*gts+epsilon/2 42 | return new_gts 43 | 44 | 45 | class Logger(): 46 | def __init__(self, path="log.txt"): 47 | self.logger = logging.getLogger('BiRefNet') 48 | self.file_handler = logging.FileHandler(path, "w") 49 | self.stdout_handler = logging.StreamHandler() 50 | self.stdout_handler.setFormatter(logging.Formatter('%(asctime)s %(levelname)s %(message)s')) 51 | self.file_handler.setFormatter(logging.Formatter('%(asctime)s %(levelname)s %(message)s')) 52 | self.logger.addHandler(self.file_handler) 53 | self.logger.addHandler(self.stdout_handler) 54 | self.logger.setLevel(logging.INFO) 55 | self.logger.propagate = False 56 | 57 | def info(self, txt): 58 | self.logger.info(txt) 59 | 60 | def close(self): 61 | self.file_handler.close() 62 | self.stdout_handler.close() 63 | 64 | 65 | class AverageMeter(object): 66 | """Computes and stores the average and current value""" 67 | def __init__(self): 68 | self.reset() 69 | 70 | def reset(self): 71 | self.val = 0.0 72 | self.avg = 0.0 73 | self.sum = 0.0 74 | self.count = 0.0 75 | 76 | def update(self, val, n=1): 77 | self.val = val 78 | self.sum += val * n 79 | self.count += n 80 | self.avg = self.sum / self.count 81 | 82 | 83 | def save_checkpoint(state, path, filename="latest.pth"): 84 | torch.save(state, os.path.join(path, filename)) 85 | 86 | 87 | def save_tensor_img(tenor_im, path): 88 | im = tenor_im.cpu().clone() 89 | im = im.squeeze(0) 90 | tensor2pil = transforms.ToPILImage() 91 | im = tensor2pil(im) 92 | im.save(path) 93 | 94 | 95 | def set_seed(seed): 96 | torch.manual_seed(seed) 97 | torch.cuda.manual_seed_all(seed) 98 | np.random.seed(seed) 99 | random.seed(seed) 100 | torch.backends.cudnn.deterministic = True -------------------------------------------------------------------------------- /birefnetNode.py: -------------------------------------------------------------------------------- 1 | import os 2 | import safetensors.torch 3 | import torch 4 | from torchvision import transforms 5 | from torch.hub import download_url_to_file 6 | import comfy 7 | from comfy import model_management 8 | import folder_paths 9 | from birefnet.models.birefnet import BiRefNet 10 | from birefnet_old.models.birefnet import BiRefNet as OldBiRefNet 11 | from birefnet.utils import check_state_dict 12 | from .util import filter_mask, add_mask_as_alpha, refine_foreground_pil, tensor_to_pil, pil_to_tensor 13 | deviceType = model_management.get_torch_device().type 14 | 15 | models_dir_key = "birefnet" 16 | 17 | models_path_default = folder_paths.get_folder_paths(models_dir_key)[0] 18 | 19 | usage_to_weights_file = { 20 | 'General': 'BiRefNet', 21 | 'General-HR': 'BiRefNet_HR', 22 | 'Matting-HR': 'BiRefNet_HR-matting', 23 | 'General-Lite': 'BiRefNet_lite', 24 | 'General-Lite-2K': 'BiRefNet_lite-2K', 25 | 'General-reso_512': 'BiRefNet_512x512', 26 | 'Portrait': 'BiRefNet-portrait', 27 | 'Matting': 'BiRefNet-matting', 28 | 'Matting-Lite': 'BiRefNet_lite-matting', 29 | # 'Anime-Lite': 'BiRefNet_lite-Anime', 30 | 'DIS': 'BiRefNet-DIS5K', 31 | 'HRSOD': 'BiRefNet-HRSOD', 32 | 'COD': 'BiRefNet-COD', 33 | 'DIS-TR_TEs': 'BiRefNet-DIS5K-TR_TEs', 34 | 'General-legacy': 'BiRefNet-legacy', 35 | 'General-dynamic': 'BiRefNet_dynamic', 36 | } 37 | 38 | modelNameList = list(usage_to_weights_file.keys()) 39 | 40 | 41 | def get_model_path(model_name): 42 | return os.path.join(models_path_default, f"{model_name}.safetensors") 43 | 44 | 45 | def download_models(model_root, model_urls): 46 | if not os.path.exists(model_root): 47 | os.makedirs(model_root, exist_ok=True) 48 | 49 | for local_file, url in model_urls: 50 | local_path = os.path.join(model_root, local_file) 51 | if not os.path.exists(local_path): 52 | local_path = os.path.abspath(os.path.join(model_root, local_file)) 53 | download_url_to_file(url, dst=local_path) 54 | 55 | 56 | def download_birefnet_model(model_name): 57 | """ 58 | Downloading model from huggingface. 59 | """ 60 | model_root = os.path.join(models_path_default) 61 | model_urls = ( 62 | (f"{model_name}.safetensors", 63 | f"https://huggingface.co/ZhengPeng7/{usage_to_weights_file[model_name]}/resolve/main/model.safetensors"), 64 | ) 65 | download_models(model_root, model_urls) 66 | 67 | interpolation_modes_mapping = { 68 | "nearest": 0, 69 | "bilinear": 2, 70 | "bicubic": 3, 71 | "nearest-exact": 0, 72 | # "lanczos": 1, #不支持 73 | } 74 | 75 | class ImagePreprocessor: 76 | def __init__(self, resolution, upscale_method="bilinear") -> None: 77 | interpolation = interpolation_modes_mapping.get(upscale_method, 2) 78 | self.transform_image = transforms.Compose([ 79 | transforms.Resize(resolution, interpolation=interpolation), 80 | # transforms.ToTensor(), 81 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 82 | ]) 83 | self.transform_image_old = transforms.Compose([ 84 | transforms.Resize(resolution, interpolation=interpolation), 85 | # transforms.ToTensor(), 86 | transforms.Normalize([0.5, 0.5, 0.5], [1.0, 1.0, 1.0]), 87 | ]) 88 | 89 | def proc(self, image) -> torch.Tensor: 90 | image = self.transform_image(image) 91 | return image 92 | 93 | def old_proc(self, image) -> torch.Tensor: 94 | image = self.transform_image_old(image) 95 | return image 96 | 97 | VERSION = ["old", "v1"] 98 | old_models_name = ["BiRefNet-DIS_ep580.pth", "BiRefNet-ep480.pth"] 99 | 100 | torch_dtype={ 101 | "float16": torch.float16, 102 | "float32": torch.float32, 103 | "bfloat16": torch.bfloat16, 104 | } 105 | 106 | class AutoDownloadBiRefNetModel: 107 | 108 | @classmethod 109 | def INPUT_TYPES(cls): 110 | return { 111 | "required": { 112 | "model_name": (modelNameList,), 113 | "device": (["AUTO", "CPU"],) 114 | }, 115 | "optional": { 116 | "dtype": (["float32", "float16"], {"default": "float32"}) 117 | } 118 | } 119 | 120 | RETURN_TYPES = ("BIREFNET",) 121 | RETURN_NAMES = ("model",) 122 | FUNCTION = "load_model" 123 | CATEGORY = "image/BiRefNet" 124 | DESCRIPTION = "Auto download BiRefNet model from huggingface to models/BiRefNet/{model_name}.safetensors" 125 | 126 | def load_model(self, model_name, device, dtype="float32"): 127 | bb_index = 3 if model_name == "General-Lite" or model_name == "General-Lite-2K" or model_name == "Matting-Lite" else 6 128 | biRefNet_model = BiRefNet(bb_pretrained=False, bb_index=bb_index) 129 | model_file_name = f'{model_name}.safetensors' 130 | model_full_path = folder_paths.get_full_path(models_dir_key, model_file_name) 131 | if model_full_path is None: 132 | download_birefnet_model(model_name) 133 | model_full_path = folder_paths.get_full_path(models_dir_key, model_file_name) 134 | if device == "AUTO": 135 | device_type = deviceType 136 | else: 137 | device_type = "cpu" 138 | state_dict = safetensors.torch.load_file(model_full_path, device=device_type) 139 | biRefNet_model.load_state_dict(state_dict) 140 | biRefNet_model.to(device_type, dtype=torch_dtype[dtype]) 141 | biRefNet_model.eval() 142 | return [(biRefNet_model, VERSION[1])] 143 | 144 | 145 | class LoadRembgByBiRefNetModel: 146 | 147 | @classmethod 148 | def INPUT_TYPES(cls): 149 | return { 150 | "required": { 151 | "model": (folder_paths.get_filename_list(models_dir_key),), 152 | "device": (["AUTO", "CPU"], ) 153 | }, 154 | "optional": { 155 | "use_weight": ("BOOLEAN", {"default": False}), 156 | "dtype": (["float32", "float16"], {"default": "float32"}) 157 | } 158 | } 159 | 160 | RETURN_TYPES = ("BIREFNET",) 161 | RETURN_NAMES = ("model",) 162 | FUNCTION = "load_model" 163 | CATEGORY = "rembg/BiRefNet" 164 | DESCRIPTION = "Load BiRefNet model from folder models/BiRefNet or the path of birefnet configured in the extra YAML file" 165 | 166 | def load_model(self, model, device, use_weight=False, dtype="float32"): 167 | if model in old_models_name: 168 | version = VERSION[0] 169 | biRefNet_model = OldBiRefNet(bb_pretrained=use_weight) 170 | else: 171 | version = VERSION[1] 172 | bb_index = 3 if model == "General-Lite.safetensors" or model == "General-Lite-2K.safetensors" or model == "Matting-Lite.safetensors" else 6 173 | biRefNet_model = BiRefNet(bb_pretrained=use_weight, bb_index=bb_index) 174 | 175 | model_path = folder_paths.get_full_path(models_dir_key, model) 176 | if device == "AUTO": 177 | device_type = deviceType 178 | else: 179 | device_type = "cpu" 180 | if model_path.endswith(".safetensors"): 181 | state_dict = safetensors.torch.load_file(model_path, device=device_type) 182 | else: 183 | state_dict = torch.load(model_path, map_location=device_type) 184 | state_dict = check_state_dict(state_dict) 185 | 186 | biRefNet_model.load_state_dict(state_dict) 187 | biRefNet_model.to(device_type, dtype=torch_dtype[dtype]) 188 | biRefNet_model.eval() 189 | return [(biRefNet_model, version)] 190 | 191 | 192 | class GetMaskByBiRefNet: 193 | 194 | @classmethod 195 | def INPUT_TYPES(cls): 196 | return { 197 | "required": { 198 | "model": ("BIREFNET",), 199 | "images": ("IMAGE",), 200 | "width": ("INT", 201 | { 202 | "default": 1024, 203 | "min": 0, 204 | "max": 16384, 205 | "tooltip": "The width of the pre-processing image, does not affect the final output image size" 206 | }), 207 | "height": ("INT", 208 | { 209 | "default": 1024, 210 | "min": 0, 211 | "max": 16384, 212 | "tooltip": "The height of the pre-processing image, does not affect the final output image size" 213 | }), 214 | "upscale_method": (["bilinear", "nearest", "nearest-exact", "bicubic"], 215 | { 216 | "default": "bilinear", 217 | "tooltip": "Interpolation method for pre-processing image and post-processing mask" 218 | }), 219 | "mask_threshold": ("FLOAT", {"default": 0.000, "min": 0.0, "max": 1.0, "step": 0.004, }), 220 | } 221 | } 222 | 223 | RETURN_TYPES = ("MASK",) 224 | RETURN_NAMES = ("mask",) 225 | FUNCTION = "get_mask" 226 | CATEGORY = "rembg/BiRefNet" 227 | 228 | def get_mask(self, model, images, width=1024, height=1024, upscale_method='bilinear', mask_threshold=0.000): 229 | model, version = model 230 | one_torch = next(model.parameters()) 231 | model_device_type = one_torch.device.type 232 | model_dtype = one_torch.dtype 233 | b, h, w, c = images.shape 234 | image_bchw = images.permute(0, 3, 1, 2) 235 | 236 | image_preproc = ImagePreprocessor(resolution=(height, width), upscale_method=upscale_method) 237 | if VERSION[0] == version: 238 | im_tensor = image_preproc.old_proc(image_bchw) 239 | else: 240 | im_tensor = image_preproc.proc(image_bchw) 241 | 242 | del image_preproc 243 | 244 | _mask_bchw = [] 245 | for each_image in im_tensor: 246 | with torch.no_grad(): 247 | each_mask = model(each_image.unsqueeze(0).to(model_device_type, dtype=model_dtype))[-1].sigmoid().cpu().float() 248 | _mask_bchw.append(each_mask) 249 | del each_mask 250 | 251 | mask_bchw = torch.cat(_mask_bchw, dim=0) 252 | del _mask_bchw 253 | # 遮罩大小需还原为与原图一致 254 | mask = comfy.utils.common_upscale(mask_bchw, w, h, upscale_method, "disabled") 255 | # (b, 1, h, w) 256 | if mask_threshold > 0: 257 | mask = filter_mask(mask, threshold=mask_threshold) 258 | # else: 259 | # 似乎几乎无影响 260 | # mask = normalize_mask(mask) 261 | 262 | return mask.squeeze(1), 263 | 264 | 265 | class BlurFusionForegroundEstimation: 266 | 267 | @classmethod 268 | def INPUT_TYPES(cls): 269 | return { 270 | "required": { 271 | "images": ("IMAGE",), 272 | "masks": ("MASK",), 273 | "blur_size": ("INT", {"default": 90, "min": 1, "max": 255, "step": 1, }), 274 | "blur_size_two": ("INT", {"default": 6, "min": 1, "max": 255, "step": 1, }), 275 | "fill_color": ("BOOLEAN", {"default": False}), 276 | "color": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFF, "step": 1, "display": "color"}), 277 | } 278 | } 279 | 280 | RETURN_TYPES = ("IMAGE", "MASK",) 281 | RETURN_NAMES = ("image", "mask",) 282 | FUNCTION = "get_foreground" 283 | CATEGORY = "rembg/BiRefNet" 284 | DESCRIPTION = "Approximate Fast Foreground Colour Estimation. https://github.com/Photoroom/fast-foreground-estimation" 285 | 286 | def get_foreground(self, images, masks, blur_size=91, blur_size_two=7, fill_color=False, color=None): 287 | b, h, w, c = images.shape 288 | if b != masks.shape[0]: 289 | raise ValueError("images and masks must have the same batch size") 290 | 291 | # image_bchw = images.permute(0, 3, 1, 2) 292 | 293 | if masks.dim() == 3: 294 | # (b, h, w) => (b, 1, h, w) 295 | out_masks = masks.unsqueeze(1) 296 | 297 | # 需要转成pil用cv2.blur,结果图的背景色比较纯(gaussian_blur的背景色不纯,边缘轮廓线比较重),应用遮罩时不能用点乘,结果可能有边缘轮廓 298 | _image_maskeds = [] 299 | # for _image, _out_mask in images, out_masks: 300 | for idx, (_image, _out_mask) in enumerate(zip(images.unbind(dim=0), out_masks.unbind(dim=0))): 301 | _image_masked = refine_foreground_pil(tensor_to_pil(_image), tensor_to_pil(_out_mask.permute(1, 2, 0))) 302 | _image_masked = pil_to_tensor(_image_masked) 303 | _image_maskeds.append(_image_masked) 304 | del _image_masked 305 | 306 | _image_masked_tensor = torch.cat(_image_maskeds, dim=0) 307 | del _image_maskeds 308 | 309 | # (b, c, h, w) 310 | # _image_masked = refine_foreground(image_bchw, out_masks, r1=blur_size, r2=blur_size_two) 311 | # (b, c, h, w) => (b, h, w, c) 312 | # _image_masked = _image_masked.permute(0, 2, 3, 1) 313 | if fill_color and color is not None: 314 | r = torch.full([b, h, w, 1], ((color >> 16) & 0xFF) / 0xFF) 315 | g = torch.full([b, h, w, 1], ((color >> 8) & 0xFF) / 0xFF) 316 | b = torch.full([b, h, w, 1], (color & 0xFF) / 0xFF) 317 | # (b, h, w, 3) 318 | background_color = torch.cat((r, g, b), dim=-1) 319 | # (b, 1, h, w) => (b, h, w, 3) 320 | apply_mask = out_masks.permute(0, 2, 3, 1).expand_as(_image_masked_tensor) 321 | out_images = _image_masked_tensor * apply_mask + background_color * (1 - apply_mask) 322 | # (b, h, w, 3)=>(b, h, w, 3) 323 | del background_color, apply_mask 324 | out_masks = out_masks.squeeze(1) 325 | else: 326 | # (b, 1, h, w) => (b, h, w) 327 | out_masks = out_masks.squeeze(1) 328 | # image的非mask对应部分设为透明 => (b, h, w, 4) 329 | out_images = add_mask_as_alpha(_image_masked_tensor.cpu(), out_masks.cpu()) 330 | 331 | del _image_masked_tensor 332 | 333 | return out_images, out_masks 334 | 335 | 336 | class RembgByBiRefNetAdvanced(GetMaskByBiRefNet, BlurFusionForegroundEstimation): 337 | 338 | @classmethod 339 | def INPUT_TYPES(cls): 340 | return { 341 | "required": { 342 | "model": ("BIREFNET",), 343 | "images": ("IMAGE",), 344 | "width": ("INT", 345 | { 346 | "default": 1024, 347 | "min": 0, 348 | "max": 16384, 349 | "tooltip": "The width of the pre-processing image, does not affect the final output image size" 350 | }), 351 | "height": ("INT", 352 | { 353 | "default": 1024, 354 | "min": 0, 355 | "max": 16384, 356 | "tooltip": "The height of the pre-processing image, does not affect the final output image size" 357 | }), 358 | "upscale_method": (["bilinear", "nearest", "nearest-exact", "bicubic"], 359 | { 360 | "default": "bilinear", 361 | "tooltip": "Interpolation method for pre-processing image and post-processing mask" 362 | }), 363 | "blur_size": ("INT", {"default": 90, "min": 1, "max": 255, "step": 1, }), 364 | "blur_size_two": ("INT", {"default": 6, "min": 1, "max": 255, "step": 1, }), 365 | "fill_color": ("BOOLEAN", {"default": False}), 366 | "color": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFF, "step": 1, "display": "color"}), 367 | "mask_threshold": ("FLOAT", {"default": 0.000, "min": 0.0, "max": 1.0, "step": 0.001, }), 368 | } 369 | } 370 | 371 | RETURN_TYPES = ("IMAGE", "MASK",) 372 | RETURN_NAMES = ("image", "mask",) 373 | FUNCTION = "rem_bg" 374 | CATEGORY = "rembg/BiRefNet" 375 | 376 | def rem_bg(self, model, images, upscale_method='bilinear', width=1024, height=1024, blur_size=91, blur_size_two=7, fill_color=False, color=None, mask_threshold=0.000): 377 | 378 | masks = super().get_mask(model, images, width, height, upscale_method, mask_threshold) 379 | 380 | out_images, out_masks = super().get_foreground(images, masks=masks[0], blur_size=blur_size, blur_size_two=blur_size_two, fill_color=fill_color, color=color) 381 | 382 | return out_images, out_masks 383 | 384 | 385 | class RembgByBiRefNet(RembgByBiRefNetAdvanced): 386 | 387 | @classmethod 388 | def INPUT_TYPES(cls): 389 | return { 390 | "required": { 391 | "model": ("BIREFNET",), 392 | "images": ("IMAGE",), 393 | } 394 | } 395 | 396 | RETURN_TYPES = ("IMAGE", "MASK",) 397 | RETURN_NAMES = ("image", "mask",) 398 | FUNCTION = "rem_bg" 399 | CATEGORY = "rembg/BiRefNet" 400 | 401 | def rem_bg(self, model, images): 402 | return super().rem_bg(model, images) 403 | 404 | 405 | NODE_CLASS_MAPPINGS = { 406 | "AutoDownloadBiRefNetModel": AutoDownloadBiRefNetModel, 407 | "LoadRembgByBiRefNetModel": LoadRembgByBiRefNetModel, 408 | "RembgByBiRefNet": RembgByBiRefNet, 409 | "RembgByBiRefNetAdvanced": RembgByBiRefNetAdvanced, 410 | "GetMaskByBiRefNet": GetMaskByBiRefNet, 411 | "BlurFusionForegroundEstimation": BlurFusionForegroundEstimation, 412 | } 413 | 414 | NODE_DISPLAY_NAME_MAPPINGS = { 415 | "AutoDownloadBiRefNetModel": "AutoDownloadBiRefNetModel", 416 | "LoadRembgByBiRefNetModel": "LoadRembgByBiRefNetModel", 417 | "RembgByBiRefNet": "RembgByBiRefNet", 418 | "RembgByBiRefNetAdvanced": "RembgByBiRefNetAdvanced", 419 | "GetMaskByBiRefNet": "GetMaskByBiRefNet", 420 | "BlurFusionForegroundEstimation": "BlurFusionForegroundEstimation", 421 | } 422 | -------------------------------------------------------------------------------- /birefnet_old/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lldacing/ComfyUI_BiRefNet_ll/5443a2aa16cfbd98bb2f7dcc8bdcb70439e08529/birefnet_old/__init__.py -------------------------------------------------------------------------------- /birefnet_old/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import torch 4 | import folder_paths 5 | 6 | 7 | class Config: 8 | def __init__(self) -> None: 9 | self.ms_supervision = True 10 | self.out_ref = self.ms_supervision and True 11 | self.dec_ipt = True 12 | self.dec_ipt_split = True 13 | self.locate_head = False 14 | self.cxt_num = [0, 3][1] # multi-scale skip connections from encoder 15 | self.mul_scl_ipt = ['', 'add', 'cat'][2] 16 | self.refine = ['', 'itself', 'RefUNet', 'Refiner', 'RefinerPVTInChannels4'][0] 17 | self.progressive_ref = self.refine and True 18 | self.ender = self.progressive_ref and False 19 | self.scale = self.progressive_ref and 2 20 | self.dec_att = ['', 'ASPP', 'ASPPDeformable'][2] 21 | self.squeeze_block = ['', 'BasicDecBlk_x1', 'ResBlk_x4', 'ASPP_x3', 'ASPPDeformable_x3'][1] 22 | self.dec_blk = ['BasicDecBlk', 'ResBlk', 'HierarAttDecBlk'][0] 23 | self.auxiliary_classification = False 24 | self.refine_iteration = 1 25 | self.freeze_bb = False 26 | self.precisionHigh = True 27 | self.compile = True 28 | self.load_all = True 29 | self.verbose_eval = True 30 | 31 | self.size = 1024 32 | self.batch_size = 2 33 | self.IoU_finetune_last_epochs = [0, -20][1] # choose 0 to skip 34 | if self.dec_blk == 'HierarAttDecBlk': 35 | self.batch_size = 2 ** [0, 1, 2, 3, 4][2] 36 | self.model = [ 37 | 'BiRefNet', 38 | ][0] 39 | 40 | # Components 41 | self.lat_blk = ['BasicLatBlk'][0] 42 | self.dec_channels_inter = ['fixed', 'adap'][0] 43 | 44 | # Backbone 45 | self.bb = [ 46 | 'vgg16', 'vgg16bn', 'resnet50', # 0, 1, 2 47 | 'pvt_v2_b2', 'pvt_v2_b5', # 3-bs10, 4-bs5 48 | 'swin_v1_b', 'swin_v1_l', # 5-bs9, 6-bs6 49 | 'swin_v1_t', 'swin_v1_s', # 7, 8 50 | 'pvt_v2_b0', 'pvt_v2_b1', # 9, 10 51 | ][6] 52 | self.lateral_channels_in_collection = { 53 | 'vgg16': [512, 256, 128, 64], 'vgg16bn': [512, 256, 128, 64], 'resnet50': [1024, 512, 256, 64], 54 | 'pvt_v2_b2': [512, 320, 128, 64], 'pvt_v2_b5': [512, 320, 128, 64], 55 | 'swin_v1_b': [1024, 512, 256, 128], 'swin_v1_l': [1536, 768, 384, 192], 56 | 'swin_v1_t': [768, 384, 192, 96], 'swin_v1_s': [768, 384, 192, 96], 57 | 'pvt_v2_b0': [256, 160, 64, 32], 'pvt_v2_b1': [512, 320, 128, 64], 58 | }[self.bb] 59 | if self.mul_scl_ipt == 'cat': 60 | self.lateral_channels_in_collection = [channel * 2 for channel in self.lateral_channels_in_collection] 61 | self.cxt = self.lateral_channels_in_collection[1:][::-1][-self.cxt_num:] if self.cxt_num else [] 62 | # self.sys_home_dir = '/root/autodl-tmp' 63 | # self.weights_root_dir = os.path.join(self.sys_home_dir, 'weights') 64 | # self.weights = { 65 | # 'pvt_v2_b2': os.path.join(self.weights_root_dir, 'pvt_v2_b2.pth'), 66 | # 'pvt_v2_b5': os.path.join(self.weights_root_dir, ['pvt_v2_b5.pth', 'pvt_v2_b5_22k.pth'][0]), 67 | # 'swin_v1_b': os.path.join(self.weights_root_dir, ['swin_base_patch4_window12_384_22kto1k.pth', 'swin_base_patch4_window12_384_22k.pth'][0]), 68 | # 'swin_v1_l': os.path.join(self.weights_root_dir, ['swin_large_patch4_window12_384_22kto1k.pth', 'swin_large_patch4_window12_384_22k.pth'][0]), 69 | # 'swin_v1_t': os.path.join(self.weights_root_dir, ['swin_tiny_patch4_window7_224_22kto1k_finetune.pth'][0]), 70 | # 'swin_v1_s': os.path.join(self.weights_root_dir, ['swin_small_patch4_window7_224_22kto1k_finetune.pth'][0]), 71 | # 'pvt_v2_b0': os.path.join(self.weights_root_dir, ['pvt_v2_b0.pth'][0]), 72 | # 'pvt_v2_b1': os.path.join(self.weights_root_dir, ['pvt_v2_b1.pth'][0]), 73 | # } 74 | weight_paths_name = "birefnet" 75 | self.weights = { 76 | 'pvt_v2_b2': folder_paths.get_full_path(weight_paths_name, 'pvt_v2_b2.pth'), 77 | 'pvt_v2_b5': folder_paths.get_full_path(weight_paths_name, ['pvt_v2_b5.pth', 'pvt_v2_b5_22k.pth'][0]), 78 | 'swin_v1_b': folder_paths.get_full_path(weight_paths_name, ['swin_base_patch4_window12_384_22kto1k.pth', 'swin_base_patch4_window12_384_22k.pth'][0]), 79 | 'swin_v1_l': folder_paths.get_full_path(weight_paths_name, ['swin_large_patch4_window12_384_22kto1k.pth', 'swin_large_patch4_window12_384_22k.pth'][0]), 80 | 'swin_v1_t': folder_paths.get_full_path(weight_paths_name, ['swin_tiny_patch4_window7_224_22kto1k_finetune.pth'][0]), 81 | 'swin_v1_s': folder_paths.get_full_path(weight_paths_name, ['swin_small_patch4_window7_224_22kto1k_finetune.pth'][0]), 82 | 'pvt_v2_b0': folder_paths.get_full_path(weight_paths_name, ['pvt_v2_b0.pth'][0]), 83 | 'pvt_v2_b1': folder_paths.get_full_path(weight_paths_name, ['pvt_v2_b1.pth'][0]), 84 | } 85 | 86 | # Training 87 | self.num_workers = 5 # will be decrease to min(it, batch_size) at the initialization of the data_loader 88 | self.optimizer = ['Adam', 'AdamW'][0] 89 | self.lr = 1e-5 * math.sqrt(self.batch_size / 5) # adapt the lr linearly 90 | self.lr_decay_epochs = [1e4] # Set to negative N to decay the lr in the last N-th epoch. 91 | self.lr_decay_rate = 0.5 92 | self.only_S_MAE = False 93 | self.SDPA_enabled = False # Bug. Slower and errors occur in multi-GPUs 94 | 95 | # Data 96 | # self.data_root_dir = os.path.join(self.sys_home_dir, 'datasets/dis') 97 | self.task = ['DIS5K', 'COD', 'HRSOD'][0] 98 | self.training_set = { 99 | 'DIS5K': 'DIS-TR', 100 | 'COD': 'TR-COD10K+TR-CAMO', 101 | 'HRSOD': ['TR-DUTS', 'TR-HRSOD+TR-UHRSD', 'TR-DUTS+TR-HRSOD+TR-UHRSD'][1] 102 | }[self.task] 103 | self.preproc_methods = ['flip', 'enhance', 'rotate', 'pepper', 'crop'][:4] 104 | 105 | # Loss 106 | self.lambdas_pix_last = { 107 | # not 0 means opening this loss 108 | # original rate -- 1 : 30 : 1.5 : 0.2, bce x 30 109 | 'bce': 30 * 1, # high performance 110 | 'iou': 0.5 * 1, # 0 / 255 111 | 'iou_patch': 0.5 * 0, # 0 / 255, win_size = (64, 64) 112 | 'mse': 150 * 0, # can smooth the saliency map 113 | 'triplet': 3 * 0, 114 | 'reg': 100 * 0, 115 | 'ssim': 10 * 1, # help contours, 116 | 'cnt': 5 * 0, # help contours 117 | } 118 | self.lambdas_cls = { 119 | 'ce': 5.0 120 | } 121 | # Adv 122 | self.lambda_adv_g = 10. * 0 # turn to 0 to avoid adv training 123 | self.lambda_adv_d = 3. * (self.lambda_adv_g > 0) 124 | 125 | # others 126 | self.device = [0, 'cpu'][0 if torch.cuda.is_available() else 1] # .to(0) == .to('cuda:0') 127 | 128 | self.batch_size_valid = 1 129 | self.rand_seed = 7 130 | # run_sh_file = [f for f in os.listdir('.') if 'train.sh' == f] + [os.path.join('..', f) for f in os.listdir('..') if 'train.sh' == f] 131 | # with open(run_sh_file[0], 'r') as f: 132 | # lines = f.readlines() 133 | # self.save_last = int([l.strip() for l in lines if 'val_last=' in l][0].split('=')[-1]) 134 | # self.save_step = int([l.strip() for l in lines if 'step=' in l][0].split('=')[-1]) 135 | # self.val_step = [0, self.save_step][0] 136 | -------------------------------------------------------------------------------- /birefnet_old/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | # import cv2 3 | from tqdm import tqdm 4 | from PIL import Image 5 | from torch.utils import data 6 | from torchvision import transforms 7 | 8 | from birefnet_old.preproc import preproc 9 | from birefnet_old.config import Config 10 | from birefnet_old.utils import path_to_image 11 | 12 | 13 | Image.MAX_IMAGE_PIXELS = None # remove DecompressionBombWarning 14 | config = Config() 15 | _class_labels_TR_sorted = 'Airplane, Ant, Antenna, Archery, Axe, BabyCarriage, Bag, BalanceBeam, Balcony, Balloon, Basket, BasketballHoop, Beatle, Bed, Bee, Bench, Bicycle, BicycleFrame, BicycleStand, Boat, Bonsai, BoomLift, Bridge, BunkBed, Butterfly, Button, Cable, CableLift, Cage, Camcorder, Cannon, Canoe, Car, CarParkDropArm, Carriage, Cart, Caterpillar, CeilingLamp, Centipede, Chair, Clip, Clock, Clothes, CoatHanger, Comb, ConcretePumpTruck, Crack, Crane, Cup, DentalChair, Desk, DeskChair, Diagram, DishRack, DoorHandle, Dragonfish, Dragonfly, Drum, Earphone, Easel, ElectricIron, Excavator, Eyeglasses, Fan, Fence, Fencing, FerrisWheel, FireExtinguisher, Fishing, Flag, FloorLamp, Forklift, GasStation, Gate, Gear, Goal, Golf, GymEquipment, Hammock, Handcart, Handcraft, Handrail, HangGlider, Harp, Harvester, Headset, Helicopter, Helmet, Hook, HorizontalBar, Hydrovalve, IroningTable, Jewelry, Key, KidsPlayground, Kitchenware, Kite, Knife, Ladder, LaundryRack, Lightning, Lobster, Locust, Machine, MachineGun, MagazineRack, Mantis, Medal, MemorialArchway, Microphone, Missile, MobileHolder, Monitor, Mosquito, Motorcycle, MovingTrolley, Mower, MusicPlayer, MusicStand, ObservationTower, Octopus, OilWell, OlympicLogo, OperatingTable, OutdoorFitnessEquipment, Parachute, Pavilion, Piano, Pipe, PlowHarrow, PoleVault, Punchbag, Rack, Racket, Rifle, Ring, Robot, RockClimbing, Rope, Sailboat, Satellite, Scaffold, Scale, Scissor, Scooter, Sculpture, Seadragon, Seahorse, Seal, SewingMachine, Ship, Shoe, ShoppingCart, ShoppingTrolley, Shower, Shrimp, Signboard, Skateboarding, Skeleton, Skiing, Spade, SpeedBoat, Spider, Spoon, Stair, Stand, Stationary, SteeringWheel, Stethoscope, Stool, Stove, StreetLamp, SweetStand, Swing, Sword, TV, Table, TableChair, TableLamp, TableTennis, Tank, Tapeline, Teapot, Telescope, Tent, TobaccoPipe, Toy, Tractor, TrafficLight, TrafficSign, Trampoline, TransmissionTower, Tree, Tricycle, TrimmerCover, Tripod, Trombone, Truck, Trumpet, Tuba, UAV, Umbrella, UnevenBars, UtilityPole, VacuumCleaner, Violin, Wakesurfing, Watch, WaterTower, WateringPot, Well, WellLid, Wheel, Wheelchair, WindTurbine, Windmill, WineGlass, WireWhisk, Yacht' 16 | class_labels_TR_sorted = _class_labels_TR_sorted.split(', ') 17 | 18 | 19 | class MyData(data.Dataset): 20 | def __init__(self, datasets, image_size, is_train=True): 21 | self.size_train = image_size 22 | self.size_test = image_size 23 | self.keep_size = not config.size 24 | self.data_size = (config.size, config.size) 25 | self.is_train = is_train 26 | self.load_all = config.load_all 27 | self.device = config.device 28 | if self.is_train and config.auxiliary_classification: 29 | self.cls_name2id = {_name: _id for _id, _name in enumerate(class_labels_TR_sorted)} 30 | self.transform_image = transforms.Compose([ 31 | transforms.Resize(self.data_size), 32 | transforms.ToTensor(), 33 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 34 | ][self.load_all or self.keep_size:]) 35 | self.transform_label = transforms.Compose([ 36 | transforms.Resize(self.data_size), 37 | transforms.ToTensor(), 38 | ][self.load_all or self.keep_size:]) 39 | dataset_root = os.path.join(config.data_root_dir, config.task) 40 | # datasets can be a list of different datasets for training on combined sets. 41 | self.image_paths = [] 42 | for dataset in datasets.split('+'): 43 | image_root = os.path.join(dataset_root, dataset, 'im') 44 | self.image_paths += [os.path.join(image_root, p) for p in os.listdir(image_root)] 45 | self.label_paths = [] 46 | for p in self.image_paths: 47 | for ext in ['.png', '.jpg', '.PNG', '.JPG', '.JPEG']: 48 | ## 'im' and 'gt' may need modifying 49 | p_gt = p.replace('/im/', '/gt/').replace('.'+p.split('.')[-1], ext) 50 | if os.path.exists(p_gt): 51 | self.label_paths.append(p_gt) 52 | break 53 | if self.load_all: 54 | self.images_loaded, self.labels_loaded = [], [] 55 | self.class_labels_loaded = [] 56 | # for image_path, label_path in zip(self.image_paths, self.label_paths): 57 | for image_path, label_path in tqdm(zip(self.image_paths, self.label_paths), total=len(self.image_paths)): 58 | _image = path_to_image(image_path, size=(config.size, config.size), color_type='rgb') 59 | _label = path_to_image(label_path, size=(config.size, config.size), color_type='gray') 60 | self.images_loaded.append(_image) 61 | self.labels_loaded.append(_label) 62 | self.class_labels_loaded.append( 63 | self.cls_name2id[label_path.split('/')[-1].split('#')[3]] if self.is_train and config.auxiliary_classification else -1 64 | ) 65 | 66 | 67 | def __getitem__(self, index): 68 | 69 | if self.load_all: 70 | image = self.images_loaded[index] 71 | label = self.labels_loaded[index] 72 | class_label = self.class_labels_loaded[index] if self.is_train and config.auxiliary_classification else -1 73 | else: 74 | image = path_to_image(self.image_paths[index], size=(config.size, config.size), color_type='rgb') 75 | label = path_to_image(self.label_paths[index], size=(config.size, config.size), color_type='gray') 76 | class_label = self.cls_name2id[self.label_paths[index].split('/')[-1].split('#')[3]] if self.is_train and config.auxiliary_classification else -1 77 | 78 | # loading image and label 79 | if self.is_train: 80 | image, label = preproc(image, label, preproc_methods=config.preproc_methods) 81 | # else: 82 | # if _label.shape[0] > 2048 or _label.shape[1] > 2048: 83 | # _image = cv2.resize(_image, (2048, 2048), interpolation=cv2.INTER_LINEAR) 84 | # _label = cv2.resize(_label, (2048, 2048), interpolation=cv2.INTER_LINEAR) 85 | 86 | image, label = self.transform_image(image), self.transform_label(label) 87 | 88 | if self.is_train: 89 | return image, label, class_label 90 | else: 91 | return image, label, self.label_paths[index] 92 | 93 | def __len__(self): 94 | return len(self.image_paths) 95 | -------------------------------------------------------------------------------- /birefnet_old/models/backbones/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lldacing/ComfyUI_BiRefNet_ll/5443a2aa16cfbd98bb2f7dcc8bdcb70439e08529/birefnet_old/models/backbones/__init__.py -------------------------------------------------------------------------------- /birefnet_old/models/backbones/build_backbone.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from collections import OrderedDict 4 | from torchvision.models import vgg16, vgg16_bn, VGG16_Weights, VGG16_BN_Weights, resnet50, ResNet50_Weights 5 | from ..backbones.pvt_v2 import pvt_v2_b0, pvt_v2_b1, pvt_v2_b2, pvt_v2_b5 6 | from ..backbones.swin_v1 import swin_v1_t, swin_v1_s, swin_v1_b, swin_v1_l 7 | from ...config import Config 8 | 9 | 10 | config = Config() 11 | 12 | def build_backbone(bb_name, pretrained=True, params_settings=''): 13 | if bb_name == 'vgg16': 14 | bb_net = list(vgg16(pretrained=VGG16_Weights.DEFAULT if pretrained else None).children())[0] 15 | bb = nn.Sequential(OrderedDict({'conv1': bb_net[:4], 'conv2': bb_net[4:9], 'conv3': bb_net[9:16], 'conv4': bb_net[16:23]})) 16 | elif bb_name == 'vgg16bn': 17 | bb_net = list(vgg16_bn(pretrained=VGG16_BN_Weights.DEFAULT if pretrained else None).children())[0] 18 | bb = nn.Sequential(OrderedDict({'conv1': bb_net[:6], 'conv2': bb_net[6:13], 'conv3': bb_net[13:23], 'conv4': bb_net[23:33]})) 19 | elif bb_name == 'resnet50': 20 | bb_net = list(resnet50(pretrained=ResNet50_Weights.DEFAULT if pretrained else None).children()) 21 | bb = nn.Sequential(OrderedDict({'conv1': nn.Sequential(*bb_net[0:3]), 'conv2': bb_net[4], 'conv3': bb_net[5], 'conv4': bb_net[6]})) 22 | else: 23 | bb = eval('{}({})'.format(bb_name, params_settings)) 24 | if pretrained: 25 | bb = load_weights(bb, bb_name) 26 | return bb 27 | 28 | def load_weights(model, model_name): 29 | save_model = torch.load(config.weights[model_name]) 30 | model_dict = model.state_dict() 31 | state_dict = {k: v if v.size() == model_dict[k].size() else model_dict[k] for k, v in save_model.items() if k in model_dict.keys()} 32 | # to ignore the weights with mismatched size when I modify the backbone itself. 33 | if not state_dict: 34 | save_model_keys = list(save_model.keys()) 35 | sub_item = save_model_keys[0] if len(save_model_keys) == 1 else None 36 | state_dict = {k: v if v.size() == model_dict[k].size() else model_dict[k] for k, v in save_model[sub_item].items() if k in model_dict.keys()} 37 | if not state_dict or not sub_item: 38 | print('Weights are not successully loaded. Check the state dict of weights file.') 39 | return None 40 | else: 41 | print('Found correct weights in the "{}" item of loaded state_dict.'.format(sub_item)) 42 | model_dict.update(state_dict) 43 | model.load_state_dict(model_dict) 44 | return model 45 | -------------------------------------------------------------------------------- /birefnet_old/models/backbones/pvt_v2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from functools import partial 4 | 5 | try: 6 | # version > 0.6.13 7 | from timm.layers import DropPath, to_2tuple, trunc_normal_ 8 | except Exception: 9 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 10 | 11 | import math 12 | 13 | from ...config import Config 14 | 15 | config = Config() 16 | 17 | class Mlp(nn.Module): 18 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 19 | super().__init__() 20 | out_features = out_features or in_features 21 | hidden_features = hidden_features or in_features 22 | self.fc1 = nn.Linear(in_features, hidden_features) 23 | self.dwconv = DWConv(hidden_features) 24 | self.act = act_layer() 25 | self.fc2 = nn.Linear(hidden_features, out_features) 26 | self.drop = nn.Dropout(drop) 27 | 28 | self.apply(self._init_weights) 29 | 30 | def _init_weights(self, m): 31 | if isinstance(m, nn.Linear): 32 | trunc_normal_(m.weight, std=.02) 33 | if isinstance(m, nn.Linear) and m.bias is not None: 34 | nn.init.constant_(m.bias, 0) 35 | elif isinstance(m, nn.LayerNorm): 36 | nn.init.constant_(m.bias, 0) 37 | nn.init.constant_(m.weight, 1.0) 38 | elif isinstance(m, nn.Conv2d): 39 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 40 | fan_out //= m.groups 41 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 42 | if m.bias is not None: 43 | m.bias.data.zero_() 44 | 45 | def forward(self, x, H, W): 46 | x = self.fc1(x) 47 | x = self.dwconv(x, H, W) 48 | x = self.act(x) 49 | x = self.drop(x) 50 | x = self.fc2(x) 51 | x = self.drop(x) 52 | return x 53 | 54 | 55 | class Attention(nn.Module): 56 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1): 57 | super().__init__() 58 | assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." 59 | 60 | self.dim = dim 61 | self.num_heads = num_heads 62 | head_dim = dim // num_heads 63 | self.scale = qk_scale or head_dim ** -0.5 64 | 65 | self.q = nn.Linear(dim, dim, bias=qkv_bias) 66 | self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) 67 | self.attn_drop_prob = attn_drop 68 | self.attn_drop = nn.Dropout(attn_drop) 69 | self.proj = nn.Linear(dim, dim) 70 | self.proj_drop = nn.Dropout(proj_drop) 71 | 72 | self.sr_ratio = sr_ratio 73 | if sr_ratio > 1: 74 | self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) 75 | self.norm = nn.LayerNorm(dim) 76 | 77 | self.apply(self._init_weights) 78 | 79 | def _init_weights(self, m): 80 | if isinstance(m, nn.Linear): 81 | trunc_normal_(m.weight, std=.02) 82 | if isinstance(m, nn.Linear) and m.bias is not None: 83 | nn.init.constant_(m.bias, 0) 84 | elif isinstance(m, nn.LayerNorm): 85 | nn.init.constant_(m.bias, 0) 86 | nn.init.constant_(m.weight, 1.0) 87 | elif isinstance(m, nn.Conv2d): 88 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 89 | fan_out //= m.groups 90 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 91 | if m.bias is not None: 92 | m.bias.data.zero_() 93 | 94 | def forward(self, x, H, W): 95 | B, N, C = x.shape 96 | q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 97 | 98 | if self.sr_ratio > 1: 99 | x_ = x.permute(0, 2, 1).reshape(B, C, H, W) 100 | x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1) 101 | x_ = self.norm(x_) 102 | kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 103 | else: 104 | kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 105 | k, v = kv[0], kv[1] 106 | 107 | if config.SDPA_enabled: 108 | x = torch.nn.functional.scaled_dot_product_attention( 109 | q, k, v, 110 | attn_mask=None, dropout_p=self.attn_drop_prob, is_causal=False 111 | ).transpose(1, 2).reshape(B, N, C) 112 | else: 113 | attn = (q @ k.transpose(-2, -1)) * self.scale 114 | attn = attn.softmax(dim=-1) 115 | attn = self.attn_drop(attn) 116 | 117 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 118 | x = self.proj(x) 119 | x = self.proj_drop(x) 120 | 121 | return x 122 | 123 | 124 | class Block(nn.Module): 125 | 126 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 127 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1): 128 | super().__init__() 129 | self.norm1 = norm_layer(dim) 130 | self.attn = Attention( 131 | dim, 132 | num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, 133 | attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio) 134 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 135 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 136 | self.norm2 = norm_layer(dim) 137 | mlp_hidden_dim = int(dim * mlp_ratio) 138 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 139 | 140 | self.apply(self._init_weights) 141 | 142 | def _init_weights(self, m): 143 | if isinstance(m, nn.Linear): 144 | trunc_normal_(m.weight, std=.02) 145 | if isinstance(m, nn.Linear) and m.bias is not None: 146 | nn.init.constant_(m.bias, 0) 147 | elif isinstance(m, nn.LayerNorm): 148 | nn.init.constant_(m.bias, 0) 149 | nn.init.constant_(m.weight, 1.0) 150 | elif isinstance(m, nn.Conv2d): 151 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 152 | fan_out //= m.groups 153 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 154 | if m.bias is not None: 155 | m.bias.data.zero_() 156 | 157 | def forward(self, x, H, W): 158 | x = x + self.drop_path(self.attn(self.norm1(x), H, W)) 159 | x = x + self.drop_path(self.mlp(self.norm2(x), H, W)) 160 | 161 | return x 162 | 163 | 164 | class OverlapPatchEmbed(nn.Module): 165 | """ Image to Patch Embedding 166 | """ 167 | 168 | def __init__(self, img_size=224, patch_size=7, stride=4, in_channels=3, embed_dim=768): 169 | super().__init__() 170 | img_size = to_2tuple(img_size) 171 | patch_size = to_2tuple(patch_size) 172 | 173 | self.img_size = img_size 174 | self.patch_size = patch_size 175 | self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1] 176 | self.num_patches = self.H * self.W 177 | self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=stride, 178 | padding=(patch_size[0] // 2, patch_size[1] // 2)) 179 | self.norm = nn.LayerNorm(embed_dim) 180 | 181 | self.apply(self._init_weights) 182 | 183 | def _init_weights(self, m): 184 | if isinstance(m, nn.Linear): 185 | trunc_normal_(m.weight, std=.02) 186 | if isinstance(m, nn.Linear) and m.bias is not None: 187 | nn.init.constant_(m.bias, 0) 188 | elif isinstance(m, nn.LayerNorm): 189 | nn.init.constant_(m.bias, 0) 190 | nn.init.constant_(m.weight, 1.0) 191 | elif isinstance(m, nn.Conv2d): 192 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 193 | fan_out //= m.groups 194 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 195 | if m.bias is not None: 196 | m.bias.data.zero_() 197 | 198 | def forward(self, x): 199 | x = self.proj(x) 200 | _, _, H, W = x.shape 201 | x = x.flatten(2).transpose(1, 2) 202 | x = self.norm(x) 203 | 204 | return x, H, W 205 | 206 | 207 | class PyramidVisionTransformerImpr(nn.Module): 208 | def __init__(self, img_size=224, patch_size=16, in_channels=3, num_classes=1000, embed_dims=[64, 128, 256, 512], 209 | num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0., 210 | attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, 211 | depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1]): 212 | super().__init__() 213 | self.num_classes = num_classes 214 | self.depths = depths 215 | 216 | # patch_embed 217 | self.patch_embed1 = OverlapPatchEmbed(img_size=img_size, patch_size=7, stride=4, in_channels=in_channels, 218 | embed_dim=embed_dims[0]) 219 | self.patch_embed2 = OverlapPatchEmbed(img_size=img_size // 4, patch_size=3, stride=2, in_channels=embed_dims[0], 220 | embed_dim=embed_dims[1]) 221 | self.patch_embed3 = OverlapPatchEmbed(img_size=img_size // 8, patch_size=3, stride=2, in_channels=embed_dims[1], 222 | embed_dim=embed_dims[2]) 223 | self.patch_embed4 = OverlapPatchEmbed(img_size=img_size // 16, patch_size=3, stride=2, in_channels=embed_dims[2], 224 | embed_dim=embed_dims[3]) 225 | 226 | # transformer encoder 227 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule 228 | cur = 0 229 | self.block1 = nn.ModuleList([Block( 230 | dim=embed_dims[0], num_heads=num_heads[0], mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale, 231 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, 232 | sr_ratio=sr_ratios[0]) 233 | for i in range(depths[0])]) 234 | self.norm1 = norm_layer(embed_dims[0]) 235 | 236 | cur += depths[0] 237 | self.block2 = nn.ModuleList([Block( 238 | dim=embed_dims[1], num_heads=num_heads[1], mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, qk_scale=qk_scale, 239 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, 240 | sr_ratio=sr_ratios[1]) 241 | for i in range(depths[1])]) 242 | self.norm2 = norm_layer(embed_dims[1]) 243 | 244 | cur += depths[1] 245 | self.block3 = nn.ModuleList([Block( 246 | dim=embed_dims[2], num_heads=num_heads[2], mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, qk_scale=qk_scale, 247 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, 248 | sr_ratio=sr_ratios[2]) 249 | for i in range(depths[2])]) 250 | self.norm3 = norm_layer(embed_dims[2]) 251 | 252 | cur += depths[2] 253 | self.block4 = nn.ModuleList([Block( 254 | dim=embed_dims[3], num_heads=num_heads[3], mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, qk_scale=qk_scale, 255 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, 256 | sr_ratio=sr_ratios[3]) 257 | for i in range(depths[3])]) 258 | self.norm4 = norm_layer(embed_dims[3]) 259 | 260 | # classification head 261 | # self.head = nn.Linear(embed_dims[3], num_classes) if num_classes > 0 else nn.Identity() 262 | 263 | self.apply(self._init_weights) 264 | 265 | def _init_weights(self, m): 266 | if isinstance(m, nn.Linear): 267 | trunc_normal_(m.weight, std=.02) 268 | if isinstance(m, nn.Linear) and m.bias is not None: 269 | nn.init.constant_(m.bias, 0) 270 | elif isinstance(m, nn.LayerNorm): 271 | nn.init.constant_(m.bias, 0) 272 | nn.init.constant_(m.weight, 1.0) 273 | elif isinstance(m, nn.Conv2d): 274 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 275 | fan_out //= m.groups 276 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 277 | if m.bias is not None: 278 | m.bias.data.zero_() 279 | 280 | def init_weights(self, pretrained=None): 281 | if isinstance(pretrained, str): 282 | logger = 1 283 | #load_checkpoint(self, pretrained, map_location='cpu', strict=False, logger=logger) 284 | 285 | def reset_drop_path(self, drop_path_rate): 286 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))] 287 | cur = 0 288 | for i in range(self.depths[0]): 289 | self.block1[i].drop_path.drop_prob = dpr[cur + i] 290 | 291 | cur += self.depths[0] 292 | for i in range(self.depths[1]): 293 | self.block2[i].drop_path.drop_prob = dpr[cur + i] 294 | 295 | cur += self.depths[1] 296 | for i in range(self.depths[2]): 297 | self.block3[i].drop_path.drop_prob = dpr[cur + i] 298 | 299 | cur += self.depths[2] 300 | for i in range(self.depths[3]): 301 | self.block4[i].drop_path.drop_prob = dpr[cur + i] 302 | 303 | def freeze_patch_emb(self): 304 | self.patch_embed1.requires_grad = False 305 | 306 | @torch.jit.ignore 307 | def no_weight_decay(self): 308 | return {'pos_embed1', 'pos_embed2', 'pos_embed3', 'pos_embed4', 'cls_token'} # has pos_embed may be better 309 | 310 | def get_classifier(self): 311 | return self.head 312 | 313 | def reset_classifier(self, num_classes, global_pool=''): 314 | self.num_classes = num_classes 315 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 316 | 317 | def forward_features(self, x): 318 | B = x.shape[0] 319 | outs = [] 320 | 321 | # stage 1 322 | x, H, W = self.patch_embed1(x) 323 | for i, blk in enumerate(self.block1): 324 | x = blk(x, H, W) 325 | x = self.norm1(x) 326 | x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 327 | outs.append(x) 328 | 329 | # stage 2 330 | x, H, W = self.patch_embed2(x) 331 | for i, blk in enumerate(self.block2): 332 | x = blk(x, H, W) 333 | x = self.norm2(x) 334 | x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 335 | outs.append(x) 336 | 337 | # stage 3 338 | x, H, W = self.patch_embed3(x) 339 | for i, blk in enumerate(self.block3): 340 | x = blk(x, H, W) 341 | x = self.norm3(x) 342 | x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 343 | outs.append(x) 344 | 345 | # stage 4 346 | x, H, W = self.patch_embed4(x) 347 | for i, blk in enumerate(self.block4): 348 | x = blk(x, H, W) 349 | x = self.norm4(x) 350 | x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 351 | outs.append(x) 352 | 353 | return outs 354 | 355 | # return x.mean(dim=1) 356 | 357 | def forward(self, x): 358 | x = self.forward_features(x) 359 | # x = self.head(x) 360 | 361 | return x 362 | 363 | 364 | class DWConv(nn.Module): 365 | def __init__(self, dim=768): 366 | super(DWConv, self).__init__() 367 | self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) 368 | 369 | def forward(self, x, H, W): 370 | B, N, C = x.shape 371 | x = x.transpose(1, 2).view(B, C, H, W).contiguous() 372 | x = self.dwconv(x) 373 | x = x.flatten(2).transpose(1, 2) 374 | 375 | return x 376 | 377 | 378 | def _conv_filter(state_dict, patch_size=16): 379 | """ convert patch embedding weight from manual patchify + linear proj to conv""" 380 | out_dict = {} 381 | for k, v in state_dict.items(): 382 | if 'patch_embed.proj.weight' in k: 383 | v = v.reshape((v.shape[0], 3, patch_size, patch_size)) 384 | out_dict[k] = v 385 | 386 | return out_dict 387 | 388 | 389 | ## @register_model 390 | class pvt_v2_b0(PyramidVisionTransformerImpr): 391 | def __init__(self, **kwargs): 392 | super(pvt_v2_b0, self).__init__( 393 | patch_size=4, embed_dims=[32, 64, 160, 256], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], 394 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], 395 | drop_rate=0.0, drop_path_rate=0.1) 396 | 397 | 398 | 399 | ## @register_model 400 | class pvt_v2_b1(PyramidVisionTransformerImpr): 401 | def __init__(self, **kwargs): 402 | super(pvt_v2_b1, self).__init__( 403 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], 404 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], 405 | drop_rate=0.0, drop_path_rate=0.1) 406 | 407 | ## @register_model 408 | class pvt_v2_b2(PyramidVisionTransformerImpr): 409 | def __init__(self, in_channels=3, **kwargs): 410 | super(pvt_v2_b2, self).__init__( 411 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], 412 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], 413 | drop_rate=0.0, drop_path_rate=0.1, in_channels=in_channels) 414 | 415 | ## @register_model 416 | class pvt_v2_b3(PyramidVisionTransformerImpr): 417 | def __init__(self, **kwargs): 418 | super(pvt_v2_b3, self).__init__( 419 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], 420 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1], 421 | drop_rate=0.0, drop_path_rate=0.1) 422 | 423 | ## @register_model 424 | class pvt_v2_b4(PyramidVisionTransformerImpr): 425 | def __init__(self, **kwargs): 426 | super(pvt_v2_b4, self).__init__( 427 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], 428 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 8, 27, 3], sr_ratios=[8, 4, 2, 1], 429 | drop_rate=0.0, drop_path_rate=0.1) 430 | 431 | 432 | ## @register_model 433 | class pvt_v2_b5(PyramidVisionTransformerImpr): 434 | def __init__(self, **kwargs): 435 | super(pvt_v2_b5, self).__init__( 436 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], 437 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 6, 40, 3], sr_ratios=[8, 4, 2, 1], 438 | drop_rate=0.0, drop_path_rate=0.1) 439 | -------------------------------------------------------------------------------- /birefnet_old/models/birefnet.py: -------------------------------------------------------------------------------- 1 | # import torch 2 | # import torch.nn as nn 3 | from collections import OrderedDict 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torchvision.models import vgg16, vgg16_bn 8 | from torchvision.models import resnet50 9 | from kornia.filters import laplacian 10 | 11 | from ..config import Config 12 | from ..dataset import class_labels_TR_sorted 13 | from .backbones.build_backbone import build_backbone 14 | from .modules.decoder_blocks import BasicDecBlk, ResBlk, HierarAttDecBlk 15 | from .modules.lateral_blocks import BasicLatBlk 16 | from .modules.aspp import ASPP, ASPPDeformable 17 | from .modules.ing import * 18 | from .refinement.refiner import Refiner, RefinerPVTInChannels4, RefUNet 19 | from .refinement.stem_layer import StemLayer 20 | 21 | 22 | class BiRefNet(nn.Module): 23 | def __init__(self, bb_pretrained=True): 24 | super(BiRefNet, self).__init__() 25 | self.config = Config() 26 | self.epoch = 1 27 | self.bb = build_backbone(self.config.bb, pretrained=bb_pretrained) 28 | 29 | channels = self.config.lateral_channels_in_collection 30 | 31 | if self.config.auxiliary_classification: 32 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 33 | self.cls_head = nn.Sequential( 34 | nn.Linear(channels[0], len(class_labels_TR_sorted)) 35 | ) 36 | 37 | if self.config.squeeze_block: 38 | self.squeeze_module = nn.Sequential(*[ 39 | eval(self.config.squeeze_block.split('_x')[0])(channels[0]+sum(self.config.cxt), channels[0]) 40 | for _ in range(eval(self.config.squeeze_block.split('_x')[1])) 41 | ]) 42 | 43 | self.decoder = Decoder(channels) 44 | 45 | if self.config.locate_head: 46 | self.locate_header = nn.ModuleList([ 47 | BasicDecBlk(channels[0], channels[-1]), 48 | nn.Sequential( 49 | nn.Conv2d(channels[-1], 1, 1, 1, 0), 50 | ) 51 | ]) 52 | 53 | if self.config.ender: 54 | self.dec_end = nn.Sequential( 55 | nn.Conv2d(1, 16, 3, 1, 1), 56 | nn.Conv2d(16, 1, 3, 1, 1), 57 | nn.ReLU(inplace=True), 58 | ) 59 | 60 | # refine patch-level segmentation 61 | if self.config.refine: 62 | if self.config.refine == 'itself': 63 | self.stem_layer = StemLayer(in_channels=3+1, inter_channels=48, out_channels=3) 64 | else: 65 | self.refiner = eval('{}({})'.format(self.config.refine, 'in_channels=3+1')) 66 | 67 | if self.config.freeze_bb: 68 | # Freeze the backbone... 69 | print(self.named_parameters()) 70 | for key, value in self.named_parameters(): 71 | if 'bb.' in key and 'refiner.' not in key: 72 | value.requires_grad = False 73 | 74 | def forward_enc(self, x): 75 | if self.config.bb in ['vgg16', 'vgg16bn', 'resnet50']: 76 | x1 = self.bb.conv1(x); x2 = self.bb.conv2(x1); x3 = self.bb.conv3(x2); x4 = self.bb.conv4(x3) 77 | else: 78 | x1, x2, x3, x4 = self.bb(x) 79 | if self.config.mul_scl_ipt == 'cat': 80 | B, C, H, W = x.shape 81 | x1_, x2_, x3_, x4_ = self.bb(F.interpolate(x, size=(H//2, W//2), mode='bilinear', align_corners=True)) 82 | x1 = torch.cat([x1, F.interpolate(x1_, size=x1.shape[2:], mode='bilinear', align_corners=True)], dim=1) 83 | x2 = torch.cat([x2, F.interpolate(x2_, size=x2.shape[2:], mode='bilinear', align_corners=True)], dim=1) 84 | x3 = torch.cat([x3, F.interpolate(x3_, size=x3.shape[2:], mode='bilinear', align_corners=True)], dim=1) 85 | x4 = torch.cat([x4, F.interpolate(x4_, size=x4.shape[2:], mode='bilinear', align_corners=True)], dim=1) 86 | elif self.config.mul_scl_ipt == 'add': 87 | B, C, H, W = x.shape 88 | x1_, x2_, x3_, x4_ = self.bb(F.interpolate(x, size=(H//2, W//2), mode='bilinear', align_corners=True)) 89 | x1 = x1 + F.interpolate(x1_, size=x1.shape[2:], mode='bilinear', align_corners=True) 90 | x2 = x2 + F.interpolate(x2_, size=x2.shape[2:], mode='bilinear', align_corners=True) 91 | x3 = x3 + F.interpolate(x3_, size=x3.shape[2:], mode='bilinear', align_corners=True) 92 | x4 = x4 + F.interpolate(x4_, size=x4.shape[2:], mode='bilinear', align_corners=True) 93 | class_preds = self.cls_head(self.avgpool(x4).view(x4.shape[0], -1)) if self.training and self.config.auxiliary_classification else None 94 | if self.config.cxt: 95 | x4 = torch.cat( 96 | ( 97 | *[ 98 | F.interpolate(x1, size=x4.shape[2:], mode='bilinear', align_corners=True), 99 | F.interpolate(x2, size=x4.shape[2:], mode='bilinear', align_corners=True), 100 | F.interpolate(x3, size=x4.shape[2:], mode='bilinear', align_corners=True), 101 | ][-len(self.config.cxt):], 102 | x4 103 | ), 104 | dim=1 105 | ) 106 | return (x1, x2, x3, x4), class_preds 107 | 108 | def forward_ori(self, x): 109 | ########## Encoder ########## 110 | (x1, x2, x3, x4), class_preds = self.forward_enc(x) 111 | if self.config.squeeze_block: 112 | x4 = self.squeeze_module(x4) 113 | ########## Decoder ########## 114 | features = [x, x1, x2, x3, x4] 115 | if self.config.out_ref: 116 | features.append(laplacian(torch.mean(x, dim=1).unsqueeze(1), kernel_size=5)) 117 | scaled_preds = self.decoder(features) 118 | return scaled_preds, class_preds 119 | 120 | def forward_ref(self, x, pred): 121 | # refine patch-level segmentation 122 | if pred.shape[2:] != x.shape[2:]: 123 | pred = F.interpolate(pred, size=x.shape[2:], mode='bilinear', align_corners=True) 124 | # pred = pred.sigmoid() 125 | if self.config.refine == 'itself': 126 | x = self.stem_layer(torch.cat([x, pred], dim=1)) 127 | scaled_preds, class_preds = self.forward_ori(x) 128 | else: 129 | scaled_preds = self.refiner([x, pred]) 130 | class_preds = None 131 | return scaled_preds, class_preds 132 | 133 | def forward_ref_end(self, x): 134 | # remove the grids of concatenated preds 135 | return self.dec_end(x) if self.config.ender else x 136 | 137 | 138 | def forward(self, x): 139 | scaled_preds, class_preds = self.forward_ori(x) 140 | class_preds_lst = [class_preds] 141 | return [scaled_preds, class_preds_lst] if self.training else scaled_preds 142 | 143 | 144 | class Decoder(nn.Module): 145 | def __init__(self, channels): 146 | super(Decoder, self).__init__() 147 | self.config = Config() 148 | DecoderBlock = eval(self.config.dec_blk) 149 | LateralBlock = eval(self.config.lat_blk) 150 | 151 | if self.config.dec_ipt: 152 | self.split = self.config.dec_ipt_split 153 | N_dec_ipt = 64 154 | DBlock = SimpleConvs 155 | ic = 64 156 | ipt_cha_opt = 1 157 | self.ipt_blk4 = DBlock(2**8*3 if self.split else 3, [N_dec_ipt, channels[0]//8][ipt_cha_opt], inter_channels=ic) 158 | self.ipt_blk3 = DBlock(2**6*3 if self.split else 3, [N_dec_ipt, channels[1]//8][ipt_cha_opt], inter_channels=ic) 159 | self.ipt_blk2 = DBlock(2**4*3 if self.split else 3, [N_dec_ipt, channels[2]//8][ipt_cha_opt], inter_channels=ic) 160 | self.ipt_blk1 = DBlock(2**0*3 if self.split else 3, [N_dec_ipt, channels[3]//8][ipt_cha_opt], inter_channels=ic) 161 | else: 162 | self.split = None 163 | 164 | self.decoder_block4 = DecoderBlock(channels[0], channels[1]) 165 | self.decoder_block3 = DecoderBlock(channels[1]+([N_dec_ipt, channels[0]//8][ipt_cha_opt] if self.config.dec_ipt else 0), channels[2]) 166 | self.decoder_block2 = DecoderBlock(channels[2]+([N_dec_ipt, channels[1]//8][ipt_cha_opt] if self.config.dec_ipt else 0), channels[3]) 167 | self.decoder_block1 = DecoderBlock(channels[3]+([N_dec_ipt, channels[2]//8][ipt_cha_opt] if self.config.dec_ipt else 0), channels[3]//2) 168 | self.conv_out1 = nn.Sequential(nn.Conv2d(channels[3]//2+([N_dec_ipt, channels[3]//8][ipt_cha_opt] if self.config.dec_ipt else 0), 1, 1, 1, 0)) 169 | 170 | self.lateral_block4 = LateralBlock(channels[1], channels[1]) 171 | self.lateral_block3 = LateralBlock(channels[2], channels[2]) 172 | self.lateral_block2 = LateralBlock(channels[3], channels[3]) 173 | 174 | if self.config.ms_supervision: 175 | self.conv_ms_spvn_4 = nn.Conv2d(channels[1], 1, 1, 1, 0) 176 | self.conv_ms_spvn_3 = nn.Conv2d(channels[2], 1, 1, 1, 0) 177 | self.conv_ms_spvn_2 = nn.Conv2d(channels[3], 1, 1, 1, 0) 178 | 179 | if self.config.out_ref: 180 | _N = 16 181 | # self.gdt_convs_4 = nn.Sequential(nn.Conv2d(channels[1], _N, 3, 1, 1), nn.BatchNorm2d(_N), nn.ReLU(inplace=True)) 182 | self.gdt_convs_3 = nn.Sequential(nn.Conv2d(channels[2], _N, 3, 1, 1), nn.BatchNorm2d(_N), nn.ReLU(inplace=True)) 183 | self.gdt_convs_2 = nn.Sequential(nn.Conv2d(channels[3], _N, 3, 1, 1), nn.BatchNorm2d(_N), nn.ReLU(inplace=True)) 184 | 185 | # self.gdt_convs_pred_4 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0)) 186 | self.gdt_convs_pred_3 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0)) 187 | self.gdt_convs_pred_2 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0)) 188 | 189 | # self.gdt_convs_attn_4 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0)) 190 | self.gdt_convs_attn_3 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0)) 191 | self.gdt_convs_attn_2 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0)) 192 | 193 | 194 | def get_patches_batch(self, x, p): 195 | _size_h, _size_w = p.shape[2:] 196 | patches_batch = [] 197 | for idx in range(x.shape[0]): 198 | columns_x = torch.split(x[idx], split_size_or_sections=_size_w, dim=-1) 199 | patches_x = [] 200 | for column_x in columns_x: 201 | patches_x += [p.unsqueeze(0) for p in torch.split(column_x, split_size_or_sections=_size_h, dim=-2)] 202 | patch_sample = torch.cat(patches_x, dim=1) 203 | patches_batch.append(patch_sample) 204 | return torch.cat(patches_batch, dim=0) 205 | 206 | def forward(self, features): 207 | if self.config.out_ref: 208 | outs_gdt_pred = [] 209 | outs_gdt_label = [] 210 | x, x1, x2, x3, x4, gdt_gt = features 211 | else: 212 | x, x1, x2, x3, x4 = features 213 | outs = [] 214 | p4 = self.decoder_block4(x4) 215 | m4 = self.conv_ms_spvn_4(p4) if self.config.ms_supervision else None 216 | _p4 = F.interpolate(p4, size=x3.shape[2:], mode='bilinear', align_corners=True) 217 | _p3 = _p4 + self.lateral_block4(x3) 218 | if self.config.dec_ipt: 219 | patches_batch = self.get_patches_batch(x, _p3) if self.split else x 220 | _p3 = torch.cat((_p3, self.ipt_blk4(F.interpolate(patches_batch, size=x3.shape[2:], mode='bilinear', align_corners=True))), 1) 221 | 222 | p3 = self.decoder_block3(_p3) 223 | m3 = self.conv_ms_spvn_3(p3) if self.config.ms_supervision else None 224 | if self.config.out_ref: 225 | # >> GT: 226 | # m3 --dilation--> m3_dia 227 | # G_3^gt * m3_dia --> G_3^m, which is the label of gradient 228 | m3_dia = m3 229 | gdt_label_main_3 = gdt_gt * F.interpolate(m3_dia, size=gdt_gt.shape[2:], mode='bilinear', align_corners=True) 230 | outs_gdt_label.append(gdt_label_main_3) 231 | # >> Pred: 232 | # p3 --conv--BN--> F_3^G, where F_3^G predicts the \hat{G_3} with xx 233 | # F_3^G --sigmoid--> A_3^G 234 | p3_gdt = self.gdt_convs_3(p3) 235 | gdt_pred_3 = self.gdt_convs_pred_3(p3_gdt) 236 | outs_gdt_pred.append(gdt_pred_3) 237 | gdt_attn_3 = self.gdt_convs_attn_3(p3_gdt).sigmoid() 238 | # >> Finally: 239 | # p3 = p3 * A_3^G 240 | p3 = p3 * gdt_attn_3 241 | _p3 = F.interpolate(p3, size=x2.shape[2:], mode='bilinear', align_corners=True) 242 | _p2 = _p3 + self.lateral_block3(x2) 243 | if self.config.dec_ipt: 244 | patches_batch = self.get_patches_batch(x, _p2) if self.split else x 245 | _p2 = torch.cat((_p2, self.ipt_blk3(F.interpolate(patches_batch, size=x2.shape[2:], mode='bilinear', align_corners=True))), 1) 246 | 247 | p2 = self.decoder_block2(_p2) 248 | m2 = self.conv_ms_spvn_2(p2) if self.config.ms_supervision else None 249 | if self.config.out_ref: 250 | # >> GT: 251 | m2_dia = m2 252 | gdt_label_main_2 = gdt_gt * F.interpolate(m2_dia, size=gdt_gt.shape[2:], mode='bilinear', align_corners=True) 253 | outs_gdt_label.append(gdt_label_main_2) 254 | # >> Pred: 255 | p2_gdt = self.gdt_convs_2(p2) 256 | gdt_pred_2 = self.gdt_convs_pred_2(p2_gdt) 257 | outs_gdt_pred.append(gdt_pred_2) 258 | gdt_attn_2 = self.gdt_convs_attn_2(p2_gdt).sigmoid() 259 | # >> Finally: 260 | p2 = p2 * gdt_attn_2 261 | _p2 = F.interpolate(p2, size=x1.shape[2:], mode='bilinear', align_corners=True) 262 | _p1 = _p2 + self.lateral_block2(x1) 263 | if self.config.dec_ipt: 264 | patches_batch = self.get_patches_batch(x, _p1) if self.split else x 265 | _p1 = torch.cat((_p1, self.ipt_blk2(F.interpolate(patches_batch, size=x1.shape[2:], mode='bilinear', align_corners=True))), 1) 266 | 267 | _p1 = self.decoder_block1(_p1) 268 | _p1 = F.interpolate(_p1, size=x.shape[2:], mode='bilinear', align_corners=True) 269 | if self.config.dec_ipt: 270 | patches_batch = self.get_patches_batch(x, _p1) if self.split else x 271 | _p1 = torch.cat((_p1, self.ipt_blk1(F.interpolate(patches_batch, size=x.shape[2:], mode='bilinear', align_corners=True))), 1) 272 | p1_out = self.conv_out1(_p1) 273 | 274 | if self.config.ms_supervision: 275 | outs.append(m4) 276 | outs.append(m3) 277 | outs.append(m2) 278 | outs.append(p1_out) 279 | return outs if not (self.config.out_ref and self.training) else ([outs_gdt_pred, outs_gdt_label], outs) 280 | 281 | 282 | class SimpleConvs(nn.Module): 283 | def __init__( 284 | self, in_channels: int, out_channels: int, inter_channels=64 285 | ) -> None: 286 | super().__init__() 287 | self.conv1 = nn.Conv2d(in_channels, inter_channels, 3, 1, 1) 288 | self.conv_out = nn.Conv2d(inter_channels, out_channels, 3, 1, 1) 289 | 290 | def forward(self, x): 291 | return self.conv_out(self.conv1(x)) 292 | -------------------------------------------------------------------------------- /birefnet_old/models/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lldacing/ComfyUI_BiRefNet_ll/5443a2aa16cfbd98bb2f7dcc8bdcb70439e08529/birefnet_old/models/modules/__init__.py -------------------------------------------------------------------------------- /birefnet_old/models/modules/aspp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from ..modules.deform_conv import DeformableConv2d 5 | from ...config import Config 6 | 7 | 8 | config = Config() 9 | 10 | 11 | class ASPPComplex(nn.Module): 12 | def __init__(self, in_channels=64, out_channels=None, output_stride=16): 13 | super(ASPPComplex, self).__init__() 14 | self.down_scale = 1 15 | if out_channels is None: 16 | out_channels = in_channels 17 | self.in_channelster = 256 // self.down_scale 18 | if output_stride == 16: 19 | dilations = [1, 6, 12, 18] 20 | elif output_stride == 8: 21 | dilations = [1, 12, 24, 36] 22 | else: 23 | raise NotImplementedError 24 | 25 | self.aspp1 = _ASPPModule(in_channels, self.in_channelster, 1, padding=0, dilation=dilations[0]) 26 | self.aspp2 = _ASPPModule(in_channels, self.in_channelster, 3, padding=dilations[1], dilation=dilations[1]) 27 | self.aspp3 = _ASPPModule(in_channels, self.in_channelster, 3, padding=dilations[2], dilation=dilations[2]) 28 | self.aspp4 = _ASPPModule(in_channels, self.in_channelster, 3, padding=dilations[3], dilation=dilations[3]) 29 | 30 | self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)), 31 | nn.Conv2d(in_channels, self.in_channelster, 1, stride=1, bias=False), 32 | nn.BatchNorm2d(self.in_channelster) if config.batch_size > 1 else nn.Identity(), 33 | nn.ReLU(inplace=True)) 34 | self.conv1 = nn.Conv2d(self.in_channelster * 5, out_channels, 1, bias=False) 35 | self.bn1 = nn.BatchNorm2d(out_channels) 36 | self.relu = nn.ReLU(inplace=True) 37 | self.dropout = nn.Dropout(0.5) 38 | 39 | def forward(self, x): 40 | x1 = self.aspp1(x) 41 | x2 = self.aspp2(x) 42 | x3 = self.aspp3(x) 43 | x4 = self.aspp4(x) 44 | x5 = self.global_avg_pool(x) 45 | x5 = F.interpolate(x5, size=x1.size()[2:], mode='bilinear', align_corners=True) 46 | x = torch.cat((x1, x2, x3, x4, x5), dim=1) 47 | 48 | x = self.conv1(x) 49 | x = self.bn1(x) 50 | x = self.relu(x) 51 | 52 | return self.dropout(x) 53 | 54 | 55 | class _ASPPModule(nn.Module): 56 | def __init__(self, in_channels, planes, kernel_size, padding, dilation): 57 | super(_ASPPModule, self).__init__() 58 | self.atrous_conv = nn.Conv2d(in_channels, planes, kernel_size=kernel_size, 59 | stride=1, padding=padding, dilation=dilation, bias=False) 60 | self.bn = nn.BatchNorm2d(planes) 61 | self.relu = nn.ReLU(inplace=True) 62 | 63 | def forward(self, x): 64 | x = self.atrous_conv(x) 65 | x = self.bn(x) 66 | 67 | return self.relu(x) 68 | 69 | 70 | class ASPP(nn.Module): 71 | def __init__(self, in_channels=64, out_channels=None, output_stride=16): 72 | super(ASPP, self).__init__() 73 | self.down_scale = 1 74 | if out_channels is None: 75 | out_channels = in_channels 76 | self.in_channelster = 256 // self.down_scale 77 | if output_stride == 16: 78 | dilations = [1, 6, 12, 18] 79 | elif output_stride == 8: 80 | dilations = [1, 12, 24, 36] 81 | else: 82 | raise NotImplementedError 83 | 84 | self.aspp1 = _ASPPModule(in_channels, self.in_channelster, 1, padding=0, dilation=dilations[0]) 85 | self.aspp2 = _ASPPModule(in_channels, self.in_channelster, 3, padding=dilations[1], dilation=dilations[1]) 86 | self.aspp3 = _ASPPModule(in_channels, self.in_channelster, 3, padding=dilations[2], dilation=dilations[2]) 87 | self.aspp4 = _ASPPModule(in_channels, self.in_channelster, 3, padding=dilations[3], dilation=dilations[3]) 88 | 89 | self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)), 90 | nn.Conv2d(in_channels, self.in_channelster, 1, stride=1, bias=False), 91 | nn.BatchNorm2d(self.in_channelster) if config.batch_size > 1 else nn.Identity(), 92 | nn.ReLU(inplace=True)) 93 | self.conv1 = nn.Conv2d(self.in_channelster * 5, out_channels, 1, bias=False) 94 | self.bn1 = nn.BatchNorm2d(out_channels) 95 | self.relu = nn.ReLU(inplace=True) 96 | self.dropout = nn.Dropout(0.5) 97 | 98 | def forward(self, x): 99 | x1 = self.aspp1(x) 100 | x2 = self.aspp2(x) 101 | x3 = self.aspp3(x) 102 | x4 = self.aspp4(x) 103 | x5 = self.global_avg_pool(x) 104 | x5 = F.interpolate(x5, size=x1.size()[2:], mode='bilinear', align_corners=True) 105 | x = torch.cat((x1, x2, x3, x4, x5), dim=1) 106 | 107 | x = self.conv1(x) 108 | x = self.bn1(x) 109 | x = self.relu(x) 110 | 111 | return self.dropout(x) 112 | 113 | 114 | ##################### Deformable 115 | class _ASPPModuleDeformable(nn.Module): 116 | def __init__(self, in_channels, planes, kernel_size, padding): 117 | super(_ASPPModuleDeformable, self).__init__() 118 | self.atrous_conv = DeformableConv2d(in_channels, planes, kernel_size=kernel_size, 119 | stride=1, padding=padding, bias=False) 120 | self.bn = nn.BatchNorm2d(planes) 121 | self.relu = nn.ReLU(inplace=True) 122 | 123 | def forward(self, x): 124 | x = self.atrous_conv(x) 125 | x = self.bn(x) 126 | 127 | return self.relu(x) 128 | 129 | 130 | class ASPPDeformable(nn.Module): 131 | def __init__(self, in_channels, out_channels=None, num_parallel_block=1): 132 | super(ASPPDeformable, self).__init__() 133 | self.down_scale = 1 134 | if out_channels is None: 135 | out_channels = in_channels 136 | self.in_channelster = 256 // self.down_scale 137 | 138 | self.aspp1 = _ASPPModuleDeformable(in_channels, self.in_channelster, 1, padding=0) 139 | self.aspp_deforms = nn.ModuleList([ 140 | _ASPPModuleDeformable(in_channels, self.in_channelster, 3, padding=1) for _ in range(num_parallel_block) 141 | ]) 142 | 143 | self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)), 144 | nn.Conv2d(in_channels, self.in_channelster, 1, stride=1, bias=False), 145 | nn.BatchNorm2d(self.in_channelster) if config.batch_size > 1 else nn.Identity(), 146 | nn.ReLU(inplace=True)) 147 | self.conv1 = nn.Conv2d(self.in_channelster * (2 + len(self.aspp_deforms)), out_channels, 1, bias=False) 148 | self.bn1 = nn.BatchNorm2d(out_channels) 149 | self.relu = nn.ReLU(inplace=True) 150 | self.dropout = nn.Dropout(0.5) 151 | 152 | def forward(self, x): 153 | x1 = self.aspp1(x) 154 | x_aspp_deforms = [aspp_deform(x) for aspp_deform in self.aspp_deforms] 155 | x5 = self.global_avg_pool(x) 156 | x5 = F.interpolate(x5, size=x1.size()[2:], mode='bilinear', align_corners=True) 157 | x = torch.cat((x1, *x_aspp_deforms, x5), dim=1) 158 | 159 | x = self.conv1(x) 160 | x = self.bn1(x) 161 | x = self.relu(x) 162 | 163 | return self.dropout(x) 164 | -------------------------------------------------------------------------------- /birefnet_old/models/modules/attentions.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | from torch.nn import init 5 | 6 | 7 | class SEWeightModule(nn.Module): 8 | def __init__(self, channels, reduction=16): 9 | super(SEWeightModule, self).__init__() 10 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 11 | self.fc1 = nn.Conv2d(channels, channels//reduction, kernel_size=1, padding=0) 12 | self.relu = nn.ReLU(inplace=True) 13 | self.fc2 = nn.Conv2d(channels//reduction, channels, kernel_size=1, padding=0) 14 | self.sigmoid = nn.Sigmoid() 15 | 16 | def forward(self, x): 17 | out = self.avg_pool(x) 18 | out = self.fc1(out) 19 | out = self.relu(out) 20 | out = self.fc2(out) 21 | weight = self.sigmoid(out) 22 | return weight 23 | 24 | 25 | class PSA(nn.Module): 26 | 27 | def __init__(self, in_channels, S=4, reduction=4): 28 | super().__init__() 29 | self.S = S 30 | 31 | _convs = [] 32 | for i in range(S): 33 | _convs.append(nn.Conv2d(in_channels//S, in_channels//S, kernel_size=2*(i+1)+1, padding=i+1)) 34 | self.convs = nn.ModuleList(_convs) 35 | 36 | self.se_block = SEWeightModule(in_channels//S, reduction=S*reduction) 37 | 38 | self.softmax = nn.Softmax(dim=1) 39 | 40 | def forward(self, x): 41 | b, c, h, w = x.size() 42 | 43 | # Step1: SPC module 44 | SPC_out = x.view(b, self.S, c//self.S, h, w) #bs,s,ci,h,w 45 | for idx, conv in enumerate(self.convs): 46 | SPC_out[:,idx,:,:,:] = conv(SPC_out[:,idx,:,:,:].clone()) 47 | 48 | # Step2: SE weight 49 | se_out=[] 50 | for idx in range(self.S): 51 | se_out.append(self.se_block(SPC_out[:, idx, :, :, :])) 52 | SE_out = torch.stack(se_out, dim=1) 53 | SE_out = SE_out.expand_as(SPC_out) 54 | 55 | # Step3: Softmax 56 | softmax_out = self.softmax(SE_out) 57 | 58 | # Step4: SPA 59 | PSA_out = SPC_out * softmax_out 60 | PSA_out = PSA_out.view(b, -1, h, w) 61 | 62 | return PSA_out 63 | 64 | 65 | class SGE(nn.Module): 66 | 67 | def __init__(self, groups): 68 | super().__init__() 69 | self.groups=groups 70 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 71 | self.weight=nn.Parameter(torch.zeros(1,groups,1,1)) 72 | self.bias=nn.Parameter(torch.zeros(1,groups,1,1)) 73 | self.sig=nn.Sigmoid() 74 | 75 | def forward(self, x): 76 | b, c, h,w=x.shape 77 | x=x.view(b*self.groups,-1,h,w) #bs*g,dim//g,h,w 78 | xn=x*self.avg_pool(x) #bs*g,dim//g,h,w 79 | xn=xn.sum(dim=1,keepdim=True) #bs*g,1,h,w 80 | t=xn.view(b*self.groups,-1) #bs*g,h*w 81 | 82 | t=t-t.mean(dim=1,keepdim=True) #bs*g,h*w 83 | std=t.std(dim=1,keepdim=True)+1e-5 84 | t=t/std #bs*g,h*w 85 | t=t.view(b,self.groups,h,w) #bs,g,h*w 86 | 87 | t=t*self.weight+self.bias #bs,g,h*w 88 | t=t.view(b*self.groups,1,h,w) #bs*g,1,h*w 89 | x=x*self.sig(t) 90 | x=x.view(b,c,h,w) 91 | 92 | return x 93 | 94 | -------------------------------------------------------------------------------- /birefnet_old/models/modules/decoder_blocks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from ..modules.aspp import ASPP, ASPPDeformable 4 | from ..modules.attentions import PSA, SGE 5 | from ...config import Config 6 | 7 | 8 | config = Config() 9 | 10 | 11 | class BasicDecBlk(nn.Module): 12 | def __init__(self, in_channels=64, out_channels=64, inter_channels=64): 13 | super(BasicDecBlk, self).__init__() 14 | inter_channels = in_channels // 4 if config.dec_channels_inter == 'adap' else 64 15 | self.conv_in = nn.Conv2d(in_channels, inter_channels, 3, 1, padding=1) 16 | self.relu_in = nn.ReLU(inplace=True) 17 | if config.dec_att == 'ASPP': 18 | self.dec_att = ASPP(in_channels=inter_channels) 19 | elif config.dec_att == 'ASPPDeformable': 20 | self.dec_att = ASPPDeformable(in_channels=inter_channels) 21 | self.conv_out = nn.Conv2d(inter_channels, out_channels, 3, 1, padding=1) 22 | self.bn_in = nn.BatchNorm2d(inter_channels) 23 | self.bn_out = nn.BatchNorm2d(out_channels) 24 | 25 | def forward(self, x): 26 | x = self.conv_in(x) 27 | x = self.bn_in(x) 28 | x = self.relu_in(x) 29 | if hasattr(self, 'dec_att'): 30 | x = self.dec_att(x) 31 | x = self.conv_out(x) 32 | x = self.bn_out(x) 33 | return x 34 | 35 | 36 | class ResBlk(nn.Module): 37 | def __init__(self, in_channels=64, out_channels=None, inter_channels=64): 38 | super(ResBlk, self).__init__() 39 | if out_channels is None: 40 | out_channels = in_channels 41 | inter_channels = in_channels // 4 if config.dec_channels_inter == 'adap' else 64 42 | 43 | self.conv_in = nn.Conv2d(in_channels, inter_channels, 3, 1, padding=1) 44 | self.bn_in = nn.BatchNorm2d(inter_channels) 45 | self.relu_in = nn.ReLU(inplace=True) 46 | 47 | if config.dec_att == 'ASPP': 48 | self.dec_att = ASPP(in_channels=inter_channels) 49 | elif config.dec_att == 'ASPPDeformable': 50 | self.dec_att = ASPPDeformable(in_channels=inter_channels) 51 | 52 | self.conv_out = nn.Conv2d(inter_channels, out_channels, 3, 1, padding=1) 53 | self.bn_out = nn.BatchNorm2d(out_channels) 54 | 55 | self.conv_resi = nn.Conv2d(in_channels, out_channels, 1, 1, 0) 56 | 57 | def forward(self, x): 58 | _x = self.conv_resi(x) 59 | x = self.conv_in(x) 60 | x = self.bn_in(x) 61 | x = self.relu_in(x) 62 | if hasattr(self, 'dec_att'): 63 | x = self.dec_att(x) 64 | x = self.conv_out(x) 65 | x = self.bn_out(x) 66 | return x + _x 67 | 68 | 69 | class HierarAttDecBlk(nn.Module): 70 | def __init__(self, in_channels=64, out_channels=None, inter_channels=64): 71 | super(HierarAttDecBlk, self).__init__() 72 | if out_channels is None: 73 | out_channels = in_channels 74 | inter_channels = in_channels // 4 if config.dec_channels_inter == 'adap' else 64 75 | self.split_y = 8 # must be divided by channels of all intermediate features 76 | self.split_x = 8 77 | 78 | self.conv_in = nn.Conv2d(in_channels, inter_channels, 3, 1, 1) 79 | 80 | self.psa = PSA(inter_channels*self.split_y*self.split_x, S=config.batch_size) 81 | self.sge = SGE(groups=config.batch_size) 82 | 83 | if config.dec_att == 'ASPP': 84 | self.dec_att = ASPP(in_channels=inter_channels) 85 | elif config.dec_att == 'ASPPDeformable': 86 | self.dec_att = ASPPDeformable(in_channels=inter_channels) 87 | self.conv_out = nn.Conv2d(inter_channels, out_channels, 3, 1, 1) 88 | 89 | def forward(self, x): 90 | x = self.conv_in(x) 91 | N, C, H, W = x.shape 92 | x_patchs = x.reshape(N, -1, H//self.split_y, W//self.split_x) 93 | 94 | # Hierarchical attention: group attention X patch spatial attention 95 | x_patchs = self.psa(x_patchs) # Group Channel Attention -- each group is a single image 96 | x_patchs = self.sge(x_patchs) # Patch Spatial Attention 97 | x = x.reshape(N, C, H, W) 98 | if hasattr(self, 'dec_att'): 99 | x = self.dec_att(x) 100 | x = self.conv_out(x) 101 | return x 102 | -------------------------------------------------------------------------------- /birefnet_old/models/modules/deform_conv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torchvision.ops import deform_conv2d 4 | 5 | 6 | class DeformableConv2d(nn.Module): 7 | def __init__(self, 8 | in_channels, 9 | out_channels, 10 | kernel_size=3, 11 | stride=1, 12 | padding=1, 13 | bias=False): 14 | 15 | super(DeformableConv2d, self).__init__() 16 | 17 | assert type(kernel_size) == tuple or type(kernel_size) == int 18 | 19 | kernel_size = kernel_size if type(kernel_size) == tuple else (kernel_size, kernel_size) 20 | self.stride = stride if type(stride) == tuple else (stride, stride) 21 | self.padding = padding 22 | 23 | self.offset_conv = nn.Conv2d(in_channels, 24 | 2 * kernel_size[0] * kernel_size[1], 25 | kernel_size=kernel_size, 26 | stride=stride, 27 | padding=self.padding, 28 | bias=True) 29 | 30 | nn.init.constant_(self.offset_conv.weight, 0.) 31 | nn.init.constant_(self.offset_conv.bias, 0.) 32 | 33 | self.modulator_conv = nn.Conv2d(in_channels, 34 | 1 * kernel_size[0] * kernel_size[1], 35 | kernel_size=kernel_size, 36 | stride=stride, 37 | padding=self.padding, 38 | bias=True) 39 | 40 | nn.init.constant_(self.modulator_conv.weight, 0.) 41 | nn.init.constant_(self.modulator_conv.bias, 0.) 42 | 43 | self.regular_conv = nn.Conv2d(in_channels, 44 | out_channels=out_channels, 45 | kernel_size=kernel_size, 46 | stride=stride, 47 | padding=self.padding, 48 | bias=bias) 49 | 50 | def forward(self, x): 51 | #h, w = x.shape[2:] 52 | #max_offset = max(h, w)/4. 53 | 54 | offset = self.offset_conv(x)#.clamp(-max_offset, max_offset) 55 | modulator = 2. * torch.sigmoid(self.modulator_conv(x)) 56 | 57 | x = deform_conv2d( 58 | input=x, 59 | offset=offset, 60 | weight=self.regular_conv.weight, 61 | bias=self.regular_conv.bias, 62 | padding=self.padding, 63 | mask=modulator, 64 | stride=self.stride, 65 | ) 66 | return x 67 | -------------------------------------------------------------------------------- /birefnet_old/models/modules/ing.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from ..modules.mlp import MLPLayer 3 | 4 | 5 | class BlockA(nn.Module): 6 | def __init__(self, in_channels=64, out_channels=64, inter_channels=64, mlp_ratio=4.): 7 | super(BlockA, self).__init__() 8 | inter_channels = in_channels 9 | self.conv = nn.Conv2d(in_channels, inter_channels, 3, 1, 1) 10 | self.norm1 = nn.LayerNorm(inter_channels) 11 | self.ffn = MLPLayer(in_features=inter_channels, 12 | hidden_features=int(inter_channels * mlp_ratio), 13 | act_layer=nn.GELU, 14 | drop=0.) 15 | self.norm2 = nn.LayerNorm(inter_channels) 16 | 17 | def forward(self, x): 18 | B, C, H, W = x.shape 19 | _x = self.conv(x) 20 | _x = _x.flatten(2).transpose(1, 2) 21 | _x = self.norm1(_x) 22 | x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 23 | 24 | x = x + _x 25 | _x1 = self.ffn(x) 26 | _x1 = self.norm2(_x1) 27 | _x1 = _x1.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 28 | x = x + _x1 29 | return x -------------------------------------------------------------------------------- /birefnet_old/models/modules/lateral_blocks.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from functools import partial 6 | 7 | from ...config import Config 8 | 9 | 10 | config = Config() 11 | 12 | 13 | class BasicLatBlk(nn.Module): 14 | def __init__(self, in_channels=64, out_channels=64, inter_channels=64): 15 | super(BasicLatBlk, self).__init__() 16 | inter_channels = in_channels // 4 if config.dec_channels_inter == 'adap' else 64 17 | self.conv = nn.Conv2d(in_channels, out_channels, 1, 1, 0) 18 | 19 | def forward(self, x): 20 | x = self.conv(x) 21 | return x 22 | -------------------------------------------------------------------------------- /birefnet_old/models/modules/mlp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from functools import partial 4 | 5 | try: 6 | # version > 0.6.13 7 | from timm.layers import DropPath, to_2tuple, trunc_normal_ 8 | except Exception: 9 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 10 | 11 | import math 12 | 13 | 14 | class MLPLayer(nn.Module): 15 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 16 | super().__init__() 17 | out_features = out_features or in_features 18 | hidden_features = hidden_features or in_features 19 | self.fc1 = nn.Linear(in_features, hidden_features) 20 | self.act = act_layer() 21 | self.fc2 = nn.Linear(hidden_features, out_features) 22 | self.drop = nn.Dropout(drop) 23 | 24 | def forward(self, x): 25 | x = self.fc1(x) 26 | x = self.act(x) 27 | x = self.drop(x) 28 | x = self.fc2(x) 29 | x = self.drop(x) 30 | return x 31 | 32 | 33 | class Attention(nn.Module): 34 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1): 35 | super().__init__() 36 | assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." 37 | 38 | self.dim = dim 39 | self.num_heads = num_heads 40 | head_dim = dim // num_heads 41 | self.scale = qk_scale or head_dim ** -0.5 42 | 43 | self.q = nn.Linear(dim, dim, bias=qkv_bias) 44 | self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) 45 | self.attn_drop = nn.Dropout(attn_drop) 46 | self.proj = nn.Linear(dim, dim) 47 | self.proj_drop = nn.Dropout(proj_drop) 48 | 49 | self.sr_ratio = sr_ratio 50 | if sr_ratio > 1: 51 | self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) 52 | self.norm = nn.LayerNorm(dim) 53 | 54 | def forward(self, x, H, W): 55 | B, N, C = x.shape 56 | q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 57 | 58 | if self.sr_ratio > 1: 59 | x_ = x.permute(0, 2, 1).reshape(B, C, H, W) 60 | x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1) 61 | x_ = self.norm(x_) 62 | kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 63 | else: 64 | kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 65 | k, v = kv[0], kv[1] 66 | 67 | attn = (q @ k.transpose(-2, -1)) * self.scale 68 | attn = attn.softmax(dim=-1) 69 | attn = self.attn_drop(attn) 70 | 71 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 72 | x = self.proj(x) 73 | x = self.proj_drop(x) 74 | return x 75 | 76 | 77 | class Block(nn.Module): 78 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 79 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1): 80 | super().__init__() 81 | self.norm1 = norm_layer(dim) 82 | self.attn = Attention( 83 | dim, 84 | num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, 85 | attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio) 86 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 87 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 88 | self.norm2 = norm_layer(dim) 89 | mlp_hidden_dim = int(dim * mlp_ratio) 90 | self.mlp = MLPLayer(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 91 | 92 | def forward(self, x, H, W): 93 | x = x + self.drop_path(self.attn(self.norm1(x), H, W)) 94 | x = x + self.drop_path(self.mlp(self.norm2(x), H, W)) 95 | return x 96 | 97 | 98 | class OverlapPatchEmbed(nn.Module): 99 | """ Image to Patch Embedding 100 | """ 101 | 102 | def __init__(self, img_size=224, patch_size=7, stride=4, in_channels=3, embed_dim=768): 103 | super().__init__() 104 | img_size = to_2tuple(img_size) 105 | patch_size = to_2tuple(patch_size) 106 | 107 | self.img_size = img_size 108 | self.patch_size = patch_size 109 | self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1] 110 | self.num_patches = self.H * self.W 111 | self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=stride, 112 | padding=(patch_size[0] // 2, patch_size[1] // 2)) 113 | self.norm = nn.LayerNorm(embed_dim) 114 | 115 | def forward(self, x): 116 | x = self.proj(x) 117 | _, _, H, W = x.shape 118 | x = x.flatten(2).transpose(1, 2) 119 | x = self.norm(x) 120 | return x, H, W 121 | 122 | -------------------------------------------------------------------------------- /birefnet_old/models/modules/utils.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | def build_act_layer(act_layer): 5 | if act_layer == 'ReLU': 6 | return nn.ReLU(inplace=True) 7 | elif act_layer == 'SiLU': 8 | return nn.SiLU(inplace=True) 9 | elif act_layer == 'GELU': 10 | return nn.GELU() 11 | 12 | raise NotImplementedError(f'build_act_layer does not support {act_layer}') 13 | 14 | 15 | def build_norm_layer(dim, 16 | norm_layer, 17 | in_format='channels_last', 18 | out_format='channels_last', 19 | eps=1e-6): 20 | layers = [] 21 | if norm_layer == 'BN': 22 | if in_format == 'channels_last': 23 | layers.append(to_channels_first()) 24 | layers.append(nn.BatchNorm2d(dim)) 25 | if out_format == 'channels_last': 26 | layers.append(to_channels_last()) 27 | elif norm_layer == 'LN': 28 | if in_format == 'channels_first': 29 | layers.append(to_channels_last()) 30 | layers.append(nn.LayerNorm(dim, eps=eps)) 31 | if out_format == 'channels_first': 32 | layers.append(to_channels_first()) 33 | else: 34 | raise NotImplementedError( 35 | f'build_norm_layer does not support {norm_layer}') 36 | return nn.Sequential(*layers) 37 | 38 | 39 | class to_channels_first(nn.Module): 40 | 41 | def __init__(self): 42 | super().__init__() 43 | 44 | def forward(self, x): 45 | return x.permute(0, 3, 1, 2) 46 | 47 | 48 | class to_channels_last(nn.Module): 49 | 50 | def __init__(self): 51 | super().__init__() 52 | 53 | def forward(self, x): 54 | return x.permute(0, 2, 3, 1) 55 | -------------------------------------------------------------------------------- /birefnet_old/models/refinement/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lldacing/ComfyUI_BiRefNet_ll/5443a2aa16cfbd98bb2f7dcc8bdcb70439e08529/birefnet_old/models/refinement/__init__.py -------------------------------------------------------------------------------- /birefnet_old/models/refinement/refiner.py: -------------------------------------------------------------------------------- 1 | # import torch 2 | # import torch.nn as nn 3 | # from collections import OrderedDict 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | # from torchvision.models import vgg16, vgg16_bn 8 | # from torchvision.models import resnet50 9 | 10 | from birefnet_old.config import Config 11 | from birefnet_old.dataset import class_labels_TR_sorted 12 | from birefnet_old.models.backbones.build_backbone import build_backbone 13 | from birefnet_old.models.modules.decoder_blocks import BasicDecBlk 14 | from birefnet_old.models.modules.lateral_blocks import BasicLatBlk 15 | from birefnet_old.models.modules.ing import * 16 | from birefnet_old.models.refinement.stem_layer import StemLayer 17 | 18 | 19 | class RefinerPVTInChannels4(nn.Module): 20 | def __init__(self, in_channels=3+1): 21 | super(RefinerPVTInChannels4, self).__init__() 22 | self.config = Config() 23 | self.epoch = 1 24 | self.bb = build_backbone(self.config.bb, params_settings='in_channels=4') 25 | 26 | lateral_channels_in_collection = { 27 | 'vgg16': [512, 256, 128, 64], 'vgg16bn': [512, 256, 128, 64], 'resnet50': [1024, 512, 256, 64], 28 | 'pvt_v2_b2': [512, 320, 128, 64], 'pvt_v2_b5': [512, 320, 128, 64], 29 | 'swin_v1_b': [1024, 512, 256, 128], 'swin_v1_l': [1536, 768, 384, 192], 30 | } 31 | channels = lateral_channels_in_collection[self.config.bb] 32 | self.squeeze_module = BasicDecBlk(channels[0], channels[0]) 33 | 34 | self.decoder = Decoder(channels) 35 | 36 | if 0: 37 | for key, value in self.named_parameters(): 38 | if 'bb.' in key: 39 | value.requires_grad = False 40 | 41 | def forward(self, x): 42 | if isinstance(x, list): 43 | x = torch.cat(x, dim=1) 44 | ########## Encoder ########## 45 | if self.config.bb in ['vgg16', 'vgg16bn', 'resnet50']: 46 | x1 = self.bb.conv1(x) 47 | x2 = self.bb.conv2(x1) 48 | x3 = self.bb.conv3(x2) 49 | x4 = self.bb.conv4(x3) 50 | else: 51 | x1, x2, x3, x4 = self.bb(x) 52 | 53 | x4 = self.squeeze_module(x4) 54 | 55 | ########## Decoder ########## 56 | 57 | features = [x, x1, x2, x3, x4] 58 | scaled_preds = self.decoder(features) 59 | 60 | return scaled_preds 61 | 62 | 63 | class Refiner(nn.Module): 64 | def __init__(self, in_channels=3+1): 65 | super(Refiner, self).__init__() 66 | self.config = Config() 67 | self.epoch = 1 68 | self.stem_layer = StemLayer(in_channels=in_channels, inter_channels=48, out_channels=3) 69 | self.bb = build_backbone(self.config.bb) 70 | 71 | lateral_channels_in_collection = { 72 | 'vgg16': [512, 256, 128, 64], 'vgg16bn': [512, 256, 128, 64], 'resnet50': [1024, 512, 256, 64], 73 | 'pvt_v2_b2': [512, 320, 128, 64], 'pvt_v2_b5': [512, 320, 128, 64], 74 | 'swin_v1_b': [1024, 512, 256, 128], 'swin_v1_l': [1536, 768, 384, 192], 75 | } 76 | channels = lateral_channels_in_collection[self.config.bb] 77 | self.squeeze_module = BasicDecBlk(channels[0], channels[0]) 78 | 79 | self.decoder = Decoder(channels) 80 | 81 | if 0: 82 | for key, value in self.named_parameters(): 83 | if 'bb.' in key: 84 | value.requires_grad = False 85 | 86 | def forward(self, x): 87 | if isinstance(x, list): 88 | x = torch.cat(x, dim=1) 89 | x = self.stem_layer(x) 90 | ########## Encoder ########## 91 | if self.config.bb in ['vgg16', 'vgg16bn', 'resnet50']: 92 | x1 = self.bb.conv1(x) 93 | x2 = self.bb.conv2(x1) 94 | x3 = self.bb.conv3(x2) 95 | x4 = self.bb.conv4(x3) 96 | else: 97 | x1, x2, x3, x4 = self.bb(x) 98 | 99 | x4 = self.squeeze_module(x4) 100 | 101 | ########## Decoder ########## 102 | 103 | features = [x, x1, x2, x3, x4] 104 | scaled_preds = self.decoder(features) 105 | 106 | return scaled_preds 107 | 108 | 109 | class Decoder(nn.Module): 110 | def __init__(self, channels): 111 | super(Decoder, self).__init__() 112 | self.config = Config() 113 | DecoderBlock = eval('BasicDecBlk') 114 | LateralBlock = eval('BasicLatBlk') 115 | 116 | self.decoder_block4 = DecoderBlock(channels[0], channels[1]) 117 | self.decoder_block3 = DecoderBlock(channels[1], channels[2]) 118 | self.decoder_block2 = DecoderBlock(channels[2], channels[3]) 119 | self.decoder_block1 = DecoderBlock(channels[3], channels[3]//2) 120 | 121 | self.lateral_block4 = LateralBlock(channels[1], channels[1]) 122 | self.lateral_block3 = LateralBlock(channels[2], channels[2]) 123 | self.lateral_block2 = LateralBlock(channels[3], channels[3]) 124 | 125 | if self.config.ms_supervision: 126 | self.conv_ms_spvn_4 = nn.Conv2d(channels[1], 1, 1, 1, 0) 127 | self.conv_ms_spvn_3 = nn.Conv2d(channels[2], 1, 1, 1, 0) 128 | self.conv_ms_spvn_2 = nn.Conv2d(channels[3], 1, 1, 1, 0) 129 | self.conv_out1 = nn.Sequential(nn.Conv2d(channels[3]//2, 1, 1, 1, 0)) 130 | 131 | def forward(self, features): 132 | x, x1, x2, x3, x4 = features 133 | outs = [] 134 | p4 = self.decoder_block4(x4) 135 | _p4 = F.interpolate(p4, size=x3.shape[2:], mode='bilinear', align_corners=True) 136 | _p3 = _p4 + self.lateral_block4(x3) 137 | 138 | p3 = self.decoder_block3(_p3) 139 | _p3 = F.interpolate(p3, size=x2.shape[2:], mode='bilinear', align_corners=True) 140 | _p2 = _p3 + self.lateral_block3(x2) 141 | 142 | p2 = self.decoder_block2(_p2) 143 | _p2 = F.interpolate(p2, size=x1.shape[2:], mode='bilinear', align_corners=True) 144 | _p1 = _p2 + self.lateral_block2(x1) 145 | 146 | _p1 = self.decoder_block1(_p1) 147 | _p1 = F.interpolate(_p1, size=x.shape[2:], mode='bilinear', align_corners=True) 148 | p1_out = self.conv_out1(_p1) 149 | 150 | if self.config.ms_supervision: 151 | outs.append(self.conv_ms_spvn_4(p4)) 152 | outs.append(self.conv_ms_spvn_3(p3)) 153 | outs.append(self.conv_ms_spvn_2(p2)) 154 | outs.append(p1_out) 155 | return outs 156 | 157 | 158 | class RefUNet(nn.Module): 159 | # Refinement 160 | def __init__(self, in_channels=3+1): 161 | super(RefUNet, self).__init__() 162 | self.encoder_1 = nn.Sequential( 163 | nn.Conv2d(in_channels, 64, 3, 1, 1), 164 | nn.Conv2d(64, 64, 3, 1, 1), 165 | nn.BatchNorm2d(64), 166 | nn.ReLU(inplace=True) 167 | ) 168 | 169 | self.encoder_2 = nn.Sequential( 170 | nn.MaxPool2d(2, 2, ceil_mode=True), 171 | nn.Conv2d(64, 64, 3, 1, 1), 172 | nn.BatchNorm2d(64), 173 | nn.ReLU(inplace=True) 174 | ) 175 | 176 | self.encoder_3 = nn.Sequential( 177 | nn.MaxPool2d(2, 2, ceil_mode=True), 178 | nn.Conv2d(64, 64, 3, 1, 1), 179 | nn.BatchNorm2d(64), 180 | nn.ReLU(inplace=True) 181 | ) 182 | 183 | self.encoder_4 = nn.Sequential( 184 | nn.MaxPool2d(2, 2, ceil_mode=True), 185 | nn.Conv2d(64, 64, 3, 1, 1), 186 | nn.BatchNorm2d(64), 187 | nn.ReLU(inplace=True) 188 | ) 189 | 190 | self.pool4 = nn.MaxPool2d(2, 2, ceil_mode=True) 191 | ##### 192 | self.decoder_5 = nn.Sequential( 193 | nn.Conv2d(64, 64, 3, 1, 1), 194 | nn.BatchNorm2d(64), 195 | nn.ReLU(inplace=True) 196 | ) 197 | ##### 198 | self.decoder_4 = nn.Sequential( 199 | nn.Conv2d(128, 64, 3, 1, 1), 200 | nn.BatchNorm2d(64), 201 | nn.ReLU(inplace=True) 202 | ) 203 | 204 | self.decoder_3 = nn.Sequential( 205 | nn.Conv2d(128, 64, 3, 1, 1), 206 | nn.BatchNorm2d(64), 207 | nn.ReLU(inplace=True) 208 | ) 209 | 210 | self.decoder_2 = nn.Sequential( 211 | nn.Conv2d(128, 64, 3, 1, 1), 212 | nn.BatchNorm2d(64), 213 | nn.ReLU(inplace=True) 214 | ) 215 | 216 | self.decoder_1 = nn.Sequential( 217 | nn.Conv2d(128, 64, 3, 1, 1), 218 | nn.BatchNorm2d(64), 219 | nn.ReLU(inplace=True) 220 | ) 221 | 222 | self.conv_d0 = nn.Conv2d(64, 1, 3, 1, 1) 223 | 224 | self.upscore2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 225 | 226 | def forward(self, x): 227 | outs = [] 228 | if isinstance(x, list): 229 | x = torch.cat(x, dim=1) 230 | hx = x 231 | 232 | hx1 = self.encoder_1(hx) 233 | hx2 = self.encoder_2(hx1) 234 | hx3 = self.encoder_3(hx2) 235 | hx4 = self.encoder_4(hx3) 236 | 237 | hx = self.decoder_5(self.pool4(hx4)) 238 | hx = torch.cat((self.upscore2(hx), hx4), 1) 239 | 240 | d4 = self.decoder_4(hx) 241 | hx = torch.cat((self.upscore2(d4), hx3), 1) 242 | 243 | d3 = self.decoder_3(hx) 244 | hx = torch.cat((self.upscore2(d3), hx2), 1) 245 | 246 | d2 = self.decoder_2(hx) 247 | hx = torch.cat((self.upscore2(d2), hx1), 1) 248 | 249 | d1 = self.decoder_1(hx) 250 | 251 | x = self.conv_d0(d1) 252 | outs.append(x) 253 | return outs 254 | -------------------------------------------------------------------------------- /birefnet_old/models/refinement/stem_layer.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from birefnet_old.models.modules.utils import build_act_layer, build_norm_layer 3 | 4 | 5 | class StemLayer(nn.Module): 6 | r""" Stem layer of InternImage 7 | Args: 8 | in_channels (int): number of input channels 9 | out_channels (int): number of output channels 10 | act_layer (str): activation layer 11 | norm_layer (str): normalization layer 12 | """ 13 | 14 | def __init__(self, 15 | in_channels=3+1, 16 | inter_channels=48, 17 | out_channels=96, 18 | act_layer='GELU', 19 | norm_layer='BN'): 20 | super().__init__() 21 | self.conv1 = nn.Conv2d(in_channels, 22 | inter_channels, 23 | kernel_size=3, 24 | stride=1, 25 | padding=1) 26 | self.norm1 = build_norm_layer( 27 | inter_channels, norm_layer, 'channels_first', 'channels_first' 28 | ) 29 | self.act = build_act_layer(act_layer) 30 | self.conv2 = nn.Conv2d(inter_channels, 31 | out_channels, 32 | kernel_size=3, 33 | stride=1, 34 | padding=1) 35 | self.norm2 = build_norm_layer( 36 | out_channels, norm_layer, 'channels_first', 'channels_first' 37 | ) 38 | 39 | def forward(self, x): 40 | x = self.conv1(x) 41 | x = self.norm1(x) 42 | x = self.act(x) 43 | x = self.conv2(x) 44 | x = self.norm2(x) 45 | return x 46 | -------------------------------------------------------------------------------- /birefnet_old/preproc.py: -------------------------------------------------------------------------------- 1 | from PIL import Image, ImageEnhance 2 | import random 3 | import numpy as np 4 | import random 5 | 6 | 7 | def preproc(image, label, preproc_methods=['flip']): 8 | if 'flip' in preproc_methods: 9 | image, label = cv_random_flip(image, label) 10 | if 'crop' in preproc_methods: 11 | image, label = random_crop(image, label) 12 | if 'rotate' in preproc_methods: 13 | image, label = random_rotate(image, label) 14 | if 'enhance' in preproc_methods: 15 | image = color_enhance(image) 16 | if 'pepper' in preproc_methods: 17 | label = random_pepper(label) 18 | return image, label 19 | 20 | 21 | def cv_random_flip(img, label): 22 | if random.random() > 0.5: 23 | img = img.transpose(Image.FLIP_LEFT_RIGHT) 24 | label = label.transpose(Image.FLIP_LEFT_RIGHT) 25 | return img, label 26 | 27 | 28 | def random_crop(image, label): 29 | border = 30 30 | image_width = image.size[0] 31 | image_height = image.size[1] 32 | border = int(min(image_width, image_height) * 0.1) 33 | crop_win_width = np.random.randint(image_width - border, image_width) 34 | crop_win_height = np.random.randint(image_height - border, image_height) 35 | random_region = ( 36 | (image_width - crop_win_width) >> 1, (image_height - crop_win_height) >> 1, (image_width + crop_win_width) >> 1, 37 | (image_height + crop_win_height) >> 1) 38 | return image.crop(random_region), label.crop(random_region) 39 | 40 | 41 | def random_rotate(image, label, angle=15): 42 | mode = Image.BICUBIC 43 | if random.random() > 0.8: 44 | random_angle = np.random.randint(-angle, angle) 45 | image = image.rotate(random_angle, mode) 46 | label = label.rotate(random_angle, mode) 47 | return image, label 48 | 49 | 50 | def color_enhance(image): 51 | bright_intensity = random.randint(5, 15) / 10.0 52 | image = ImageEnhance.Brightness(image).enhance(bright_intensity) 53 | contrast_intensity = random.randint(5, 15) / 10.0 54 | image = ImageEnhance.Contrast(image).enhance(contrast_intensity) 55 | color_intensity = random.randint(0, 20) / 10.0 56 | image = ImageEnhance.Color(image).enhance(color_intensity) 57 | sharp_intensity = random.randint(0, 30) / 10.0 58 | image = ImageEnhance.Sharpness(image).enhance(sharp_intensity) 59 | return image 60 | 61 | 62 | def random_gaussian(image, mean=0.1, sigma=0.35): 63 | def gaussianNoisy(im, mean=mean, sigma=sigma): 64 | for _i in range(len(im)): 65 | im[_i] += random.gauss(mean, sigma) 66 | return im 67 | 68 | img = np.asarray(image) 69 | width, height = img.shape 70 | img = gaussianNoisy(img[:].flatten(), mean, sigma) 71 | img = img.reshape([width, height]) 72 | return Image.fromarray(np.uint8(img)) 73 | 74 | 75 | def random_pepper(img, N=0.0015): 76 | img = np.array(img) 77 | noiseNum = int(N * img.shape[0] * img.shape[1]) 78 | for i in range(noiseNum): 79 | randX = random.randint(0, img.shape[0] - 1) 80 | randY = random.randint(0, img.shape[1] - 1) 81 | if random.randint(0, 1) == 0: 82 | img[randX, randY] = 0 83 | else: 84 | img[randX, randY] = 255 85 | return Image.fromarray(img) 86 | -------------------------------------------------------------------------------- /birefnet_old/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import torch 4 | from torchvision import transforms 5 | import numpy as np 6 | import random 7 | import cv2 8 | from PIL import Image 9 | 10 | 11 | def path_to_image(path, size=(1024, 1024), color_type=['rgb', 'gray'][0]): 12 | if color_type.lower() == 'rgb': 13 | image = cv2.imread(path) 14 | elif color_type.lower() == 'gray': 15 | image = cv2.imread(path, cv2.IMREAD_GRAYSCALE) 16 | else: 17 | print('Select the color_type to return, either to RGB or gray image.') 18 | return 19 | if size: 20 | image = cv2.resize(image, size, interpolation=cv2.INTER_LINEAR) 21 | if color_type.lower() == 'rgb': 22 | image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)).convert('RGB') 23 | else: 24 | image = Image.fromarray(image).convert('L') 25 | return image 26 | 27 | 28 | 29 | def check_state_dict(state_dict, unwanted_prefix='_orig_mod.'): 30 | for k, v in list(state_dict.items()): 31 | if k.startswith(unwanted_prefix): 32 | state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) 33 | return state_dict 34 | 35 | 36 | def generate_smoothed_gt(gts): 37 | epsilon = 0.001 38 | new_gts = (1-epsilon)*gts+epsilon/2 39 | return new_gts 40 | 41 | 42 | class Logger(): 43 | def __init__(self, path="log.txt"): 44 | self.logger = logging.getLogger('BiRefNet') 45 | self.file_handler = logging.FileHandler(path, "w") 46 | self.stdout_handler = logging.StreamHandler() 47 | self.stdout_handler.setFormatter(logging.Formatter('%(asctime)s %(levelname)s %(message)s')) 48 | self.file_handler.setFormatter(logging.Formatter('%(asctime)s %(levelname)s %(message)s')) 49 | self.logger.addHandler(self.file_handler) 50 | self.logger.addHandler(self.stdout_handler) 51 | self.logger.setLevel(logging.INFO) 52 | self.logger.propagate = False 53 | 54 | def info(self, txt): 55 | self.logger.info(txt) 56 | 57 | def close(self): 58 | self.file_handler.close() 59 | self.stdout_handler.close() 60 | 61 | 62 | class AverageMeter(object): 63 | """Computes and stores the average and current value""" 64 | def __init__(self): 65 | self.reset() 66 | 67 | def reset(self): 68 | self.val = 0.0 69 | self.avg = 0.0 70 | self.sum = 0.0 71 | self.count = 0.0 72 | 73 | def update(self, val, n=1): 74 | self.val = val 75 | self.sum += val * n 76 | self.count += n 77 | self.avg = self.sum / self.count 78 | 79 | 80 | def save_checkpoint(state, path, filename="latest.pth"): 81 | torch.save(state, os.path.join(path, filename)) 82 | 83 | 84 | def save_tensor_img(tenor_im, path): 85 | im = tenor_im.cpu().clone() 86 | im = im.squeeze(0) 87 | tensor2pil = transforms.ToPILImage() 88 | im = tensor2pil(im) 89 | im.save(path) 90 | 91 | 92 | def set_seed(seed): 93 | torch.manual_seed(seed) 94 | torch.cuda.manual_seed_all(seed) 95 | np.random.seed(seed) 96 | random.seed(seed) 97 | torch.backends.cudnn.deterministic = True -------------------------------------------------------------------------------- /doc/base.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lldacing/ComfyUI_BiRefNet_ll/5443a2aa16cfbd98bb2f7dcc8bdcb70439e08529/doc/base.png -------------------------------------------------------------------------------- /doc/video.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lldacing/ComfyUI_BiRefNet_ll/5443a2aa16cfbd98bb2f7dcc8bdcb70439e08529/doc/video.gif -------------------------------------------------------------------------------- /example/workflow_base.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lldacing/ComfyUI_BiRefNet_ll/5443a2aa16cfbd98bb2f7dcc8bdcb70439e08529/example/workflow_base.png -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "comfyui_birefnet_ll" 3 | description = "Sync with version of BiRefNet. NODES:AutoDownloadBiRefNetModel, LoadRembgByBiRefNetModel, RembgByBiRefNet, RembgByBiRefNetAdvanced, GetMaskByBiRefNet, BlurFusionForegroundEstimation." 4 | version = "1.1.4" 5 | license = {file = "LICENSE"} 6 | dependencies = ["numpy", "opencv-python", "timm"] 7 | 8 | [project.urls] 9 | Repository = "https://github.com/lldacing/ComfyUI_BiRefNet_ll" 10 | # Used by Comfy Registry https://comfyregistry.org 11 | 12 | [tool.comfy] 13 | PublisherId = "lldacing" 14 | DisplayName = "ComfyUI_BiRefNet_ll" 15 | Icon = "" 16 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | opencv-python 3 | timm -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from PIL import Image 4 | import torchvision.transforms.v2 as T 5 | import cv2 6 | 7 | 8 | def tensor_to_pil(image): 9 | return Image.fromarray(np.clip(255. * image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8)) 10 | 11 | 12 | def pil_to_tensor(image): 13 | return torch.from_numpy(np.array(image).astype(np.float32) / 255.0).unsqueeze(0) 14 | 15 | 16 | def refine_foreground(image_tensor, mask_tensor, r1=90, r2=7): 17 | if r1 % 2 == 0: 18 | r1 += 1 19 | 20 | if r2 % 2 == 0: 21 | r2 += 1 22 | 23 | return FB_blur_fusion_foreground_estimator_2(image_tensor, mask_tensor, r1=r1, r2=r2)[0] 24 | 25 | 26 | def FB_blur_fusion_foreground_estimator_2(image_tensor, alpha_tensor, r1=90, r2=7): 27 | # https://github.com/Photoroom/fast-foreground-estimation 28 | if alpha_tensor.dim() == 3: 29 | alpha_tensor = alpha_tensor.unsqueeze(0) # Add batch 30 | F, blur_B = FB_blur_fusion_foreground_estimator(image_tensor, image_tensor, image_tensor, alpha_tensor, r=r1) 31 | return FB_blur_fusion_foreground_estimator(image_tensor, F, blur_B, alpha_tensor, r=r2) 32 | 33 | 34 | def FB_blur_fusion_foreground_estimator(image_tensor, F_tensor, B_tensor, alpha_tensor, r=90): 35 | if image_tensor.dim() == 3: 36 | image_tensor = image_tensor.unsqueeze(0) 37 | 38 | blurred_alpha = T.functional.gaussian_blur(alpha_tensor, r) 39 | 40 | blurred_FA = T.functional.gaussian_blur(F_tensor * alpha_tensor, r) 41 | blurred_F = blurred_FA / (blurred_alpha + 1e-5) 42 | 43 | blurred_B1A = T.functional.gaussian_blur(B_tensor * (1 - alpha_tensor), r) 44 | blurred_B = blurred_B1A / ((1 - blurred_alpha) + 1e-5) 45 | F_tensor = blurred_F + alpha_tensor * (image_tensor - alpha_tensor * blurred_F - (1 - alpha_tensor) * blurred_B) 46 | F_tensor = torch.clamp(F_tensor, 0, 1) 47 | return F_tensor, blurred_B 48 | 49 | 50 | ### copied and modified image_proc.py 51 | def refine_foreground_pil(image, mask, r1=90, r2=6): 52 | if mask.size != image.size: 53 | mask = mask.resize(image.size) 54 | image = np.array(image) / 255.0 55 | mask = np.array(mask) / 255.0 56 | estimated_foreground = FB_blur_fusion_foreground_estimator_pil_2(image, mask, r1=r1, r2=r2) 57 | image_masked = Image.fromarray((estimated_foreground * 255.0).astype(np.uint8)) 58 | return image_masked 59 | 60 | 61 | def FB_blur_fusion_foreground_estimator_pil_2(image, alpha, r1=90, r2=6): 62 | # Thanks to the source: https://github.com/Photoroom/fast-foreground-estimation 63 | alpha = alpha[:, :, None] 64 | F, blur_B = FB_blur_fusion_foreground_estimator_pil( 65 | image, image, image, alpha, r=r1) 66 | return FB_blur_fusion_foreground_estimator_pil(image, F, blur_B, alpha, r=r2)[0] 67 | 68 | 69 | def FB_blur_fusion_foreground_estimator_pil(image, F, B, alpha, r=90): 70 | if isinstance(image, Image.Image): 71 | image = np.array(image) / 255.0 72 | blurred_alpha = cv2.blur(alpha, (r, r))[:, :, None] 73 | 74 | blurred_FA = cv2.blur(F * alpha, (r, r)) 75 | blurred_F = blurred_FA / (blurred_alpha + 1e-5) 76 | 77 | blurred_B1A = cv2.blur(B * (1 - alpha), (r, r)) 78 | blurred_B = blurred_B1A / ((1 - blurred_alpha) + 1e-5) 79 | F = blurred_F + alpha * (image - alpha * blurred_F - (1 - alpha) * blurred_B) 80 | F = np.clip(F, 0, 1) 81 | return F, blurred_B 82 | 83 | 84 | def apply_mask_to_image(image, mask): 85 | """ 86 | Apply a mask to an image and set non-masked parts to transparent. 87 | 88 | Args: 89 | image (torch.Tensor): Image tensor of shape (h, w, c) or (1, h, w, c). 90 | mask (torch.Tensor): Mask tensor of shape (1, 1, h, w) or (h, w). 91 | 92 | Returns: 93 | torch.Tensor: Masked image tensor of shape (h, w, c+1) with transparency. 94 | """ 95 | # 判断 image 的形状 96 | if image.dim() == 3: 97 | pass 98 | elif image.dim() == 4: 99 | image = image.squeeze(0) 100 | else: 101 | raise ValueError("Image should be of shape (h, w, c) or (1, h, w, c).") 102 | 103 | h, w, c = image.shape 104 | # 判断 mask 的形状 105 | if mask.dim() == 4: 106 | mask = mask.squeeze(0).squeeze(0) # 去掉前2个维度 (h,w) 107 | elif mask.dim() == 3: 108 | mask = mask.squeeze(0) 109 | elif mask.dim() == 2: 110 | pass 111 | else: 112 | raise ValueError("Mask should be of shape (1, 1, h, w) or (h, w).") 113 | 114 | assert mask.shape == (h, w), "Mask shape does not match image shape." 115 | 116 | # 将 mask 扩展到与 image 相同的通道数 117 | image_mask = mask.unsqueeze(-1).expand(h, w, c) 118 | 119 | # 应用遮罩,黑色部分是0,相乘后白色1的部分会被保留,其它部分变为了黑色 120 | masked_image = image * image_mask 121 | 122 | # 遮罩的黑白当做alpha通道的不透明度,黑色是0表示透明,白色是1表示不透明 123 | alpha = mask 124 | # alpha通道拼接到原图像的RGB中 125 | masked_image_with_alpha = torch.cat((masked_image[:, :, :3], alpha.unsqueeze(2)), dim=2) 126 | 127 | return masked_image_with_alpha.unsqueeze(0) 128 | 129 | 130 | def normalize_mask(mask_tensor): 131 | max_val = torch.max(mask_tensor) 132 | min_val = torch.min(mask_tensor) 133 | 134 | if max_val == min_val: 135 | return mask_tensor 136 | 137 | normalized_mask = (mask_tensor - min_val) / (max_val - min_val) 138 | 139 | return normalized_mask 140 | 141 | def add_mask_as_alpha(image, mask): 142 | """ 143 | 将 (b, h, w) 形状的 mask 添加为 (b, h, w, 3) 形状的 image 的第 4 个通道(alpha 通道)。 144 | """ 145 | # 检查输入形状 146 | assert image.dim() == 4 and image.size(-1) == 3, "The shape of image should be (b, h, w, 3)." 147 | assert mask.dim() == 3, "The shape of mask should be (b, h, w)" 148 | assert image.size(0) == mask.size(0) and image.size(1) == mask.size(1) and image.size(2) == mask.size(2), "The batch, height, and width dimensions of the image and mask must be consistent" 149 | 150 | # 将 mask 扩展为 (b, h, w, 1) 151 | mask = mask[..., None] 152 | 153 | # 不做点乘,可能会有边缘轮廓线 154 | # image = image * mask 155 | # 将 image 和 mask 拼接为 (b, h, w, 4) 156 | image_with_alpha = torch.cat([image, mask], dim=-1) 157 | 158 | return image_with_alpha 159 | 160 | def filter_mask(mask, threshold=4e-3): 161 | mask_binary = mask > threshold 162 | filtered_mask = mask * mask_binary 163 | return filtered_mask 164 | --------------------------------------------------------------------------------