├── .gitattributes ├── .github └── workflows │ └── publish.yml ├── .gitignore ├── LICENSE ├── README.md ├── __init__.py ├── depth_estimation_node.py ├── images ├── depth-estimation-icon.png ├── depth-estimation-icon.svg ├── depth-estimation-logo-with-smaller-z.svg ├── depth-estimation-node.png └── depth_map_generator_showcase.jpg ├── publish.yaml ├── pyproject.toml ├── requirements.txt └── workflows ├── Depth_Map_Generator_V1.json └── Depth_Map_Generator_V1.png /.gitattributes: -------------------------------------------------------------------------------- 1 | # Handle line endings (CRLF for Windows, LF for Unix) 2 | * text=auto 3 | 4 | # Ensure Python files use LF for Unix compatibility 5 | *.py text eol=lf 6 | 7 | # Treat Markdown and YAML files as text 8 | *.md text 9 | *.yaml text 10 | *.yml text 11 | 12 | # Handle images and binaries as binary 13 | *.png binary 14 | *.jpg binary 15 | *.zip binary 16 | 17 | # Handle compiled files as binary (just in case you have any) 18 | *.exe binary 19 | *.dll binary 20 | *.so binary 21 | 22 | # Ensure JSON files use LF line endings 23 | *.json text eol=lf 24 | 25 | # Handle Node.js package lock file (if relevant) 26 | package-lock.json merge=union 27 | 28 | -------------------------------------------------------------------------------- /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish to Comfy registry 2 | on: 3 | workflow_dispatch: 4 | push: 5 | branches: 6 | - main 7 | - master 8 | paths: 9 | - "pyproject.toml" 10 | 11 | permissions: 12 | issues: write 13 | 14 | jobs: 15 | publish-node: 16 | name: Publish Custom Node to registry 17 | runs-on: ubuntu-latest 18 | if: ${{ github.repository_owner == 'Limbicnation' }} 19 | steps: 20 | - name: Check out code 21 | uses: actions/checkout@v4 22 | - name: Publish Custom Node 23 | uses: Comfy-Org/publish-node-action@v1 24 | with: 25 | ## Add your own personal access token to your Github Repository secrets and reference it here. 26 | personal_access_token: ${{ secrets.REGISTRY_ACCESS_TOKEN }} 27 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Depth Estimation Icon ComfyUI Depth Estimation Node 2 | 3 |
4 | Depth Estimation Logo 5 |
6 | 7 | A robust custom depth estimation node for ComfyUI using Depth-Anything models to generate depth maps from images. 8 | 9 | ## Features 10 | - Multiple model options: 11 | - Depth-Anything-Small 12 | - Depth-Anything-Base 13 | - Depth-Anything-Large 14 | - Depth-Anything-V2-Small 15 | - Depth-Anything-V2-Base 16 | - Post-processing options: 17 | - Gaussian blur (adjustable radius) 18 | - Median filtering (configurable size) 19 | - Automatic contrast enhancement 20 | - Gamma correction 21 | - Advanced options: 22 | - Force CPU processing for compatibility 23 | - Force model reload for troubleshooting 24 | 25 | ## Installation 26 | 27 | ### Method 1: Install via ComfyUI Manager (Recommended) 28 | 1. Open ComfyUI and install the ComfyUI Manager if you haven't already 29 | 2. Go to the Manager tab 30 | 3. Search for "Depth Estimation" and install the node 31 | 32 | ### Method 2: Manual Installation 33 | 1. Navigate to your ComfyUI custom nodes directory: 34 | ```bash 35 | cd ComfyUI/custom_nodes/ 36 | ``` 37 | 38 | 2. Clone the repository: 39 | ```bash 40 | git clone https://github.com/Limbicnation/ComfyUIDepthEstimation.git 41 | ``` 42 | 43 | 3. Install the required dependencies: 44 | ```bash 45 | cd ComfyUIDepthEstimation 46 | pip install -r requirements.txt 47 | ``` 48 | 49 | 4. Restart ComfyUI to load the new custom node. 50 | 51 | > **Note**: On first use, the node will download the selected model from Hugging Face. This may take some time depending on your internet connection. 52 | 53 | ## Usage 54 | 55 |
56 | Depth Estimation Node Preview 57 |
58 |
59 | Depth Map Generator Showcase 60 |
61 | 62 | ### Node Parameters 63 | 64 | #### Required Parameters 65 | - **image**: Input image (IMAGE type) 66 | - **model_name**: Select from available Depth-Anything models 67 | - **blur_radius**: Gaussian blur radius (0.0 - 10.0, default: 2.0) 68 | - **median_size**: Median filter size (3, 5, 7, 9, 11) 69 | - **apply_auto_contrast**: Enable automatic contrast enhancement 70 | - **apply_gamma**: Enable gamma correction 71 | 72 | #### Optional Parameters 73 | - **force_reload**: Force the model to reload (useful for troubleshooting) 74 | - **force_cpu**: Use CPU for processing instead of GPU (slower but more compatible) 75 | 76 | ### Example Usage 77 | 1. Add the `Depth Estimation` node to your ComfyUI workflow 78 | 2. Connect an image source to the node's image input 79 | 3. Configure the parameters: 80 | - Select a model (e.g., "Depth-Anything-V2-Small" is fastest) 81 | - Adjust blur_radius (0-10) for depth map smoothing 82 | - Choose median_size (3-11) for noise reduction 83 | - Toggle auto_contrast and gamma correction as needed 84 | 4. Connect the output to a Preview Image node or other image processing nodes 85 | 86 | ## Model Information 87 | 88 | | Model Name | Quality | VRAM Usage | Speed | 89 | |------------|---------|------------|-------| 90 | | Depth-Anything-V2-Small | Good | ~1.5 GB | Fast | 91 | | Depth-Anything-Small | Good | ~1.5 GB | Fast | 92 | | Depth-Anything-V2-Base | Better | ~2.5 GB | Medium | 93 | | Depth-Anything-Base | Better | ~2.5 GB | Medium | 94 | | Depth-Anything-Large | Best | ~4.0 GB | Slow | 95 | 96 | ## Troubleshooting Guide 97 | 98 | ### Common Issues and Solutions 99 | 100 | #### Model Download Issues 101 | - **Error**: "Failed to load model" or "Model not found" 102 | - **Solution**: 103 | 1. Check your internet connection 104 | 2. Try authenticating with Hugging Face: 105 | ```bash 106 | pip install huggingface_hub 107 | huggingface-cli login 108 | ``` 109 | 3. Try a different model (e.g., switch to Depth-Anything-V2-Small) 110 | 4. Check the ComfyUI console for detailed error messages 111 | 112 | #### CUDA Out of Memory Errors 113 | - **Error**: "CUDA out of memory" or node shows red error image 114 | - **Solution**: 115 | 1. Try a smaller model (Depth-Anything-V2-Small uses the least memory) 116 | 2. Enable the `force_cpu` option (slower but uses less VRAM) 117 | 3. Reduce the size of your input image 118 | 4. Close other VRAM-intensive applications 119 | 120 | #### Node Not Appearing in ComfyUI 121 | - **Solution**: 122 | 1. Check your ComfyUI console for error messages 123 | 2. Verify that all dependencies are installed: 124 | ```bash 125 | pip install transformers>=4.20.0 Pillow>=9.1.0 numpy>=1.23.0 timm>=0.6.12 126 | ``` 127 | 3. Try restarting ComfyUI 128 | 4. Check that the node files are in the correct directory 129 | 130 | #### Node Returns Original Image or Black Image 131 | - **Solution**: 132 | 1. Try enabling the `force_reload` option 133 | 2. Check the ComfyUI console for error messages 134 | 3. Try using a different model 135 | 4. Make sure your input image is valid (not corrupted or empty) 136 | 5. Try restarting ComfyUI 137 | 138 | #### Slow Performance 139 | - **Solution**: 140 | 1. Use a smaller model (Depth-Anything-V2-Small is fastest) 141 | 2. Reduce input image size 142 | 3. If using CPU mode, consider using GPU if available 143 | 4. Close other applications that might be using GPU resources 144 | 145 | ### Where to Get Help 146 | - Create an issue on the [GitHub repository](https://github.com/Limbicnation/ComfyUIDepthEstimation/issues) 147 | - Check the ComfyUI console for detailed error messages 148 | - Visit the ComfyUI Discord for community support 149 | 150 | ## License 151 | 152 | This project is licensed under the Apache License. 153 | 154 | --- -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | ComfyUI Depth Estimation Node 3 | A custom node for depth map estimation using Depth-Anything models. 4 | """ 5 | 6 | import os 7 | import logging 8 | import importlib.util 9 | 10 | # Setup logging 11 | logging.basicConfig(level=logging.INFO) 12 | logger = logging.getLogger("DepthEstimation") 13 | 14 | # Version info 15 | __version__ = "1.1.1" 16 | 17 | # Node class mappings - will be populated based on dependency checks 18 | NODE_CLASS_MAPPINGS = {} 19 | NODE_DISPLAY_NAME_MAPPINGS = {} 20 | 21 | # Web extension info for ComfyUI 22 | WEB_DIRECTORY = "./js" 23 | 24 | # Graceful dependency checking 25 | required_dependencies = { 26 | "torch": "2.0.0", 27 | "transformers": "4.20.0", 28 | "numpy": "1.23.0", 29 | "PIL": "9.2.0", # Pillow is imported as PIL 30 | "timm": "0.6.12", 31 | "huggingface_hub": "0.16.0" 32 | } 33 | 34 | missing_dependencies = [] 35 | 36 | # Check each dependency 37 | for module_name, min_version in required_dependencies.items(): 38 | try: 39 | if module_name == "PIL": 40 | # Special case for Pillow/PIL 41 | import PIL 42 | module_version = PIL.__version__ 43 | else: 44 | module = __import__(module_name) 45 | module_version = getattr(module, "__version__", "unknown") 46 | 47 | logger.info(f"Found {module_name} version {module_version}") 48 | except ImportError: 49 | missing_dependencies.append(f"{module_name}>={min_version}") 50 | logger.warning(f"Missing required dependency: {module_name}>={min_version}") 51 | 52 | if missing_dependencies: 53 | # Create placeholder node with dependency error 54 | class DependencyErrorNode: 55 | """Placeholder node that shows dependency installation instructions.""" 56 | 57 | @classmethod 58 | def INPUT_TYPES(cls): 59 | return {"required": {}} 60 | 61 | RETURN_TYPES = ("STRING",) 62 | FUNCTION = "error_message" 63 | CATEGORY = "depth" 64 | 65 | def error_message(self): 66 | missing = ", ".join(missing_dependencies) 67 | message = f"Dependencies missing: {missing}. Please install with: pip install {' '.join(missing_dependencies)}" 68 | print(f"DepthEstimation Node Error: {message}") 69 | return (message,) 70 | 71 | # Register the error node instead of the real node 72 | NODE_CLASS_MAPPINGS = { 73 | "DepthEstimationNode": DependencyErrorNode 74 | } 75 | 76 | NODE_DISPLAY_NAME_MAPPINGS = { 77 | "DepthEstimationNode": "Depth Estimation (Missing Dependencies)" 78 | } 79 | else: 80 | # All dependencies are available, try to import the actual node 81 | try: 82 | from .depth_estimation_node import DepthEstimationNode 83 | 84 | # Register the actual depth estimation node 85 | NODE_CLASS_MAPPINGS = { 86 | "DepthEstimationNode": DepthEstimationNode 87 | } 88 | 89 | NODE_DISPLAY_NAME_MAPPINGS = { 90 | "DepthEstimationNode": "Depth Estimation" 91 | } 92 | except Exception as e: 93 | # Capture any import errors that might occur with transformers 94 | logger.error(f"Error importing depth estimation node: {str(e)}") 95 | 96 | # Create a more specific error node 97 | class TransformersErrorNode: 98 | @classmethod 99 | def INPUT_TYPES(cls): 100 | return {"required": {}} 101 | 102 | RETURN_TYPES = ("STRING",) 103 | FUNCTION = "error_message" 104 | CATEGORY = "depth" 105 | 106 | def error_message(self): 107 | if "Descriptors cannot be created directly" in str(e): 108 | message = "Protobuf version conflict. Run: pip install protobuf==3.20.3" 109 | else: 110 | message = f"Error loading depth estimation: {str(e)}" 111 | return (message,) 112 | 113 | NODE_CLASS_MAPPINGS = { 114 | "DepthEstimationNode": TransformersErrorNode 115 | } 116 | 117 | NODE_DISPLAY_NAME_MAPPINGS = { 118 | "DepthEstimationNode": "Depth Estimation (Error)" 119 | } 120 | 121 | # Module exports 122 | __all__ = [ 123 | "NODE_CLASS_MAPPINGS", 124 | "NODE_DISPLAY_NAME_MAPPINGS", 125 | "__version__", 126 | "WEB_DIRECTORY" 127 | ] -------------------------------------------------------------------------------- /depth_estimation_node.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import traceback 5 | import time 6 | import requests 7 | import urllib.request 8 | import wget 9 | from pathlib import Path 10 | from transformers import pipeline 11 | from PIL import Image, ImageFilter, ImageOps, ImageDraw, ImageFont 12 | import folder_paths 13 | from comfy.model_management import get_torch_device, get_free_memory 14 | import gc 15 | import logging 16 | import torch.nn as nn 17 | import torch.nn.functional as F 18 | from typing import Tuple, List, Dict, Any, Optional, Union 19 | 20 | # Try to import timm (for vision transformers) 21 | try: 22 | import timm 23 | TIMM_AVAILABLE = True 24 | except ImportError: 25 | TIMM_AVAILABLE = False 26 | print("Warning: timm not available. Direct loading of Depth Anything models may not work.") 27 | 28 | # Setup logging 29 | logging.basicConfig(level=logging.INFO) 30 | logger = logging.getLogger("DepthEstimation") 31 | 32 | # Depth Anything V2 Implementation 33 | class DepthAnythingV2(nn.Module): 34 | """Direct implementation of Depth Anything V2 model""" 35 | def __init__(self, encoder='vits', features=64, out_channels=[48, 96, 192, 384]): 36 | super().__init__() 37 | self.encoder = encoder 38 | self.features = features 39 | self.out_channels = out_channels 40 | self.device = 'cuda' if torch.cuda.is_available() else 'cpu' 41 | 42 | # Create encoder based on specification 43 | if TIMM_AVAILABLE: 44 | if encoder == 'vits': 45 | self.backbone = timm.create_model('vit_small_patch16_224', pretrained=False) 46 | self.embed_dim = 384 47 | elif encoder == 'vitb': 48 | self.backbone = timm.create_model('vit_base_patch16_224', pretrained=False) 49 | self.embed_dim = 768 50 | elif encoder == 'vitl': 51 | self.backbone = timm.create_model('vit_large_patch16_224', pretrained=False) 52 | self.embed_dim = 1024 53 | else: # fallback to vits 54 | self.backbone = timm.create_model('vit_small_patch16_224', pretrained=False) 55 | self.embed_dim = 384 56 | 57 | # Implement the rest of the model architecture 58 | self.initialize_decoder() 59 | else: 60 | # Fallback if timm is not available 61 | from torchvision.models import resnet50 62 | self.backbone = resnet50(pretrained=False) 63 | self.embed_dim = 2048 64 | logger.warning("Using fallback ResNet50 model (timm not available)") 65 | 66 | def initialize_decoder(self): 67 | """Initialize the decoder layers""" 68 | self.neck = nn.Sequential( 69 | nn.Conv2d(self.embed_dim, self.features, 1, 1, 0), 70 | nn.Conv2d(self.features, self.features, 3, 1, 1), 71 | ) 72 | 73 | # Create decoders for each level 74 | self.decoders = nn.ModuleList([ 75 | self.create_decoder_level(self.features, self.out_channels[0]), 76 | self.create_decoder_level(self.out_channels[0], self.out_channels[1]), 77 | self.create_decoder_level(self.out_channels[1], self.out_channels[2]), 78 | self.create_decoder_level(self.out_channels[2], self.out_channels[3]) 79 | ]) 80 | 81 | # Final depth head 82 | self.depth_head = nn.Sequential( 83 | nn.Conv2d(self.out_channels[3], self.out_channels[3], 3, 1, 1), 84 | nn.BatchNorm2d(self.out_channels[3]), 85 | nn.ReLU(True), 86 | nn.Conv2d(self.out_channels[3], 1, 1) 87 | ) 88 | 89 | def create_decoder_level(self, in_channels, out_channels): 90 | """Create a decoder level""" 91 | return nn.Sequential( 92 | nn.Conv2d(in_channels, out_channels, 3, 1, 1), 93 | nn.BatchNorm2d(out_channels), 94 | nn.ReLU(True), 95 | nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 96 | ) 97 | 98 | def forward(self, x): 99 | """Forward pass of the model""" 100 | # For timm ViT models 101 | if hasattr(self.backbone, 'forward_features'): 102 | features = self.backbone.forward_features(x) 103 | 104 | # Reshape features based on model type 105 | if 'vit' in self.encoder: 106 | # Reshape transformer output to spatial features 107 | # Exact reshape depends on the model details 108 | h = w = int(features.shape[1]**0.5) 109 | features = features.reshape(-1, h, w, self.embed_dim).permute(0, 3, 1, 2) 110 | 111 | # Process through decoder 112 | x = self.neck(features) 113 | 114 | # Apply decoder stages 115 | for decoder in self.decoders: 116 | x = decoder(x) 117 | 118 | # Final depth prediction 119 | depth = self.depth_head(x) 120 | 121 | return depth 122 | else: 123 | # Fallback for ResNet 124 | x = self.backbone.conv1(x) 125 | x = self.backbone.bn1(x) 126 | x = self.backbone.relu(x) 127 | x = self.backbone.maxpool(x) 128 | 129 | x = self.backbone.layer1(x) 130 | x = self.backbone.layer2(x) 131 | x = self.backbone.layer3(x) 132 | x = self.backbone.layer4(x) 133 | 134 | # Process through simple decoder 135 | x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True) 136 | x = self.depth_head(x) 137 | 138 | return x 139 | 140 | def infer_image(self, image): 141 | """Process an image and return the depth map 142 | 143 | Args: 144 | image: A numpy image in BGR format (OpenCV) or RGB PIL Image 145 | 146 | Returns: 147 | depth: A numpy array containing the depth map 148 | """ 149 | # Convert input to tensor 150 | if isinstance(image, np.ndarray): 151 | # Convert BGR to RGB 152 | if image.shape[2] == 3: 153 | image = image[:, :, ::-1] 154 | # Normalize 155 | image = image.astype(np.float32) / 255.0 156 | # HWC to CHW 157 | image = image.transpose(2, 0, 1) 158 | # Add batch dimension 159 | image = torch.from_numpy(image).unsqueeze(0) 160 | elif isinstance(image, Image.Image): 161 | # Convert PIL image to numpy 162 | image = np.array(image).astype(np.float32) / 255.0 163 | # HWC to CHW 164 | image = image.transpose(2, 0, 1) 165 | # Add batch dimension 166 | image = torch.from_numpy(image).unsqueeze(0) 167 | 168 | # Move to device 169 | image = image.to(self.device) 170 | 171 | # Set model to eval mode 172 | self.eval() 173 | 174 | # Get prediction 175 | with torch.no_grad(): 176 | depth = self.forward(image) 177 | 178 | # Convert to numpy 179 | depth = depth.squeeze().cpu().numpy() 180 | 181 | return depth 182 | 183 | def __call__(self, image): 184 | """Compatible interface with the pipeline API""" 185 | if isinstance(image, Image.Image): 186 | # Convert to numpy for processing 187 | depth = self.infer_image(image) 188 | # Return in the format expected by the node 189 | return {"predicted_depth": torch.from_numpy(depth).unsqueeze(0)} 190 | else: 191 | # Already a tensor, process directly 192 | self.eval() 193 | with torch.no_grad(): 194 | depth = self.forward(image) 195 | return {"predicted_depth": depth} 196 | 197 | # Configure model paths 198 | if not hasattr(folder_paths, "models_dir"): 199 | folder_paths.models_dir = os.path.join(folder_paths.base_path, "models") 200 | 201 | # Register depth models path - support multiple possible directory structures 202 | DEPTH_DIR = "depth_anything" 203 | DEPTH_ANYTHING_DIR = "depthanything" 204 | 205 | # Check which directory structure exists 206 | possible_paths = [ 207 | os.path.join(folder_paths.models_dir, DEPTH_DIR), 208 | os.path.join(folder_paths.models_dir, DEPTH_ANYTHING_DIR), 209 | os.path.join(folder_paths.models_dir, DEPTH_ANYTHING_DIR, DEPTH_DIR), 210 | os.path.join(folder_paths.models_dir, "checkpoints", DEPTH_DIR), 211 | os.path.join(folder_paths.models_dir, "checkpoints", DEPTH_ANYTHING_DIR), 212 | ] 213 | 214 | # Filter to only paths that exist 215 | existing_paths = [p for p in possible_paths if os.path.exists(p)] 216 | if not existing_paths: 217 | # If none exists, create the default one 218 | existing_paths = [os.path.join(folder_paths.models_dir, DEPTH_DIR)] 219 | os.makedirs(existing_paths[0], exist_ok=True) 220 | logger.info(f"Created model directory: {existing_paths[0]}") 221 | 222 | # Log all found paths for debugging 223 | logger.info(f"Found depth model directories: {existing_paths}") 224 | 225 | # Register all possible paths for model loading 226 | folder_paths.folder_names_and_paths[DEPTH_DIR] = (existing_paths, folder_paths.supported_pt_extensions) 227 | 228 | # Set primary models directory to the first available path 229 | MODELS_DIR = existing_paths[0] 230 | logger.info(f"Using primary models directory: {MODELS_DIR}") 231 | 232 | # Set Hugging Face cache to the models directory to ensure models are saved there 233 | os.environ["TRANSFORMERS_CACHE"] = MODELS_DIR 234 | os.environ["HF_HOME"] = MODELS_DIR 235 | 236 | # Define model configurations for direct loading 237 | MODEL_CONFIGS = { 238 | 'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]}, 239 | 'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]}, 240 | 'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]}, 241 | 'vitg': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]} 242 | } 243 | 244 | # Define all models mentioned in the README with memory requirements 245 | DEPTH_MODELS = { 246 | "Depth-Anything-Small": { 247 | "path": "LiheYoung/depth-anything-small-hf", # Correct HF path for V1 248 | "vram_mb": 1500, 249 | "direct_url": "https://github.com/LiheYoung/Depth-Anything/releases/download/v1.0/depth_anything_vitb14.pt", 250 | "model_type": "v1", 251 | "encoder": "vitb" 252 | }, 253 | "Depth-Anything-Base": { 254 | "path": "LiheYoung/depth-anything-base-hf", # Correct HF path for V1 255 | "vram_mb": 2500, 256 | "direct_url": "https://github.com/LiheYoung/Depth-Anything/releases/download/v1.0/depth_anything_vitl14.pt", 257 | "model_type": "v1", 258 | "encoder": "vitl" 259 | }, 260 | "Depth-Anything-Large": { 261 | "path": "LiheYoung/depth-anything-large-hf", # Correct HF path for V1 262 | "vram_mb": 4000, 263 | "direct_url": "https://github.com/LiheYoung/Depth-Anything/releases/download/v1.0/depth_anything_vitl14.pt", 264 | "model_type": "v1", 265 | "encoder": "vitl" 266 | }, 267 | "Depth-Anything-V2-Small": { 268 | "path": "depth-anything/Depth-Anything-V2-Small-hf", # Updated corrected path as shown in example 269 | "vram_mb": 1500, 270 | "direct_url": "https://huggingface.co/depth-anything/Depth-Anything-V2-Small-hf/resolve/main/pytorch_model.bin", 271 | "model_type": "v2", 272 | "encoder": "vits", 273 | "config": MODEL_CONFIGS["vits"] 274 | }, 275 | "Depth-Anything-V2-Base": { 276 | "path": "depth-anything/Depth-Anything-V2-Base-hf", # Updated corrected path 277 | "vram_mb": 2500, 278 | "direct_url": "https://huggingface.co/depth-anything/Depth-Anything-V2-Base-hf/resolve/main/pytorch_model.bin", 279 | "model_type": "v2", 280 | "encoder": "vitb", 281 | "config": MODEL_CONFIGS["vitb"] 282 | }, 283 | # Add MiDaS models as dedicated options with direct download URLs 284 | "MiDaS-Small": { 285 | "path": "Intel/dpt-hybrid-midas", 286 | "vram_mb": 1000, 287 | "midas_type": "MiDaS_small", 288 | "direct_url": "https://github.com/intel-isl/MiDaS/releases/download/v2_1/midas_v21_small_256.pt" 289 | }, 290 | "MiDaS-Base": { 291 | "path": "Intel/dpt-hybrid-midas", 292 | "vram_mb": 1200, 293 | "midas_type": "DPT_Hybrid", 294 | "direct_url": "https://github.com/intel-isl/MiDaS/releases/download/v3/dpt_hybrid-midas-501f0c75.pt" 295 | } 296 | } 297 | 298 | class MiDaSWrapper: 299 | def __init__(self, model_type, device): 300 | self.device = device 301 | 302 | try: 303 | # Import required libraries 304 | import torch.nn.functional as F 305 | 306 | # Use a more reliable approach to loading MiDaS models 307 | if model_type == "DPT_Hybrid" or model_type == "dpt_hybrid": 308 | # Use direct URL download for MiDaS models 309 | midas_url = "https://github.com/intel-isl/MiDaS/releases/download/v3/dpt_hybrid-midas-501f0c75.pt" 310 | local_path = os.path.join(MODELS_DIR, "dpt_hybrid_midas.pt") 311 | 312 | if not os.path.exists(local_path): 313 | logger.info(f"Downloading MiDaS model from {midas_url}") 314 | try: 315 | response = requests.get(midas_url, stream=True) 316 | if response.status_code == 200: 317 | with open(local_path, 'wb') as f: 318 | for chunk in response.iter_content(chunk_size=8192): 319 | f.write(chunk) 320 | logger.info(f"Downloaded MiDaS model to {local_path}") 321 | else: 322 | logger.error(f"Failed to download model: {response.status_code}") 323 | except Exception as e: 324 | logger.error(f"Error downloading MiDaS model: {e}") 325 | 326 | # Load pretrained model 327 | try: 328 | # Create a simple model architecture 329 | from torchvision.models import resnet50 330 | self.model = resnet50() 331 | self.model.fc = torch.nn.Linear(2048, 1) 332 | 333 | # Load state dict if available 334 | if os.path.exists(local_path): 335 | logger.info(f"Loading MiDaS model from {local_path}") 336 | state_dict = torch.load(local_path, map_location=device) 337 | # Convert all parameters to float 338 | floated_state_dict = {k: v.float() for k, v in state_dict.items()} 339 | self.model.load_state_dict(floated_state_dict) 340 | 341 | except Exception as e: 342 | logger.error(f"Error loading MiDaS model state dict: {e}") 343 | # Fallback to ResNet 344 | self.model = resnet50(pretrained=True) 345 | self.model.fc = torch.nn.Linear(2048, 1) 346 | 347 | else: # Other model types or fallback 348 | from torchvision.models import resnet50 349 | self.model = resnet50(pretrained=True) 350 | self.model.fc = torch.nn.Linear(2048, 1) 351 | 352 | # Ensure model parameters are float 353 | for param in self.model.parameters(): 354 | param.data = param.data.float() 355 | 356 | # Explicitly convert model to FloatTensor 357 | self.model = self.model.float() 358 | 359 | # Move model to device and set to eval mode 360 | self.model = self.model.to(device) 361 | self.model.eval() 362 | 363 | except Exception as e: 364 | logger.error(f"Failed to load MiDaS model: {e}") 365 | logger.error(traceback.format_exc()) 366 | # Create a minimal model as absolute fallback 367 | from torchvision.models import resnet18 368 | self.model = resnet18(pretrained=True).float().to(device) 369 | self.model.fc = torch.nn.Linear(512, 1).float().to(device) 370 | self.model.eval() 371 | 372 | def __call__(self, image): 373 | """Process an image and return the depth map""" 374 | try: 375 | # Convert PIL image to tensor for processing 376 | if isinstance(image, Image.Image): 377 | # Get original dimensions 378 | original_width, original_height = image.size 379 | 380 | # Ensure dimensions are multiple of 32 (required for some models) 381 | # This helps prevent tensor dimension mismatches 382 | target_height = ((original_height + 31) // 32) * 32 383 | target_width = ((original_width + 31) // 32) * 32 384 | 385 | # Keep original dimensions - don't force 384x384 386 | # The caller should already have resized to the requested input_size 387 | 388 | # Log resize information if needed 389 | if (target_width != original_width) or (target_height != original_height): 390 | logger.info(f"Adjusting dimensions from {original_width}x{original_height} to {target_width}x{target_height} (multiples of 32)") 391 | img_resized = image.resize((target_width, target_height), Image.LANCZOS) 392 | else: 393 | img_resized = image 394 | 395 | # Convert to numpy array 396 | img_np = np.array(img_resized).astype(np.float32) / 255.0 397 | 398 | # Check for NaN values and replace them with zeros 399 | if np.isnan(img_np).any(): 400 | logger.warning("Input image contains NaN values. Replacing with zeros.") 401 | img_np = np.nan_to_num(img_np, nan=0.0) 402 | 403 | # Convert to tensor with proper shape (B,C,H,W) 404 | if len(img_np.shape) == 3: 405 | # RGB image 406 | img_np = img_np.transpose(2, 0, 1) # (H,W,C) -> (C,H,W) 407 | else: 408 | # Grayscale image - add channel dimension 409 | img_np = np.expand_dims(img_np, axis=0) 410 | 411 | # Add batch dimension and ensure float32 412 | input_tensor = torch.from_numpy(img_np).unsqueeze(0).float() 413 | else: 414 | # Already a tensor - ensure float32 by explicitly converting 415 | # This is the key fix for the "Input type (torch.cuda.DoubleTensor) and weight type (torch.cuda.FloatTensor)" error 416 | input_tensor = None 417 | 418 | # Handle potential error cases with clearer messages 419 | if not torch.is_tensor(image): 420 | logger.error(f"Expected tensor or PIL image, got {type(image)}") 421 | # Create dummy tensor as fallback 422 | input_tensor = torch.ones((1, 3, 512, 512), dtype=torch.float32) 423 | elif image.numel() == 0: 424 | logger.error("Input tensor is empty") 425 | # Create dummy tensor as fallback 426 | input_tensor = torch.ones((1, 3, 512, 512), dtype=torch.float32) 427 | else: 428 | # Check for NaN values 429 | if torch.isnan(image).any(): 430 | logger.warning("Input tensor contains NaN values. Replacing with zeros.") 431 | image = torch.nan_to_num(image, nan=0.0) 432 | 433 | # Always convert to float32 to prevent type mismatches 434 | input_tensor = image.float() # Convert any tensor to FloatTensor 435 | 436 | # Handle tensor shape issues with more robust dimension checking 437 | if input_tensor.dim() == 2: # [H, W] 438 | # Single channel 2D tensor 439 | input_tensor = input_tensor.unsqueeze(0).unsqueeze(0) # Add batch and channel dims [1, 1, H, W] 440 | logger.info(f"Converted 2D tensor to 4D with shape: {input_tensor.shape}") 441 | elif input_tensor.dim() == 3: 442 | # Could be [C, H, W] or [B, H, W] or [H, W, C] 443 | shape = input_tensor.shape 444 | if shape[-1] == 3 or shape[-1] == 1: # [H, W, C] format 445 | # Convert from HWC to BCHW 446 | input_tensor = input_tensor.permute(2, 0, 1).unsqueeze(0) # [H, W, C] -> [1, C, H, W] 447 | logger.info(f"Converted HWC tensor to BCHW with shape: {input_tensor.shape}") 448 | elif shape[0] <= 3: # Likely [C, H, W] 449 | input_tensor = input_tensor.unsqueeze(0) # Add batch dim [1, C, H, W] 450 | logger.info(f"Added batch dimension to CHW tensor: {input_tensor.shape}") 451 | else: # Likely [B, H, W] 452 | input_tensor = input_tensor.unsqueeze(1) # Add channel dim [B, 1, H, W] 453 | logger.info(f"Added channel dimension to BHW tensor: {input_tensor.shape}") 454 | 455 | # Ensure proper shape after corrections 456 | if input_tensor.dim() != 4: 457 | logger.warning(f"Tensor still has incorrect dimensions ({input_tensor.dim()}). Forcing reshape.") 458 | # Force reshape to 4D 459 | orig_shape = input_tensor.shape 460 | if input_tensor.dim() > 4: 461 | # Too many dimensions, collapse extras 462 | input_tensor = input_tensor.reshape(1, -1, orig_shape[-2], orig_shape[-1]) 463 | else: 464 | # Create a standard 4D tensor as fallback 465 | input_tensor = torch.ones((1, 3, 512, 512), dtype=torch.float32) 466 | 467 | # Move to device and ensure float type 468 | input_tensor = input_tensor.to(self.device).float() 469 | 470 | # Log tensor shape for debugging 471 | logger.info(f"MiDaS input tensor shape: {input_tensor.shape}, dtype: {input_tensor.dtype}") 472 | 473 | # Run inference with better error handling 474 | with torch.no_grad(): 475 | try: 476 | # Make sure input is float32 and model weights are float32 477 | output = self.model(input_tensor) 478 | 479 | # Handle various output shapes 480 | if output.dim() == 1: # [B*H*W] flattened output 481 | # Reshape based on input dimensions 482 | b, _, h, w = input_tensor.shape 483 | output = output.reshape(b, 1, h, w) 484 | elif output.dim() == 2: # [B, H*W] or similar 485 | # Could be flattened spatial dimensions 486 | b = output.shape[0] 487 | if b == input_tensor.shape[0]: # Batch size matches 488 | h = int(np.sqrt(output.shape[1])) # Estimate height assuming square 489 | w = h 490 | if h * w == output.shape[1]: # Perfect square 491 | output = output.reshape(b, 1, h, w) 492 | else: 493 | # Not a perfect square, use input dimensions 494 | _, _, h, w = input_tensor.shape 495 | output = output.reshape(b, 1, h, w) 496 | else: 497 | # Add dimensions to make 4D 498 | output = output.unsqueeze(1).unsqueeze(1) 499 | 500 | # Ensure output has standard 4D shape (B,C,H,W) for interpolation 501 | if output.dim() != 4: 502 | logger.warning(f"Output has non-standard dimensions: {output.shape}, adding dimensions") 503 | # Add dimensions until we have 4D 504 | while output.dim() < 4: 505 | output = output.unsqueeze(-1) 506 | 507 | # Resize to match input resolution 508 | if isinstance(image, Image.Image): 509 | w, h = image.size 510 | 511 | # Log the shape for debugging 512 | logger.info(f"Resizing output tensor from shape {output.shape} to size ({h}, {w})") 513 | 514 | # Ensure output tensor has correct number of dimensions for interpolation 515 | # Standard interpolation requires 4D tensor (B,C,H,W) 516 | try: 517 | # Now interpolate with proper dimensions 518 | output = torch.nn.functional.interpolate( 519 | output, 520 | size=(h, w), 521 | mode="bicubic", 522 | align_corners=False 523 | ) 524 | except RuntimeError as resize_err: 525 | logger.error(f"Interpolation error: {resize_err}. Attempting to fix tensor shape.") 526 | 527 | # Last resort: create compatible tensor from output data 528 | try: 529 | # Get data and reshape to simple 2D first 530 | output_data = output.view(-1).cpu().numpy() 531 | output_reshaped = torch.from_numpy( 532 | np.resize(output_data, (h * w)) 533 | ).reshape(1, 1, h, w).to(self.device).float() 534 | 535 | logger.info(f"Corrected output shape to {output_reshaped.shape}") 536 | output = output_reshaped 537 | except Exception as reshape_err: 538 | logger.error(f"Reshape fix failed: {reshape_err}. Using fallback tensor.") 539 | # Create a basic gradient as fallback 540 | output = torch.ones((1, 1, h, w), device=self.device, dtype=torch.float32) 541 | y_coords = torch.linspace(0, 1, h).reshape(-1, 1).repeat(1, w) 542 | output[0, 0, :, :] = y_coords.to(self.device) 543 | 544 | except Exception as model_err: 545 | logger.error(f"Model inference error: {model_err}") 546 | logger.error(traceback.format_exc()) 547 | 548 | # Create a visually distinguishable gradient pattern fallback 549 | if isinstance(image, Image.Image): 550 | w, h = image.size 551 | else: 552 | # Extract dimensions from input tensor 553 | _, _, h, w = input_tensor.shape if input_tensor.dim() >= 4 else (1, 1, 512, 512) 554 | 555 | # Create gradient depth map as fallback 556 | output = torch.ones((1, 1, h, w), device=self.device, dtype=torch.float32) 557 | y_coords = torch.linspace(0, 1, h).reshape(-1, 1).repeat(1, w) 558 | output[0, 0, :, :] = y_coords.to(self.device) 559 | 560 | # Final validation - ensure output is float32 and has no NaNs 561 | output = output.float() 562 | if torch.isnan(output).any(): 563 | logger.warning("Output contains NaN values. Replacing with zeros.") 564 | output = torch.nan_to_num(output, nan=0.0) 565 | 566 | # Use same interface as the pipeline 567 | return {"predicted_depth": output} 568 | 569 | except Exception as e: 570 | logger.error(f"Error in MiDaS inference: {e}") 571 | logger.error(traceback.format_exc()) 572 | 573 | # Return a placeholder depth map 574 | if isinstance(image, Image.Image): 575 | w, h = image.size 576 | dummy_tensor = torch.ones((1, 1, h, w), device=self.device) 577 | else: 578 | # Try to get shape from tensor 579 | shape = image.shape 580 | if len(shape) >= 3: 581 | if shape[0] == 3: # CHW format 582 | h, w = shape[1], shape[2] 583 | else: # HWC format 584 | h, w = shape[0], shape[1] 585 | else: 586 | h, w = 512, 512 587 | dummy_tensor = torch.ones((1, 1, h, w), device=self.device) 588 | 589 | return {"predicted_depth": dummy_tensor} 590 | 591 | class DepthEstimationNode: 592 | """ 593 | ComfyUI node for depth estimation using Depth Anything models. 594 | 595 | This node provides depth map generation from images using various Depth Anything models 596 | with configurable post-processing options like blur, median filtering, contrast enhancement, 597 | and gamma correction. 598 | """ 599 | 600 | MEDIAN_SIZES = ["3", "5", "7", "9", "11"] 601 | 602 | def __init__(self): 603 | self.device = None 604 | self.depth_estimator = None 605 | self.current_model = None 606 | logger.info("Initialized DepthEstimationNode") 607 | 608 | @classmethod 609 | def INPUT_TYPES(cls) -> Dict[str, Dict[str, Any]]: 610 | """Define the input types for the node.""" 611 | return { 612 | "required": { 613 | "image": ("IMAGE",), 614 | "model_name": (list(DEPTH_MODELS.keys()),), 615 | # Ensure minimum size is enforced by the UI 616 | "input_size": ("INT", {"default": 1024, "min": 256, "max": 1024, "step": 1}), 617 | "blur_radius": ("FLOAT", {"default": 2.0, "min": 0.0, "max": 10.0, "step": 0.1}), 618 | # Define median_size as a dropdown with specific string values 619 | "median_size": (cls.MEDIAN_SIZES, {"default": "3"}), 620 | "apply_auto_contrast": ("BOOLEAN", {"default": True}), 621 | "apply_gamma": ("BOOLEAN", {"default": True}) 622 | }, 623 | "optional": { 624 | "force_reload": ("BOOLEAN", {"default": False}), 625 | "force_cpu": ("BOOLEAN", {"default": False}) 626 | } 627 | } 628 | 629 | RETURN_TYPES = ("IMAGE",) 630 | FUNCTION = "estimate_depth" 631 | CATEGORY = "depth" 632 | 633 | def cleanup(self) -> None: 634 | """Clean up resources and free VRAM.""" 635 | try: 636 | if self.depth_estimator is not None: 637 | # Save model name before deletion for logging 638 | model_name = self.current_model 639 | 640 | # Delete the estimator 641 | del self.depth_estimator 642 | self.depth_estimator = None 643 | self.current_model = None 644 | 645 | # Force CUDA cache clearing 646 | if torch.cuda.is_available(): 647 | torch.cuda.empty_cache() 648 | gc.collect() 649 | 650 | logger.info(f"Cleaned up model resources for {model_name}") 651 | 652 | # Log available memory after cleanup if CUDA is available 653 | if torch.cuda.is_available(): 654 | try: 655 | free_mem_info = get_free_memory(get_torch_device()) 656 | # Handle return value whether it's a tuple or a single value 657 | if isinstance(free_mem_info, tuple): 658 | free_mem, total_mem = free_mem_info 659 | logger.info(f"Available VRAM after cleanup: {free_mem/1024:.2f}MB of {total_mem/1024:.2f}MB") 660 | else: 661 | logger.info(f"Available VRAM after cleanup: {free_mem_info/1024:.2f}MB") 662 | except Exception as e: 663 | logger.warning(f"Error getting memory info: {e}") 664 | except Exception as e: 665 | logger.warning(f"Error during cleanup: {e}") 666 | logger.debug(traceback.format_exc()) 667 | 668 | def ensure_model_loaded(self, model_name: str, force_reload: bool = False, force_cpu: bool = False) -> None: 669 | """ 670 | Ensures the correct model is loaded with proper VRAM management and fallback options. 671 | 672 | Args: 673 | model_name: The name of the model to load 674 | force_reload: If True, reload the model even if it's already loaded 675 | force_cpu: If True, force loading on CPU regardless of GPU availability 676 | 677 | Raises: 678 | RuntimeError: If the model fails to load after all fallback attempts 679 | """ 680 | try: 681 | # Check for valid model name with more helpful fallback 682 | if model_name not in DEPTH_MODELS: 683 | # Find the most similar model name if possible 684 | available_models = list(DEPTH_MODELS.keys()) 685 | 686 | if len(available_models) > 0: 687 | # First try to find a model with a similar name 688 | model_name_lower = model_name.lower() 689 | 690 | # Prioritized fallback selection logic: 691 | # 1. Try to match on similar name 692 | # 2. Prefer V2 models if V2 was requested 693 | # 3. Prefer smaller models (more reliable) 694 | if "v2" in model_name_lower and "small" in model_name_lower: 695 | fallback_model = "Depth-Anything-V2-Small" 696 | elif "v2" in model_name_lower and "base" in model_name_lower: 697 | fallback_model = "Depth-Anything-V2-Base" 698 | elif "v2" in model_name_lower: 699 | fallback_model = "Depth-Anything-V2-Small" 700 | elif "small" in model_name_lower: 701 | fallback_model = "Depth-Anything-Small" 702 | elif "midas" in model_name_lower: 703 | fallback_model = "MiDaS-Small" 704 | else: 705 | # Default to the first model if no better match found 706 | fallback_model = available_models[0] 707 | 708 | logger.warning(f"Unknown model: {model_name}. Falling back to {fallback_model}") 709 | model_name = fallback_model 710 | else: 711 | raise ValueError(f"No depth models available. Please check your installation.") 712 | 713 | # Get model info and validate 714 | model_info = DEPTH_MODELS[model_name] 715 | 716 | # Handle model_info as string or dict with better defaults 717 | if isinstance(model_info, dict): 718 | model_path = model_info.get("path", "") 719 | required_vram = model_info.get("vram_mb", 2000) * 1024 # Convert to KB 720 | model_type = model_info.get("model_type", "v1") # v1 or v2 721 | encoder = model_info.get("encoder", "vits") # Model encoder type 722 | config = model_info.get("config", None) # Model config for direct loading 723 | direct_url = model_info.get("direct_url", None) # Direct download URL 724 | else: 725 | model_path = str(model_info) 726 | required_vram = 2000 * 1024 # Default 2GB 727 | model_type = "v1" 728 | encoder = "vits" 729 | config = None 730 | direct_url = None 731 | 732 | # Only reload if needed or forced 733 | if not force_reload and self.depth_estimator is not None and self.current_model == model_path: 734 | logger.info(f"Model '{model_name}' already loaded") 735 | return 736 | 737 | # Clean up any existing model to free memory before loading new one 738 | self.cleanup() 739 | 740 | # Set up device for model 741 | if self.device is None: 742 | self.device = get_torch_device() 743 | 744 | logger.info(f"Loading depth model: {model_name} on {'CPU' if force_cpu else self.device}") 745 | 746 | # Enhanced VRAM check with better error handling 747 | if torch.cuda.is_available() and not force_cpu: 748 | try: 749 | free_mem_info = get_free_memory(self.device) 750 | 751 | # Process different return formats from get_free_memory 752 | if isinstance(free_mem_info, tuple): 753 | free_mem, total_mem = free_mem_info 754 | logger.info(f"Available VRAM: {free_mem/1024:.2f}MB, Required: {required_vram/1024:.2f}MB") 755 | else: 756 | free_mem = free_mem_info 757 | logger.info(f"Available VRAM: {free_mem/1024:.2f}MB, Required: {required_vram/1024:.2f}MB") 758 | total_mem = free_mem * 2 # Estimate if not available 759 | 760 | # Add buffer to required memory to avoid OOM issues 761 | required_vram_with_buffer = required_vram * 1.2 # Add 20% buffer 762 | 763 | # If not enough memory, fall back to CPU with warning 764 | if free_mem < required_vram_with_buffer: 765 | logger.warning( 766 | f"Insufficient VRAM for {model_name} (need ~{required_vram/1024:.1f}MB, " + 767 | f"have {free_mem/1024:.1f}MB). Using CPU instead." 768 | ) 769 | force_cpu = True 770 | except Exception as mem_error: 771 | logger.warning(f"Error checking VRAM: {str(mem_error)}. Using CPU to be safe.") 772 | force_cpu = True 773 | 774 | # Determine optimal device configuration 775 | device_type = 'cpu' if force_cpu else ('cuda' if torch.cuda.is_available() else 'cpu') 776 | 777 | # Use appropriate dtype based on device and model 778 | # FP16 for CUDA saves VRAM but doesn't work well for all models 779 | if 'cuda' in str(self.device) and not force_cpu: 780 | # V2 models have issues with FP16 - use FP32 for them 781 | if model_type == "v2": 782 | dtype = torch.float32 783 | else: 784 | # Other models can use FP16 to save VRAM 785 | dtype = torch.float16 786 | else: 787 | # CPU always uses FP32 788 | dtype = torch.float32 789 | 790 | # Create model-specific cache directory 791 | # Use consistent naming to improve cache hits 792 | model_cache_name = model_name.replace("-", "_").lower() 793 | cache_dir = os.path.join(MODELS_DIR, model_cache_name) 794 | os.makedirs(cache_dir, exist_ok=True) 795 | 796 | # Prioritized loading strategy: 797 | # 1. Check for locally cached model files 798 | # 2. Try direct download from URLs that don't require authentication 799 | # 3. Try loading from Hugging Face using transformers pipeline 800 | # 4. Fall back to direct model implementation 801 | # 5. Fall back to MiDaS model 802 | 803 | # Step 1: First check if we already have a local model file 804 | local_model_file = None 805 | 806 | # Search all valid model directories for existing files 807 | for base_path in existing_paths: 808 | # Check multiple possible locations and naming patterns 809 | locations_to_check = [ 810 | os.path.join(base_path, model_cache_name), 811 | os.path.join(base_path, model_path.replace("/", "_")), 812 | base_path, 813 | os.path.join(base_path, "v2") if model_type == "v2" else None, 814 | ] 815 | 816 | # Filter out None values 817 | locations_to_check = [loc for loc in locations_to_check if loc is not None] 818 | 819 | # Common model filenames to check 820 | model_filenames = [ 821 | "pytorch_model.bin", 822 | "model.pt", 823 | "model.pth", 824 | f"{model_cache_name}.pt", 825 | f"{model_cache_name}.bin", 826 | f"depth_anything_{encoder}.pt", # Common naming for Depth Anything models 827 | f"depth_anything_v2_{encoder}.pt", # V2 naming format 828 | ] 829 | 830 | # Search all locations and filenames 831 | for location in locations_to_check: 832 | if os.path.exists(location): 833 | for filename in model_filenames: 834 | file_path = os.path.join(location, filename) 835 | if os.path.exists(file_path) and os.path.getsize(file_path) > 1000000: # >1MB to avoid empty files 836 | local_model_file = file_path 837 | logger.info(f"Found existing model file: {local_model_file}") 838 | break 839 | 840 | if local_model_file: 841 | break 842 | 843 | if local_model_file: 844 | break 845 | 846 | # Step 2: If no local file found, try downloading from direct URLs 847 | # These URLs don't require authentication and are more reliable 848 | if not local_model_file and model_type == "v2": 849 | # Comprehensive list of URLs to try for V2 models 850 | alternative_urls = { 851 | "Depth-Anything-V2-Small": [ 852 | "https://huggingface.co/depth-anything/Depth-Anything-V2-Small-hf/resolve/main/pytorch_model.bin", 853 | "https://github.com/LiheYoung/Depth-Anything/releases/download/v2.0/depth_anything_v2_small.pt", 854 | "https://huggingface.co/ckpt/depth-anything-v2/resolve/main/depth_anything_v2_small.pt" 855 | ], 856 | "Depth-Anything-V2-Base": [ 857 | "https://huggingface.co/depth-anything/Depth-Anything-V2-Base-hf/resolve/main/pytorch_model.bin", 858 | "https://github.com/LiheYoung/Depth-Anything/releases/download/v2.0/depth_anything_v2_base.pt", 859 | "https://huggingface.co/ckpt/depth-anything-v2/resolve/main/depth_anything_v2_base.pt" 860 | ], 861 | "MiDaS-Small": [ 862 | "https://github.com/intel-isl/MiDaS/releases/download/v2_1/midas_v21_small_256.pt" 863 | ], 864 | "MiDaS-Base": [ 865 | "https://github.com/intel-isl/MiDaS/releases/download/v3/dpt_hybrid-midas-501f0c75.pt" 866 | ] 867 | } 868 | 869 | # Get URLs to try (including the direct_url from model_info) 870 | urls_to_try = [] 871 | if direct_url: 872 | urls_to_try.append(direct_url) 873 | 874 | # Add alternative URLs for this specific model if available 875 | if model_name in alternative_urls: 876 | urls_to_try.extend(alternative_urls[model_name]) 877 | 878 | # Try downloading the model if not found locally 879 | if urls_to_try: 880 | for url in urls_to_try: 881 | try: 882 | # Determine output filename and path 883 | model_filename = os.path.basename(url) 884 | download_path = os.path.join(cache_dir, model_filename) 885 | 886 | # Check if already downloaded 887 | if os.path.exists(download_path) and os.path.getsize(download_path) > 1000000: # >1MB to avoid empty files 888 | logger.info(f"Found existing downloaded model at {download_path}") 889 | local_model_file = download_path 890 | break 891 | 892 | # Create parent directory if needed 893 | os.makedirs(os.path.dirname(download_path), exist_ok=True) 894 | 895 | # Download using a reliable method with multiple retries 896 | logger.info(f"Downloading model from {url} to {download_path}") 897 | 898 | success = False 899 | 900 | # Try wget first (most reliable for large files) 901 | try: 902 | import wget 903 | wget.download(url, out=download_path) 904 | if os.path.exists(download_path) and os.path.getsize(download_path) > 1000000: 905 | logger.info(f"Successfully downloaded model with wget to {download_path}") 906 | success = True 907 | except Exception as wget_error: 908 | logger.warning(f"wget download failed: {str(wget_error)}") 909 | 910 | # Try requests if wget failed 911 | if not success: 912 | try: 913 | # Download with progress reporting 914 | with requests.get(url, stream=True, timeout=60) as response: 915 | response.raise_for_status() 916 | total_size = int(response.headers.get('content-length', 0)) 917 | 918 | with open(download_path, 'wb') as f: 919 | downloaded = 0 920 | for chunk in response.iter_content(chunk_size=8192): 921 | f.write(chunk) 922 | downloaded += len(chunk) 923 | 924 | if total_size > 0 and downloaded % (20 * 1024 * 1024) == 0: # Log every 20MB 925 | percent = int(100 * downloaded / total_size) 926 | logger.info(f"Download progress: {downloaded/1024/1024:.1f}MB of {total_size/1024/1024:.1f}MB ({percent}%)") 927 | 928 | if os.path.exists(download_path) and os.path.getsize(download_path) > 1000000: 929 | logger.info(f"Successfully downloaded model with requests to {download_path}") 930 | success = True 931 | except Exception as req_error: 932 | logger.warning(f"requests download failed: {str(req_error)}") 933 | 934 | # Try urllib as last resort 935 | if not success: 936 | try: 937 | import urllib.request 938 | urllib.request.urlretrieve(url, download_path) 939 | if os.path.exists(download_path) and os.path.getsize(download_path) > 1000000: 940 | logger.info(f"Successfully downloaded model with urllib to {download_path}") 941 | success = True 942 | except Exception as urllib_error: 943 | logger.warning(f"urllib download failed: {str(urllib_error)}") 944 | 945 | # Set the local_model_file if download succeeded 946 | if success: 947 | local_model_file = download_path 948 | break 949 | else: 950 | # Clean up failed download 951 | if os.path.exists(download_path): 952 | try: 953 | os.remove(download_path) 954 | except: 955 | pass 956 | 957 | except Exception as download_error: 958 | logger.warning(f"Failed to download from {url}: {str(download_error)}") 959 | continue 960 | 961 | # Step 3: Try loading with transformers pipeline 962 | # This is the most feature-complete approach but may fail with auth issues 963 | logger.info(f"Trying to load model '{model_name}' using transformers pipeline") 964 | 965 | # Priority-ordered list of model paths to try 966 | # Ordered from most to least likely to work 967 | model_paths_to_try = [] 968 | 969 | # Start with the specific model requested 970 | model_paths_to_try.append(model_path) 971 | 972 | # Add V2-specific paths for V2 models 973 | if model_type == "v2": 974 | if "small" in model_name.lower(): 975 | model_paths_to_try.append("depth-anything/Depth-Anything-V2-Small-hf") 976 | elif "base" in model_name.lower(): 977 | model_paths_to_try.append("depth-anything/Depth-Anything-V2-Base-hf") 978 | 979 | # Add variants with and without -hf suffix 980 | model_paths_to_try.append(model_path.replace("-hf", "")) 981 | if "-hf" not in model_path: 982 | model_paths_to_try.append(model_path + "-hf") 983 | 984 | # Try both organization name formats 985 | model_paths_to_try.append(model_path.replace("LiheYoung", "depth-anything")) 986 | model_paths_to_try.append(model_path.replace("depth-anything", "LiheYoung")) 987 | else: 988 | # For V1 models, add common variants 989 | model_paths_to_try.append(model_path.replace("-hf", "")) 990 | if "-hf" not in model_path: 991 | model_paths_to_try.append(model_path + "-hf") 992 | 993 | # Add MiDaS fallbacks only if not already trying MiDaS 994 | if "midas" not in model_name.lower(): 995 | model_paths_to_try.append("Intel/dpt-hybrid-midas") 996 | 997 | # Remove duplicates while preserving order 998 | model_paths_to_try = list(dict.fromkeys(model_paths_to_try)) 999 | 1000 | # Log all paths we're going to try 1001 | logger.info(f"Will try loading from these paths in order: {model_paths_to_try}") 1002 | 1003 | # Try loading with transformers pipeline 1004 | pipeline_success = False 1005 | pipeline_error = None 1006 | 1007 | for path in model_paths_to_try: 1008 | # Skip empty paths 1009 | if not path.strip(): 1010 | continue 1011 | 1012 | logger.info(f"Attempting to load model from: {path}") 1013 | 1014 | # First try online loading (allows downloading new models) 1015 | try: 1016 | logger.info(f"Loading with online mode: model={path}, device={device_type}, dtype={dtype}") 1017 | 1018 | # Try standard pipeline creation first 1019 | try: 1020 | from transformers import pipeline 1021 | 1022 | # Create pipeline with timeout and error handling 1023 | self.depth_estimator = pipeline( 1024 | "depth-estimation", 1025 | model=path, 1026 | cache_dir=cache_dir, 1027 | local_files_only=False, # Try online first 1028 | device_map=device_type, 1029 | torch_dtype=dtype 1030 | ) 1031 | 1032 | # Verify that pipeline was created 1033 | if self.depth_estimator is None: 1034 | raise RuntimeError(f"Pipeline initialization returned None for {path}") 1035 | 1036 | # Validate by running a test inference 1037 | test_img = Image.new("RGB", (64, 64), color=(128, 128, 128)) 1038 | try: 1039 | test_result = self.depth_estimator(test_img) 1040 | 1041 | # Further verify the output format 1042 | if not isinstance(test_result, dict) or "predicted_depth" not in test_result: 1043 | raise RuntimeError("Invalid output format from pipeline") 1044 | 1045 | # Success - log and break 1046 | logger.info(f"Successfully loaded model from {path} with online mode") 1047 | pipeline_success = True 1048 | break 1049 | 1050 | except Exception as test_error: 1051 | logger.warning(f"Pipeline created but test failed: {str(test_error)}") 1052 | raise 1053 | 1054 | except TypeError as type_error: 1055 | # Handle unpacking errors which are common with older transformers versions 1056 | logger.warning(f"TypeError when creating pipeline: {str(type_error)}") 1057 | 1058 | # Try alternative approach with manual component loading 1059 | logger.info("Trying manual component loading as alternative...") 1060 | 1061 | try: 1062 | from transformers import AutoModelForDepthEstimation, AutoImageProcessor 1063 | 1064 | # Load components separately 1065 | processor = AutoImageProcessor.from_pretrained( 1066 | path, cache_dir=cache_dir, local_files_only=False 1067 | ) 1068 | 1069 | model = AutoModelForDepthEstimation.from_pretrained( 1070 | path, cache_dir=cache_dir, local_files_only=False, 1071 | torch_dtype=dtype 1072 | ) 1073 | 1074 | # Move model to correct device 1075 | if not force_cpu and 'cuda' in device_type: 1076 | model = model.to(self.device) 1077 | 1078 | # Create custom pipeline class 1079 | class CustomDepthEstimator: 1080 | def __init__(self, model, processor, device): 1081 | self.model = model 1082 | self.processor = processor 1083 | self.device = device 1084 | 1085 | def __call__(self, image): 1086 | # Process image and run model 1087 | inputs = self.processor(images=image, return_tensors="pt") 1088 | 1089 | # Move inputs to correct device 1090 | if torch.cuda.is_available() and not force_cpu: 1091 | inputs = {k: v.to(self.device) for k, v in inputs.items()} 1092 | 1093 | # Run model 1094 | with torch.no_grad(): 1095 | outputs = self.model(**inputs) 1096 | 1097 | # Return results in standard format 1098 | return {"predicted_depth": outputs.predicted_depth} 1099 | 1100 | # Create custom pipeline 1101 | self.depth_estimator = CustomDepthEstimator(model, processor, self.device) 1102 | 1103 | # Test the pipeline 1104 | test_img = Image.new("RGB", (64, 64), color=(128, 128, 128)) 1105 | test_result = self.depth_estimator(test_img) 1106 | 1107 | # Verify output format 1108 | if not isinstance(test_result, dict) or "predicted_depth" not in test_result: 1109 | raise RuntimeError("Invalid output format from custom pipeline") 1110 | 1111 | logger.info(f"Successfully loaded model from {path} with custom pipeline") 1112 | pipeline_success = True 1113 | break 1114 | 1115 | except Exception as custom_error: 1116 | logger.warning(f"Custom pipeline creation failed: {str(custom_error)}") 1117 | raise 1118 | 1119 | except Exception as online_error: 1120 | logger.warning(f"Online loading failed for {path}: {str(online_error)}") 1121 | pipeline_error = online_error 1122 | 1123 | # Try local-only mode if online fails (faster and often works with cached files) 1124 | try: 1125 | logger.info(f"Trying local-only mode for {path}") 1126 | 1127 | from transformers import pipeline 1128 | 1129 | self.depth_estimator = pipeline( 1130 | "depth-estimation", 1131 | model=path, 1132 | cache_dir=cache_dir, 1133 | local_files_only=True, # Only use local files 1134 | device_map=device_type, 1135 | torch_dtype=dtype 1136 | ) 1137 | 1138 | # Verify pipeline 1139 | if self.depth_estimator is None: 1140 | raise RuntimeError(f"Local pipeline initialization returned None for {path}") 1141 | 1142 | # Test with small image 1143 | test_img = Image.new("RGB", (64, 64), color=(128, 128, 128)) 1144 | test_result = self.depth_estimator(test_img) 1145 | 1146 | logger.info(f"Successfully loaded model from {path} with local-only mode") 1147 | pipeline_success = True 1148 | break 1149 | 1150 | except Exception as local_error: 1151 | logger.warning(f"Local-only mode failed for {path}: {str(local_error)}") 1152 | pipeline_error = local_error 1153 | # Continue to next path 1154 | 1155 | # Step 4: If transformers pipeline failed, try direct model loading 1156 | if not pipeline_success: 1157 | logger.info("Pipeline loading failed. Trying direct model implementation.") 1158 | 1159 | # If we have a local model file from previous steps, use it with direct loading 1160 | direct_success = False 1161 | 1162 | if local_model_file: 1163 | logger.info(f"Loading model directly from: {local_model_file}") 1164 | 1165 | # For V2 models, use the DepthAnythingV2 implementation 1166 | if model_type == "v2" and TIMM_AVAILABLE: 1167 | try: 1168 | logger.info(f"Using DepthAnythingV2 implementation with config: {config}") 1169 | 1170 | # Create model instance with appropriate config 1171 | if config: 1172 | model_instance = DepthAnythingV2(**config) 1173 | else: 1174 | # Use default config for this encoder type 1175 | default_config = MODEL_CONFIGS.get(encoder, MODEL_CONFIGS["vits"]) 1176 | model_instance = DepthAnythingV2(**default_config) 1177 | 1178 | # Load state dict from file 1179 | logger.info(f"Loading weights from {local_model_file}") 1180 | state_dict = torch.load(local_model_file, map_location="cpu") 1181 | 1182 | # Convert state dict to float32 1183 | if any(v.dtype == torch.float64 for v in state_dict.values() if hasattr(v, 'dtype')): 1184 | logger.info("Converting state dict from float64 to float32") 1185 | state_dict = {k: v.float() if hasattr(v, 'dtype') else v for k, v in state_dict.items()} 1186 | 1187 | # Try loading with different state dict formats 1188 | try: 1189 | if "model" in state_dict: 1190 | model_instance.load_state_dict(state_dict["model"]) 1191 | else: 1192 | model_instance.load_state_dict(state_dict) 1193 | except Exception as e: 1194 | logger.warning(f"Strict loading failed: {str(e)}. Trying non-strict loading.") 1195 | if "model" in state_dict: 1196 | model_instance.load_state_dict(state_dict["model"], strict=False) 1197 | else: 1198 | model_instance.load_state_dict(state_dict, strict=False) 1199 | 1200 | # Move to correct device and set eval mode 1201 | model_instance = model_instance.to(self.device).float().eval() 1202 | 1203 | # Test the model 1204 | test_img = Image.new("RGB", (64, 64), color=(128, 128, 128)) 1205 | _ = model_instance(test_img) 1206 | 1207 | # Success - assign model 1208 | self.depth_estimator = model_instance 1209 | direct_success = True 1210 | logger.info("Successfully loaded model with direct implementation") 1211 | 1212 | except Exception as v2_error: 1213 | logger.warning(f"DepthAnythingV2 loading failed: {str(v2_error)}") 1214 | 1215 | # If V2 direct loading failed or isn't applicable, try MiDaS wrapper 1216 | if not direct_success: 1217 | try: 1218 | logger.info("Falling back to MiDaS wrapper implementation") 1219 | 1220 | # Determine MiDaS model type based on model name 1221 | midas_type = "DPT_Hybrid" # Default 1222 | 1223 | if "small" in model_name.lower(): 1224 | midas_type = "MiDaS_small" 1225 | elif "large" in model_name.lower(): 1226 | midas_type = "DPT_Large" 1227 | 1228 | # Create MiDaS wrapper with model type 1229 | midas_model = MiDaSWrapper(midas_type, self.device) 1230 | 1231 | # Test the model 1232 | test_img = Image.new("RGB", (64, 64), color=(128, 128, 128)) 1233 | _ = midas_model(test_img) 1234 | 1235 | # Success - assign model 1236 | self.depth_estimator = midas_model 1237 | direct_success = True 1238 | logger.info("Successfully loaded MiDaS fallback model") 1239 | 1240 | except Exception as midas_error: 1241 | logger.warning(f"MiDaS wrapper loading failed: {str(midas_error)}") 1242 | 1243 | # Step 5: If all previous attempts failed, try one last MiDaS fallback 1244 | if not direct_success and not pipeline_success: 1245 | try: 1246 | logger.info("All model loading attempts failed. Trying basic MiDaS fallback.") 1247 | 1248 | # Create simple MiDaS wrapper with default settings 1249 | midas_model = MiDaSWrapper("dpt_hybrid", self.device) 1250 | 1251 | # Test with simple image 1252 | test_img = Image.new("RGB", (64, 64), color=(128, 128, 128)) 1253 | _ = midas_model(test_img) 1254 | 1255 | # Assign model 1256 | self.depth_estimator = midas_model 1257 | logger.info("Successfully loaded basic MiDaS fallback model") 1258 | 1259 | except Exception as final_error: 1260 | # If we get here, all attempts have failed 1261 | error_msg = f"All model loading attempts failed for {model_name}. Last error: {str(final_error)}" 1262 | logger.error(error_msg) 1263 | 1264 | # Create a helpful error message with instructions 1265 | all_model_dirs = "\n".join(existing_paths) 1266 | 1267 | # Determine error type for better help message 1268 | if pipeline_error: 1269 | error_str = str(pipeline_error).lower() 1270 | 1271 | if "unauthorized" in error_str or "401" in error_str or "authentication" in error_str: 1272 | # Authentication error - guide for manual download 1273 | error_solution = f""" 1274 | AUTHENTICATION ERROR: The model couldn't be downloaded due to Hugging Face authentication requirements. 1275 | 1276 | SOLUTION: 1277 | 1. Use force_cpu=True in the node settings 1278 | 2. Try a different model like MiDaS-Small 1279 | 3. Download the model manually using one of these direct links: 1280 | - https://huggingface.co/depth-anything/Depth-Anything-V2-Small-hf/resolve/main/pytorch_model.bin 1281 | - https://github.com/LiheYoung/Depth-Anything/releases/download/v2.0/depth_anything_v2_small.pt 1282 | - https://huggingface.co/ckpt/depth-anything-v2/resolve/main/depth_anything_v2_small.pt 1283 | 1284 | Save the file to one of these directories: 1285 | {all_model_dirs} 1286 | """ 1287 | elif "cuda" in error_str or "gpu" in error_str or "vram" in error_str or "memory" in error_str: 1288 | # GPU/memory error 1289 | error_solution = """ 1290 | GPU ERROR: The model failed to load on your GPU. 1291 | 1292 | SOLUTION: 1293 | 1. Use force_cpu=True to use CPU processing instead 1294 | 2. Reduce input_size parameter to 384 to reduce memory requirements 1295 | 3. Try a smaller model like MiDaS-Small 1296 | 4. Ensure you have the latest GPU drivers installed 1297 | """ 1298 | else: 1299 | # Generic error 1300 | error_solution = f""" 1301 | Failed to load any depth estimation model. 1302 | 1303 | SOLUTION: 1304 | 1. Use force_cpu=True to use CPU for processing 1305 | 2. Try using a different model like MiDaS-Small 1306 | 3. Download a model file manually and place it in one of these directories: 1307 | {all_model_dirs} 1308 | 4. Restart ComfyUI to ensure clean state 1309 | """ 1310 | else: 1311 | # Generic error when pipeline_error isn't set 1312 | error_solution = f""" 1313 | Failed to load any depth estimation model. 1314 | 1315 | SOLUTION: 1316 | 1. Use force_cpu=True to use CPU for processing 1317 | 2. Try using a different model like MiDaS-Small 1318 | 3. Download a model file manually and place it in one of these directories: 1319 | {all_model_dirs} 1320 | 4. Restart ComfyUI to ensure clean state 1321 | """ 1322 | 1323 | # Raise helpful error 1324 | raise RuntimeError(f"MODEL LOADING ERROR: {error_solution}") 1325 | 1326 | # Ensure the model is on the correct device 1327 | if hasattr(self.depth_estimator, 'model') and hasattr(self.depth_estimator.model, 'to'): 1328 | if force_cpu: 1329 | self.depth_estimator.model = self.depth_estimator.model.to('cpu') 1330 | else: 1331 | self.depth_estimator.model = self.depth_estimator.model.to(self.device) 1332 | 1333 | # Set model to eval mode if applicable 1334 | if hasattr(self.depth_estimator, 'eval'): 1335 | self.depth_estimator.eval() 1336 | elif hasattr(self.depth_estimator, 'model') and hasattr(self.depth_estimator.model, 'eval'): 1337 | self.depth_estimator.model.eval() 1338 | 1339 | # Store current model info 1340 | self.current_model = model_path 1341 | logger.info(f"Model '{model_name}' loaded successfully") 1342 | 1343 | except Exception as e: 1344 | # Clean up on failure 1345 | self.cleanup() 1346 | 1347 | # Log detailed error info 1348 | error_msg = f"Failed to load model {model_name}: {str(e)}" 1349 | logger.error(error_msg) 1350 | logger.error(traceback.format_exc()) 1351 | 1352 | # Re-raise with clear message 1353 | raise RuntimeError(error_msg) 1354 | 1355 | def load_model_direct(self, model_name, model_info, force_cpu=False): 1356 | """ 1357 | Directly loads a depth model without using transformers pipeline. 1358 | This is a fallback method when the normal pipeline loading fails. 1359 | 1360 | Args: 1361 | model_name: Name of the model to load 1362 | model_info: Dictionary with model information 1363 | force_cpu: Whether to force CPU usage 1364 | 1365 | Returns: 1366 | A depth estimation model that implements the __call__ interface 1367 | """ 1368 | try: 1369 | logger.info(f"Attempting direct model loading for {model_name}") 1370 | 1371 | # Determine device 1372 | device_type = 'cpu' if force_cpu else ('cuda' if torch.cuda.is_available() else 'cpu') 1373 | device = torch.device(device_type) 1374 | 1375 | # Look in all possible model directories 1376 | # This is important to support various directory structures 1377 | model_found = False 1378 | model_path_local = None 1379 | 1380 | # Make a unique model cache directory for this specific model 1381 | model_subfolder = model_name.replace("-", "_").lower() 1382 | 1383 | # Check all possible locations for the model file 1384 | for base_path in existing_paths: 1385 | # Try different possible locations and filename patterns 1386 | possible_model_locations = [ 1387 | # Direct downloads in the model directory 1388 | os.path.join(base_path, model_subfolder), 1389 | 1390 | # Using the full HF directory structure 1391 | os.path.join(base_path, model_info.get("path", "").replace("/", "_")), 1392 | 1393 | # Directly in base directory 1394 | base_path, 1395 | ] 1396 | 1397 | # Add directory structure with model configs if V2 1398 | if model_info.get("model_type") == "v2": 1399 | v2_path = os.path.join(base_path, "v2") 1400 | possible_model_locations.append(v2_path) 1401 | possible_model_locations.append(os.path.join(v2_path, model_subfolder)) 1402 | 1403 | # Try all locations 1404 | logger.info(f"Searching for existing model in these directories: {possible_model_locations}") 1405 | 1406 | for location in possible_model_locations: 1407 | # Check for model file with various naming patterns 1408 | if os.path.exists(location): 1409 | # Check for common filenames 1410 | for filename in ["pytorch_model.bin", "model.pt", "model.pth", 1411 | f"{model_subfolder}.pt", f"{model_subfolder}.bin"]: 1412 | file_path = os.path.join(location, filename) 1413 | if os.path.exists(file_path): 1414 | logger.info(f"Found existing model file: {file_path}") 1415 | model_path_local = file_path 1416 | model_found = True 1417 | break 1418 | 1419 | if model_found: 1420 | break 1421 | 1422 | if model_found: 1423 | break 1424 | 1425 | # If model not found, use the first directory for downloading 1426 | cache_dir = os.path.join(existing_paths[0], model_subfolder) 1427 | os.makedirs(cache_dir, exist_ok=True) 1428 | 1429 | # Get model configuration 1430 | model_type = model_info.get("model_type", "v1") 1431 | encoder = model_info.get("encoder", "vits") 1432 | config = model_info.get("config", MODEL_CONFIGS.get(encoder, MODEL_CONFIGS["vits"])) 1433 | 1434 | # Step 1: If model not found locally, download it 1435 | # List of alternative URLs that don't require authentication 1436 | alternative_urls = { 1437 | "Depth-Anything-V2-Small": [ 1438 | "https://huggingface.co/depth-anything/Depth-Anything-V2-Small-hf/resolve/main/pytorch_model.bin", 1439 | "https://github.com/LiheYoung/Depth-Anything/releases/download/v2.0/depth_anything_v2_small.pt", 1440 | "https://huggingface.co/ckpt/depth-anything-v2/resolve/main/depth_anything_v2_small.pt" 1441 | ], 1442 | "Depth-Anything-V2-Base": [ 1443 | "https://huggingface.co/depth-anything/Depth-Anything-V2-Base-hf/resolve/main/pytorch_model.bin", 1444 | "https://github.com/LiheYoung/Depth-Anything/releases/download/v2.0/depth_anything_v2_base.pt", 1445 | "https://huggingface.co/ckpt/depth-anything-v2/resolve/main/depth_anything_v2_base.pt" 1446 | ], 1447 | "MiDaS-Base": [ 1448 | "https://github.com/intel-isl/MiDaS/releases/download/v3/dpt_hybrid-midas-501f0c75.pt" 1449 | ] 1450 | } 1451 | 1452 | # Get primary URL from model_info 1453 | direct_url = model_info.get("direct_url") 1454 | 1455 | # Add alternative URLs to try if the main one fails 1456 | urls_to_try = [direct_url] if direct_url else [] 1457 | 1458 | # Add alternative URLs for this model if available 1459 | if model_name in alternative_urls: 1460 | urls_to_try.extend(alternative_urls[model_name]) 1461 | 1462 | # Try downloading the model if not found locally 1463 | if not model_found and urls_to_try: 1464 | # Try each URL in sequence until one works 1465 | for url in urls_to_try: 1466 | if not url: 1467 | continue 1468 | 1469 | try: 1470 | model_filename = os.path.basename(url) 1471 | model_path_local = os.path.join(cache_dir, model_filename) 1472 | 1473 | if os.path.exists(model_path_local): 1474 | logger.info(f"Model already exists at {model_path_local}") 1475 | model_found = True 1476 | break 1477 | 1478 | logger.info(f"Attempting to download model from {url} to {model_path_local}") 1479 | 1480 | # Create parent directory if needed 1481 | os.makedirs(os.path.dirname(model_path_local), exist_ok=True) 1482 | 1483 | # Try different download methods 1484 | download_success = False 1485 | 1486 | # First try wget (more reliable for large files) 1487 | try: 1488 | logger.info(f"Downloading with wget: {url}") 1489 | wget.download(url, out=model_path_local) 1490 | logger.info(f"Downloaded model weights to {model_path_local}") 1491 | download_success = True 1492 | except Exception as wget_error: 1493 | logger.warning(f"wget download failed: {str(wget_error)}") 1494 | 1495 | # Fallback to requests 1496 | try: 1497 | logger.info(f"Downloading with requests: {url}") 1498 | response = requests.get(url, stream=True) 1499 | 1500 | if response.status_code == 200: 1501 | total_size = int(response.headers.get('content-length', 0)) 1502 | logger.info(f"File size: {total_size/1024/1024:.1f} MB") 1503 | 1504 | with open(model_path_local, 'wb') as f: 1505 | downloaded = 0 1506 | for data in response.iter_content(1024 * 1024): # 1MB chunks 1507 | f.write(data) 1508 | downloaded += len(data) 1509 | if total_size > 0 and downloaded % (10 * 1024 * 1024) == 0: # Log every 10MB 1510 | progress = (downloaded / total_size) * 100 1511 | logger.info(f"Downloaded {downloaded/1024/1024:.1f}MB of {total_size/1024/1024:.1f}MB ({progress:.1f}%)") 1512 | 1513 | logger.info(f"Download complete: {model_path_local}") 1514 | download_success = True 1515 | else: 1516 | logger.warning(f"Failed to download from {url}: HTTP status {response.status_code}") 1517 | except Exception as req_error: 1518 | logger.warning(f"Requests download failed: {str(req_error)}") 1519 | 1520 | # Try urllib as last resort 1521 | if not download_success: 1522 | try: 1523 | logger.info(f"Downloading with urllib: {url}") 1524 | urllib.request.urlretrieve(url, model_path_local) 1525 | logger.info(f"Downloaded model weights to {model_path_local}") 1526 | download_success = True 1527 | except Exception as urllib_error: 1528 | logger.warning(f"urllib download failed: {str(urllib_error)}") 1529 | 1530 | # Check if download succeeded 1531 | if download_success and os.path.exists(model_path_local) and os.path.getsize(model_path_local) > 0: 1532 | logger.info(f"Successfully downloaded model to {model_path_local}") 1533 | model_found = True 1534 | break 1535 | else: 1536 | logger.warning(f"Download appeared to succeed but file is empty or missing") 1537 | # Try to remove the failed download 1538 | if os.path.exists(model_path_local): 1539 | try: 1540 | os.remove(model_path_local) 1541 | except: 1542 | pass 1543 | 1544 | except Exception as dl_error: 1545 | logger.warning(f"Error downloading from {url}: {str(dl_error)}") 1546 | continue 1547 | 1548 | if not model_found: 1549 | logger.error("All download attempts failed") 1550 | 1551 | # Step 2: Create and load the appropriate model if found 1552 | if model_found and model_path_local and os.path.exists(model_path_local): 1553 | logger.info(f"Found model file at: {model_path_local}") 1554 | 1555 | # Handle V2 models with DepthAnythingV2 implementation 1556 | if model_type == "v2" and TIMM_AVAILABLE: 1557 | try: 1558 | logger.info(f"Loading as DepthAnythingV2 model with config: {config}") 1559 | 1560 | # Create model with the appropriate configuration 1561 | model = DepthAnythingV2(**config) 1562 | 1563 | # Load weights from checkpoint 1564 | logger.info(f"Loading weights from {model_path_local}") 1565 | state_dict = torch.load(model_path_local, map_location=device) 1566 | 1567 | # Convert state dict to float32 if needed 1568 | if any(v.dtype == torch.float64 for v in state_dict.values() if hasattr(v, 'dtype')): 1569 | logger.info("Converting state dict from float64 to float32") 1570 | state_dict = {k: v.float() if hasattr(v, 'dtype') else v for k, v in state_dict.items()} 1571 | 1572 | # Attempt to load the state dict (handles different formats) 1573 | try: 1574 | if "model" in state_dict: 1575 | model.load_state_dict(state_dict["model"]) 1576 | else: 1577 | model.load_state_dict(state_dict) 1578 | except Exception as e: 1579 | logger.warning(f"Error loading state dict: {str(e)}") 1580 | logger.warning("Trying to load with strict=False") 1581 | if "model" in state_dict: 1582 | model.load_state_dict(state_dict["model"], strict=False) 1583 | else: 1584 | model.load_state_dict(state_dict, strict=False) 1585 | 1586 | # Move model to the correct device and ensure float32 1587 | model = model.float().to(device) 1588 | model.device = device 1589 | model.eval() 1590 | 1591 | # Test the model 1592 | logger.info("Testing model with sample image") 1593 | test_img = Image.new("RGB", (64, 64), color=(128, 128, 128)) 1594 | 1595 | try: 1596 | _ = model(test_img) 1597 | logger.info("DepthAnythingV2 model loaded and tested successfully") 1598 | return model 1599 | except Exception as test_error: 1600 | logger.error(f"Error during model test: {str(test_error)}") 1601 | logger.debug(traceback.format_exc()) 1602 | except Exception as e: 1603 | logger.error(f"Error loading DepthAnythingV2: {str(e)}") 1604 | logger.debug(traceback.format_exc()) 1605 | 1606 | # Fallback to MiDaS model if V2 loading failed or for V1 models 1607 | try: 1608 | logger.info("Falling back to MiDaS model") 1609 | 1610 | # Determine the appropriate MiDaS model type 1611 | midas_model_type = "dpt_hybrid" 1612 | if "large" in model_name.lower(): 1613 | midas_model_type = "dpt_large" 1614 | elif "small" in model_name.lower(): 1615 | midas_model_type = "midas_v21_small" 1616 | 1617 | # Create and test the MiDaS model 1618 | midas_model = MiDaSWrapper(midas_model_type, device) 1619 | 1620 | # Test with a small image 1621 | test_img = Image.new("RGB", (64, 64), color=(128, 128, 128)) 1622 | _ = midas_model(test_img) 1623 | 1624 | logger.info("MiDaS model loaded and tested successfully") 1625 | return midas_model 1626 | 1627 | except Exception as e: 1628 | logger.error(f"Error loading MiDaS: {str(e)}") 1629 | logger.debug(traceback.format_exc()) 1630 | 1631 | # If all else fails, return None 1632 | return None 1633 | 1634 | except Exception as e: 1635 | logger.error(f"Direct model loading failed: {str(e)}") 1636 | logger.debug(traceback.format_exc()) 1637 | return None 1638 | 1639 | def process_image(self, image: Union[torch.Tensor, np.ndarray], input_size: int = 518) -> Image.Image: 1640 | """ 1641 | Converts input image to proper format for depth estimation and resizes it. 1642 | 1643 | Args: 1644 | image: Input image as tensor or numpy array 1645 | input_size: Target size for the longest dimension of the image 1646 | 1647 | Returns: 1648 | PIL Image ready for depth estimation 1649 | """ 1650 | try: 1651 | # Log input information for debugging 1652 | if torch.is_tensor(image): 1653 | logger.info(f"Processing tensor image with shape {image.shape}, dtype {image.dtype}") 1654 | elif isinstance(image, np.ndarray): 1655 | logger.info(f"Processing numpy image with shape {image.shape}, dtype {image.dtype}") 1656 | else: 1657 | logger.warning(f"Unexpected image type: {type(image)}") 1658 | 1659 | # Validate and normalize input_size with improved bounds checking 1660 | if not isinstance(input_size, (int, float)): 1661 | logger.warning(f"Invalid input_size type: {type(input_size)}. Using default 518.") 1662 | input_size = 518 1663 | else: 1664 | try: 1665 | # Convert to int and constrain to valid range 1666 | input_size = int(input_size) 1667 | except: 1668 | logger.warning(f"Error converting input_size to int. Using default 518.") 1669 | input_size = 518 1670 | 1671 | # Ensure input_size is within valid range 1672 | if input_size < 256: 1673 | logger.warning(f"Input size {input_size} is too small, using 256 instead") 1674 | input_size = 256 1675 | elif input_size > 1024: 1676 | logger.warning(f"Input size {input_size} is too large, using 1024 instead") 1677 | input_size = 1024 1678 | 1679 | # Process tensor input with comprehensive error handling 1680 | if torch.is_tensor(image): 1681 | try: 1682 | # Check tensor dtype and convert to float32 if needed 1683 | if image.dtype != torch.float32: 1684 | logger.info(f"Converting input tensor from {image.dtype} to torch.float32") 1685 | image = image.float() # Convert to FloatTensor for consistency 1686 | 1687 | # Check for NaN/Inf values in tensor 1688 | nan_count = torch.isnan(image).sum().item() 1689 | inf_count = torch.isinf(image).sum().item() 1690 | 1691 | if nan_count > 0 or inf_count > 0: 1692 | logger.warning(f"Input tensor contains {nan_count} NaN and {inf_count} Inf values. Replacing with valid values.") 1693 | image = torch.nan_to_num(image, nan=0.0, posinf=1.0, neginf=0.0) 1694 | 1695 | # Handle empty or invalid tensors 1696 | if image.numel() == 0: 1697 | logger.error("Input tensor is empty (zero elements)") 1698 | return Image.new('RGB', (512, 512), (128, 128, 128)) 1699 | 1700 | # Handle tensor with incorrect number of dimensions 1701 | # We need a 3D or 4D tensor to properly extract the image 1702 | if image.dim() < 3: 1703 | logger.warning(f"Input tensor has too few dimensions: {image.dim()}D. Adding dimensions.") 1704 | # Add dimensions until we have at least 3D tensor 1705 | while image.dim() < 3: 1706 | image = image.unsqueeze(0) 1707 | logger.info(f"Adjusted tensor shape to: {image.shape}") 1708 | 1709 | # Normalize values to [0, 1] range if needed 1710 | if image.max() > 1.0 + 1e-5: # Allow small floating point error 1711 | min_val, max_val = image.min().item(), image.max().item() 1712 | logger.info(f"Input tensor values outside [0,1] range: min={min_val}, max={max_val}. Normalizing.") 1713 | 1714 | # Common ranges and conversions 1715 | if min_val >= 0 and max_val <= 255: 1716 | # Assume [0-255] range for images 1717 | image = image / 255.0 1718 | else: 1719 | # General normalization 1720 | image = (image - min_val) / (max_val - min_val) 1721 | 1722 | # Extract first image from batch if we have a batch dimension 1723 | if image.dim() == 4: # [B, C, H, W] 1724 | # Use first image in batch 1725 | image_for_conversion = image[0] 1726 | else: # 3D tensor, assumed to be [C, H, W] 1727 | image_for_conversion = image 1728 | 1729 | # Move to CPU for numpy conversion 1730 | image_for_conversion = image_for_conversion.cpu() 1731 | 1732 | # Convert to numpy, handling different layouts 1733 | if image_for_conversion.shape[0] <= 3: # [C, H, W] format with 1-3 channels 1734 | # Standard CHW layout - convert to HWC for PIL 1735 | image_np = image_for_conversion.permute(1, 2, 0).numpy() 1736 | image_np = image_np * 255.0 # Scale to [0, 255] 1737 | else: 1738 | # Unusual channel count - probably not CHW format 1739 | logger.warning(f"Unusual channel count for CHW format: {image_for_conversion.shape[0]}. Using reshape logic.") 1740 | 1741 | # Try to infer the format and convert appropriately 1742 | if image_for_conversion.dim() == 3 and image_for_conversion.shape[-1] <= 3: 1743 | # Likely [H, W, C] format, no need to permute 1744 | image_np = image_for_conversion.numpy() * 255.0 1745 | else: 1746 | # Unknown format - try to reshape intelligently 1747 | logger.warning("Unable to determine tensor layout. Using first 3 channels.") 1748 | 1749 | # Default: take first 3 channels (or fewer if < 3 channels) 1750 | channels = min(3, image_for_conversion.shape[0]) 1751 | image_np = image_for_conversion[:channels].permute(1, 2, 0).numpy() * 255.0 1752 | 1753 | # Ensure we have proper RGB image (3 channels) 1754 | if len(image_np.shape) == 2: # Single channel image 1755 | image_np = np.stack([image_np] * 3, axis=-1) 1756 | elif image_np.shape[-1] == 1: # Single channel in last dimension 1757 | image_np = np.concatenate([image_np] * 3, axis=-1) 1758 | elif image_np.shape[-1] == 4: # RGBA image - drop alpha channel 1759 | image_np = image_np[..., :3] 1760 | elif image_np.shape[-1] > 4: # More than 4 channels - use first 3 1761 | logger.warning(f"Image has {image_np.shape[-1]} channels. Using first 3 channels.") 1762 | image_np = image_np[..., :3] 1763 | 1764 | # Ensure proper data type and range 1765 | image_np = np.clip(image_np, 0, 255).astype(np.uint8) 1766 | 1767 | except Exception as tensor_error: 1768 | logger.error(f"Error processing tensor image: {str(tensor_error)}") 1769 | logger.error(traceback.format_exc()) 1770 | # Create a fallback RGB gradient as placeholder 1771 | placeholder = np.zeros((512, 512, 3), dtype=np.uint8) 1772 | # Add a gradient pattern for visual distinction 1773 | for i in range(512): 1774 | v = int(i / 512 * 255) 1775 | placeholder[i, :, 0] = v # R channel gradient 1776 | placeholder[:, i, 1] = v # G channel gradient 1777 | return Image.fromarray(placeholder) 1778 | 1779 | # Process numpy array input with comprehensive error handling 1780 | elif isinstance(image, np.ndarray): 1781 | try: 1782 | # Check for NaN/Inf values in array 1783 | if np.isnan(image).any() or np.isinf(image).any(): 1784 | logger.warning("Input array contains NaN or Inf values. Replacing with zeros.") 1785 | image = np.nan_to_num(image, nan=0.0, posinf=1.0, neginf=0.0) 1786 | 1787 | # Handle empty or invalid arrays 1788 | if image.size == 0: 1789 | logger.error("Input array is empty (zero elements)") 1790 | return Image.new('RGB', (512, 512), (128, 128, 128)) 1791 | 1792 | # Convert high-precision types to float32 for consistency 1793 | if image.dtype == np.float64 or image.dtype == np.float16: 1794 | logger.info(f"Converting numpy array from {image.dtype} to float32") 1795 | image = image.astype(np.float32) 1796 | 1797 | # Normalize values to [0-1] range if float array 1798 | if np.issubdtype(image.dtype, np.floating): 1799 | # Check current range 1800 | min_val, max_val = image.min(), image.max() 1801 | 1802 | # Normalize if outside [0,1] range 1803 | if min_val < 0.0 - 1e-5 or max_val > 1.0 + 1e-5: # Allow small floating point error 1804 | logger.info(f"Normalizing array from range [{min_val:.2f}, {max_val:.2f}] to [0, 1]") 1805 | 1806 | # Common ranges and conversions 1807 | if min_val >= 0 and max_val <= 255: 1808 | # Assume [0-255] float range 1809 | image = image / 255.0 1810 | else: 1811 | # General min-max normalization 1812 | image = (image - min_val) / (max_val - min_val) 1813 | 1814 | # Convert normalized float to uint8 for PIL 1815 | image_np = (image * 255).astype(np.uint8) 1816 | elif np.issubdtype(image.dtype, np.integer): 1817 | # For integer types, check if normalization is needed 1818 | max_val = image.max() 1819 | if max_val > 255: 1820 | logger.info(f"Scaling integer array with max value {max_val} to 0-255 range") 1821 | # Scale the array to 0-255 range 1822 | scaled = (image.astype(np.float32) / max_val) * 255 1823 | image_np = scaled.astype(np.uint8) 1824 | else: 1825 | # Already in valid range 1826 | image_np = image.astype(np.uint8) 1827 | else: 1828 | # Unsupported dtype - convert through float32 1829 | logger.warning(f"Unsupported dtype: {image.dtype}. Converting through float32.") 1830 | image_np = (image.astype(np.float32) * 255).astype(np.uint8) 1831 | 1832 | # Handle different channel configurations and dimensions 1833 | if len(image_np.shape) == 2: 1834 | # Grayscale image - convert to RGB 1835 | logger.info("Converting grayscale image to RGB") 1836 | image_np = np.stack([image_np] * 3, axis=-1) 1837 | elif len(image_np.shape) == 3: 1838 | # Check channel dimension 1839 | if image_np.shape[-1] == 1: 1840 | # Single-channel 3D array - convert to RGB 1841 | image_np = np.concatenate([image_np] * 3, axis=-1) 1842 | elif image_np.shape[-1] == 4: 1843 | # RGBA image - drop alpha channel 1844 | image_np = image_np[..., :3] 1845 | elif image_np.shape[-1] > 4: 1846 | # More than 4 channels - use first 3 1847 | logger.warning(f"Image has {image_np.shape[-1]} channels. Using first 3 channels.") 1848 | image_np = image_np[..., :3] 1849 | elif image_np.shape[-1] < 3: 1850 | # Less than 3 channels but not 1 - unusual case 1851 | logger.warning(f"Unusual channel count: {image_np.shape[-1]}. Expanding to RGB.") 1852 | # Repeat the channels to get 3 1853 | channels = [image_np[..., i % image_np.shape[-1]] for i in range(3)] 1854 | image_np = np.stack(channels, axis=-1) 1855 | elif len(image_np.shape) > 3: 1856 | # More than 3 dimensions - attempt to extract a valid image 1857 | logger.warning(f"Array has {len(image_np.shape)} dimensions. Attempting to extract 3D slice.") 1858 | 1859 | # Try to get a 3D slice with channels in the last dimension 1860 | if image_np.shape[-1] <= 3: 1861 | # Extract first instance of higher dimensions 1862 | while len(image_np.shape) > 3: 1863 | image_np = image_np[0] 1864 | 1865 | # If we have less than 3 channels, expand to RGB 1866 | if image_np.shape[-1] < 3: 1867 | channels = [image_np[..., i % image_np.shape[-1]] for i in range(3)] 1868 | image_np = np.stack(channels, axis=-1) 1869 | else: 1870 | # Channels not in last dimension - reshape based on assumptions 1871 | logger.warning("Could not determine valid layout. Creating placeholder.") 1872 | image_np = np.zeros((512, 512, 3), dtype=np.uint8) 1873 | # Add gradient for visual distinction 1874 | for i in range(512): 1875 | v = int(i / 512 * 255) 1876 | image_np[i, :, 0] = v 1877 | 1878 | except Exception as numpy_error: 1879 | logger.error(f"Error processing numpy array: {str(numpy_error)}") 1880 | logger.error(traceback.format_exc()) 1881 | # Create a fallback pattern as placeholder 1882 | placeholder = np.zeros((512, 512, 3), dtype=np.uint8) 1883 | # Add checkboard pattern 1884 | for i in range(0, 512, 32): 1885 | for j in range(0, 512, 32): 1886 | if (i//32 + j//32) % 2 == 0: 1887 | placeholder[i:i+32, j:j+32] = 200 1888 | return Image.fromarray(placeholder) 1889 | 1890 | # Fallback for non-tensor, non-numpy inputs 1891 | else: 1892 | logger.error(f"Unsupported image type: {type(image)}") 1893 | return Image.new('RGB', (512, 512), (100, 100, 150)) # Distinct color for type errors 1894 | 1895 | # Convert to PIL image with error handling 1896 | try: 1897 | pil_image = Image.fromarray(image_np) 1898 | except Exception as pil_error: 1899 | logger.error(f"Error creating PIL image: {str(pil_error)}") 1900 | # Try shape correction if possible 1901 | try: 1902 | if len(image_np.shape) != 3 or image_np.shape[-1] not in [1, 3, 4]: 1903 | logger.warning(f"Invalid array shape for PIL: {image_np.shape}") 1904 | # Create valid RGB array as fallback 1905 | image_np = np.zeros((512, 512, 3), dtype=np.uint8) 1906 | pil_image = Image.fromarray(image_np) 1907 | else: 1908 | # Other error - use placeholder 1909 | pil_image = Image.new('RGB', (512, 512), (128, 128, 128)) 1910 | except: 1911 | # Ultimate fallback 1912 | pil_image = Image.new('RGB', (512, 512), (128, 128, 128)) 1913 | 1914 | # Resize the image while preserving aspect ratio and ensuring multiple of 32 1915 | # The multiple of 32 constraint helps prevent tensor dimension errors 1916 | width, height = pil_image.size 1917 | logger.info(f"Original PIL image size: {width}x{height}") 1918 | 1919 | # Determine which dimension to scale to input_size 1920 | if width > height: 1921 | new_width = input_size 1922 | new_height = int(height * (new_width / width)) 1923 | else: 1924 | new_height = input_size 1925 | new_width = int(width * (new_height / height)) 1926 | 1927 | # Ensure dimensions are multiples of 32 for better compatibility 1928 | new_width = ((new_width + 31) // 32) * 32 1929 | new_height = ((new_height + 31) // 32) * 32 1930 | 1931 | # Resize the image with antialiasing 1932 | try: 1933 | resized_image = pil_image.resize((new_width, new_height), Image.LANCZOS) 1934 | logger.info(f"Resized image from {width}x{height} to {new_width}x{new_height}") 1935 | 1936 | # Verify the resized image 1937 | if resized_image.size[0] <= 0 or resized_image.size[1] <= 0: 1938 | raise ValueError(f"Invalid resize dimensions: {resized_image.size}") 1939 | 1940 | return resized_image 1941 | except Exception as resize_error: 1942 | logger.error(f"Error during image resize: {str(resize_error)}") 1943 | 1944 | # Try a simpler resize method as fallback 1945 | try: 1946 | logger.info("Trying fallback resize method") 1947 | resized_image = pil_image.resize((new_width, new_height), Image.NEAREST) 1948 | return resized_image 1949 | except: 1950 | # Last resort - return original image or placeholder 1951 | if width > 0 and height > 0: 1952 | logger.warning("Fallback resize failed. Returning original image.") 1953 | return pil_image 1954 | else: 1955 | logger.warning("Invalid original image. Returning placeholder.") 1956 | return Image.new('RGB', (512, 512), (128, 128, 128)) 1957 | 1958 | except Exception as e: 1959 | # Global catch-all handler 1960 | logger.error(f"Error processing image: {str(e)}") 1961 | logger.error(traceback.format_exc()) 1962 | 1963 | # Create a visually distinct placeholder image 1964 | placeholder = Image.new('RGB', (512, 512), (120, 80, 80)) 1965 | 1966 | try: 1967 | # Add error text to the image for better user feedback 1968 | from PIL import ImageDraw, ImageFont 1969 | draw = ImageDraw.Draw(placeholder) 1970 | 1971 | # Try to get a font, fall back to default if needed 1972 | try: 1973 | font = ImageFont.truetype("arial.ttf", 20) 1974 | except: 1975 | font = ImageFont.load_default() 1976 | 1977 | # Add error text 1978 | error_text = str(e) 1979 | # Limit error text length 1980 | if len(error_text) > 60: 1981 | error_text = error_text[:57] + "..." 1982 | 1983 | # Draw error message 1984 | draw.text((10, 10), "Image Processing Error", fill=(255, 50, 50), font=font) 1985 | draw.text((10, 40), error_text, fill=(255, 255, 255), font=font) 1986 | except: 1987 | # Error adding text - just return the plain placeholder 1988 | pass 1989 | 1990 | return placeholder 1991 | 1992 | def _create_error_image(self, input_image=None): 1993 | """Create an error image placeholder based on input image if possible.""" 1994 | try: 1995 | if input_image is not None and isinstance(input_image, torch.Tensor) and input_image.shape[0] > 0: 1996 | # Check tensor type - if it's float64, log it for debugging 1997 | if input_image.dtype == torch.float64 or input_image.dtype == torch.double: 1998 | logger.info(f"Input tensor for error image is {input_image.dtype}, will create float32 error image") 1999 | 2000 | # Create gray error image with same dimensions as input 2001 | # Ensure tensor has the right shape for error display (BHWC) 2002 | if input_image.ndim == 4: 2003 | if input_image.shape[-1] != 3: # if not BHWC format 2004 | if input_image.shape[1] == 3: # if BCHW format 2005 | # Extract height and width from BCHW 2006 | h, w = input_image.shape[2], input_image.shape[3] 2007 | else: 2008 | # Default to dimensions from input 2009 | h, w = input_image.shape[2], input_image.shape[3] 2010 | else: 2011 | # Already in BHWC format 2012 | h, w = input_image.shape[1], input_image.shape[2] 2013 | else: 2014 | # Unexpected shape, use default 2015 | return self._create_basic_error_image() 2016 | 2017 | # Make sure dimensions aren't too small 2018 | if h <= 1 or w <= 1: 2019 | logger.warning(f"Input has invalid dimensions {h}x{w}, using default error image") 2020 | return self._create_basic_error_image() 2021 | 2022 | # Gray background with slight red tint to indicate error - explicitly use float32 2023 | placeholder = torch.ones((1, h, w, 3), dtype=torch.float32) * torch.tensor([0.5, 0.4, 0.4], dtype=torch.float32) 2024 | 2025 | if self.device is not None: 2026 | placeholder = placeholder.to(self.device) 2027 | 2028 | # Verify the placeholder is float32 2029 | if placeholder.dtype != torch.float32: 2030 | logger.warning(f"Error image has unexpected dtype {placeholder.dtype}, converting to float32") 2031 | placeholder = placeholder.float() 2032 | 2033 | return placeholder 2034 | else: 2035 | return self._create_basic_error_image() 2036 | except Exception as e: 2037 | logger.error(f"Error creating error image: {str(e)}") 2038 | return self._create_basic_error_image() 2039 | 2040 | def _create_basic_error_image(self): 2041 | """Create a basic error image when no input dimensions are available.""" 2042 | # Standard size error image (512x512) 2043 | h, w = 512, 512 2044 | # Gray background with slight red tint to indicate error - explicitly use float32 2045 | placeholder = torch.ones((1, h, w, 3), dtype=torch.float32) * torch.tensor([0.5, 0.4, 0.4], dtype=torch.float32) 2046 | 2047 | if self.device is not None: 2048 | placeholder = placeholder.to(self.device) 2049 | 2050 | # Double-check that we're returning a float32 tensor 2051 | if placeholder.dtype != torch.float32: 2052 | placeholder = placeholder.float() 2053 | 2054 | return placeholder 2055 | 2056 | def _add_error_text_to_image(self, image_tensor, error_text): 2057 | """Add error text to the image tensor for visual feedback.""" 2058 | try: 2059 | # Convert tensor to PIL for text rendering 2060 | if image_tensor is None: 2061 | return 2062 | 2063 | temp_img = self._tensor_to_pil(image_tensor) 2064 | 2065 | # Draw error text 2066 | draw = ImageDraw.Draw(temp_img) 2067 | 2068 | # Try to get a font, fall back to default if needed 2069 | try: 2070 | font = ImageFont.truetype("arial.ttf", 20) 2071 | except: 2072 | font = ImageFont.load_default() 2073 | 2074 | # Split text into multiple lines if too long 2075 | lines = [] 2076 | words = error_text.split() 2077 | current_line = words[0] if words else "Error" 2078 | 2079 | for word in words[1:]: 2080 | if len(current_line + " " + word) < 50: 2081 | current_line += " " + word 2082 | else: 2083 | lines.append(current_line) 2084 | current_line = word 2085 | 2086 | lines.append(current_line) 2087 | 2088 | # Draw title 2089 | draw.text((10, 10), "Depth Estimation Error", fill=(255, 50, 50), font=font) 2090 | 2091 | # Draw error message 2092 | y_position = 40 2093 | for line in lines: 2094 | draw.text((10, y_position), line, fill=(255, 255, 255), font=font) 2095 | y_position += 25 2096 | 2097 | # Convert back to tensor 2098 | result = self._pil_to_tensor(temp_img) 2099 | 2100 | # Copy to original tensor if shapes match 2101 | if image_tensor.shape == result.shape: 2102 | image_tensor.copy_(result) 2103 | return image_tensor 2104 | 2105 | except Exception as e: 2106 | logger.error(f"Error adding text to error image: {e}") 2107 | return image_tensor 2108 | 2109 | def _tensor_to_pil(self, tensor): 2110 | """Convert a tensor to PIL Image.""" 2111 | if tensor.shape[0] == 1: # Batch size 1 2112 | img_np = (tensor[0].cpu().numpy() * 255).astype(np.uint8) 2113 | return Image.fromarray(img_np) 2114 | return Image.new('RGB', (512, 512), color=(128, 100, 100)) 2115 | 2116 | def _pil_to_tensor(self, pil_img): 2117 | """Convert PIL Image back to tensor.""" 2118 | img_np = np.array(pil_img).astype(np.float32) / 255.0 2119 | tensor = torch.from_numpy(img_np).unsqueeze(0) 2120 | 2121 | if self.device is not None: 2122 | tensor = tensor.to(self.device) 2123 | 2124 | return tensor 2125 | 2126 | def estimate_depth(self, 2127 | image: torch.Tensor, 2128 | model_name: str, 2129 | input_size: int = 518, 2130 | blur_radius: float = 2.0, 2131 | median_size: str = "5", 2132 | apply_auto_contrast: bool = True, 2133 | apply_gamma: bool = True, 2134 | force_reload: bool = False, 2135 | force_cpu: bool = False) -> Tuple[torch.Tensor]: 2136 | """ 2137 | Estimates depth from input image with error handling and cleanup. 2138 | 2139 | Args: 2140 | image: Input image tensor 2141 | model_name: Name of the depth model to use 2142 | input_size: Target size for the longest dimension of the image (between 256 and 1024) 2143 | blur_radius: Gaussian blur radius for smoothing 2144 | median_size: Size of median filter for noise reduction 2145 | apply_auto_contrast: Whether to enhance contrast automatically 2146 | apply_gamma: Whether to apply gamma correction 2147 | force_reload: Whether to force reload the model 2148 | force_cpu: Whether to force using CPU for inference 2149 | 2150 | Returns: 2151 | Tuple containing depth map tensor 2152 | """ 2153 | error_image = None 2154 | start_time = time.time() 2155 | 2156 | try: 2157 | # Sanity check inputs and log initial info 2158 | logger.info(f"Starting depth estimation with model: {model_name}, input_size: {input_size}, force_cpu: {force_cpu}") 2159 | 2160 | # Enhanced input validation with better error handling 2161 | if image is None: 2162 | logger.error("Input image is None") 2163 | error_image = self._create_basic_error_image() 2164 | self._add_error_text_to_image(error_image, "Input image is None") 2165 | return (error_image,) 2166 | 2167 | if image.numel() == 0: 2168 | logger.error("Input image is empty (zero elements)") 2169 | error_image = self._create_basic_error_image() 2170 | self._add_error_text_to_image(error_image, "Input image is empty (zero elements)") 2171 | return (error_image,) 2172 | 2173 | # Log tensor information before processing 2174 | logger.info(f"Input tensor shape: {image.shape}, dtype: {image.dtype}, device: {image.device}") 2175 | 2176 | # Verify tensor dimensions - support different input formats 2177 | if image.ndim != 4: 2178 | logger.warning(f"Expected 4D tensor for image, got {image.ndim}D. Attempting to reshape.") 2179 | try: 2180 | # Try to reshape based on common dimension patterns 2181 | if image.ndim == 3: 2182 | # Could be [C, H, W] or [H, W, C] format 2183 | if image.shape[0] <= 3 and image.shape[0] > 0: # Likely [C, H, W] 2184 | image = image.unsqueeze(0) # Add batch dim -> [1, C, H, W] 2185 | logger.info(f"Reshaped 3D tensor to 4D with shape: {image.shape}") 2186 | elif image.shape[-1] <= 3 and image.shape[-1] > 0: # Likely [H, W, C] 2187 | # Permute to [C, H, W] then add batch dim 2188 | image = image.permute(2, 0, 1).unsqueeze(0) 2189 | logger.info(f"Reshaped HWC tensor to BCHW with shape: {image.shape}") 2190 | else: 2191 | # Assume single channel image and add missing dimensions 2192 | image = image.unsqueeze(0).unsqueeze(0) 2193 | logger.info(f"Added batch and channel dimensions to 3D tensor: {image.shape}") 2194 | elif image.ndim == 2: 2195 | # Assume [H, W] format - add batch and channel dims 2196 | image = image.unsqueeze(0).unsqueeze(0) 2197 | logger.info(f"Reshaped 2D tensor to 4D with shape: {image.shape}") 2198 | elif image.ndim > 4: 2199 | # Too many dimensions, collapse extras 2200 | orig_shape = image.shape 2201 | image = image.reshape(1, orig_shape[1], orig_shape[-2], orig_shape[-1]) 2202 | logger.info(f"Collapsed >4D tensor to 4D with shape: {image.shape}") 2203 | else: 2204 | # Fallback for other unusual dimensions 2205 | logger.error(f"Cannot automatically reshape tensor with {image.ndim} dimensions") 2206 | error_image = self._create_basic_error_image() 2207 | self._add_error_text_to_image(error_image, f"Unsupported tensor dimensions: {image.ndim}D") 2208 | return (error_image,) 2209 | except Exception as reshape_error: 2210 | logger.error(f"Error reshaping tensor: {str(reshape_error)}") 2211 | error_image = self._create_basic_error_image() 2212 | self._add_error_text_to_image(error_image, f"Error reshaping tensor: {str(reshape_error)[:100]}") 2213 | return (error_image,) 2214 | 2215 | # Comprehensive type checking and conversion - verify at multiple points 2216 | # 1. Initial type check and convert if needed 2217 | if image.dtype != torch.float32: 2218 | logger.info(f"Converting input tensor from {image.dtype} to torch.float32") 2219 | 2220 | # Safe conversion that handles different input types 2221 | try: 2222 | if image.dtype == torch.uint8: 2223 | # Normalize [0-255] -> [0-1] for uint8 inputs 2224 | image = image.float() / 255.0 2225 | else: 2226 | # Standard conversion for other types 2227 | image = image.float() 2228 | except Exception as type_error: 2229 | logger.error(f"Error converting tensor type: {str(type_error)}") 2230 | error_image = self._create_basic_error_image() 2231 | self._add_error_text_to_image(error_image, f"Type conversion error: {str(type_error)[:100]}") 2232 | return (error_image,) 2233 | 2234 | # Create error image placeholder based on input dimensions 2235 | error_image = self._create_error_image(image) 2236 | 2237 | # 2. Check for NaN/Inf values and fix them 2238 | nan_count = torch.isnan(image).sum().item() 2239 | inf_count = torch.isinf(image).sum().item() 2240 | 2241 | if nan_count > 0 or inf_count > 0: 2242 | logger.warning(f"Input contains {nan_count} NaN and {inf_count} Inf values. Fixing problematic values.") 2243 | image = torch.nan_to_num(image, nan=0.0, posinf=1.0, neginf=0.0) 2244 | 2245 | # 3. Check value range and normalize if needed 2246 | min_val, max_val = image.min().item(), image.max().item() 2247 | if max_val > 1.0 + 1e-5: # Allow small floating point error 2248 | logger.info(f"Input values outside [0,1] range: min={min_val}, max={max_val}. Normalizing.") 2249 | 2250 | # Ensure values are in [0,1] range - handle different input formats 2251 | if min_val >= 0 and max_val <= 255: 2252 | # Likely [0-255] range 2253 | image = image / 255.0 2254 | else: 2255 | # General min-max normalization 2256 | image = (image - min_val) / (max_val - min_val) 2257 | 2258 | # Parameter validation with safer defaults 2259 | # Validate and normalize median_size parameter 2260 | median_size_str = str(median_size) # Convert to string regardless of input type 2261 | if median_size_str not in self.MEDIAN_SIZES: 2262 | logger.warning(f"Invalid median_size: '{median_size}' (type: {type(median_size)}). Using default '5'.") 2263 | median_size_str = "5" 2264 | 2265 | # Validate input_size with stricter bounds 2266 | if not isinstance(input_size, (int, float)): 2267 | logger.warning(f"Invalid input_size type: {type(input_size)}. Using default 518.") 2268 | input_size = 518 2269 | else: 2270 | # Convert to int and constrain to valid range 2271 | try: 2272 | input_size = int(input_size) 2273 | input_size = max(256, min(input_size, 1024)) # Clamp between 256 and 1024 2274 | except: 2275 | logger.warning(f"Error converting input_size to int. Using default 518.") 2276 | input_size = 518 2277 | 2278 | # Try loading the model with graceful fallback 2279 | try: 2280 | self.ensure_model_loaded(model_name, force_reload, force_cpu) 2281 | logger.info(f"Model '{model_name}' loaded successfully") 2282 | except Exception as model_error: 2283 | error_msg = f"Failed to load model '{model_name}': {str(model_error)}" 2284 | logger.error(error_msg) 2285 | 2286 | # Try a more reliable fallback model before giving up 2287 | fallback_models = ["MiDaS-Small", "MiDaS-Base"] 2288 | for fallback_model in fallback_models: 2289 | if fallback_model != model_name: 2290 | logger.info(f"Attempting to load fallback model: {fallback_model}") 2291 | try: 2292 | self.ensure_model_loaded(fallback_model, True, True) # Force reload and CPU for reliability 2293 | logger.info(f"Fallback model '{fallback_model}' loaded successfully") 2294 | # Update model_name to reflect the fallback 2295 | model_name = fallback_model 2296 | # Break the loop since we successfully loaded a fallback 2297 | break 2298 | except Exception as fallback_error: 2299 | logger.warning(f"Fallback model '{fallback_model}' also failed: {str(fallback_error)}") 2300 | continue 2301 | 2302 | # If we still don't have a model loaded, return error image 2303 | if self.depth_estimator is None: 2304 | self._add_error_text_to_image(error_image, f"Model Error: {str(model_error)[:100]}...") 2305 | return (error_image,) 2306 | 2307 | # Process input image with enhanced error recovery 2308 | try: 2309 | # Convert to PIL with robust error handling 2310 | pil_image = self.process_image(image, input_size) 2311 | # Store original dimensions for later resizing 2312 | original_width, original_height = pil_image.size 2313 | logger.info(f"Image processed to size: {pil_image.size} (will preserve these dimensions in output)") 2314 | except Exception as img_error: 2315 | logger.error(f"Image processing error: {str(img_error)}") 2316 | logger.error(traceback.format_exc()) 2317 | 2318 | # Try a more basic approach if the standard processing fails 2319 | try: 2320 | logger.info("Attempting basic image conversion as fallback") 2321 | # Simple conversion fallback 2322 | if image.shape[1] > 3: # BCHW format with unusual channel count 2323 | # Select first 3 channels or average if > 3 2324 | logger.warning(f"Unusual channel count: {image.shape[1]}. Using first 3 channels.") 2325 | if image.shape[1] > 3: 2326 | image = image[:, :3, :, :] 2327 | 2328 | # Convert to CPU numpy array 2329 | img_np = image.squeeze(0).cpu().numpy() 2330 | 2331 | # Handle different layouts 2332 | if img_np.shape[0] <= 3: # [C, H, W] 2333 | img_np = np.transpose(img_np, (1, 2, 0)) # -> [H, W, C] 2334 | 2335 | # Ensure 3 channels 2336 | if len(img_np.shape) == 2: # Grayscale 2337 | img_np = np.stack([img_np] * 3, axis=-1) 2338 | elif img_np.shape[-1] == 1: # Single channel 2339 | img_np = np.concatenate([img_np] * 3, axis=-1) 2340 | 2341 | # Normalize values 2342 | if img_np.max() > 1.0: 2343 | img_np = img_np / 255.0 2344 | 2345 | # Convert to PIL 2346 | img_np = (img_np * 255).astype(np.uint8) 2347 | pil_image = Image.fromarray(img_np) 2348 | 2349 | # Store original dimensions 2350 | original_width, original_height = pil_image.size 2351 | 2352 | # Resize to appropriate dimensions 2353 | if input_size > 0: 2354 | w, h = pil_image.size 2355 | # Determine which dimension to scale to input_size 2356 | if w > h: 2357 | new_w = input_size 2358 | new_h = int(h * (new_w / w)) 2359 | else: 2360 | new_h = input_size 2361 | new_w = int(w * (new_h / h)) 2362 | 2363 | # Ensure dimensions are multiples of 32 2364 | new_h = ((new_h + 31) // 32) * 32 2365 | new_w = ((new_w + 31) // 32) * 32 2366 | 2367 | pil_image = pil_image.resize((new_w, new_h), Image.LANCZOS) 2368 | 2369 | logger.info(f"Fallback image processing succeeded with size: {pil_image.size}") 2370 | except Exception as fallback_error: 2371 | logger.error(f"Fallback image processing also failed: {str(fallback_error)}") 2372 | self._add_error_text_to_image(error_image, f"Image Error: {str(img_error)[:100]}...") 2373 | return (error_image,) 2374 | 2375 | # Depth estimation with comprehensive error handling 2376 | try: 2377 | # Use inference_mode for better memory usage 2378 | with torch.inference_mode(): 2379 | logger.info(f"Running inference on image of size {pil_image.size}") 2380 | 2381 | # Ensure the depth estimator is in eval mode 2382 | if hasattr(self.depth_estimator, 'eval'): 2383 | self.depth_estimator.eval() 2384 | 2385 | # Add timing for performance analysis 2386 | inference_start = time.time() 2387 | 2388 | # Perform inference with better error handling 2389 | try: 2390 | depth_result = self.depth_estimator(pil_image) 2391 | inference_time = time.time() - inference_start 2392 | logger.info(f"Depth inference completed in {inference_time:.2f} seconds") 2393 | except Exception as inference_error: 2394 | logger.error(f"Depth estimator inference error: {str(inference_error)}") 2395 | 2396 | # Detailed error analysis for better debugging 2397 | error_str = str(inference_error).lower() 2398 | 2399 | # Try fallback for common error types 2400 | if "cuda" in error_str and "out of memory" in error_str: 2401 | logger.warning("CUDA out of memory detected. Attempting CPU fallback.") 2402 | # Already on CPU? Check and handle 2403 | if force_cpu: 2404 | logger.error("Already using CPU but still encountered memory error") 2405 | self._add_error_text_to_image(error_image, "Memory error even on CPU. Try smaller input size.") 2406 | return (error_image,) 2407 | else: 2408 | # Try CPU fallback 2409 | logger.info("Switching to CPU processing") 2410 | return self.estimate_depth( 2411 | image.cpu(), model_name, input_size, blur_radius, median_size_str, 2412 | apply_auto_contrast, apply_gamma, True, True # Force CPU 2413 | ) 2414 | 2415 | # Type mismatch errors 2416 | elif "input type" in error_str and "weight type" in error_str: 2417 | logger.warning("Tensor type mismatch detected. Attempting explicit type conversion.") 2418 | # Try with explicit CPU conversion 2419 | return self.estimate_depth( 2420 | image.float().cpu(), model_name, input_size, blur_radius, median_size_str, 2421 | apply_auto_contrast, apply_gamma, True, True # Force CPU and reload 2422 | ) 2423 | 2424 | # Dimension mismatch errors 2425 | elif "dimensions" in error_str or "dimension" in error_str or "shape" in error_str: 2426 | logger.warning("Tensor dimension mismatch detected. Trying alternate approach.") 2427 | # Fall back to CPU MiDaS model which has more robust dimension handling 2428 | logger.info("Falling back to MiDaS model on CPU") 2429 | try: 2430 | # Clean up current model 2431 | self.cleanup() 2432 | # Try to load MiDaS model 2433 | self.ensure_model_loaded("MiDaS-Small", True, True) 2434 | # Retry with the new model 2435 | return self.estimate_depth( 2436 | image.cpu(), "MiDaS-Small", input_size, blur_radius, median_size_str, 2437 | apply_auto_contrast, apply_gamma, False, True # Already reloaded, force CPU 2438 | ) 2439 | except Exception as midas_error: 2440 | logger.error(f"MiDaS fallback also failed: {str(midas_error)}") 2441 | self._add_error_text_to_image(error_image, f"Inference Error: {str(inference_error)[:100]}...") 2442 | return (error_image,) 2443 | 2444 | # Other errors - just return error image 2445 | self._add_error_text_to_image(error_image, f"Inference Error: {str(inference_error)[:100]}...") 2446 | return (error_image,) 2447 | 2448 | # Verify depth result and convert to float32 2449 | if not isinstance(depth_result, dict) or "predicted_depth" not in depth_result: 2450 | logger.error(f"Invalid depth result format: {type(depth_result)}") 2451 | self._add_error_text_to_image(error_image, "Invalid depth result format") 2452 | return (error_image,) 2453 | 2454 | # Extract and validate predicted depth 2455 | predicted_depth = depth_result["predicted_depth"] 2456 | 2457 | # Ensure correct tensor type 2458 | if not torch.is_tensor(predicted_depth): 2459 | logger.error(f"Predicted depth is not a tensor: {type(predicted_depth)}") 2460 | self._add_error_text_to_image(error_image, "Predicted depth is not a tensor") 2461 | return (error_image,) 2462 | 2463 | # Convert to float32 if needed 2464 | if predicted_depth.dtype != torch.float32: 2465 | logger.info(f"Converting predicted depth from {predicted_depth.dtype} to float32") 2466 | predicted_depth = predicted_depth.float() 2467 | 2468 | # Convert to CPU for post-processing 2469 | depth_map = predicted_depth.squeeze().cpu().numpy() 2470 | except RuntimeError as rt_error: 2471 | # Handle runtime errors separately for clearer error messages 2472 | error_msg = str(rt_error) 2473 | logger.error(f"Runtime error during depth estimation: {error_msg}") 2474 | 2475 | # Check for specific error types 2476 | if "CUDA out of memory" in error_msg: 2477 | logger.warning("CUDA out of memory. Trying CPU fallback.") 2478 | 2479 | # Only try CPU fallback if not already using CPU 2480 | if not force_cpu: 2481 | try: 2482 | logger.info("Switching to CPU processing") 2483 | return self.estimate_depth( 2484 | image.cpu(), model_name, input_size, blur_radius, median_size_str, 2485 | apply_auto_contrast, apply_gamma, True, True # Force reload and CPU 2486 | ) 2487 | except Exception as cpu_error: 2488 | logger.error(f"CPU fallback failed: {str(cpu_error)}") 2489 | 2490 | self._add_error_text_to_image(error_image, "CUDA Out of Memory. Try a smaller model or image size.") 2491 | else: 2492 | # Generic runtime error 2493 | self._add_error_text_to_image(error_image, f"Runtime Error: {error_msg[:100]}...") 2494 | 2495 | return (error_image,) 2496 | except Exception as e: 2497 | # Handle other exceptions 2498 | error_msg = f"Depth estimation failed: {str(e)}" 2499 | logger.error(error_msg) 2500 | logger.error(traceback.format_exc()) 2501 | self._add_error_text_to_image(error_image, f"Error: {str(e)[:100]}...") 2502 | return (error_image,) 2503 | 2504 | # Validate depth map 2505 | # Check for NaN/Inf values 2506 | if np.isnan(depth_map).any() or np.isinf(depth_map).any(): 2507 | logger.warning("Depth map contains NaN or Inf values. Replacing with zeros.") 2508 | depth_map = np.nan_to_num(depth_map, nan=0.0, posinf=1.0, neginf=0.0) 2509 | 2510 | # Check for empty or invalid depth map 2511 | if depth_map.size == 0: 2512 | logger.error("Depth map is empty") 2513 | self._add_error_text_to_image(error_image, "Empty depth map returned") 2514 | return (error_image,) 2515 | 2516 | # Post-processing with enhanced error handling 2517 | try: 2518 | # Ensure depth values have reasonable range for normalization 2519 | depth_min, depth_max = depth_map.min(), depth_map.max() 2520 | 2521 | # Handle constant depth maps (avoid division by zero) 2522 | if np.isclose(depth_max, depth_min): 2523 | logger.warning("Constant depth map detected (min = max). Using values directly.") 2524 | # Just use a normalized constant value 2525 | depth_map = np.ones_like(depth_map) * 0.5 2526 | else: 2527 | # Normalize to [0, 1] range first - safer for later operations 2528 | depth_map = (depth_map - depth_min) / (depth_max - depth_min) 2529 | 2530 | # Scale to [0, 255] for PIL operations 2531 | depth_map_uint8 = (depth_map * 255.0).astype(np.uint8) 2532 | 2533 | # Log the depth map shape coming from the model 2534 | logger.info(f"Depth map shape from model: {depth_map_uint8.shape}") 2535 | 2536 | # Create PIL image explicitly with L mode (grayscale) 2537 | try: 2538 | depth_pil = Image.fromarray(depth_map_uint8, mode='L') 2539 | logger.info(f"Depth PIL image size before resize: {depth_pil.size}") 2540 | except Exception as pil_error: 2541 | logger.error(f"Error creating PIL image: {str(pil_error)}") 2542 | # Try to reshape the array if dimensions are wrong 2543 | if len(depth_map_uint8.shape) != 2: 2544 | logger.warning(f"Unexpected depth map shape: {depth_map_uint8.shape}. Attempting to fix.") 2545 | # Try to extract first channel if multi-channel 2546 | if len(depth_map_uint8.shape) > 2: 2547 | depth_map_uint8 = depth_map_uint8[..., 0] 2548 | elif len(depth_map_uint8.shape) == 1: 2549 | # 1D array - try to reshape to 2D 2550 | h = int(np.sqrt(depth_map_uint8.size)) 2551 | w = depth_map_uint8.size // h 2552 | depth_map_uint8 = depth_map_uint8.reshape(h, w) 2553 | 2554 | depth_pil = Image.fromarray(depth_map_uint8, mode='L') 2555 | 2556 | # Resize depth map to original dimensions from input image before post-processing 2557 | # This is the key fix for the resolution issue 2558 | try: 2559 | logger.info(f"Resizing depth map to original dimensions: {original_width}x{original_height}") 2560 | depth_pil = depth_pil.resize((original_width, original_height), Image.BICUBIC) 2561 | logger.info(f"Depth PIL image size after resize: {depth_pil.size}") 2562 | except Exception as resize_error: 2563 | logger.error(f"Error resizing depth map: {str(resize_error)}") 2564 | logger.error(traceback.format_exc()) 2565 | 2566 | # Apply post-processing with parameter validation 2567 | # Apply blur if radius is positive 2568 | if blur_radius > 0: 2569 | try: 2570 | depth_pil = depth_pil.filter(ImageFilter.GaussianBlur(radius=blur_radius)) 2571 | except Exception as blur_error: 2572 | logger.warning(f"Error applying blur: {str(blur_error)}. Skipping.") 2573 | 2574 | # Apply median filter if size is valid 2575 | try: 2576 | median_size_int = int(median_size_str) 2577 | if median_size_int > 0: 2578 | depth_pil = depth_pil.filter(ImageFilter.MedianFilter(size=median_size_int)) 2579 | except Exception as median_error: 2580 | logger.warning(f"Error applying median filter: {str(median_error)}. Skipping.") 2581 | 2582 | # Apply auto contrast if requested 2583 | if apply_auto_contrast: 2584 | try: 2585 | depth_pil = ImageOps.autocontrast(depth_pil) 2586 | except Exception as contrast_error: 2587 | logger.warning(f"Error applying auto contrast: {str(contrast_error)}. Skipping.") 2588 | 2589 | # Apply gamma correction if requested 2590 | if apply_gamma: 2591 | try: 2592 | depth_array = np.array(depth_pil).astype(np.float32) / 255.0 2593 | mean_luminance = np.mean(depth_array) 2594 | 2595 | # Avoid division by zero or negative values 2596 | if mean_luminance > 0.001: 2597 | # Calculate gamma based on mean luminance for adaptive correction 2598 | gamma = np.log(0.5) / np.log(mean_luminance) 2599 | 2600 | # Clamp gamma to reasonable range to avoid extreme corrections 2601 | gamma = max(0.1, min(gamma, 3.0)) 2602 | logger.info(f"Applying gamma correction with value: {gamma:.2f}") 2603 | 2604 | # Apply gamma correction 2605 | corrected = np.power(depth_array, 1.0/gamma) * 255.0 2606 | depth_pil = Image.fromarray(corrected.astype(np.uint8), mode='L') 2607 | else: 2608 | logger.warning(f"Mean luminance too low: {mean_luminance}. Skipping gamma correction.") 2609 | except Exception as gamma_error: 2610 | logger.warning(f"Error applying gamma correction: {str(gamma_error)}. Skipping.") 2611 | 2612 | # Convert processed image back to tensor 2613 | # Convert to numpy array for tensor conversion 2614 | depth_array = np.array(depth_pil).astype(np.float32) / 255.0 2615 | 2616 | # Final validation checks 2617 | # Check for invalid dimensions 2618 | h, w = depth_array.shape 2619 | if h <= 1 or w <= 1: 2620 | logger.error(f"Invalid depth map dimensions: {h}x{w}") 2621 | self._add_error_text_to_image(error_image, "Invalid depth map dimensions (too small)") 2622 | return (error_image,) 2623 | 2624 | # Log final dimensions for debugging 2625 | logger.info(f"Final depth map dimensions: {h}x{w}") 2626 | 2627 | # Create RGB depth map by stacking the same grayscale image three times 2628 | # Stack to create a 3-channel image compatible with ComfyUI 2629 | depth_rgb = np.stack([depth_array] * 3, axis=-1) # Shape becomes (h, w, 3) 2630 | 2631 | # Convert to tensor and add batch dimension 2632 | depth_tensor = torch.from_numpy(depth_rgb).unsqueeze(0).float() # Shape becomes (1, h, w, 3) 2633 | 2634 | # Optional: move to device if not using CPU 2635 | if self.device is not None and not force_cpu: 2636 | depth_tensor = depth_tensor.to(self.device) 2637 | 2638 | # Validate output tensor 2639 | if torch.isnan(depth_tensor).any() or torch.isinf(depth_tensor).any(): 2640 | logger.warning("Final tensor contains NaN or Inf values. Fixing.") 2641 | depth_tensor = torch.nan_to_num(depth_tensor, nan=0.0, posinf=1.0, neginf=0.0) 2642 | 2643 | # Ensure values are in [0, 1] range 2644 | min_val, max_val = depth_tensor.min().item(), depth_tensor.max().item() 2645 | if min_val < 0.0 or max_val > 1.0: 2646 | logger.warning(f"Depth tensor values outside [0,1] range: min={min_val}, max={max_val}. Normalizing.") 2647 | depth_tensor = torch.clamp(depth_tensor, 0.0, 1.0) 2648 | 2649 | # Log completion info 2650 | processing_time = time.time() - start_time 2651 | logger.info(f"Depth processing completed in {processing_time:.2f} seconds") 2652 | logger.info(f"Output tensor: shape={depth_tensor.shape}, dtype={depth_tensor.dtype}, device={depth_tensor.device}") 2653 | 2654 | return (depth_tensor,) 2655 | 2656 | except Exception as post_error: 2657 | # Handle post-processing errors 2658 | error_msg = f"Error during depth map post-processing: {str(post_error)}" 2659 | logger.error(error_msg) 2660 | logger.error(traceback.format_exc()) 2661 | self._add_error_text_to_image(error_image, f"Post-processing Error: {str(post_error)[:100]}...") 2662 | return (error_image,) 2663 | 2664 | except Exception as e: 2665 | # Global catch-all error handler 2666 | error_msg = f"Depth estimation failed: {str(e)}" 2667 | logger.error(error_msg) 2668 | logger.error(traceback.format_exc()) 2669 | 2670 | # Create error image if needed 2671 | if error_image is None: 2672 | error_image = self._create_basic_error_image() 2673 | 2674 | self._add_error_text_to_image(error_image, f"Unexpected Error: {str(e)[:100]}...") 2675 | return (error_image,) 2676 | finally: 2677 | # Always clean up resources regardless of success or failure 2678 | torch.cuda.empty_cache() 2679 | gc.collect() 2680 | 2681 | def gamma_correction(self, img: Image.Image, gamma: float = 1.0) -> Image.Image: 2682 | """Applies gamma correction to the image.""" 2683 | # Convert to numpy array 2684 | img_array = np.array(img) 2685 | 2686 | # Apply gamma correction directly with numpy 2687 | corrected = np.power(img_array.astype(np.float32) / 255.0, 1.0/gamma) * 255.0 2688 | 2689 | # Ensure uint8 type and create image with explicit mode 2690 | return Image.fromarray(corrected.astype(np.uint8), mode='L') 2691 | 2692 | # Node registration 2693 | NODE_CLASS_MAPPINGS = { 2694 | "DepthEstimationNode": DepthEstimationNode 2695 | } 2696 | 2697 | NODE_DISPLAY_NAME_MAPPINGS = { 2698 | "DepthEstimationNode": "Depth Estimation (V2)" 2699 | } -------------------------------------------------------------------------------- /images/depth-estimation-icon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Limbicnation/ComfyUIDepthEstimation/1c3035351a2269874d6200073b3bb20ac61f2513/images/depth-estimation-icon.png -------------------------------------------------------------------------------- /images/depth-estimation-icon.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | -------------------------------------------------------------------------------- /images/depth-estimation-logo-with-smaller-z.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | Z 22 | 23 | -------------------------------------------------------------------------------- /images/depth-estimation-node.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Limbicnation/ComfyUIDepthEstimation/1c3035351a2269874d6200073b3bb20ac61f2513/images/depth-estimation-node.png -------------------------------------------------------------------------------- /images/depth_map_generator_showcase.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Limbicnation/ComfyUIDepthEstimation/1c3035351a2269874d6200073b3bb20ac61f2513/images/depth_map_generator_showcase.jpg -------------------------------------------------------------------------------- /publish.yaml: -------------------------------------------------------------------------------- 1 | name: Publish to Comfy registry 2 | on: 3 | workflow_dispatch: 4 | push: 5 | branches: 6 | - main 7 | - master 8 | paths: 9 | - "pyproject.toml" 10 | 11 | jobs: 12 | publish-node: 13 | name: Publish Custom Node to registry 14 | runs-on: ubuntu-latest 15 | steps: 16 | - name: Check out code 17 | uses: actions/checkout@v4 18 | - name: Publish Custom Node 19 | uses: Comfy-Org/publish-node-action@main 20 | with: 21 | personal_access_token: ${{ secrets.REGISTRY_ACCESS_TOKEN }} 22 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "comfyuidepthestimation" 3 | description = "A robust custom depth estimation node for ComfyUI using Depth-Anything models. It integrates depth estimation with configurable post-processing options including blur, median filtering, contrast enhancement, and gamma correction." 4 | version = "1.1.1" 5 | license = { file = "LICENSE" } 6 | dependencies = [ 7 | "transformers>=4.20.0", 8 | "tokenizers>=0.13.3", 9 | "timm>=0.6.12", 10 | "huggingface-hub>=0.16.0", 11 | "protobuf==3.20.3", 12 | "requests>=2.27.0", 13 | "wget>=3.2" 14 | ] 15 | 16 | [project.urls] 17 | Repository = "https://github.com/Limbicnation/ComfyUIDepthEstimation" 18 | # Used by Comfy Registry https://comfyregistry.org 19 | 20 | [tool.comfy] 21 | PublisherId = "limbicnation" 22 | DisplayName = "Depth Estimation Node" 23 | Icon = "images/depth-estimation-icon.png" -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # requirements.txt for ComfyUI-DepthEstimation Node 2 | # Note: These are minimum requirements. ComfyUI's environment may provide newer versions. 3 | 4 | # Fix for protobuf errors with transformers 5 | protobuf==3.20.3 6 | 7 | # Core dependencies 8 | tokenizers>=0.13.3 # Pre-built version compatible with most platforms 9 | transformers>=4.20.0 # Required for Depth Anything models, but ComfyUI may have a specific version 10 | 11 | # Pillow (PIL Fork) - Compatibility with other ComfyUI nodes 12 | # Don't specify Pillow version to avoid conflicts with ComfyUI environment 13 | 14 | # NumPy - Using version that properly supports numpy.dtypes 15 | # Don't specify version to avoid conflicts with ComfyUI environment 16 | 17 | # Additional dependencies specific to depth estimation node 18 | timm>=0.6.12 # Required for Depth Anything models 19 | huggingface-hub>=0.16.0 # For model downloading 20 | wget>=3.2 # For reliable model downloading 21 | 22 | # Torch version requirements 23 | # Note: PyTorch dependencies are handled by ComfyUI's core installation 24 | # If you're installing this node directly, ensure torch>=2.0.0 is available 25 | 26 | # Network dependencies 27 | requests>=2.27.0 # For model downloading -------------------------------------------------------------------------------- /workflows/Depth_Map_Generator_V1.json: -------------------------------------------------------------------------------- 1 | {"last_node_id":11,"last_link_id":6,"nodes":[{"id":10,"type":"LoadImage","pos":[-67.18539428710938,435.0294494628906],"size":[315,314],"flags":{},"order":0,"mode":0,"inputs":[],"outputs":[{"name":"IMAGE","type":"IMAGE","links":[5],"slot_index":0},{"name":"MASK","type":"MASK","links":null}],"properties":{"cnr_id":"comfy-core","ver":"0.3.15","Node name for S&R":"LoadImage","enableTabs":false,"tabWidth":65,"tabXOffset":10,"hasSecondTab":false,"secondTabText":"Send Back","secondTabOffset":80,"secondTabWidth":65},"widgets_values":["Fluix_Redux_NeoTokyo_CyberPortrait_00043_.png","image"]},{"id":11,"type":"PreviewImage","pos":[807.515869140625,435.0294494628906],"size":[210,246],"flags":{},"order":2,"mode":0,"inputs":[{"name":"images","type":"IMAGE","link":6}],"outputs":[],"properties":{"cnr_id":"comfy-core","ver":"0.3.15","Node name for S&R":"PreviewImage","enableTabs":false,"tabWidth":65,"tabXOffset":10,"hasSecondTab":false,"secondTabText":"Send Back","secondTabOffset":80,"secondTabWidth":65},"widgets_values":[]},{"id":9,"type":"DepthEstimationNode","pos":[370.5541076660156,435.0294494628906],"size":[315,154],"flags":{},"order":1,"mode":0,"inputs":[{"name":"image","type":"IMAGE","link":5}],"outputs":[{"name":"IMAGE","type":"IMAGE","links":[6],"slot_index":0}],"properties":{"aux_id":"Limbicnation/ComfyUIDepthEstimation","ver":"1fac38f691c69b81dc7b5f5e0c83bf61583220be","Node name for S&R":"DepthEstimationNode","enableTabs":false,"tabWidth":65,"tabXOffset":10,"hasSecondTab":false,"secondTabText":"Send Back","secondTabOffset":80,"secondTabWidth":65,"cnr_id":"comfyuidepthestimation"},"widgets_values":["Depth-Anything-V2-Base",0.5,"3",true,true]}],"links":[[5,10,0,9,0,"IMAGE"],[6,9,0,11,0,"IMAGE"]],"groups":[],"config":{},"extra":{"ds":{"scale":2.593742460100004,"offset":{"0":118.01049041748047,"1":-242.42494201660156}},"ue_links":[]},"version":0.4} -------------------------------------------------------------------------------- /workflows/Depth_Map_Generator_V1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Limbicnation/ComfyUIDepthEstimation/1c3035351a2269874d6200073b3bb20ac61f2513/workflows/Depth_Map_Generator_V1.png --------------------------------------------------------------------------------