├── .gitattributes
├── .github
└── workflows
│ └── publish.yml
├── .gitignore
├── LICENSE
├── README.md
├── __init__.py
├── depth_estimation_node.py
├── images
├── depth-estimation-icon.png
├── depth-estimation-icon.svg
├── depth-estimation-logo-with-smaller-z.svg
├── depth-estimation-node.png
└── depth_map_generator_showcase.jpg
├── publish.yaml
├── pyproject.toml
├── requirements.txt
└── workflows
├── Depth_Map_Generator_V1.json
└── Depth_Map_Generator_V1.png
/.gitattributes:
--------------------------------------------------------------------------------
1 | # Handle line endings (CRLF for Windows, LF for Unix)
2 | * text=auto
3 |
4 | # Ensure Python files use LF for Unix compatibility
5 | *.py text eol=lf
6 |
7 | # Treat Markdown and YAML files as text
8 | *.md text
9 | *.yaml text
10 | *.yml text
11 |
12 | # Handle images and binaries as binary
13 | *.png binary
14 | *.jpg binary
15 | *.zip binary
16 |
17 | # Handle compiled files as binary (just in case you have any)
18 | *.exe binary
19 | *.dll binary
20 | *.so binary
21 |
22 | # Ensure JSON files use LF line endings
23 | *.json text eol=lf
24 |
25 | # Handle Node.js package lock file (if relevant)
26 | package-lock.json merge=union
27 |
28 |
--------------------------------------------------------------------------------
/.github/workflows/publish.yml:
--------------------------------------------------------------------------------
1 | name: Publish to Comfy registry
2 | on:
3 | workflow_dispatch:
4 | push:
5 | branches:
6 | - main
7 | - master
8 | paths:
9 | - "pyproject.toml"
10 |
11 | permissions:
12 | issues: write
13 |
14 | jobs:
15 | publish-node:
16 | name: Publish Custom Node to registry
17 | runs-on: ubuntu-latest
18 | if: ${{ github.repository_owner == 'Limbicnation' }}
19 | steps:
20 | - name: Check out code
21 | uses: actions/checkout@v4
22 | - name: Publish Custom Node
23 | uses: Comfy-Org/publish-node-action@v1
24 | with:
25 | ## Add your own personal access token to your Github Repository secrets and reference it here.
26 | personal_access_token: ${{ secrets.REGISTRY_ACCESS_TOKEN }}
27 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
110 | .pdm.toml
111 | .pdm-python
112 | .pdm-build/
113 |
114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115 | __pypackages__/
116 |
117 | # Celery stuff
118 | celerybeat-schedule
119 | celerybeat.pid
120 |
121 | # SageMath parsed files
122 | *.sage.py
123 |
124 | # Environments
125 | .env
126 | .venv
127 | env/
128 | venv/
129 | ENV/
130 | env.bak/
131 | venv.bak/
132 |
133 | # Spyder project settings
134 | .spyderproject
135 | .spyproject
136 |
137 | # Rope project settings
138 | .ropeproject
139 |
140 | # mkdocs documentation
141 | /site
142 |
143 | # mypy
144 | .mypy_cache/
145 | .dmypy.json
146 | dmypy.json
147 |
148 | # Pyre type checker
149 | .pyre/
150 |
151 | # pytype static type analyzer
152 | .pytype/
153 |
154 | # Cython debug symbols
155 | cython_debug/
156 |
157 | # PyCharm
158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160 | # and can be added to the global gitignore or merged into this file. For a more nuclear
161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162 | #.idea/
163 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## ComfyUI Depth Estimation Node
2 |
3 |
4 |
5 |
6 |
7 | A robust custom depth estimation node for ComfyUI using Depth-Anything models to generate depth maps from images.
8 |
9 | ## Features
10 | - Multiple model options:
11 | - Depth-Anything-Small
12 | - Depth-Anything-Base
13 | - Depth-Anything-Large
14 | - Depth-Anything-V2-Small
15 | - Depth-Anything-V2-Base
16 | - Post-processing options:
17 | - Gaussian blur (adjustable radius)
18 | - Median filtering (configurable size)
19 | - Automatic contrast enhancement
20 | - Gamma correction
21 | - Advanced options:
22 | - Force CPU processing for compatibility
23 | - Force model reload for troubleshooting
24 |
25 | ## Installation
26 |
27 | ### Method 1: Install via ComfyUI Manager (Recommended)
28 | 1. Open ComfyUI and install the ComfyUI Manager if you haven't already
29 | 2. Go to the Manager tab
30 | 3. Search for "Depth Estimation" and install the node
31 |
32 | ### Method 2: Manual Installation
33 | 1. Navigate to your ComfyUI custom nodes directory:
34 | ```bash
35 | cd ComfyUI/custom_nodes/
36 | ```
37 |
38 | 2. Clone the repository:
39 | ```bash
40 | git clone https://github.com/Limbicnation/ComfyUIDepthEstimation.git
41 | ```
42 |
43 | 3. Install the required dependencies:
44 | ```bash
45 | cd ComfyUIDepthEstimation
46 | pip install -r requirements.txt
47 | ```
48 |
49 | 4. Restart ComfyUI to load the new custom node.
50 |
51 | > **Note**: On first use, the node will download the selected model from Hugging Face. This may take some time depending on your internet connection.
52 |
53 | ## Usage
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 | ### Node Parameters
63 |
64 | #### Required Parameters
65 | - **image**: Input image (IMAGE type)
66 | - **model_name**: Select from available Depth-Anything models
67 | - **blur_radius**: Gaussian blur radius (0.0 - 10.0, default: 2.0)
68 | - **median_size**: Median filter size (3, 5, 7, 9, 11)
69 | - **apply_auto_contrast**: Enable automatic contrast enhancement
70 | - **apply_gamma**: Enable gamma correction
71 |
72 | #### Optional Parameters
73 | - **force_reload**: Force the model to reload (useful for troubleshooting)
74 | - **force_cpu**: Use CPU for processing instead of GPU (slower but more compatible)
75 |
76 | ### Example Usage
77 | 1. Add the `Depth Estimation` node to your ComfyUI workflow
78 | 2. Connect an image source to the node's image input
79 | 3. Configure the parameters:
80 | - Select a model (e.g., "Depth-Anything-V2-Small" is fastest)
81 | - Adjust blur_radius (0-10) for depth map smoothing
82 | - Choose median_size (3-11) for noise reduction
83 | - Toggle auto_contrast and gamma correction as needed
84 | 4. Connect the output to a Preview Image node or other image processing nodes
85 |
86 | ## Model Information
87 |
88 | | Model Name | Quality | VRAM Usage | Speed |
89 | |------------|---------|------------|-------|
90 | | Depth-Anything-V2-Small | Good | ~1.5 GB | Fast |
91 | | Depth-Anything-Small | Good | ~1.5 GB | Fast |
92 | | Depth-Anything-V2-Base | Better | ~2.5 GB | Medium |
93 | | Depth-Anything-Base | Better | ~2.5 GB | Medium |
94 | | Depth-Anything-Large | Best | ~4.0 GB | Slow |
95 |
96 | ## Troubleshooting Guide
97 |
98 | ### Common Issues and Solutions
99 |
100 | #### Model Download Issues
101 | - **Error**: "Failed to load model" or "Model not found"
102 | - **Solution**:
103 | 1. Check your internet connection
104 | 2. Try authenticating with Hugging Face:
105 | ```bash
106 | pip install huggingface_hub
107 | huggingface-cli login
108 | ```
109 | 3. Try a different model (e.g., switch to Depth-Anything-V2-Small)
110 | 4. Check the ComfyUI console for detailed error messages
111 |
112 | #### CUDA Out of Memory Errors
113 | - **Error**: "CUDA out of memory" or node shows red error image
114 | - **Solution**:
115 | 1. Try a smaller model (Depth-Anything-V2-Small uses the least memory)
116 | 2. Enable the `force_cpu` option (slower but uses less VRAM)
117 | 3. Reduce the size of your input image
118 | 4. Close other VRAM-intensive applications
119 |
120 | #### Node Not Appearing in ComfyUI
121 | - **Solution**:
122 | 1. Check your ComfyUI console for error messages
123 | 2. Verify that all dependencies are installed:
124 | ```bash
125 | pip install transformers>=4.20.0 Pillow>=9.1.0 numpy>=1.23.0 timm>=0.6.12
126 | ```
127 | 3. Try restarting ComfyUI
128 | 4. Check that the node files are in the correct directory
129 |
130 | #### Node Returns Original Image or Black Image
131 | - **Solution**:
132 | 1. Try enabling the `force_reload` option
133 | 2. Check the ComfyUI console for error messages
134 | 3. Try using a different model
135 | 4. Make sure your input image is valid (not corrupted or empty)
136 | 5. Try restarting ComfyUI
137 |
138 | #### Slow Performance
139 | - **Solution**:
140 | 1. Use a smaller model (Depth-Anything-V2-Small is fastest)
141 | 2. Reduce input image size
142 | 3. If using CPU mode, consider using GPU if available
143 | 4. Close other applications that might be using GPU resources
144 |
145 | ### Where to Get Help
146 | - Create an issue on the [GitHub repository](https://github.com/Limbicnation/ComfyUIDepthEstimation/issues)
147 | - Check the ComfyUI console for detailed error messages
148 | - Visit the ComfyUI Discord for community support
149 |
150 | ## License
151 |
152 | This project is licensed under the Apache License.
153 |
154 | ---
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | ComfyUI Depth Estimation Node
3 | A custom node for depth map estimation using Depth-Anything models.
4 | """
5 |
6 | import os
7 | import logging
8 | import importlib.util
9 |
10 | # Setup logging
11 | logging.basicConfig(level=logging.INFO)
12 | logger = logging.getLogger("DepthEstimation")
13 |
14 | # Version info
15 | __version__ = "1.1.1"
16 |
17 | # Node class mappings - will be populated based on dependency checks
18 | NODE_CLASS_MAPPINGS = {}
19 | NODE_DISPLAY_NAME_MAPPINGS = {}
20 |
21 | # Web extension info for ComfyUI
22 | WEB_DIRECTORY = "./js"
23 |
24 | # Graceful dependency checking
25 | required_dependencies = {
26 | "torch": "2.0.0",
27 | "transformers": "4.20.0",
28 | "numpy": "1.23.0",
29 | "PIL": "9.2.0", # Pillow is imported as PIL
30 | "timm": "0.6.12",
31 | "huggingface_hub": "0.16.0"
32 | }
33 |
34 | missing_dependencies = []
35 |
36 | # Check each dependency
37 | for module_name, min_version in required_dependencies.items():
38 | try:
39 | if module_name == "PIL":
40 | # Special case for Pillow/PIL
41 | import PIL
42 | module_version = PIL.__version__
43 | else:
44 | module = __import__(module_name)
45 | module_version = getattr(module, "__version__", "unknown")
46 |
47 | logger.info(f"Found {module_name} version {module_version}")
48 | except ImportError:
49 | missing_dependencies.append(f"{module_name}>={min_version}")
50 | logger.warning(f"Missing required dependency: {module_name}>={min_version}")
51 |
52 | if missing_dependencies:
53 | # Create placeholder node with dependency error
54 | class DependencyErrorNode:
55 | """Placeholder node that shows dependency installation instructions."""
56 |
57 | @classmethod
58 | def INPUT_TYPES(cls):
59 | return {"required": {}}
60 |
61 | RETURN_TYPES = ("STRING",)
62 | FUNCTION = "error_message"
63 | CATEGORY = "depth"
64 |
65 | def error_message(self):
66 | missing = ", ".join(missing_dependencies)
67 | message = f"Dependencies missing: {missing}. Please install with: pip install {' '.join(missing_dependencies)}"
68 | print(f"DepthEstimation Node Error: {message}")
69 | return (message,)
70 |
71 | # Register the error node instead of the real node
72 | NODE_CLASS_MAPPINGS = {
73 | "DepthEstimationNode": DependencyErrorNode
74 | }
75 |
76 | NODE_DISPLAY_NAME_MAPPINGS = {
77 | "DepthEstimationNode": "Depth Estimation (Missing Dependencies)"
78 | }
79 | else:
80 | # All dependencies are available, try to import the actual node
81 | try:
82 | from .depth_estimation_node import DepthEstimationNode
83 |
84 | # Register the actual depth estimation node
85 | NODE_CLASS_MAPPINGS = {
86 | "DepthEstimationNode": DepthEstimationNode
87 | }
88 |
89 | NODE_DISPLAY_NAME_MAPPINGS = {
90 | "DepthEstimationNode": "Depth Estimation"
91 | }
92 | except Exception as e:
93 | # Capture any import errors that might occur with transformers
94 | logger.error(f"Error importing depth estimation node: {str(e)}")
95 |
96 | # Create a more specific error node
97 | class TransformersErrorNode:
98 | @classmethod
99 | def INPUT_TYPES(cls):
100 | return {"required": {}}
101 |
102 | RETURN_TYPES = ("STRING",)
103 | FUNCTION = "error_message"
104 | CATEGORY = "depth"
105 |
106 | def error_message(self):
107 | if "Descriptors cannot be created directly" in str(e):
108 | message = "Protobuf version conflict. Run: pip install protobuf==3.20.3"
109 | else:
110 | message = f"Error loading depth estimation: {str(e)}"
111 | return (message,)
112 |
113 | NODE_CLASS_MAPPINGS = {
114 | "DepthEstimationNode": TransformersErrorNode
115 | }
116 |
117 | NODE_DISPLAY_NAME_MAPPINGS = {
118 | "DepthEstimationNode": "Depth Estimation (Error)"
119 | }
120 |
121 | # Module exports
122 | __all__ = [
123 | "NODE_CLASS_MAPPINGS",
124 | "NODE_DISPLAY_NAME_MAPPINGS",
125 | "__version__",
126 | "WEB_DIRECTORY"
127 | ]
--------------------------------------------------------------------------------
/depth_estimation_node.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import torch
4 | import traceback
5 | import time
6 | import requests
7 | import urllib.request
8 | import wget
9 | from pathlib import Path
10 | from transformers import pipeline
11 | from PIL import Image, ImageFilter, ImageOps, ImageDraw, ImageFont
12 | import folder_paths
13 | from comfy.model_management import get_torch_device, get_free_memory
14 | import gc
15 | import logging
16 | import torch.nn as nn
17 | import torch.nn.functional as F
18 | from typing import Tuple, List, Dict, Any, Optional, Union
19 |
20 | # Try to import timm (for vision transformers)
21 | try:
22 | import timm
23 | TIMM_AVAILABLE = True
24 | except ImportError:
25 | TIMM_AVAILABLE = False
26 | print("Warning: timm not available. Direct loading of Depth Anything models may not work.")
27 |
28 | # Setup logging
29 | logging.basicConfig(level=logging.INFO)
30 | logger = logging.getLogger("DepthEstimation")
31 |
32 | # Depth Anything V2 Implementation
33 | class DepthAnythingV2(nn.Module):
34 | """Direct implementation of Depth Anything V2 model"""
35 | def __init__(self, encoder='vits', features=64, out_channels=[48, 96, 192, 384]):
36 | super().__init__()
37 | self.encoder = encoder
38 | self.features = features
39 | self.out_channels = out_channels
40 | self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
41 |
42 | # Create encoder based on specification
43 | if TIMM_AVAILABLE:
44 | if encoder == 'vits':
45 | self.backbone = timm.create_model('vit_small_patch16_224', pretrained=False)
46 | self.embed_dim = 384
47 | elif encoder == 'vitb':
48 | self.backbone = timm.create_model('vit_base_patch16_224', pretrained=False)
49 | self.embed_dim = 768
50 | elif encoder == 'vitl':
51 | self.backbone = timm.create_model('vit_large_patch16_224', pretrained=False)
52 | self.embed_dim = 1024
53 | else: # fallback to vits
54 | self.backbone = timm.create_model('vit_small_patch16_224', pretrained=False)
55 | self.embed_dim = 384
56 |
57 | # Implement the rest of the model architecture
58 | self.initialize_decoder()
59 | else:
60 | # Fallback if timm is not available
61 | from torchvision.models import resnet50
62 | self.backbone = resnet50(pretrained=False)
63 | self.embed_dim = 2048
64 | logger.warning("Using fallback ResNet50 model (timm not available)")
65 |
66 | def initialize_decoder(self):
67 | """Initialize the decoder layers"""
68 | self.neck = nn.Sequential(
69 | nn.Conv2d(self.embed_dim, self.features, 1, 1, 0),
70 | nn.Conv2d(self.features, self.features, 3, 1, 1),
71 | )
72 |
73 | # Create decoders for each level
74 | self.decoders = nn.ModuleList([
75 | self.create_decoder_level(self.features, self.out_channels[0]),
76 | self.create_decoder_level(self.out_channels[0], self.out_channels[1]),
77 | self.create_decoder_level(self.out_channels[1], self.out_channels[2]),
78 | self.create_decoder_level(self.out_channels[2], self.out_channels[3])
79 | ])
80 |
81 | # Final depth head
82 | self.depth_head = nn.Sequential(
83 | nn.Conv2d(self.out_channels[3], self.out_channels[3], 3, 1, 1),
84 | nn.BatchNorm2d(self.out_channels[3]),
85 | nn.ReLU(True),
86 | nn.Conv2d(self.out_channels[3], 1, 1)
87 | )
88 |
89 | def create_decoder_level(self, in_channels, out_channels):
90 | """Create a decoder level"""
91 | return nn.Sequential(
92 | nn.Conv2d(in_channels, out_channels, 3, 1, 1),
93 | nn.BatchNorm2d(out_channels),
94 | nn.ReLU(True),
95 | nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
96 | )
97 |
98 | def forward(self, x):
99 | """Forward pass of the model"""
100 | # For timm ViT models
101 | if hasattr(self.backbone, 'forward_features'):
102 | features = self.backbone.forward_features(x)
103 |
104 | # Reshape features based on model type
105 | if 'vit' in self.encoder:
106 | # Reshape transformer output to spatial features
107 | # Exact reshape depends on the model details
108 | h = w = int(features.shape[1]**0.5)
109 | features = features.reshape(-1, h, w, self.embed_dim).permute(0, 3, 1, 2)
110 |
111 | # Process through decoder
112 | x = self.neck(features)
113 |
114 | # Apply decoder stages
115 | for decoder in self.decoders:
116 | x = decoder(x)
117 |
118 | # Final depth prediction
119 | depth = self.depth_head(x)
120 |
121 | return depth
122 | else:
123 | # Fallback for ResNet
124 | x = self.backbone.conv1(x)
125 | x = self.backbone.bn1(x)
126 | x = self.backbone.relu(x)
127 | x = self.backbone.maxpool(x)
128 |
129 | x = self.backbone.layer1(x)
130 | x = self.backbone.layer2(x)
131 | x = self.backbone.layer3(x)
132 | x = self.backbone.layer4(x)
133 |
134 | # Process through simple decoder
135 | x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True)
136 | x = self.depth_head(x)
137 |
138 | return x
139 |
140 | def infer_image(self, image):
141 | """Process an image and return the depth map
142 |
143 | Args:
144 | image: A numpy image in BGR format (OpenCV) or RGB PIL Image
145 |
146 | Returns:
147 | depth: A numpy array containing the depth map
148 | """
149 | # Convert input to tensor
150 | if isinstance(image, np.ndarray):
151 | # Convert BGR to RGB
152 | if image.shape[2] == 3:
153 | image = image[:, :, ::-1]
154 | # Normalize
155 | image = image.astype(np.float32) / 255.0
156 | # HWC to CHW
157 | image = image.transpose(2, 0, 1)
158 | # Add batch dimension
159 | image = torch.from_numpy(image).unsqueeze(0)
160 | elif isinstance(image, Image.Image):
161 | # Convert PIL image to numpy
162 | image = np.array(image).astype(np.float32) / 255.0
163 | # HWC to CHW
164 | image = image.transpose(2, 0, 1)
165 | # Add batch dimension
166 | image = torch.from_numpy(image).unsqueeze(0)
167 |
168 | # Move to device
169 | image = image.to(self.device)
170 |
171 | # Set model to eval mode
172 | self.eval()
173 |
174 | # Get prediction
175 | with torch.no_grad():
176 | depth = self.forward(image)
177 |
178 | # Convert to numpy
179 | depth = depth.squeeze().cpu().numpy()
180 |
181 | return depth
182 |
183 | def __call__(self, image):
184 | """Compatible interface with the pipeline API"""
185 | if isinstance(image, Image.Image):
186 | # Convert to numpy for processing
187 | depth = self.infer_image(image)
188 | # Return in the format expected by the node
189 | return {"predicted_depth": torch.from_numpy(depth).unsqueeze(0)}
190 | else:
191 | # Already a tensor, process directly
192 | self.eval()
193 | with torch.no_grad():
194 | depth = self.forward(image)
195 | return {"predicted_depth": depth}
196 |
197 | # Configure model paths
198 | if not hasattr(folder_paths, "models_dir"):
199 | folder_paths.models_dir = os.path.join(folder_paths.base_path, "models")
200 |
201 | # Register depth models path - support multiple possible directory structures
202 | DEPTH_DIR = "depth_anything"
203 | DEPTH_ANYTHING_DIR = "depthanything"
204 |
205 | # Check which directory structure exists
206 | possible_paths = [
207 | os.path.join(folder_paths.models_dir, DEPTH_DIR),
208 | os.path.join(folder_paths.models_dir, DEPTH_ANYTHING_DIR),
209 | os.path.join(folder_paths.models_dir, DEPTH_ANYTHING_DIR, DEPTH_DIR),
210 | os.path.join(folder_paths.models_dir, "checkpoints", DEPTH_DIR),
211 | os.path.join(folder_paths.models_dir, "checkpoints", DEPTH_ANYTHING_DIR),
212 | ]
213 |
214 | # Filter to only paths that exist
215 | existing_paths = [p for p in possible_paths if os.path.exists(p)]
216 | if not existing_paths:
217 | # If none exists, create the default one
218 | existing_paths = [os.path.join(folder_paths.models_dir, DEPTH_DIR)]
219 | os.makedirs(existing_paths[0], exist_ok=True)
220 | logger.info(f"Created model directory: {existing_paths[0]}")
221 |
222 | # Log all found paths for debugging
223 | logger.info(f"Found depth model directories: {existing_paths}")
224 |
225 | # Register all possible paths for model loading
226 | folder_paths.folder_names_and_paths[DEPTH_DIR] = (existing_paths, folder_paths.supported_pt_extensions)
227 |
228 | # Set primary models directory to the first available path
229 | MODELS_DIR = existing_paths[0]
230 | logger.info(f"Using primary models directory: {MODELS_DIR}")
231 |
232 | # Set Hugging Face cache to the models directory to ensure models are saved there
233 | os.environ["TRANSFORMERS_CACHE"] = MODELS_DIR
234 | os.environ["HF_HOME"] = MODELS_DIR
235 |
236 | # Define model configurations for direct loading
237 | MODEL_CONFIGS = {
238 | 'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
239 | 'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
240 | 'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
241 | 'vitg': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]}
242 | }
243 |
244 | # Define all models mentioned in the README with memory requirements
245 | DEPTH_MODELS = {
246 | "Depth-Anything-Small": {
247 | "path": "LiheYoung/depth-anything-small-hf", # Correct HF path for V1
248 | "vram_mb": 1500,
249 | "direct_url": "https://github.com/LiheYoung/Depth-Anything/releases/download/v1.0/depth_anything_vitb14.pt",
250 | "model_type": "v1",
251 | "encoder": "vitb"
252 | },
253 | "Depth-Anything-Base": {
254 | "path": "LiheYoung/depth-anything-base-hf", # Correct HF path for V1
255 | "vram_mb": 2500,
256 | "direct_url": "https://github.com/LiheYoung/Depth-Anything/releases/download/v1.0/depth_anything_vitl14.pt",
257 | "model_type": "v1",
258 | "encoder": "vitl"
259 | },
260 | "Depth-Anything-Large": {
261 | "path": "LiheYoung/depth-anything-large-hf", # Correct HF path for V1
262 | "vram_mb": 4000,
263 | "direct_url": "https://github.com/LiheYoung/Depth-Anything/releases/download/v1.0/depth_anything_vitl14.pt",
264 | "model_type": "v1",
265 | "encoder": "vitl"
266 | },
267 | "Depth-Anything-V2-Small": {
268 | "path": "depth-anything/Depth-Anything-V2-Small-hf", # Updated corrected path as shown in example
269 | "vram_mb": 1500,
270 | "direct_url": "https://huggingface.co/depth-anything/Depth-Anything-V2-Small-hf/resolve/main/pytorch_model.bin",
271 | "model_type": "v2",
272 | "encoder": "vits",
273 | "config": MODEL_CONFIGS["vits"]
274 | },
275 | "Depth-Anything-V2-Base": {
276 | "path": "depth-anything/Depth-Anything-V2-Base-hf", # Updated corrected path
277 | "vram_mb": 2500,
278 | "direct_url": "https://huggingface.co/depth-anything/Depth-Anything-V2-Base-hf/resolve/main/pytorch_model.bin",
279 | "model_type": "v2",
280 | "encoder": "vitb",
281 | "config": MODEL_CONFIGS["vitb"]
282 | },
283 | # Add MiDaS models as dedicated options with direct download URLs
284 | "MiDaS-Small": {
285 | "path": "Intel/dpt-hybrid-midas",
286 | "vram_mb": 1000,
287 | "midas_type": "MiDaS_small",
288 | "direct_url": "https://github.com/intel-isl/MiDaS/releases/download/v2_1/midas_v21_small_256.pt"
289 | },
290 | "MiDaS-Base": {
291 | "path": "Intel/dpt-hybrid-midas",
292 | "vram_mb": 1200,
293 | "midas_type": "DPT_Hybrid",
294 | "direct_url": "https://github.com/intel-isl/MiDaS/releases/download/v3/dpt_hybrid-midas-501f0c75.pt"
295 | }
296 | }
297 |
298 | class MiDaSWrapper:
299 | def __init__(self, model_type, device):
300 | self.device = device
301 |
302 | try:
303 | # Import required libraries
304 | import torch.nn.functional as F
305 |
306 | # Use a more reliable approach to loading MiDaS models
307 | if model_type == "DPT_Hybrid" or model_type == "dpt_hybrid":
308 | # Use direct URL download for MiDaS models
309 | midas_url = "https://github.com/intel-isl/MiDaS/releases/download/v3/dpt_hybrid-midas-501f0c75.pt"
310 | local_path = os.path.join(MODELS_DIR, "dpt_hybrid_midas.pt")
311 |
312 | if not os.path.exists(local_path):
313 | logger.info(f"Downloading MiDaS model from {midas_url}")
314 | try:
315 | response = requests.get(midas_url, stream=True)
316 | if response.status_code == 200:
317 | with open(local_path, 'wb') as f:
318 | for chunk in response.iter_content(chunk_size=8192):
319 | f.write(chunk)
320 | logger.info(f"Downloaded MiDaS model to {local_path}")
321 | else:
322 | logger.error(f"Failed to download model: {response.status_code}")
323 | except Exception as e:
324 | logger.error(f"Error downloading MiDaS model: {e}")
325 |
326 | # Load pretrained model
327 | try:
328 | # Create a simple model architecture
329 | from torchvision.models import resnet50
330 | self.model = resnet50()
331 | self.model.fc = torch.nn.Linear(2048, 1)
332 |
333 | # Load state dict if available
334 | if os.path.exists(local_path):
335 | logger.info(f"Loading MiDaS model from {local_path}")
336 | state_dict = torch.load(local_path, map_location=device)
337 | # Convert all parameters to float
338 | floated_state_dict = {k: v.float() for k, v in state_dict.items()}
339 | self.model.load_state_dict(floated_state_dict)
340 |
341 | except Exception as e:
342 | logger.error(f"Error loading MiDaS model state dict: {e}")
343 | # Fallback to ResNet
344 | self.model = resnet50(pretrained=True)
345 | self.model.fc = torch.nn.Linear(2048, 1)
346 |
347 | else: # Other model types or fallback
348 | from torchvision.models import resnet50
349 | self.model = resnet50(pretrained=True)
350 | self.model.fc = torch.nn.Linear(2048, 1)
351 |
352 | # Ensure model parameters are float
353 | for param in self.model.parameters():
354 | param.data = param.data.float()
355 |
356 | # Explicitly convert model to FloatTensor
357 | self.model = self.model.float()
358 |
359 | # Move model to device and set to eval mode
360 | self.model = self.model.to(device)
361 | self.model.eval()
362 |
363 | except Exception as e:
364 | logger.error(f"Failed to load MiDaS model: {e}")
365 | logger.error(traceback.format_exc())
366 | # Create a minimal model as absolute fallback
367 | from torchvision.models import resnet18
368 | self.model = resnet18(pretrained=True).float().to(device)
369 | self.model.fc = torch.nn.Linear(512, 1).float().to(device)
370 | self.model.eval()
371 |
372 | def __call__(self, image):
373 | """Process an image and return the depth map"""
374 | try:
375 | # Convert PIL image to tensor for processing
376 | if isinstance(image, Image.Image):
377 | # Get original dimensions
378 | original_width, original_height = image.size
379 |
380 | # Ensure dimensions are multiple of 32 (required for some models)
381 | # This helps prevent tensor dimension mismatches
382 | target_height = ((original_height + 31) // 32) * 32
383 | target_width = ((original_width + 31) // 32) * 32
384 |
385 | # Keep original dimensions - don't force 384x384
386 | # The caller should already have resized to the requested input_size
387 |
388 | # Log resize information if needed
389 | if (target_width != original_width) or (target_height != original_height):
390 | logger.info(f"Adjusting dimensions from {original_width}x{original_height} to {target_width}x{target_height} (multiples of 32)")
391 | img_resized = image.resize((target_width, target_height), Image.LANCZOS)
392 | else:
393 | img_resized = image
394 |
395 | # Convert to numpy array
396 | img_np = np.array(img_resized).astype(np.float32) / 255.0
397 |
398 | # Check for NaN values and replace them with zeros
399 | if np.isnan(img_np).any():
400 | logger.warning("Input image contains NaN values. Replacing with zeros.")
401 | img_np = np.nan_to_num(img_np, nan=0.0)
402 |
403 | # Convert to tensor with proper shape (B,C,H,W)
404 | if len(img_np.shape) == 3:
405 | # RGB image
406 | img_np = img_np.transpose(2, 0, 1) # (H,W,C) -> (C,H,W)
407 | else:
408 | # Grayscale image - add channel dimension
409 | img_np = np.expand_dims(img_np, axis=0)
410 |
411 | # Add batch dimension and ensure float32
412 | input_tensor = torch.from_numpy(img_np).unsqueeze(0).float()
413 | else:
414 | # Already a tensor - ensure float32 by explicitly converting
415 | # This is the key fix for the "Input type (torch.cuda.DoubleTensor) and weight type (torch.cuda.FloatTensor)" error
416 | input_tensor = None
417 |
418 | # Handle potential error cases with clearer messages
419 | if not torch.is_tensor(image):
420 | logger.error(f"Expected tensor or PIL image, got {type(image)}")
421 | # Create dummy tensor as fallback
422 | input_tensor = torch.ones((1, 3, 512, 512), dtype=torch.float32)
423 | elif image.numel() == 0:
424 | logger.error("Input tensor is empty")
425 | # Create dummy tensor as fallback
426 | input_tensor = torch.ones((1, 3, 512, 512), dtype=torch.float32)
427 | else:
428 | # Check for NaN values
429 | if torch.isnan(image).any():
430 | logger.warning("Input tensor contains NaN values. Replacing with zeros.")
431 | image = torch.nan_to_num(image, nan=0.0)
432 |
433 | # Always convert to float32 to prevent type mismatches
434 | input_tensor = image.float() # Convert any tensor to FloatTensor
435 |
436 | # Handle tensor shape issues with more robust dimension checking
437 | if input_tensor.dim() == 2: # [H, W]
438 | # Single channel 2D tensor
439 | input_tensor = input_tensor.unsqueeze(0).unsqueeze(0) # Add batch and channel dims [1, 1, H, W]
440 | logger.info(f"Converted 2D tensor to 4D with shape: {input_tensor.shape}")
441 | elif input_tensor.dim() == 3:
442 | # Could be [C, H, W] or [B, H, W] or [H, W, C]
443 | shape = input_tensor.shape
444 | if shape[-1] == 3 or shape[-1] == 1: # [H, W, C] format
445 | # Convert from HWC to BCHW
446 | input_tensor = input_tensor.permute(2, 0, 1).unsqueeze(0) # [H, W, C] -> [1, C, H, W]
447 | logger.info(f"Converted HWC tensor to BCHW with shape: {input_tensor.shape}")
448 | elif shape[0] <= 3: # Likely [C, H, W]
449 | input_tensor = input_tensor.unsqueeze(0) # Add batch dim [1, C, H, W]
450 | logger.info(f"Added batch dimension to CHW tensor: {input_tensor.shape}")
451 | else: # Likely [B, H, W]
452 | input_tensor = input_tensor.unsqueeze(1) # Add channel dim [B, 1, H, W]
453 | logger.info(f"Added channel dimension to BHW tensor: {input_tensor.shape}")
454 |
455 | # Ensure proper shape after corrections
456 | if input_tensor.dim() != 4:
457 | logger.warning(f"Tensor still has incorrect dimensions ({input_tensor.dim()}). Forcing reshape.")
458 | # Force reshape to 4D
459 | orig_shape = input_tensor.shape
460 | if input_tensor.dim() > 4:
461 | # Too many dimensions, collapse extras
462 | input_tensor = input_tensor.reshape(1, -1, orig_shape[-2], orig_shape[-1])
463 | else:
464 | # Create a standard 4D tensor as fallback
465 | input_tensor = torch.ones((1, 3, 512, 512), dtype=torch.float32)
466 |
467 | # Move to device and ensure float type
468 | input_tensor = input_tensor.to(self.device).float()
469 |
470 | # Log tensor shape for debugging
471 | logger.info(f"MiDaS input tensor shape: {input_tensor.shape}, dtype: {input_tensor.dtype}")
472 |
473 | # Run inference with better error handling
474 | with torch.no_grad():
475 | try:
476 | # Make sure input is float32 and model weights are float32
477 | output = self.model(input_tensor)
478 |
479 | # Handle various output shapes
480 | if output.dim() == 1: # [B*H*W] flattened output
481 | # Reshape based on input dimensions
482 | b, _, h, w = input_tensor.shape
483 | output = output.reshape(b, 1, h, w)
484 | elif output.dim() == 2: # [B, H*W] or similar
485 | # Could be flattened spatial dimensions
486 | b = output.shape[0]
487 | if b == input_tensor.shape[0]: # Batch size matches
488 | h = int(np.sqrt(output.shape[1])) # Estimate height assuming square
489 | w = h
490 | if h * w == output.shape[1]: # Perfect square
491 | output = output.reshape(b, 1, h, w)
492 | else:
493 | # Not a perfect square, use input dimensions
494 | _, _, h, w = input_tensor.shape
495 | output = output.reshape(b, 1, h, w)
496 | else:
497 | # Add dimensions to make 4D
498 | output = output.unsqueeze(1).unsqueeze(1)
499 |
500 | # Ensure output has standard 4D shape (B,C,H,W) for interpolation
501 | if output.dim() != 4:
502 | logger.warning(f"Output has non-standard dimensions: {output.shape}, adding dimensions")
503 | # Add dimensions until we have 4D
504 | while output.dim() < 4:
505 | output = output.unsqueeze(-1)
506 |
507 | # Resize to match input resolution
508 | if isinstance(image, Image.Image):
509 | w, h = image.size
510 |
511 | # Log the shape for debugging
512 | logger.info(f"Resizing output tensor from shape {output.shape} to size ({h}, {w})")
513 |
514 | # Ensure output tensor has correct number of dimensions for interpolation
515 | # Standard interpolation requires 4D tensor (B,C,H,W)
516 | try:
517 | # Now interpolate with proper dimensions
518 | output = torch.nn.functional.interpolate(
519 | output,
520 | size=(h, w),
521 | mode="bicubic",
522 | align_corners=False
523 | )
524 | except RuntimeError as resize_err:
525 | logger.error(f"Interpolation error: {resize_err}. Attempting to fix tensor shape.")
526 |
527 | # Last resort: create compatible tensor from output data
528 | try:
529 | # Get data and reshape to simple 2D first
530 | output_data = output.view(-1).cpu().numpy()
531 | output_reshaped = torch.from_numpy(
532 | np.resize(output_data, (h * w))
533 | ).reshape(1, 1, h, w).to(self.device).float()
534 |
535 | logger.info(f"Corrected output shape to {output_reshaped.shape}")
536 | output = output_reshaped
537 | except Exception as reshape_err:
538 | logger.error(f"Reshape fix failed: {reshape_err}. Using fallback tensor.")
539 | # Create a basic gradient as fallback
540 | output = torch.ones((1, 1, h, w), device=self.device, dtype=torch.float32)
541 | y_coords = torch.linspace(0, 1, h).reshape(-1, 1).repeat(1, w)
542 | output[0, 0, :, :] = y_coords.to(self.device)
543 |
544 | except Exception as model_err:
545 | logger.error(f"Model inference error: {model_err}")
546 | logger.error(traceback.format_exc())
547 |
548 | # Create a visually distinguishable gradient pattern fallback
549 | if isinstance(image, Image.Image):
550 | w, h = image.size
551 | else:
552 | # Extract dimensions from input tensor
553 | _, _, h, w = input_tensor.shape if input_tensor.dim() >= 4 else (1, 1, 512, 512)
554 |
555 | # Create gradient depth map as fallback
556 | output = torch.ones((1, 1, h, w), device=self.device, dtype=torch.float32)
557 | y_coords = torch.linspace(0, 1, h).reshape(-1, 1).repeat(1, w)
558 | output[0, 0, :, :] = y_coords.to(self.device)
559 |
560 | # Final validation - ensure output is float32 and has no NaNs
561 | output = output.float()
562 | if torch.isnan(output).any():
563 | logger.warning("Output contains NaN values. Replacing with zeros.")
564 | output = torch.nan_to_num(output, nan=0.0)
565 |
566 | # Use same interface as the pipeline
567 | return {"predicted_depth": output}
568 |
569 | except Exception as e:
570 | logger.error(f"Error in MiDaS inference: {e}")
571 | logger.error(traceback.format_exc())
572 |
573 | # Return a placeholder depth map
574 | if isinstance(image, Image.Image):
575 | w, h = image.size
576 | dummy_tensor = torch.ones((1, 1, h, w), device=self.device)
577 | else:
578 | # Try to get shape from tensor
579 | shape = image.shape
580 | if len(shape) >= 3:
581 | if shape[0] == 3: # CHW format
582 | h, w = shape[1], shape[2]
583 | else: # HWC format
584 | h, w = shape[0], shape[1]
585 | else:
586 | h, w = 512, 512
587 | dummy_tensor = torch.ones((1, 1, h, w), device=self.device)
588 |
589 | return {"predicted_depth": dummy_tensor}
590 |
591 | class DepthEstimationNode:
592 | """
593 | ComfyUI node for depth estimation using Depth Anything models.
594 |
595 | This node provides depth map generation from images using various Depth Anything models
596 | with configurable post-processing options like blur, median filtering, contrast enhancement,
597 | and gamma correction.
598 | """
599 |
600 | MEDIAN_SIZES = ["3", "5", "7", "9", "11"]
601 |
602 | def __init__(self):
603 | self.device = None
604 | self.depth_estimator = None
605 | self.current_model = None
606 | logger.info("Initialized DepthEstimationNode")
607 |
608 | @classmethod
609 | def INPUT_TYPES(cls) -> Dict[str, Dict[str, Any]]:
610 | """Define the input types for the node."""
611 | return {
612 | "required": {
613 | "image": ("IMAGE",),
614 | "model_name": (list(DEPTH_MODELS.keys()),),
615 | # Ensure minimum size is enforced by the UI
616 | "input_size": ("INT", {"default": 1024, "min": 256, "max": 1024, "step": 1}),
617 | "blur_radius": ("FLOAT", {"default": 2.0, "min": 0.0, "max": 10.0, "step": 0.1}),
618 | # Define median_size as a dropdown with specific string values
619 | "median_size": (cls.MEDIAN_SIZES, {"default": "3"}),
620 | "apply_auto_contrast": ("BOOLEAN", {"default": True}),
621 | "apply_gamma": ("BOOLEAN", {"default": True})
622 | },
623 | "optional": {
624 | "force_reload": ("BOOLEAN", {"default": False}),
625 | "force_cpu": ("BOOLEAN", {"default": False})
626 | }
627 | }
628 |
629 | RETURN_TYPES = ("IMAGE",)
630 | FUNCTION = "estimate_depth"
631 | CATEGORY = "depth"
632 |
633 | def cleanup(self) -> None:
634 | """Clean up resources and free VRAM."""
635 | try:
636 | if self.depth_estimator is not None:
637 | # Save model name before deletion for logging
638 | model_name = self.current_model
639 |
640 | # Delete the estimator
641 | del self.depth_estimator
642 | self.depth_estimator = None
643 | self.current_model = None
644 |
645 | # Force CUDA cache clearing
646 | if torch.cuda.is_available():
647 | torch.cuda.empty_cache()
648 | gc.collect()
649 |
650 | logger.info(f"Cleaned up model resources for {model_name}")
651 |
652 | # Log available memory after cleanup if CUDA is available
653 | if torch.cuda.is_available():
654 | try:
655 | free_mem_info = get_free_memory(get_torch_device())
656 | # Handle return value whether it's a tuple or a single value
657 | if isinstance(free_mem_info, tuple):
658 | free_mem, total_mem = free_mem_info
659 | logger.info(f"Available VRAM after cleanup: {free_mem/1024:.2f}MB of {total_mem/1024:.2f}MB")
660 | else:
661 | logger.info(f"Available VRAM after cleanup: {free_mem_info/1024:.2f}MB")
662 | except Exception as e:
663 | logger.warning(f"Error getting memory info: {e}")
664 | except Exception as e:
665 | logger.warning(f"Error during cleanup: {e}")
666 | logger.debug(traceback.format_exc())
667 |
668 | def ensure_model_loaded(self, model_name: str, force_reload: bool = False, force_cpu: bool = False) -> None:
669 | """
670 | Ensures the correct model is loaded with proper VRAM management and fallback options.
671 |
672 | Args:
673 | model_name: The name of the model to load
674 | force_reload: If True, reload the model even if it's already loaded
675 | force_cpu: If True, force loading on CPU regardless of GPU availability
676 |
677 | Raises:
678 | RuntimeError: If the model fails to load after all fallback attempts
679 | """
680 | try:
681 | # Check for valid model name with more helpful fallback
682 | if model_name not in DEPTH_MODELS:
683 | # Find the most similar model name if possible
684 | available_models = list(DEPTH_MODELS.keys())
685 |
686 | if len(available_models) > 0:
687 | # First try to find a model with a similar name
688 | model_name_lower = model_name.lower()
689 |
690 | # Prioritized fallback selection logic:
691 | # 1. Try to match on similar name
692 | # 2. Prefer V2 models if V2 was requested
693 | # 3. Prefer smaller models (more reliable)
694 | if "v2" in model_name_lower and "small" in model_name_lower:
695 | fallback_model = "Depth-Anything-V2-Small"
696 | elif "v2" in model_name_lower and "base" in model_name_lower:
697 | fallback_model = "Depth-Anything-V2-Base"
698 | elif "v2" in model_name_lower:
699 | fallback_model = "Depth-Anything-V2-Small"
700 | elif "small" in model_name_lower:
701 | fallback_model = "Depth-Anything-Small"
702 | elif "midas" in model_name_lower:
703 | fallback_model = "MiDaS-Small"
704 | else:
705 | # Default to the first model if no better match found
706 | fallback_model = available_models[0]
707 |
708 | logger.warning(f"Unknown model: {model_name}. Falling back to {fallback_model}")
709 | model_name = fallback_model
710 | else:
711 | raise ValueError(f"No depth models available. Please check your installation.")
712 |
713 | # Get model info and validate
714 | model_info = DEPTH_MODELS[model_name]
715 |
716 | # Handle model_info as string or dict with better defaults
717 | if isinstance(model_info, dict):
718 | model_path = model_info.get("path", "")
719 | required_vram = model_info.get("vram_mb", 2000) * 1024 # Convert to KB
720 | model_type = model_info.get("model_type", "v1") # v1 or v2
721 | encoder = model_info.get("encoder", "vits") # Model encoder type
722 | config = model_info.get("config", None) # Model config for direct loading
723 | direct_url = model_info.get("direct_url", None) # Direct download URL
724 | else:
725 | model_path = str(model_info)
726 | required_vram = 2000 * 1024 # Default 2GB
727 | model_type = "v1"
728 | encoder = "vits"
729 | config = None
730 | direct_url = None
731 |
732 | # Only reload if needed or forced
733 | if not force_reload and self.depth_estimator is not None and self.current_model == model_path:
734 | logger.info(f"Model '{model_name}' already loaded")
735 | return
736 |
737 | # Clean up any existing model to free memory before loading new one
738 | self.cleanup()
739 |
740 | # Set up device for model
741 | if self.device is None:
742 | self.device = get_torch_device()
743 |
744 | logger.info(f"Loading depth model: {model_name} on {'CPU' if force_cpu else self.device}")
745 |
746 | # Enhanced VRAM check with better error handling
747 | if torch.cuda.is_available() and not force_cpu:
748 | try:
749 | free_mem_info = get_free_memory(self.device)
750 |
751 | # Process different return formats from get_free_memory
752 | if isinstance(free_mem_info, tuple):
753 | free_mem, total_mem = free_mem_info
754 | logger.info(f"Available VRAM: {free_mem/1024:.2f}MB, Required: {required_vram/1024:.2f}MB")
755 | else:
756 | free_mem = free_mem_info
757 | logger.info(f"Available VRAM: {free_mem/1024:.2f}MB, Required: {required_vram/1024:.2f}MB")
758 | total_mem = free_mem * 2 # Estimate if not available
759 |
760 | # Add buffer to required memory to avoid OOM issues
761 | required_vram_with_buffer = required_vram * 1.2 # Add 20% buffer
762 |
763 | # If not enough memory, fall back to CPU with warning
764 | if free_mem < required_vram_with_buffer:
765 | logger.warning(
766 | f"Insufficient VRAM for {model_name} (need ~{required_vram/1024:.1f}MB, " +
767 | f"have {free_mem/1024:.1f}MB). Using CPU instead."
768 | )
769 | force_cpu = True
770 | except Exception as mem_error:
771 | logger.warning(f"Error checking VRAM: {str(mem_error)}. Using CPU to be safe.")
772 | force_cpu = True
773 |
774 | # Determine optimal device configuration
775 | device_type = 'cpu' if force_cpu else ('cuda' if torch.cuda.is_available() else 'cpu')
776 |
777 | # Use appropriate dtype based on device and model
778 | # FP16 for CUDA saves VRAM but doesn't work well for all models
779 | if 'cuda' in str(self.device) and not force_cpu:
780 | # V2 models have issues with FP16 - use FP32 for them
781 | if model_type == "v2":
782 | dtype = torch.float32
783 | else:
784 | # Other models can use FP16 to save VRAM
785 | dtype = torch.float16
786 | else:
787 | # CPU always uses FP32
788 | dtype = torch.float32
789 |
790 | # Create model-specific cache directory
791 | # Use consistent naming to improve cache hits
792 | model_cache_name = model_name.replace("-", "_").lower()
793 | cache_dir = os.path.join(MODELS_DIR, model_cache_name)
794 | os.makedirs(cache_dir, exist_ok=True)
795 |
796 | # Prioritized loading strategy:
797 | # 1. Check for locally cached model files
798 | # 2. Try direct download from URLs that don't require authentication
799 | # 3. Try loading from Hugging Face using transformers pipeline
800 | # 4. Fall back to direct model implementation
801 | # 5. Fall back to MiDaS model
802 |
803 | # Step 1: First check if we already have a local model file
804 | local_model_file = None
805 |
806 | # Search all valid model directories for existing files
807 | for base_path in existing_paths:
808 | # Check multiple possible locations and naming patterns
809 | locations_to_check = [
810 | os.path.join(base_path, model_cache_name),
811 | os.path.join(base_path, model_path.replace("/", "_")),
812 | base_path,
813 | os.path.join(base_path, "v2") if model_type == "v2" else None,
814 | ]
815 |
816 | # Filter out None values
817 | locations_to_check = [loc for loc in locations_to_check if loc is not None]
818 |
819 | # Common model filenames to check
820 | model_filenames = [
821 | "pytorch_model.bin",
822 | "model.pt",
823 | "model.pth",
824 | f"{model_cache_name}.pt",
825 | f"{model_cache_name}.bin",
826 | f"depth_anything_{encoder}.pt", # Common naming for Depth Anything models
827 | f"depth_anything_v2_{encoder}.pt", # V2 naming format
828 | ]
829 |
830 | # Search all locations and filenames
831 | for location in locations_to_check:
832 | if os.path.exists(location):
833 | for filename in model_filenames:
834 | file_path = os.path.join(location, filename)
835 | if os.path.exists(file_path) and os.path.getsize(file_path) > 1000000: # >1MB to avoid empty files
836 | local_model_file = file_path
837 | logger.info(f"Found existing model file: {local_model_file}")
838 | break
839 |
840 | if local_model_file:
841 | break
842 |
843 | if local_model_file:
844 | break
845 |
846 | # Step 2: If no local file found, try downloading from direct URLs
847 | # These URLs don't require authentication and are more reliable
848 | if not local_model_file and model_type == "v2":
849 | # Comprehensive list of URLs to try for V2 models
850 | alternative_urls = {
851 | "Depth-Anything-V2-Small": [
852 | "https://huggingface.co/depth-anything/Depth-Anything-V2-Small-hf/resolve/main/pytorch_model.bin",
853 | "https://github.com/LiheYoung/Depth-Anything/releases/download/v2.0/depth_anything_v2_small.pt",
854 | "https://huggingface.co/ckpt/depth-anything-v2/resolve/main/depth_anything_v2_small.pt"
855 | ],
856 | "Depth-Anything-V2-Base": [
857 | "https://huggingface.co/depth-anything/Depth-Anything-V2-Base-hf/resolve/main/pytorch_model.bin",
858 | "https://github.com/LiheYoung/Depth-Anything/releases/download/v2.0/depth_anything_v2_base.pt",
859 | "https://huggingface.co/ckpt/depth-anything-v2/resolve/main/depth_anything_v2_base.pt"
860 | ],
861 | "MiDaS-Small": [
862 | "https://github.com/intel-isl/MiDaS/releases/download/v2_1/midas_v21_small_256.pt"
863 | ],
864 | "MiDaS-Base": [
865 | "https://github.com/intel-isl/MiDaS/releases/download/v3/dpt_hybrid-midas-501f0c75.pt"
866 | ]
867 | }
868 |
869 | # Get URLs to try (including the direct_url from model_info)
870 | urls_to_try = []
871 | if direct_url:
872 | urls_to_try.append(direct_url)
873 |
874 | # Add alternative URLs for this specific model if available
875 | if model_name in alternative_urls:
876 | urls_to_try.extend(alternative_urls[model_name])
877 |
878 | # Try downloading the model if not found locally
879 | if urls_to_try:
880 | for url in urls_to_try:
881 | try:
882 | # Determine output filename and path
883 | model_filename = os.path.basename(url)
884 | download_path = os.path.join(cache_dir, model_filename)
885 |
886 | # Check if already downloaded
887 | if os.path.exists(download_path) and os.path.getsize(download_path) > 1000000: # >1MB to avoid empty files
888 | logger.info(f"Found existing downloaded model at {download_path}")
889 | local_model_file = download_path
890 | break
891 |
892 | # Create parent directory if needed
893 | os.makedirs(os.path.dirname(download_path), exist_ok=True)
894 |
895 | # Download using a reliable method with multiple retries
896 | logger.info(f"Downloading model from {url} to {download_path}")
897 |
898 | success = False
899 |
900 | # Try wget first (most reliable for large files)
901 | try:
902 | import wget
903 | wget.download(url, out=download_path)
904 | if os.path.exists(download_path) and os.path.getsize(download_path) > 1000000:
905 | logger.info(f"Successfully downloaded model with wget to {download_path}")
906 | success = True
907 | except Exception as wget_error:
908 | logger.warning(f"wget download failed: {str(wget_error)}")
909 |
910 | # Try requests if wget failed
911 | if not success:
912 | try:
913 | # Download with progress reporting
914 | with requests.get(url, stream=True, timeout=60) as response:
915 | response.raise_for_status()
916 | total_size = int(response.headers.get('content-length', 0))
917 |
918 | with open(download_path, 'wb') as f:
919 | downloaded = 0
920 | for chunk in response.iter_content(chunk_size=8192):
921 | f.write(chunk)
922 | downloaded += len(chunk)
923 |
924 | if total_size > 0 and downloaded % (20 * 1024 * 1024) == 0: # Log every 20MB
925 | percent = int(100 * downloaded / total_size)
926 | logger.info(f"Download progress: {downloaded/1024/1024:.1f}MB of {total_size/1024/1024:.1f}MB ({percent}%)")
927 |
928 | if os.path.exists(download_path) and os.path.getsize(download_path) > 1000000:
929 | logger.info(f"Successfully downloaded model with requests to {download_path}")
930 | success = True
931 | except Exception as req_error:
932 | logger.warning(f"requests download failed: {str(req_error)}")
933 |
934 | # Try urllib as last resort
935 | if not success:
936 | try:
937 | import urllib.request
938 | urllib.request.urlretrieve(url, download_path)
939 | if os.path.exists(download_path) and os.path.getsize(download_path) > 1000000:
940 | logger.info(f"Successfully downloaded model with urllib to {download_path}")
941 | success = True
942 | except Exception as urllib_error:
943 | logger.warning(f"urllib download failed: {str(urllib_error)}")
944 |
945 | # Set the local_model_file if download succeeded
946 | if success:
947 | local_model_file = download_path
948 | break
949 | else:
950 | # Clean up failed download
951 | if os.path.exists(download_path):
952 | try:
953 | os.remove(download_path)
954 | except:
955 | pass
956 |
957 | except Exception as download_error:
958 | logger.warning(f"Failed to download from {url}: {str(download_error)}")
959 | continue
960 |
961 | # Step 3: Try loading with transformers pipeline
962 | # This is the most feature-complete approach but may fail with auth issues
963 | logger.info(f"Trying to load model '{model_name}' using transformers pipeline")
964 |
965 | # Priority-ordered list of model paths to try
966 | # Ordered from most to least likely to work
967 | model_paths_to_try = []
968 |
969 | # Start with the specific model requested
970 | model_paths_to_try.append(model_path)
971 |
972 | # Add V2-specific paths for V2 models
973 | if model_type == "v2":
974 | if "small" in model_name.lower():
975 | model_paths_to_try.append("depth-anything/Depth-Anything-V2-Small-hf")
976 | elif "base" in model_name.lower():
977 | model_paths_to_try.append("depth-anything/Depth-Anything-V2-Base-hf")
978 |
979 | # Add variants with and without -hf suffix
980 | model_paths_to_try.append(model_path.replace("-hf", ""))
981 | if "-hf" not in model_path:
982 | model_paths_to_try.append(model_path + "-hf")
983 |
984 | # Try both organization name formats
985 | model_paths_to_try.append(model_path.replace("LiheYoung", "depth-anything"))
986 | model_paths_to_try.append(model_path.replace("depth-anything", "LiheYoung"))
987 | else:
988 | # For V1 models, add common variants
989 | model_paths_to_try.append(model_path.replace("-hf", ""))
990 | if "-hf" not in model_path:
991 | model_paths_to_try.append(model_path + "-hf")
992 |
993 | # Add MiDaS fallbacks only if not already trying MiDaS
994 | if "midas" not in model_name.lower():
995 | model_paths_to_try.append("Intel/dpt-hybrid-midas")
996 |
997 | # Remove duplicates while preserving order
998 | model_paths_to_try = list(dict.fromkeys(model_paths_to_try))
999 |
1000 | # Log all paths we're going to try
1001 | logger.info(f"Will try loading from these paths in order: {model_paths_to_try}")
1002 |
1003 | # Try loading with transformers pipeline
1004 | pipeline_success = False
1005 | pipeline_error = None
1006 |
1007 | for path in model_paths_to_try:
1008 | # Skip empty paths
1009 | if not path.strip():
1010 | continue
1011 |
1012 | logger.info(f"Attempting to load model from: {path}")
1013 |
1014 | # First try online loading (allows downloading new models)
1015 | try:
1016 | logger.info(f"Loading with online mode: model={path}, device={device_type}, dtype={dtype}")
1017 |
1018 | # Try standard pipeline creation first
1019 | try:
1020 | from transformers import pipeline
1021 |
1022 | # Create pipeline with timeout and error handling
1023 | self.depth_estimator = pipeline(
1024 | "depth-estimation",
1025 | model=path,
1026 | cache_dir=cache_dir,
1027 | local_files_only=False, # Try online first
1028 | device_map=device_type,
1029 | torch_dtype=dtype
1030 | )
1031 |
1032 | # Verify that pipeline was created
1033 | if self.depth_estimator is None:
1034 | raise RuntimeError(f"Pipeline initialization returned None for {path}")
1035 |
1036 | # Validate by running a test inference
1037 | test_img = Image.new("RGB", (64, 64), color=(128, 128, 128))
1038 | try:
1039 | test_result = self.depth_estimator(test_img)
1040 |
1041 | # Further verify the output format
1042 | if not isinstance(test_result, dict) or "predicted_depth" not in test_result:
1043 | raise RuntimeError("Invalid output format from pipeline")
1044 |
1045 | # Success - log and break
1046 | logger.info(f"Successfully loaded model from {path} with online mode")
1047 | pipeline_success = True
1048 | break
1049 |
1050 | except Exception as test_error:
1051 | logger.warning(f"Pipeline created but test failed: {str(test_error)}")
1052 | raise
1053 |
1054 | except TypeError as type_error:
1055 | # Handle unpacking errors which are common with older transformers versions
1056 | logger.warning(f"TypeError when creating pipeline: {str(type_error)}")
1057 |
1058 | # Try alternative approach with manual component loading
1059 | logger.info("Trying manual component loading as alternative...")
1060 |
1061 | try:
1062 | from transformers import AutoModelForDepthEstimation, AutoImageProcessor
1063 |
1064 | # Load components separately
1065 | processor = AutoImageProcessor.from_pretrained(
1066 | path, cache_dir=cache_dir, local_files_only=False
1067 | )
1068 |
1069 | model = AutoModelForDepthEstimation.from_pretrained(
1070 | path, cache_dir=cache_dir, local_files_only=False,
1071 | torch_dtype=dtype
1072 | )
1073 |
1074 | # Move model to correct device
1075 | if not force_cpu and 'cuda' in device_type:
1076 | model = model.to(self.device)
1077 |
1078 | # Create custom pipeline class
1079 | class CustomDepthEstimator:
1080 | def __init__(self, model, processor, device):
1081 | self.model = model
1082 | self.processor = processor
1083 | self.device = device
1084 |
1085 | def __call__(self, image):
1086 | # Process image and run model
1087 | inputs = self.processor(images=image, return_tensors="pt")
1088 |
1089 | # Move inputs to correct device
1090 | if torch.cuda.is_available() and not force_cpu:
1091 | inputs = {k: v.to(self.device) for k, v in inputs.items()}
1092 |
1093 | # Run model
1094 | with torch.no_grad():
1095 | outputs = self.model(**inputs)
1096 |
1097 | # Return results in standard format
1098 | return {"predicted_depth": outputs.predicted_depth}
1099 |
1100 | # Create custom pipeline
1101 | self.depth_estimator = CustomDepthEstimator(model, processor, self.device)
1102 |
1103 | # Test the pipeline
1104 | test_img = Image.new("RGB", (64, 64), color=(128, 128, 128))
1105 | test_result = self.depth_estimator(test_img)
1106 |
1107 | # Verify output format
1108 | if not isinstance(test_result, dict) or "predicted_depth" not in test_result:
1109 | raise RuntimeError("Invalid output format from custom pipeline")
1110 |
1111 | logger.info(f"Successfully loaded model from {path} with custom pipeline")
1112 | pipeline_success = True
1113 | break
1114 |
1115 | except Exception as custom_error:
1116 | logger.warning(f"Custom pipeline creation failed: {str(custom_error)}")
1117 | raise
1118 |
1119 | except Exception as online_error:
1120 | logger.warning(f"Online loading failed for {path}: {str(online_error)}")
1121 | pipeline_error = online_error
1122 |
1123 | # Try local-only mode if online fails (faster and often works with cached files)
1124 | try:
1125 | logger.info(f"Trying local-only mode for {path}")
1126 |
1127 | from transformers import pipeline
1128 |
1129 | self.depth_estimator = pipeline(
1130 | "depth-estimation",
1131 | model=path,
1132 | cache_dir=cache_dir,
1133 | local_files_only=True, # Only use local files
1134 | device_map=device_type,
1135 | torch_dtype=dtype
1136 | )
1137 |
1138 | # Verify pipeline
1139 | if self.depth_estimator is None:
1140 | raise RuntimeError(f"Local pipeline initialization returned None for {path}")
1141 |
1142 | # Test with small image
1143 | test_img = Image.new("RGB", (64, 64), color=(128, 128, 128))
1144 | test_result = self.depth_estimator(test_img)
1145 |
1146 | logger.info(f"Successfully loaded model from {path} with local-only mode")
1147 | pipeline_success = True
1148 | break
1149 |
1150 | except Exception as local_error:
1151 | logger.warning(f"Local-only mode failed for {path}: {str(local_error)}")
1152 | pipeline_error = local_error
1153 | # Continue to next path
1154 |
1155 | # Step 4: If transformers pipeline failed, try direct model loading
1156 | if not pipeline_success:
1157 | logger.info("Pipeline loading failed. Trying direct model implementation.")
1158 |
1159 | # If we have a local model file from previous steps, use it with direct loading
1160 | direct_success = False
1161 |
1162 | if local_model_file:
1163 | logger.info(f"Loading model directly from: {local_model_file}")
1164 |
1165 | # For V2 models, use the DepthAnythingV2 implementation
1166 | if model_type == "v2" and TIMM_AVAILABLE:
1167 | try:
1168 | logger.info(f"Using DepthAnythingV2 implementation with config: {config}")
1169 |
1170 | # Create model instance with appropriate config
1171 | if config:
1172 | model_instance = DepthAnythingV2(**config)
1173 | else:
1174 | # Use default config for this encoder type
1175 | default_config = MODEL_CONFIGS.get(encoder, MODEL_CONFIGS["vits"])
1176 | model_instance = DepthAnythingV2(**default_config)
1177 |
1178 | # Load state dict from file
1179 | logger.info(f"Loading weights from {local_model_file}")
1180 | state_dict = torch.load(local_model_file, map_location="cpu")
1181 |
1182 | # Convert state dict to float32
1183 | if any(v.dtype == torch.float64 for v in state_dict.values() if hasattr(v, 'dtype')):
1184 | logger.info("Converting state dict from float64 to float32")
1185 | state_dict = {k: v.float() if hasattr(v, 'dtype') else v for k, v in state_dict.items()}
1186 |
1187 | # Try loading with different state dict formats
1188 | try:
1189 | if "model" in state_dict:
1190 | model_instance.load_state_dict(state_dict["model"])
1191 | else:
1192 | model_instance.load_state_dict(state_dict)
1193 | except Exception as e:
1194 | logger.warning(f"Strict loading failed: {str(e)}. Trying non-strict loading.")
1195 | if "model" in state_dict:
1196 | model_instance.load_state_dict(state_dict["model"], strict=False)
1197 | else:
1198 | model_instance.load_state_dict(state_dict, strict=False)
1199 |
1200 | # Move to correct device and set eval mode
1201 | model_instance = model_instance.to(self.device).float().eval()
1202 |
1203 | # Test the model
1204 | test_img = Image.new("RGB", (64, 64), color=(128, 128, 128))
1205 | _ = model_instance(test_img)
1206 |
1207 | # Success - assign model
1208 | self.depth_estimator = model_instance
1209 | direct_success = True
1210 | logger.info("Successfully loaded model with direct implementation")
1211 |
1212 | except Exception as v2_error:
1213 | logger.warning(f"DepthAnythingV2 loading failed: {str(v2_error)}")
1214 |
1215 | # If V2 direct loading failed or isn't applicable, try MiDaS wrapper
1216 | if not direct_success:
1217 | try:
1218 | logger.info("Falling back to MiDaS wrapper implementation")
1219 |
1220 | # Determine MiDaS model type based on model name
1221 | midas_type = "DPT_Hybrid" # Default
1222 |
1223 | if "small" in model_name.lower():
1224 | midas_type = "MiDaS_small"
1225 | elif "large" in model_name.lower():
1226 | midas_type = "DPT_Large"
1227 |
1228 | # Create MiDaS wrapper with model type
1229 | midas_model = MiDaSWrapper(midas_type, self.device)
1230 |
1231 | # Test the model
1232 | test_img = Image.new("RGB", (64, 64), color=(128, 128, 128))
1233 | _ = midas_model(test_img)
1234 |
1235 | # Success - assign model
1236 | self.depth_estimator = midas_model
1237 | direct_success = True
1238 | logger.info("Successfully loaded MiDaS fallback model")
1239 |
1240 | except Exception as midas_error:
1241 | logger.warning(f"MiDaS wrapper loading failed: {str(midas_error)}")
1242 |
1243 | # Step 5: If all previous attempts failed, try one last MiDaS fallback
1244 | if not direct_success and not pipeline_success:
1245 | try:
1246 | logger.info("All model loading attempts failed. Trying basic MiDaS fallback.")
1247 |
1248 | # Create simple MiDaS wrapper with default settings
1249 | midas_model = MiDaSWrapper("dpt_hybrid", self.device)
1250 |
1251 | # Test with simple image
1252 | test_img = Image.new("RGB", (64, 64), color=(128, 128, 128))
1253 | _ = midas_model(test_img)
1254 |
1255 | # Assign model
1256 | self.depth_estimator = midas_model
1257 | logger.info("Successfully loaded basic MiDaS fallback model")
1258 |
1259 | except Exception as final_error:
1260 | # If we get here, all attempts have failed
1261 | error_msg = f"All model loading attempts failed for {model_name}. Last error: {str(final_error)}"
1262 | logger.error(error_msg)
1263 |
1264 | # Create a helpful error message with instructions
1265 | all_model_dirs = "\n".join(existing_paths)
1266 |
1267 | # Determine error type for better help message
1268 | if pipeline_error:
1269 | error_str = str(pipeline_error).lower()
1270 |
1271 | if "unauthorized" in error_str or "401" in error_str or "authentication" in error_str:
1272 | # Authentication error - guide for manual download
1273 | error_solution = f"""
1274 | AUTHENTICATION ERROR: The model couldn't be downloaded due to Hugging Face authentication requirements.
1275 |
1276 | SOLUTION:
1277 | 1. Use force_cpu=True in the node settings
1278 | 2. Try a different model like MiDaS-Small
1279 | 3. Download the model manually using one of these direct links:
1280 | - https://huggingface.co/depth-anything/Depth-Anything-V2-Small-hf/resolve/main/pytorch_model.bin
1281 | - https://github.com/LiheYoung/Depth-Anything/releases/download/v2.0/depth_anything_v2_small.pt
1282 | - https://huggingface.co/ckpt/depth-anything-v2/resolve/main/depth_anything_v2_small.pt
1283 |
1284 | Save the file to one of these directories:
1285 | {all_model_dirs}
1286 | """
1287 | elif "cuda" in error_str or "gpu" in error_str or "vram" in error_str or "memory" in error_str:
1288 | # GPU/memory error
1289 | error_solution = """
1290 | GPU ERROR: The model failed to load on your GPU.
1291 |
1292 | SOLUTION:
1293 | 1. Use force_cpu=True to use CPU processing instead
1294 | 2. Reduce input_size parameter to 384 to reduce memory requirements
1295 | 3. Try a smaller model like MiDaS-Small
1296 | 4. Ensure you have the latest GPU drivers installed
1297 | """
1298 | else:
1299 | # Generic error
1300 | error_solution = f"""
1301 | Failed to load any depth estimation model.
1302 |
1303 | SOLUTION:
1304 | 1. Use force_cpu=True to use CPU for processing
1305 | 2. Try using a different model like MiDaS-Small
1306 | 3. Download a model file manually and place it in one of these directories:
1307 | {all_model_dirs}
1308 | 4. Restart ComfyUI to ensure clean state
1309 | """
1310 | else:
1311 | # Generic error when pipeline_error isn't set
1312 | error_solution = f"""
1313 | Failed to load any depth estimation model.
1314 |
1315 | SOLUTION:
1316 | 1. Use force_cpu=True to use CPU for processing
1317 | 2. Try using a different model like MiDaS-Small
1318 | 3. Download a model file manually and place it in one of these directories:
1319 | {all_model_dirs}
1320 | 4. Restart ComfyUI to ensure clean state
1321 | """
1322 |
1323 | # Raise helpful error
1324 | raise RuntimeError(f"MODEL LOADING ERROR: {error_solution}")
1325 |
1326 | # Ensure the model is on the correct device
1327 | if hasattr(self.depth_estimator, 'model') and hasattr(self.depth_estimator.model, 'to'):
1328 | if force_cpu:
1329 | self.depth_estimator.model = self.depth_estimator.model.to('cpu')
1330 | else:
1331 | self.depth_estimator.model = self.depth_estimator.model.to(self.device)
1332 |
1333 | # Set model to eval mode if applicable
1334 | if hasattr(self.depth_estimator, 'eval'):
1335 | self.depth_estimator.eval()
1336 | elif hasattr(self.depth_estimator, 'model') and hasattr(self.depth_estimator.model, 'eval'):
1337 | self.depth_estimator.model.eval()
1338 |
1339 | # Store current model info
1340 | self.current_model = model_path
1341 | logger.info(f"Model '{model_name}' loaded successfully")
1342 |
1343 | except Exception as e:
1344 | # Clean up on failure
1345 | self.cleanup()
1346 |
1347 | # Log detailed error info
1348 | error_msg = f"Failed to load model {model_name}: {str(e)}"
1349 | logger.error(error_msg)
1350 | logger.error(traceback.format_exc())
1351 |
1352 | # Re-raise with clear message
1353 | raise RuntimeError(error_msg)
1354 |
1355 | def load_model_direct(self, model_name, model_info, force_cpu=False):
1356 | """
1357 | Directly loads a depth model without using transformers pipeline.
1358 | This is a fallback method when the normal pipeline loading fails.
1359 |
1360 | Args:
1361 | model_name: Name of the model to load
1362 | model_info: Dictionary with model information
1363 | force_cpu: Whether to force CPU usage
1364 |
1365 | Returns:
1366 | A depth estimation model that implements the __call__ interface
1367 | """
1368 | try:
1369 | logger.info(f"Attempting direct model loading for {model_name}")
1370 |
1371 | # Determine device
1372 | device_type = 'cpu' if force_cpu else ('cuda' if torch.cuda.is_available() else 'cpu')
1373 | device = torch.device(device_type)
1374 |
1375 | # Look in all possible model directories
1376 | # This is important to support various directory structures
1377 | model_found = False
1378 | model_path_local = None
1379 |
1380 | # Make a unique model cache directory for this specific model
1381 | model_subfolder = model_name.replace("-", "_").lower()
1382 |
1383 | # Check all possible locations for the model file
1384 | for base_path in existing_paths:
1385 | # Try different possible locations and filename patterns
1386 | possible_model_locations = [
1387 | # Direct downloads in the model directory
1388 | os.path.join(base_path, model_subfolder),
1389 |
1390 | # Using the full HF directory structure
1391 | os.path.join(base_path, model_info.get("path", "").replace("/", "_")),
1392 |
1393 | # Directly in base directory
1394 | base_path,
1395 | ]
1396 |
1397 | # Add directory structure with model configs if V2
1398 | if model_info.get("model_type") == "v2":
1399 | v2_path = os.path.join(base_path, "v2")
1400 | possible_model_locations.append(v2_path)
1401 | possible_model_locations.append(os.path.join(v2_path, model_subfolder))
1402 |
1403 | # Try all locations
1404 | logger.info(f"Searching for existing model in these directories: {possible_model_locations}")
1405 |
1406 | for location in possible_model_locations:
1407 | # Check for model file with various naming patterns
1408 | if os.path.exists(location):
1409 | # Check for common filenames
1410 | for filename in ["pytorch_model.bin", "model.pt", "model.pth",
1411 | f"{model_subfolder}.pt", f"{model_subfolder}.bin"]:
1412 | file_path = os.path.join(location, filename)
1413 | if os.path.exists(file_path):
1414 | logger.info(f"Found existing model file: {file_path}")
1415 | model_path_local = file_path
1416 | model_found = True
1417 | break
1418 |
1419 | if model_found:
1420 | break
1421 |
1422 | if model_found:
1423 | break
1424 |
1425 | # If model not found, use the first directory for downloading
1426 | cache_dir = os.path.join(existing_paths[0], model_subfolder)
1427 | os.makedirs(cache_dir, exist_ok=True)
1428 |
1429 | # Get model configuration
1430 | model_type = model_info.get("model_type", "v1")
1431 | encoder = model_info.get("encoder", "vits")
1432 | config = model_info.get("config", MODEL_CONFIGS.get(encoder, MODEL_CONFIGS["vits"]))
1433 |
1434 | # Step 1: If model not found locally, download it
1435 | # List of alternative URLs that don't require authentication
1436 | alternative_urls = {
1437 | "Depth-Anything-V2-Small": [
1438 | "https://huggingface.co/depth-anything/Depth-Anything-V2-Small-hf/resolve/main/pytorch_model.bin",
1439 | "https://github.com/LiheYoung/Depth-Anything/releases/download/v2.0/depth_anything_v2_small.pt",
1440 | "https://huggingface.co/ckpt/depth-anything-v2/resolve/main/depth_anything_v2_small.pt"
1441 | ],
1442 | "Depth-Anything-V2-Base": [
1443 | "https://huggingface.co/depth-anything/Depth-Anything-V2-Base-hf/resolve/main/pytorch_model.bin",
1444 | "https://github.com/LiheYoung/Depth-Anything/releases/download/v2.0/depth_anything_v2_base.pt",
1445 | "https://huggingface.co/ckpt/depth-anything-v2/resolve/main/depth_anything_v2_base.pt"
1446 | ],
1447 | "MiDaS-Base": [
1448 | "https://github.com/intel-isl/MiDaS/releases/download/v3/dpt_hybrid-midas-501f0c75.pt"
1449 | ]
1450 | }
1451 |
1452 | # Get primary URL from model_info
1453 | direct_url = model_info.get("direct_url")
1454 |
1455 | # Add alternative URLs to try if the main one fails
1456 | urls_to_try = [direct_url] if direct_url else []
1457 |
1458 | # Add alternative URLs for this model if available
1459 | if model_name in alternative_urls:
1460 | urls_to_try.extend(alternative_urls[model_name])
1461 |
1462 | # Try downloading the model if not found locally
1463 | if not model_found and urls_to_try:
1464 | # Try each URL in sequence until one works
1465 | for url in urls_to_try:
1466 | if not url:
1467 | continue
1468 |
1469 | try:
1470 | model_filename = os.path.basename(url)
1471 | model_path_local = os.path.join(cache_dir, model_filename)
1472 |
1473 | if os.path.exists(model_path_local):
1474 | logger.info(f"Model already exists at {model_path_local}")
1475 | model_found = True
1476 | break
1477 |
1478 | logger.info(f"Attempting to download model from {url} to {model_path_local}")
1479 |
1480 | # Create parent directory if needed
1481 | os.makedirs(os.path.dirname(model_path_local), exist_ok=True)
1482 |
1483 | # Try different download methods
1484 | download_success = False
1485 |
1486 | # First try wget (more reliable for large files)
1487 | try:
1488 | logger.info(f"Downloading with wget: {url}")
1489 | wget.download(url, out=model_path_local)
1490 | logger.info(f"Downloaded model weights to {model_path_local}")
1491 | download_success = True
1492 | except Exception as wget_error:
1493 | logger.warning(f"wget download failed: {str(wget_error)}")
1494 |
1495 | # Fallback to requests
1496 | try:
1497 | logger.info(f"Downloading with requests: {url}")
1498 | response = requests.get(url, stream=True)
1499 |
1500 | if response.status_code == 200:
1501 | total_size = int(response.headers.get('content-length', 0))
1502 | logger.info(f"File size: {total_size/1024/1024:.1f} MB")
1503 |
1504 | with open(model_path_local, 'wb') as f:
1505 | downloaded = 0
1506 | for data in response.iter_content(1024 * 1024): # 1MB chunks
1507 | f.write(data)
1508 | downloaded += len(data)
1509 | if total_size > 0 and downloaded % (10 * 1024 * 1024) == 0: # Log every 10MB
1510 | progress = (downloaded / total_size) * 100
1511 | logger.info(f"Downloaded {downloaded/1024/1024:.1f}MB of {total_size/1024/1024:.1f}MB ({progress:.1f}%)")
1512 |
1513 | logger.info(f"Download complete: {model_path_local}")
1514 | download_success = True
1515 | else:
1516 | logger.warning(f"Failed to download from {url}: HTTP status {response.status_code}")
1517 | except Exception as req_error:
1518 | logger.warning(f"Requests download failed: {str(req_error)}")
1519 |
1520 | # Try urllib as last resort
1521 | if not download_success:
1522 | try:
1523 | logger.info(f"Downloading with urllib: {url}")
1524 | urllib.request.urlretrieve(url, model_path_local)
1525 | logger.info(f"Downloaded model weights to {model_path_local}")
1526 | download_success = True
1527 | except Exception as urllib_error:
1528 | logger.warning(f"urllib download failed: {str(urllib_error)}")
1529 |
1530 | # Check if download succeeded
1531 | if download_success and os.path.exists(model_path_local) and os.path.getsize(model_path_local) > 0:
1532 | logger.info(f"Successfully downloaded model to {model_path_local}")
1533 | model_found = True
1534 | break
1535 | else:
1536 | logger.warning(f"Download appeared to succeed but file is empty or missing")
1537 | # Try to remove the failed download
1538 | if os.path.exists(model_path_local):
1539 | try:
1540 | os.remove(model_path_local)
1541 | except:
1542 | pass
1543 |
1544 | except Exception as dl_error:
1545 | logger.warning(f"Error downloading from {url}: {str(dl_error)}")
1546 | continue
1547 |
1548 | if not model_found:
1549 | logger.error("All download attempts failed")
1550 |
1551 | # Step 2: Create and load the appropriate model if found
1552 | if model_found and model_path_local and os.path.exists(model_path_local):
1553 | logger.info(f"Found model file at: {model_path_local}")
1554 |
1555 | # Handle V2 models with DepthAnythingV2 implementation
1556 | if model_type == "v2" and TIMM_AVAILABLE:
1557 | try:
1558 | logger.info(f"Loading as DepthAnythingV2 model with config: {config}")
1559 |
1560 | # Create model with the appropriate configuration
1561 | model = DepthAnythingV2(**config)
1562 |
1563 | # Load weights from checkpoint
1564 | logger.info(f"Loading weights from {model_path_local}")
1565 | state_dict = torch.load(model_path_local, map_location=device)
1566 |
1567 | # Convert state dict to float32 if needed
1568 | if any(v.dtype == torch.float64 for v in state_dict.values() if hasattr(v, 'dtype')):
1569 | logger.info("Converting state dict from float64 to float32")
1570 | state_dict = {k: v.float() if hasattr(v, 'dtype') else v for k, v in state_dict.items()}
1571 |
1572 | # Attempt to load the state dict (handles different formats)
1573 | try:
1574 | if "model" in state_dict:
1575 | model.load_state_dict(state_dict["model"])
1576 | else:
1577 | model.load_state_dict(state_dict)
1578 | except Exception as e:
1579 | logger.warning(f"Error loading state dict: {str(e)}")
1580 | logger.warning("Trying to load with strict=False")
1581 | if "model" in state_dict:
1582 | model.load_state_dict(state_dict["model"], strict=False)
1583 | else:
1584 | model.load_state_dict(state_dict, strict=False)
1585 |
1586 | # Move model to the correct device and ensure float32
1587 | model = model.float().to(device)
1588 | model.device = device
1589 | model.eval()
1590 |
1591 | # Test the model
1592 | logger.info("Testing model with sample image")
1593 | test_img = Image.new("RGB", (64, 64), color=(128, 128, 128))
1594 |
1595 | try:
1596 | _ = model(test_img)
1597 | logger.info("DepthAnythingV2 model loaded and tested successfully")
1598 | return model
1599 | except Exception as test_error:
1600 | logger.error(f"Error during model test: {str(test_error)}")
1601 | logger.debug(traceback.format_exc())
1602 | except Exception as e:
1603 | logger.error(f"Error loading DepthAnythingV2: {str(e)}")
1604 | logger.debug(traceback.format_exc())
1605 |
1606 | # Fallback to MiDaS model if V2 loading failed or for V1 models
1607 | try:
1608 | logger.info("Falling back to MiDaS model")
1609 |
1610 | # Determine the appropriate MiDaS model type
1611 | midas_model_type = "dpt_hybrid"
1612 | if "large" in model_name.lower():
1613 | midas_model_type = "dpt_large"
1614 | elif "small" in model_name.lower():
1615 | midas_model_type = "midas_v21_small"
1616 |
1617 | # Create and test the MiDaS model
1618 | midas_model = MiDaSWrapper(midas_model_type, device)
1619 |
1620 | # Test with a small image
1621 | test_img = Image.new("RGB", (64, 64), color=(128, 128, 128))
1622 | _ = midas_model(test_img)
1623 |
1624 | logger.info("MiDaS model loaded and tested successfully")
1625 | return midas_model
1626 |
1627 | except Exception as e:
1628 | logger.error(f"Error loading MiDaS: {str(e)}")
1629 | logger.debug(traceback.format_exc())
1630 |
1631 | # If all else fails, return None
1632 | return None
1633 |
1634 | except Exception as e:
1635 | logger.error(f"Direct model loading failed: {str(e)}")
1636 | logger.debug(traceback.format_exc())
1637 | return None
1638 |
1639 | def process_image(self, image: Union[torch.Tensor, np.ndarray], input_size: int = 518) -> Image.Image:
1640 | """
1641 | Converts input image to proper format for depth estimation and resizes it.
1642 |
1643 | Args:
1644 | image: Input image as tensor or numpy array
1645 | input_size: Target size for the longest dimension of the image
1646 |
1647 | Returns:
1648 | PIL Image ready for depth estimation
1649 | """
1650 | try:
1651 | # Log input information for debugging
1652 | if torch.is_tensor(image):
1653 | logger.info(f"Processing tensor image with shape {image.shape}, dtype {image.dtype}")
1654 | elif isinstance(image, np.ndarray):
1655 | logger.info(f"Processing numpy image with shape {image.shape}, dtype {image.dtype}")
1656 | else:
1657 | logger.warning(f"Unexpected image type: {type(image)}")
1658 |
1659 | # Validate and normalize input_size with improved bounds checking
1660 | if not isinstance(input_size, (int, float)):
1661 | logger.warning(f"Invalid input_size type: {type(input_size)}. Using default 518.")
1662 | input_size = 518
1663 | else:
1664 | try:
1665 | # Convert to int and constrain to valid range
1666 | input_size = int(input_size)
1667 | except:
1668 | logger.warning(f"Error converting input_size to int. Using default 518.")
1669 | input_size = 518
1670 |
1671 | # Ensure input_size is within valid range
1672 | if input_size < 256:
1673 | logger.warning(f"Input size {input_size} is too small, using 256 instead")
1674 | input_size = 256
1675 | elif input_size > 1024:
1676 | logger.warning(f"Input size {input_size} is too large, using 1024 instead")
1677 | input_size = 1024
1678 |
1679 | # Process tensor input with comprehensive error handling
1680 | if torch.is_tensor(image):
1681 | try:
1682 | # Check tensor dtype and convert to float32 if needed
1683 | if image.dtype != torch.float32:
1684 | logger.info(f"Converting input tensor from {image.dtype} to torch.float32")
1685 | image = image.float() # Convert to FloatTensor for consistency
1686 |
1687 | # Check for NaN/Inf values in tensor
1688 | nan_count = torch.isnan(image).sum().item()
1689 | inf_count = torch.isinf(image).sum().item()
1690 |
1691 | if nan_count > 0 or inf_count > 0:
1692 | logger.warning(f"Input tensor contains {nan_count} NaN and {inf_count} Inf values. Replacing with valid values.")
1693 | image = torch.nan_to_num(image, nan=0.0, posinf=1.0, neginf=0.0)
1694 |
1695 | # Handle empty or invalid tensors
1696 | if image.numel() == 0:
1697 | logger.error("Input tensor is empty (zero elements)")
1698 | return Image.new('RGB', (512, 512), (128, 128, 128))
1699 |
1700 | # Handle tensor with incorrect number of dimensions
1701 | # We need a 3D or 4D tensor to properly extract the image
1702 | if image.dim() < 3:
1703 | logger.warning(f"Input tensor has too few dimensions: {image.dim()}D. Adding dimensions.")
1704 | # Add dimensions until we have at least 3D tensor
1705 | while image.dim() < 3:
1706 | image = image.unsqueeze(0)
1707 | logger.info(f"Adjusted tensor shape to: {image.shape}")
1708 |
1709 | # Normalize values to [0, 1] range if needed
1710 | if image.max() > 1.0 + 1e-5: # Allow small floating point error
1711 | min_val, max_val = image.min().item(), image.max().item()
1712 | logger.info(f"Input tensor values outside [0,1] range: min={min_val}, max={max_val}. Normalizing.")
1713 |
1714 | # Common ranges and conversions
1715 | if min_val >= 0 and max_val <= 255:
1716 | # Assume [0-255] range for images
1717 | image = image / 255.0
1718 | else:
1719 | # General normalization
1720 | image = (image - min_val) / (max_val - min_val)
1721 |
1722 | # Extract first image from batch if we have a batch dimension
1723 | if image.dim() == 4: # [B, C, H, W]
1724 | # Use first image in batch
1725 | image_for_conversion = image[0]
1726 | else: # 3D tensor, assumed to be [C, H, W]
1727 | image_for_conversion = image
1728 |
1729 | # Move to CPU for numpy conversion
1730 | image_for_conversion = image_for_conversion.cpu()
1731 |
1732 | # Convert to numpy, handling different layouts
1733 | if image_for_conversion.shape[0] <= 3: # [C, H, W] format with 1-3 channels
1734 | # Standard CHW layout - convert to HWC for PIL
1735 | image_np = image_for_conversion.permute(1, 2, 0).numpy()
1736 | image_np = image_np * 255.0 # Scale to [0, 255]
1737 | else:
1738 | # Unusual channel count - probably not CHW format
1739 | logger.warning(f"Unusual channel count for CHW format: {image_for_conversion.shape[0]}. Using reshape logic.")
1740 |
1741 | # Try to infer the format and convert appropriately
1742 | if image_for_conversion.dim() == 3 and image_for_conversion.shape[-1] <= 3:
1743 | # Likely [H, W, C] format, no need to permute
1744 | image_np = image_for_conversion.numpy() * 255.0
1745 | else:
1746 | # Unknown format - try to reshape intelligently
1747 | logger.warning("Unable to determine tensor layout. Using first 3 channels.")
1748 |
1749 | # Default: take first 3 channels (or fewer if < 3 channels)
1750 | channels = min(3, image_for_conversion.shape[0])
1751 | image_np = image_for_conversion[:channels].permute(1, 2, 0).numpy() * 255.0
1752 |
1753 | # Ensure we have proper RGB image (3 channels)
1754 | if len(image_np.shape) == 2: # Single channel image
1755 | image_np = np.stack([image_np] * 3, axis=-1)
1756 | elif image_np.shape[-1] == 1: # Single channel in last dimension
1757 | image_np = np.concatenate([image_np] * 3, axis=-1)
1758 | elif image_np.shape[-1] == 4: # RGBA image - drop alpha channel
1759 | image_np = image_np[..., :3]
1760 | elif image_np.shape[-1] > 4: # More than 4 channels - use first 3
1761 | logger.warning(f"Image has {image_np.shape[-1]} channels. Using first 3 channels.")
1762 | image_np = image_np[..., :3]
1763 |
1764 | # Ensure proper data type and range
1765 | image_np = np.clip(image_np, 0, 255).astype(np.uint8)
1766 |
1767 | except Exception as tensor_error:
1768 | logger.error(f"Error processing tensor image: {str(tensor_error)}")
1769 | logger.error(traceback.format_exc())
1770 | # Create a fallback RGB gradient as placeholder
1771 | placeholder = np.zeros((512, 512, 3), dtype=np.uint8)
1772 | # Add a gradient pattern for visual distinction
1773 | for i in range(512):
1774 | v = int(i / 512 * 255)
1775 | placeholder[i, :, 0] = v # R channel gradient
1776 | placeholder[:, i, 1] = v # G channel gradient
1777 | return Image.fromarray(placeholder)
1778 |
1779 | # Process numpy array input with comprehensive error handling
1780 | elif isinstance(image, np.ndarray):
1781 | try:
1782 | # Check for NaN/Inf values in array
1783 | if np.isnan(image).any() or np.isinf(image).any():
1784 | logger.warning("Input array contains NaN or Inf values. Replacing with zeros.")
1785 | image = np.nan_to_num(image, nan=0.0, posinf=1.0, neginf=0.0)
1786 |
1787 | # Handle empty or invalid arrays
1788 | if image.size == 0:
1789 | logger.error("Input array is empty (zero elements)")
1790 | return Image.new('RGB', (512, 512), (128, 128, 128))
1791 |
1792 | # Convert high-precision types to float32 for consistency
1793 | if image.dtype == np.float64 or image.dtype == np.float16:
1794 | logger.info(f"Converting numpy array from {image.dtype} to float32")
1795 | image = image.astype(np.float32)
1796 |
1797 | # Normalize values to [0-1] range if float array
1798 | if np.issubdtype(image.dtype, np.floating):
1799 | # Check current range
1800 | min_val, max_val = image.min(), image.max()
1801 |
1802 | # Normalize if outside [0,1] range
1803 | if min_val < 0.0 - 1e-5 or max_val > 1.0 + 1e-5: # Allow small floating point error
1804 | logger.info(f"Normalizing array from range [{min_val:.2f}, {max_val:.2f}] to [0, 1]")
1805 |
1806 | # Common ranges and conversions
1807 | if min_val >= 0 and max_val <= 255:
1808 | # Assume [0-255] float range
1809 | image = image / 255.0
1810 | else:
1811 | # General min-max normalization
1812 | image = (image - min_val) / (max_val - min_val)
1813 |
1814 | # Convert normalized float to uint8 for PIL
1815 | image_np = (image * 255).astype(np.uint8)
1816 | elif np.issubdtype(image.dtype, np.integer):
1817 | # For integer types, check if normalization is needed
1818 | max_val = image.max()
1819 | if max_val > 255:
1820 | logger.info(f"Scaling integer array with max value {max_val} to 0-255 range")
1821 | # Scale the array to 0-255 range
1822 | scaled = (image.astype(np.float32) / max_val) * 255
1823 | image_np = scaled.astype(np.uint8)
1824 | else:
1825 | # Already in valid range
1826 | image_np = image.astype(np.uint8)
1827 | else:
1828 | # Unsupported dtype - convert through float32
1829 | logger.warning(f"Unsupported dtype: {image.dtype}. Converting through float32.")
1830 | image_np = (image.astype(np.float32) * 255).astype(np.uint8)
1831 |
1832 | # Handle different channel configurations and dimensions
1833 | if len(image_np.shape) == 2:
1834 | # Grayscale image - convert to RGB
1835 | logger.info("Converting grayscale image to RGB")
1836 | image_np = np.stack([image_np] * 3, axis=-1)
1837 | elif len(image_np.shape) == 3:
1838 | # Check channel dimension
1839 | if image_np.shape[-1] == 1:
1840 | # Single-channel 3D array - convert to RGB
1841 | image_np = np.concatenate([image_np] * 3, axis=-1)
1842 | elif image_np.shape[-1] == 4:
1843 | # RGBA image - drop alpha channel
1844 | image_np = image_np[..., :3]
1845 | elif image_np.shape[-1] > 4:
1846 | # More than 4 channels - use first 3
1847 | logger.warning(f"Image has {image_np.shape[-1]} channels. Using first 3 channels.")
1848 | image_np = image_np[..., :3]
1849 | elif image_np.shape[-1] < 3:
1850 | # Less than 3 channels but not 1 - unusual case
1851 | logger.warning(f"Unusual channel count: {image_np.shape[-1]}. Expanding to RGB.")
1852 | # Repeat the channels to get 3
1853 | channels = [image_np[..., i % image_np.shape[-1]] for i in range(3)]
1854 | image_np = np.stack(channels, axis=-1)
1855 | elif len(image_np.shape) > 3:
1856 | # More than 3 dimensions - attempt to extract a valid image
1857 | logger.warning(f"Array has {len(image_np.shape)} dimensions. Attempting to extract 3D slice.")
1858 |
1859 | # Try to get a 3D slice with channels in the last dimension
1860 | if image_np.shape[-1] <= 3:
1861 | # Extract first instance of higher dimensions
1862 | while len(image_np.shape) > 3:
1863 | image_np = image_np[0]
1864 |
1865 | # If we have less than 3 channels, expand to RGB
1866 | if image_np.shape[-1] < 3:
1867 | channels = [image_np[..., i % image_np.shape[-1]] for i in range(3)]
1868 | image_np = np.stack(channels, axis=-1)
1869 | else:
1870 | # Channels not in last dimension - reshape based on assumptions
1871 | logger.warning("Could not determine valid layout. Creating placeholder.")
1872 | image_np = np.zeros((512, 512, 3), dtype=np.uint8)
1873 | # Add gradient for visual distinction
1874 | for i in range(512):
1875 | v = int(i / 512 * 255)
1876 | image_np[i, :, 0] = v
1877 |
1878 | except Exception as numpy_error:
1879 | logger.error(f"Error processing numpy array: {str(numpy_error)}")
1880 | logger.error(traceback.format_exc())
1881 | # Create a fallback pattern as placeholder
1882 | placeholder = np.zeros((512, 512, 3), dtype=np.uint8)
1883 | # Add checkboard pattern
1884 | for i in range(0, 512, 32):
1885 | for j in range(0, 512, 32):
1886 | if (i//32 + j//32) % 2 == 0:
1887 | placeholder[i:i+32, j:j+32] = 200
1888 | return Image.fromarray(placeholder)
1889 |
1890 | # Fallback for non-tensor, non-numpy inputs
1891 | else:
1892 | logger.error(f"Unsupported image type: {type(image)}")
1893 | return Image.new('RGB', (512, 512), (100, 100, 150)) # Distinct color for type errors
1894 |
1895 | # Convert to PIL image with error handling
1896 | try:
1897 | pil_image = Image.fromarray(image_np)
1898 | except Exception as pil_error:
1899 | logger.error(f"Error creating PIL image: {str(pil_error)}")
1900 | # Try shape correction if possible
1901 | try:
1902 | if len(image_np.shape) != 3 or image_np.shape[-1] not in [1, 3, 4]:
1903 | logger.warning(f"Invalid array shape for PIL: {image_np.shape}")
1904 | # Create valid RGB array as fallback
1905 | image_np = np.zeros((512, 512, 3), dtype=np.uint8)
1906 | pil_image = Image.fromarray(image_np)
1907 | else:
1908 | # Other error - use placeholder
1909 | pil_image = Image.new('RGB', (512, 512), (128, 128, 128))
1910 | except:
1911 | # Ultimate fallback
1912 | pil_image = Image.new('RGB', (512, 512), (128, 128, 128))
1913 |
1914 | # Resize the image while preserving aspect ratio and ensuring multiple of 32
1915 | # The multiple of 32 constraint helps prevent tensor dimension errors
1916 | width, height = pil_image.size
1917 | logger.info(f"Original PIL image size: {width}x{height}")
1918 |
1919 | # Determine which dimension to scale to input_size
1920 | if width > height:
1921 | new_width = input_size
1922 | new_height = int(height * (new_width / width))
1923 | else:
1924 | new_height = input_size
1925 | new_width = int(width * (new_height / height))
1926 |
1927 | # Ensure dimensions are multiples of 32 for better compatibility
1928 | new_width = ((new_width + 31) // 32) * 32
1929 | new_height = ((new_height + 31) // 32) * 32
1930 |
1931 | # Resize the image with antialiasing
1932 | try:
1933 | resized_image = pil_image.resize((new_width, new_height), Image.LANCZOS)
1934 | logger.info(f"Resized image from {width}x{height} to {new_width}x{new_height}")
1935 |
1936 | # Verify the resized image
1937 | if resized_image.size[0] <= 0 or resized_image.size[1] <= 0:
1938 | raise ValueError(f"Invalid resize dimensions: {resized_image.size}")
1939 |
1940 | return resized_image
1941 | except Exception as resize_error:
1942 | logger.error(f"Error during image resize: {str(resize_error)}")
1943 |
1944 | # Try a simpler resize method as fallback
1945 | try:
1946 | logger.info("Trying fallback resize method")
1947 | resized_image = pil_image.resize((new_width, new_height), Image.NEAREST)
1948 | return resized_image
1949 | except:
1950 | # Last resort - return original image or placeholder
1951 | if width > 0 and height > 0:
1952 | logger.warning("Fallback resize failed. Returning original image.")
1953 | return pil_image
1954 | else:
1955 | logger.warning("Invalid original image. Returning placeholder.")
1956 | return Image.new('RGB', (512, 512), (128, 128, 128))
1957 |
1958 | except Exception as e:
1959 | # Global catch-all handler
1960 | logger.error(f"Error processing image: {str(e)}")
1961 | logger.error(traceback.format_exc())
1962 |
1963 | # Create a visually distinct placeholder image
1964 | placeholder = Image.new('RGB', (512, 512), (120, 80, 80))
1965 |
1966 | try:
1967 | # Add error text to the image for better user feedback
1968 | from PIL import ImageDraw, ImageFont
1969 | draw = ImageDraw.Draw(placeholder)
1970 |
1971 | # Try to get a font, fall back to default if needed
1972 | try:
1973 | font = ImageFont.truetype("arial.ttf", 20)
1974 | except:
1975 | font = ImageFont.load_default()
1976 |
1977 | # Add error text
1978 | error_text = str(e)
1979 | # Limit error text length
1980 | if len(error_text) > 60:
1981 | error_text = error_text[:57] + "..."
1982 |
1983 | # Draw error message
1984 | draw.text((10, 10), "Image Processing Error", fill=(255, 50, 50), font=font)
1985 | draw.text((10, 40), error_text, fill=(255, 255, 255), font=font)
1986 | except:
1987 | # Error adding text - just return the plain placeholder
1988 | pass
1989 |
1990 | return placeholder
1991 |
1992 | def _create_error_image(self, input_image=None):
1993 | """Create an error image placeholder based on input image if possible."""
1994 | try:
1995 | if input_image is not None and isinstance(input_image, torch.Tensor) and input_image.shape[0] > 0:
1996 | # Check tensor type - if it's float64, log it for debugging
1997 | if input_image.dtype == torch.float64 or input_image.dtype == torch.double:
1998 | logger.info(f"Input tensor for error image is {input_image.dtype}, will create float32 error image")
1999 |
2000 | # Create gray error image with same dimensions as input
2001 | # Ensure tensor has the right shape for error display (BHWC)
2002 | if input_image.ndim == 4:
2003 | if input_image.shape[-1] != 3: # if not BHWC format
2004 | if input_image.shape[1] == 3: # if BCHW format
2005 | # Extract height and width from BCHW
2006 | h, w = input_image.shape[2], input_image.shape[3]
2007 | else:
2008 | # Default to dimensions from input
2009 | h, w = input_image.shape[2], input_image.shape[3]
2010 | else:
2011 | # Already in BHWC format
2012 | h, w = input_image.shape[1], input_image.shape[2]
2013 | else:
2014 | # Unexpected shape, use default
2015 | return self._create_basic_error_image()
2016 |
2017 | # Make sure dimensions aren't too small
2018 | if h <= 1 or w <= 1:
2019 | logger.warning(f"Input has invalid dimensions {h}x{w}, using default error image")
2020 | return self._create_basic_error_image()
2021 |
2022 | # Gray background with slight red tint to indicate error - explicitly use float32
2023 | placeholder = torch.ones((1, h, w, 3), dtype=torch.float32) * torch.tensor([0.5, 0.4, 0.4], dtype=torch.float32)
2024 |
2025 | if self.device is not None:
2026 | placeholder = placeholder.to(self.device)
2027 |
2028 | # Verify the placeholder is float32
2029 | if placeholder.dtype != torch.float32:
2030 | logger.warning(f"Error image has unexpected dtype {placeholder.dtype}, converting to float32")
2031 | placeholder = placeholder.float()
2032 |
2033 | return placeholder
2034 | else:
2035 | return self._create_basic_error_image()
2036 | except Exception as e:
2037 | logger.error(f"Error creating error image: {str(e)}")
2038 | return self._create_basic_error_image()
2039 |
2040 | def _create_basic_error_image(self):
2041 | """Create a basic error image when no input dimensions are available."""
2042 | # Standard size error image (512x512)
2043 | h, w = 512, 512
2044 | # Gray background with slight red tint to indicate error - explicitly use float32
2045 | placeholder = torch.ones((1, h, w, 3), dtype=torch.float32) * torch.tensor([0.5, 0.4, 0.4], dtype=torch.float32)
2046 |
2047 | if self.device is not None:
2048 | placeholder = placeholder.to(self.device)
2049 |
2050 | # Double-check that we're returning a float32 tensor
2051 | if placeholder.dtype != torch.float32:
2052 | placeholder = placeholder.float()
2053 |
2054 | return placeholder
2055 |
2056 | def _add_error_text_to_image(self, image_tensor, error_text):
2057 | """Add error text to the image tensor for visual feedback."""
2058 | try:
2059 | # Convert tensor to PIL for text rendering
2060 | if image_tensor is None:
2061 | return
2062 |
2063 | temp_img = self._tensor_to_pil(image_tensor)
2064 |
2065 | # Draw error text
2066 | draw = ImageDraw.Draw(temp_img)
2067 |
2068 | # Try to get a font, fall back to default if needed
2069 | try:
2070 | font = ImageFont.truetype("arial.ttf", 20)
2071 | except:
2072 | font = ImageFont.load_default()
2073 |
2074 | # Split text into multiple lines if too long
2075 | lines = []
2076 | words = error_text.split()
2077 | current_line = words[0] if words else "Error"
2078 |
2079 | for word in words[1:]:
2080 | if len(current_line + " " + word) < 50:
2081 | current_line += " " + word
2082 | else:
2083 | lines.append(current_line)
2084 | current_line = word
2085 |
2086 | lines.append(current_line)
2087 |
2088 | # Draw title
2089 | draw.text((10, 10), "Depth Estimation Error", fill=(255, 50, 50), font=font)
2090 |
2091 | # Draw error message
2092 | y_position = 40
2093 | for line in lines:
2094 | draw.text((10, y_position), line, fill=(255, 255, 255), font=font)
2095 | y_position += 25
2096 |
2097 | # Convert back to tensor
2098 | result = self._pil_to_tensor(temp_img)
2099 |
2100 | # Copy to original tensor if shapes match
2101 | if image_tensor.shape == result.shape:
2102 | image_tensor.copy_(result)
2103 | return image_tensor
2104 |
2105 | except Exception as e:
2106 | logger.error(f"Error adding text to error image: {e}")
2107 | return image_tensor
2108 |
2109 | def _tensor_to_pil(self, tensor):
2110 | """Convert a tensor to PIL Image."""
2111 | if tensor.shape[0] == 1: # Batch size 1
2112 | img_np = (tensor[0].cpu().numpy() * 255).astype(np.uint8)
2113 | return Image.fromarray(img_np)
2114 | return Image.new('RGB', (512, 512), color=(128, 100, 100))
2115 |
2116 | def _pil_to_tensor(self, pil_img):
2117 | """Convert PIL Image back to tensor."""
2118 | img_np = np.array(pil_img).astype(np.float32) / 255.0
2119 | tensor = torch.from_numpy(img_np).unsqueeze(0)
2120 |
2121 | if self.device is not None:
2122 | tensor = tensor.to(self.device)
2123 |
2124 | return tensor
2125 |
2126 | def estimate_depth(self,
2127 | image: torch.Tensor,
2128 | model_name: str,
2129 | input_size: int = 518,
2130 | blur_radius: float = 2.0,
2131 | median_size: str = "5",
2132 | apply_auto_contrast: bool = True,
2133 | apply_gamma: bool = True,
2134 | force_reload: bool = False,
2135 | force_cpu: bool = False) -> Tuple[torch.Tensor]:
2136 | """
2137 | Estimates depth from input image with error handling and cleanup.
2138 |
2139 | Args:
2140 | image: Input image tensor
2141 | model_name: Name of the depth model to use
2142 | input_size: Target size for the longest dimension of the image (between 256 and 1024)
2143 | blur_radius: Gaussian blur radius for smoothing
2144 | median_size: Size of median filter for noise reduction
2145 | apply_auto_contrast: Whether to enhance contrast automatically
2146 | apply_gamma: Whether to apply gamma correction
2147 | force_reload: Whether to force reload the model
2148 | force_cpu: Whether to force using CPU for inference
2149 |
2150 | Returns:
2151 | Tuple containing depth map tensor
2152 | """
2153 | error_image = None
2154 | start_time = time.time()
2155 |
2156 | try:
2157 | # Sanity check inputs and log initial info
2158 | logger.info(f"Starting depth estimation with model: {model_name}, input_size: {input_size}, force_cpu: {force_cpu}")
2159 |
2160 | # Enhanced input validation with better error handling
2161 | if image is None:
2162 | logger.error("Input image is None")
2163 | error_image = self._create_basic_error_image()
2164 | self._add_error_text_to_image(error_image, "Input image is None")
2165 | return (error_image,)
2166 |
2167 | if image.numel() == 0:
2168 | logger.error("Input image is empty (zero elements)")
2169 | error_image = self._create_basic_error_image()
2170 | self._add_error_text_to_image(error_image, "Input image is empty (zero elements)")
2171 | return (error_image,)
2172 |
2173 | # Log tensor information before processing
2174 | logger.info(f"Input tensor shape: {image.shape}, dtype: {image.dtype}, device: {image.device}")
2175 |
2176 | # Verify tensor dimensions - support different input formats
2177 | if image.ndim != 4:
2178 | logger.warning(f"Expected 4D tensor for image, got {image.ndim}D. Attempting to reshape.")
2179 | try:
2180 | # Try to reshape based on common dimension patterns
2181 | if image.ndim == 3:
2182 | # Could be [C, H, W] or [H, W, C] format
2183 | if image.shape[0] <= 3 and image.shape[0] > 0: # Likely [C, H, W]
2184 | image = image.unsqueeze(0) # Add batch dim -> [1, C, H, W]
2185 | logger.info(f"Reshaped 3D tensor to 4D with shape: {image.shape}")
2186 | elif image.shape[-1] <= 3 and image.shape[-1] > 0: # Likely [H, W, C]
2187 | # Permute to [C, H, W] then add batch dim
2188 | image = image.permute(2, 0, 1).unsqueeze(0)
2189 | logger.info(f"Reshaped HWC tensor to BCHW with shape: {image.shape}")
2190 | else:
2191 | # Assume single channel image and add missing dimensions
2192 | image = image.unsqueeze(0).unsqueeze(0)
2193 | logger.info(f"Added batch and channel dimensions to 3D tensor: {image.shape}")
2194 | elif image.ndim == 2:
2195 | # Assume [H, W] format - add batch and channel dims
2196 | image = image.unsqueeze(0).unsqueeze(0)
2197 | logger.info(f"Reshaped 2D tensor to 4D with shape: {image.shape}")
2198 | elif image.ndim > 4:
2199 | # Too many dimensions, collapse extras
2200 | orig_shape = image.shape
2201 | image = image.reshape(1, orig_shape[1], orig_shape[-2], orig_shape[-1])
2202 | logger.info(f"Collapsed >4D tensor to 4D with shape: {image.shape}")
2203 | else:
2204 | # Fallback for other unusual dimensions
2205 | logger.error(f"Cannot automatically reshape tensor with {image.ndim} dimensions")
2206 | error_image = self._create_basic_error_image()
2207 | self._add_error_text_to_image(error_image, f"Unsupported tensor dimensions: {image.ndim}D")
2208 | return (error_image,)
2209 | except Exception as reshape_error:
2210 | logger.error(f"Error reshaping tensor: {str(reshape_error)}")
2211 | error_image = self._create_basic_error_image()
2212 | self._add_error_text_to_image(error_image, f"Error reshaping tensor: {str(reshape_error)[:100]}")
2213 | return (error_image,)
2214 |
2215 | # Comprehensive type checking and conversion - verify at multiple points
2216 | # 1. Initial type check and convert if needed
2217 | if image.dtype != torch.float32:
2218 | logger.info(f"Converting input tensor from {image.dtype} to torch.float32")
2219 |
2220 | # Safe conversion that handles different input types
2221 | try:
2222 | if image.dtype == torch.uint8:
2223 | # Normalize [0-255] -> [0-1] for uint8 inputs
2224 | image = image.float() / 255.0
2225 | else:
2226 | # Standard conversion for other types
2227 | image = image.float()
2228 | except Exception as type_error:
2229 | logger.error(f"Error converting tensor type: {str(type_error)}")
2230 | error_image = self._create_basic_error_image()
2231 | self._add_error_text_to_image(error_image, f"Type conversion error: {str(type_error)[:100]}")
2232 | return (error_image,)
2233 |
2234 | # Create error image placeholder based on input dimensions
2235 | error_image = self._create_error_image(image)
2236 |
2237 | # 2. Check for NaN/Inf values and fix them
2238 | nan_count = torch.isnan(image).sum().item()
2239 | inf_count = torch.isinf(image).sum().item()
2240 |
2241 | if nan_count > 0 or inf_count > 0:
2242 | logger.warning(f"Input contains {nan_count} NaN and {inf_count} Inf values. Fixing problematic values.")
2243 | image = torch.nan_to_num(image, nan=0.0, posinf=1.0, neginf=0.0)
2244 |
2245 | # 3. Check value range and normalize if needed
2246 | min_val, max_val = image.min().item(), image.max().item()
2247 | if max_val > 1.0 + 1e-5: # Allow small floating point error
2248 | logger.info(f"Input values outside [0,1] range: min={min_val}, max={max_val}. Normalizing.")
2249 |
2250 | # Ensure values are in [0,1] range - handle different input formats
2251 | if min_val >= 0 and max_val <= 255:
2252 | # Likely [0-255] range
2253 | image = image / 255.0
2254 | else:
2255 | # General min-max normalization
2256 | image = (image - min_val) / (max_val - min_val)
2257 |
2258 | # Parameter validation with safer defaults
2259 | # Validate and normalize median_size parameter
2260 | median_size_str = str(median_size) # Convert to string regardless of input type
2261 | if median_size_str not in self.MEDIAN_SIZES:
2262 | logger.warning(f"Invalid median_size: '{median_size}' (type: {type(median_size)}). Using default '5'.")
2263 | median_size_str = "5"
2264 |
2265 | # Validate input_size with stricter bounds
2266 | if not isinstance(input_size, (int, float)):
2267 | logger.warning(f"Invalid input_size type: {type(input_size)}. Using default 518.")
2268 | input_size = 518
2269 | else:
2270 | # Convert to int and constrain to valid range
2271 | try:
2272 | input_size = int(input_size)
2273 | input_size = max(256, min(input_size, 1024)) # Clamp between 256 and 1024
2274 | except:
2275 | logger.warning(f"Error converting input_size to int. Using default 518.")
2276 | input_size = 518
2277 |
2278 | # Try loading the model with graceful fallback
2279 | try:
2280 | self.ensure_model_loaded(model_name, force_reload, force_cpu)
2281 | logger.info(f"Model '{model_name}' loaded successfully")
2282 | except Exception as model_error:
2283 | error_msg = f"Failed to load model '{model_name}': {str(model_error)}"
2284 | logger.error(error_msg)
2285 |
2286 | # Try a more reliable fallback model before giving up
2287 | fallback_models = ["MiDaS-Small", "MiDaS-Base"]
2288 | for fallback_model in fallback_models:
2289 | if fallback_model != model_name:
2290 | logger.info(f"Attempting to load fallback model: {fallback_model}")
2291 | try:
2292 | self.ensure_model_loaded(fallback_model, True, True) # Force reload and CPU for reliability
2293 | logger.info(f"Fallback model '{fallback_model}' loaded successfully")
2294 | # Update model_name to reflect the fallback
2295 | model_name = fallback_model
2296 | # Break the loop since we successfully loaded a fallback
2297 | break
2298 | except Exception as fallback_error:
2299 | logger.warning(f"Fallback model '{fallback_model}' also failed: {str(fallback_error)}")
2300 | continue
2301 |
2302 | # If we still don't have a model loaded, return error image
2303 | if self.depth_estimator is None:
2304 | self._add_error_text_to_image(error_image, f"Model Error: {str(model_error)[:100]}...")
2305 | return (error_image,)
2306 |
2307 | # Process input image with enhanced error recovery
2308 | try:
2309 | # Convert to PIL with robust error handling
2310 | pil_image = self.process_image(image, input_size)
2311 | # Store original dimensions for later resizing
2312 | original_width, original_height = pil_image.size
2313 | logger.info(f"Image processed to size: {pil_image.size} (will preserve these dimensions in output)")
2314 | except Exception as img_error:
2315 | logger.error(f"Image processing error: {str(img_error)}")
2316 | logger.error(traceback.format_exc())
2317 |
2318 | # Try a more basic approach if the standard processing fails
2319 | try:
2320 | logger.info("Attempting basic image conversion as fallback")
2321 | # Simple conversion fallback
2322 | if image.shape[1] > 3: # BCHW format with unusual channel count
2323 | # Select first 3 channels or average if > 3
2324 | logger.warning(f"Unusual channel count: {image.shape[1]}. Using first 3 channels.")
2325 | if image.shape[1] > 3:
2326 | image = image[:, :3, :, :]
2327 |
2328 | # Convert to CPU numpy array
2329 | img_np = image.squeeze(0).cpu().numpy()
2330 |
2331 | # Handle different layouts
2332 | if img_np.shape[0] <= 3: # [C, H, W]
2333 | img_np = np.transpose(img_np, (1, 2, 0)) # -> [H, W, C]
2334 |
2335 | # Ensure 3 channels
2336 | if len(img_np.shape) == 2: # Grayscale
2337 | img_np = np.stack([img_np] * 3, axis=-1)
2338 | elif img_np.shape[-1] == 1: # Single channel
2339 | img_np = np.concatenate([img_np] * 3, axis=-1)
2340 |
2341 | # Normalize values
2342 | if img_np.max() > 1.0:
2343 | img_np = img_np / 255.0
2344 |
2345 | # Convert to PIL
2346 | img_np = (img_np * 255).astype(np.uint8)
2347 | pil_image = Image.fromarray(img_np)
2348 |
2349 | # Store original dimensions
2350 | original_width, original_height = pil_image.size
2351 |
2352 | # Resize to appropriate dimensions
2353 | if input_size > 0:
2354 | w, h = pil_image.size
2355 | # Determine which dimension to scale to input_size
2356 | if w > h:
2357 | new_w = input_size
2358 | new_h = int(h * (new_w / w))
2359 | else:
2360 | new_h = input_size
2361 | new_w = int(w * (new_h / h))
2362 |
2363 | # Ensure dimensions are multiples of 32
2364 | new_h = ((new_h + 31) // 32) * 32
2365 | new_w = ((new_w + 31) // 32) * 32
2366 |
2367 | pil_image = pil_image.resize((new_w, new_h), Image.LANCZOS)
2368 |
2369 | logger.info(f"Fallback image processing succeeded with size: {pil_image.size}")
2370 | except Exception as fallback_error:
2371 | logger.error(f"Fallback image processing also failed: {str(fallback_error)}")
2372 | self._add_error_text_to_image(error_image, f"Image Error: {str(img_error)[:100]}...")
2373 | return (error_image,)
2374 |
2375 | # Depth estimation with comprehensive error handling
2376 | try:
2377 | # Use inference_mode for better memory usage
2378 | with torch.inference_mode():
2379 | logger.info(f"Running inference on image of size {pil_image.size}")
2380 |
2381 | # Ensure the depth estimator is in eval mode
2382 | if hasattr(self.depth_estimator, 'eval'):
2383 | self.depth_estimator.eval()
2384 |
2385 | # Add timing for performance analysis
2386 | inference_start = time.time()
2387 |
2388 | # Perform inference with better error handling
2389 | try:
2390 | depth_result = self.depth_estimator(pil_image)
2391 | inference_time = time.time() - inference_start
2392 | logger.info(f"Depth inference completed in {inference_time:.2f} seconds")
2393 | except Exception as inference_error:
2394 | logger.error(f"Depth estimator inference error: {str(inference_error)}")
2395 |
2396 | # Detailed error analysis for better debugging
2397 | error_str = str(inference_error).lower()
2398 |
2399 | # Try fallback for common error types
2400 | if "cuda" in error_str and "out of memory" in error_str:
2401 | logger.warning("CUDA out of memory detected. Attempting CPU fallback.")
2402 | # Already on CPU? Check and handle
2403 | if force_cpu:
2404 | logger.error("Already using CPU but still encountered memory error")
2405 | self._add_error_text_to_image(error_image, "Memory error even on CPU. Try smaller input size.")
2406 | return (error_image,)
2407 | else:
2408 | # Try CPU fallback
2409 | logger.info("Switching to CPU processing")
2410 | return self.estimate_depth(
2411 | image.cpu(), model_name, input_size, blur_radius, median_size_str,
2412 | apply_auto_contrast, apply_gamma, True, True # Force CPU
2413 | )
2414 |
2415 | # Type mismatch errors
2416 | elif "input type" in error_str and "weight type" in error_str:
2417 | logger.warning("Tensor type mismatch detected. Attempting explicit type conversion.")
2418 | # Try with explicit CPU conversion
2419 | return self.estimate_depth(
2420 | image.float().cpu(), model_name, input_size, blur_radius, median_size_str,
2421 | apply_auto_contrast, apply_gamma, True, True # Force CPU and reload
2422 | )
2423 |
2424 | # Dimension mismatch errors
2425 | elif "dimensions" in error_str or "dimension" in error_str or "shape" in error_str:
2426 | logger.warning("Tensor dimension mismatch detected. Trying alternate approach.")
2427 | # Fall back to CPU MiDaS model which has more robust dimension handling
2428 | logger.info("Falling back to MiDaS model on CPU")
2429 | try:
2430 | # Clean up current model
2431 | self.cleanup()
2432 | # Try to load MiDaS model
2433 | self.ensure_model_loaded("MiDaS-Small", True, True)
2434 | # Retry with the new model
2435 | return self.estimate_depth(
2436 | image.cpu(), "MiDaS-Small", input_size, blur_radius, median_size_str,
2437 | apply_auto_contrast, apply_gamma, False, True # Already reloaded, force CPU
2438 | )
2439 | except Exception as midas_error:
2440 | logger.error(f"MiDaS fallback also failed: {str(midas_error)}")
2441 | self._add_error_text_to_image(error_image, f"Inference Error: {str(inference_error)[:100]}...")
2442 | return (error_image,)
2443 |
2444 | # Other errors - just return error image
2445 | self._add_error_text_to_image(error_image, f"Inference Error: {str(inference_error)[:100]}...")
2446 | return (error_image,)
2447 |
2448 | # Verify depth result and convert to float32
2449 | if not isinstance(depth_result, dict) or "predicted_depth" not in depth_result:
2450 | logger.error(f"Invalid depth result format: {type(depth_result)}")
2451 | self._add_error_text_to_image(error_image, "Invalid depth result format")
2452 | return (error_image,)
2453 |
2454 | # Extract and validate predicted depth
2455 | predicted_depth = depth_result["predicted_depth"]
2456 |
2457 | # Ensure correct tensor type
2458 | if not torch.is_tensor(predicted_depth):
2459 | logger.error(f"Predicted depth is not a tensor: {type(predicted_depth)}")
2460 | self._add_error_text_to_image(error_image, "Predicted depth is not a tensor")
2461 | return (error_image,)
2462 |
2463 | # Convert to float32 if needed
2464 | if predicted_depth.dtype != torch.float32:
2465 | logger.info(f"Converting predicted depth from {predicted_depth.dtype} to float32")
2466 | predicted_depth = predicted_depth.float()
2467 |
2468 | # Convert to CPU for post-processing
2469 | depth_map = predicted_depth.squeeze().cpu().numpy()
2470 | except RuntimeError as rt_error:
2471 | # Handle runtime errors separately for clearer error messages
2472 | error_msg = str(rt_error)
2473 | logger.error(f"Runtime error during depth estimation: {error_msg}")
2474 |
2475 | # Check for specific error types
2476 | if "CUDA out of memory" in error_msg:
2477 | logger.warning("CUDA out of memory. Trying CPU fallback.")
2478 |
2479 | # Only try CPU fallback if not already using CPU
2480 | if not force_cpu:
2481 | try:
2482 | logger.info("Switching to CPU processing")
2483 | return self.estimate_depth(
2484 | image.cpu(), model_name, input_size, blur_radius, median_size_str,
2485 | apply_auto_contrast, apply_gamma, True, True # Force reload and CPU
2486 | )
2487 | except Exception as cpu_error:
2488 | logger.error(f"CPU fallback failed: {str(cpu_error)}")
2489 |
2490 | self._add_error_text_to_image(error_image, "CUDA Out of Memory. Try a smaller model or image size.")
2491 | else:
2492 | # Generic runtime error
2493 | self._add_error_text_to_image(error_image, f"Runtime Error: {error_msg[:100]}...")
2494 |
2495 | return (error_image,)
2496 | except Exception as e:
2497 | # Handle other exceptions
2498 | error_msg = f"Depth estimation failed: {str(e)}"
2499 | logger.error(error_msg)
2500 | logger.error(traceback.format_exc())
2501 | self._add_error_text_to_image(error_image, f"Error: {str(e)[:100]}...")
2502 | return (error_image,)
2503 |
2504 | # Validate depth map
2505 | # Check for NaN/Inf values
2506 | if np.isnan(depth_map).any() or np.isinf(depth_map).any():
2507 | logger.warning("Depth map contains NaN or Inf values. Replacing with zeros.")
2508 | depth_map = np.nan_to_num(depth_map, nan=0.0, posinf=1.0, neginf=0.0)
2509 |
2510 | # Check for empty or invalid depth map
2511 | if depth_map.size == 0:
2512 | logger.error("Depth map is empty")
2513 | self._add_error_text_to_image(error_image, "Empty depth map returned")
2514 | return (error_image,)
2515 |
2516 | # Post-processing with enhanced error handling
2517 | try:
2518 | # Ensure depth values have reasonable range for normalization
2519 | depth_min, depth_max = depth_map.min(), depth_map.max()
2520 |
2521 | # Handle constant depth maps (avoid division by zero)
2522 | if np.isclose(depth_max, depth_min):
2523 | logger.warning("Constant depth map detected (min = max). Using values directly.")
2524 | # Just use a normalized constant value
2525 | depth_map = np.ones_like(depth_map) * 0.5
2526 | else:
2527 | # Normalize to [0, 1] range first - safer for later operations
2528 | depth_map = (depth_map - depth_min) / (depth_max - depth_min)
2529 |
2530 | # Scale to [0, 255] for PIL operations
2531 | depth_map_uint8 = (depth_map * 255.0).astype(np.uint8)
2532 |
2533 | # Log the depth map shape coming from the model
2534 | logger.info(f"Depth map shape from model: {depth_map_uint8.shape}")
2535 |
2536 | # Create PIL image explicitly with L mode (grayscale)
2537 | try:
2538 | depth_pil = Image.fromarray(depth_map_uint8, mode='L')
2539 | logger.info(f"Depth PIL image size before resize: {depth_pil.size}")
2540 | except Exception as pil_error:
2541 | logger.error(f"Error creating PIL image: {str(pil_error)}")
2542 | # Try to reshape the array if dimensions are wrong
2543 | if len(depth_map_uint8.shape) != 2:
2544 | logger.warning(f"Unexpected depth map shape: {depth_map_uint8.shape}. Attempting to fix.")
2545 | # Try to extract first channel if multi-channel
2546 | if len(depth_map_uint8.shape) > 2:
2547 | depth_map_uint8 = depth_map_uint8[..., 0]
2548 | elif len(depth_map_uint8.shape) == 1:
2549 | # 1D array - try to reshape to 2D
2550 | h = int(np.sqrt(depth_map_uint8.size))
2551 | w = depth_map_uint8.size // h
2552 | depth_map_uint8 = depth_map_uint8.reshape(h, w)
2553 |
2554 | depth_pil = Image.fromarray(depth_map_uint8, mode='L')
2555 |
2556 | # Resize depth map to original dimensions from input image before post-processing
2557 | # This is the key fix for the resolution issue
2558 | try:
2559 | logger.info(f"Resizing depth map to original dimensions: {original_width}x{original_height}")
2560 | depth_pil = depth_pil.resize((original_width, original_height), Image.BICUBIC)
2561 | logger.info(f"Depth PIL image size after resize: {depth_pil.size}")
2562 | except Exception as resize_error:
2563 | logger.error(f"Error resizing depth map: {str(resize_error)}")
2564 | logger.error(traceback.format_exc())
2565 |
2566 | # Apply post-processing with parameter validation
2567 | # Apply blur if radius is positive
2568 | if blur_radius > 0:
2569 | try:
2570 | depth_pil = depth_pil.filter(ImageFilter.GaussianBlur(radius=blur_radius))
2571 | except Exception as blur_error:
2572 | logger.warning(f"Error applying blur: {str(blur_error)}. Skipping.")
2573 |
2574 | # Apply median filter if size is valid
2575 | try:
2576 | median_size_int = int(median_size_str)
2577 | if median_size_int > 0:
2578 | depth_pil = depth_pil.filter(ImageFilter.MedianFilter(size=median_size_int))
2579 | except Exception as median_error:
2580 | logger.warning(f"Error applying median filter: {str(median_error)}. Skipping.")
2581 |
2582 | # Apply auto contrast if requested
2583 | if apply_auto_contrast:
2584 | try:
2585 | depth_pil = ImageOps.autocontrast(depth_pil)
2586 | except Exception as contrast_error:
2587 | logger.warning(f"Error applying auto contrast: {str(contrast_error)}. Skipping.")
2588 |
2589 | # Apply gamma correction if requested
2590 | if apply_gamma:
2591 | try:
2592 | depth_array = np.array(depth_pil).astype(np.float32) / 255.0
2593 | mean_luminance = np.mean(depth_array)
2594 |
2595 | # Avoid division by zero or negative values
2596 | if mean_luminance > 0.001:
2597 | # Calculate gamma based on mean luminance for adaptive correction
2598 | gamma = np.log(0.5) / np.log(mean_luminance)
2599 |
2600 | # Clamp gamma to reasonable range to avoid extreme corrections
2601 | gamma = max(0.1, min(gamma, 3.0))
2602 | logger.info(f"Applying gamma correction with value: {gamma:.2f}")
2603 |
2604 | # Apply gamma correction
2605 | corrected = np.power(depth_array, 1.0/gamma) * 255.0
2606 | depth_pil = Image.fromarray(corrected.astype(np.uint8), mode='L')
2607 | else:
2608 | logger.warning(f"Mean luminance too low: {mean_luminance}. Skipping gamma correction.")
2609 | except Exception as gamma_error:
2610 | logger.warning(f"Error applying gamma correction: {str(gamma_error)}. Skipping.")
2611 |
2612 | # Convert processed image back to tensor
2613 | # Convert to numpy array for tensor conversion
2614 | depth_array = np.array(depth_pil).astype(np.float32) / 255.0
2615 |
2616 | # Final validation checks
2617 | # Check for invalid dimensions
2618 | h, w = depth_array.shape
2619 | if h <= 1 or w <= 1:
2620 | logger.error(f"Invalid depth map dimensions: {h}x{w}")
2621 | self._add_error_text_to_image(error_image, "Invalid depth map dimensions (too small)")
2622 | return (error_image,)
2623 |
2624 | # Log final dimensions for debugging
2625 | logger.info(f"Final depth map dimensions: {h}x{w}")
2626 |
2627 | # Create RGB depth map by stacking the same grayscale image three times
2628 | # Stack to create a 3-channel image compatible with ComfyUI
2629 | depth_rgb = np.stack([depth_array] * 3, axis=-1) # Shape becomes (h, w, 3)
2630 |
2631 | # Convert to tensor and add batch dimension
2632 | depth_tensor = torch.from_numpy(depth_rgb).unsqueeze(0).float() # Shape becomes (1, h, w, 3)
2633 |
2634 | # Optional: move to device if not using CPU
2635 | if self.device is not None and not force_cpu:
2636 | depth_tensor = depth_tensor.to(self.device)
2637 |
2638 | # Validate output tensor
2639 | if torch.isnan(depth_tensor).any() or torch.isinf(depth_tensor).any():
2640 | logger.warning("Final tensor contains NaN or Inf values. Fixing.")
2641 | depth_tensor = torch.nan_to_num(depth_tensor, nan=0.0, posinf=1.0, neginf=0.0)
2642 |
2643 | # Ensure values are in [0, 1] range
2644 | min_val, max_val = depth_tensor.min().item(), depth_tensor.max().item()
2645 | if min_val < 0.0 or max_val > 1.0:
2646 | logger.warning(f"Depth tensor values outside [0,1] range: min={min_val}, max={max_val}. Normalizing.")
2647 | depth_tensor = torch.clamp(depth_tensor, 0.0, 1.0)
2648 |
2649 | # Log completion info
2650 | processing_time = time.time() - start_time
2651 | logger.info(f"Depth processing completed in {processing_time:.2f} seconds")
2652 | logger.info(f"Output tensor: shape={depth_tensor.shape}, dtype={depth_tensor.dtype}, device={depth_tensor.device}")
2653 |
2654 | return (depth_tensor,)
2655 |
2656 | except Exception as post_error:
2657 | # Handle post-processing errors
2658 | error_msg = f"Error during depth map post-processing: {str(post_error)}"
2659 | logger.error(error_msg)
2660 | logger.error(traceback.format_exc())
2661 | self._add_error_text_to_image(error_image, f"Post-processing Error: {str(post_error)[:100]}...")
2662 | return (error_image,)
2663 |
2664 | except Exception as e:
2665 | # Global catch-all error handler
2666 | error_msg = f"Depth estimation failed: {str(e)}"
2667 | logger.error(error_msg)
2668 | logger.error(traceback.format_exc())
2669 |
2670 | # Create error image if needed
2671 | if error_image is None:
2672 | error_image = self._create_basic_error_image()
2673 |
2674 | self._add_error_text_to_image(error_image, f"Unexpected Error: {str(e)[:100]}...")
2675 | return (error_image,)
2676 | finally:
2677 | # Always clean up resources regardless of success or failure
2678 | torch.cuda.empty_cache()
2679 | gc.collect()
2680 |
2681 | def gamma_correction(self, img: Image.Image, gamma: float = 1.0) -> Image.Image:
2682 | """Applies gamma correction to the image."""
2683 | # Convert to numpy array
2684 | img_array = np.array(img)
2685 |
2686 | # Apply gamma correction directly with numpy
2687 | corrected = np.power(img_array.astype(np.float32) / 255.0, 1.0/gamma) * 255.0
2688 |
2689 | # Ensure uint8 type and create image with explicit mode
2690 | return Image.fromarray(corrected.astype(np.uint8), mode='L')
2691 |
2692 | # Node registration
2693 | NODE_CLASS_MAPPINGS = {
2694 | "DepthEstimationNode": DepthEstimationNode
2695 | }
2696 |
2697 | NODE_DISPLAY_NAME_MAPPINGS = {
2698 | "DepthEstimationNode": "Depth Estimation (V2)"
2699 | }
--------------------------------------------------------------------------------
/images/depth-estimation-icon.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Limbicnation/ComfyUIDepthEstimation/1c3035351a2269874d6200073b3bb20ac61f2513/images/depth-estimation-icon.png
--------------------------------------------------------------------------------
/images/depth-estimation-icon.svg:
--------------------------------------------------------------------------------
1 |
16 |
--------------------------------------------------------------------------------
/images/depth-estimation-logo-with-smaller-z.svg:
--------------------------------------------------------------------------------
1 |
23 |
--------------------------------------------------------------------------------
/images/depth-estimation-node.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Limbicnation/ComfyUIDepthEstimation/1c3035351a2269874d6200073b3bb20ac61f2513/images/depth-estimation-node.png
--------------------------------------------------------------------------------
/images/depth_map_generator_showcase.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Limbicnation/ComfyUIDepthEstimation/1c3035351a2269874d6200073b3bb20ac61f2513/images/depth_map_generator_showcase.jpg
--------------------------------------------------------------------------------
/publish.yaml:
--------------------------------------------------------------------------------
1 | name: Publish to Comfy registry
2 | on:
3 | workflow_dispatch:
4 | push:
5 | branches:
6 | - main
7 | - master
8 | paths:
9 | - "pyproject.toml"
10 |
11 | jobs:
12 | publish-node:
13 | name: Publish Custom Node to registry
14 | runs-on: ubuntu-latest
15 | steps:
16 | - name: Check out code
17 | uses: actions/checkout@v4
18 | - name: Publish Custom Node
19 | uses: Comfy-Org/publish-node-action@main
20 | with:
21 | personal_access_token: ${{ secrets.REGISTRY_ACCESS_TOKEN }}
22 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "comfyuidepthestimation"
3 | description = "A robust custom depth estimation node for ComfyUI using Depth-Anything models. It integrates depth estimation with configurable post-processing options including blur, median filtering, contrast enhancement, and gamma correction."
4 | version = "1.1.1"
5 | license = { file = "LICENSE" }
6 | dependencies = [
7 | "transformers>=4.20.0",
8 | "tokenizers>=0.13.3",
9 | "timm>=0.6.12",
10 | "huggingface-hub>=0.16.0",
11 | "protobuf==3.20.3",
12 | "requests>=2.27.0",
13 | "wget>=3.2"
14 | ]
15 |
16 | [project.urls]
17 | Repository = "https://github.com/Limbicnation/ComfyUIDepthEstimation"
18 | # Used by Comfy Registry https://comfyregistry.org
19 |
20 | [tool.comfy]
21 | PublisherId = "limbicnation"
22 | DisplayName = "Depth Estimation Node"
23 | Icon = "images/depth-estimation-icon.png"
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | # requirements.txt for ComfyUI-DepthEstimation Node
2 | # Note: These are minimum requirements. ComfyUI's environment may provide newer versions.
3 |
4 | # Fix for protobuf errors with transformers
5 | protobuf==3.20.3
6 |
7 | # Core dependencies
8 | tokenizers>=0.13.3 # Pre-built version compatible with most platforms
9 | transformers>=4.20.0 # Required for Depth Anything models, but ComfyUI may have a specific version
10 |
11 | # Pillow (PIL Fork) - Compatibility with other ComfyUI nodes
12 | # Don't specify Pillow version to avoid conflicts with ComfyUI environment
13 |
14 | # NumPy - Using version that properly supports numpy.dtypes
15 | # Don't specify version to avoid conflicts with ComfyUI environment
16 |
17 | # Additional dependencies specific to depth estimation node
18 | timm>=0.6.12 # Required for Depth Anything models
19 | huggingface-hub>=0.16.0 # For model downloading
20 | wget>=3.2 # For reliable model downloading
21 |
22 | # Torch version requirements
23 | # Note: PyTorch dependencies are handled by ComfyUI's core installation
24 | # If you're installing this node directly, ensure torch>=2.0.0 is available
25 |
26 | # Network dependencies
27 | requests>=2.27.0 # For model downloading
--------------------------------------------------------------------------------
/workflows/Depth_Map_Generator_V1.json:
--------------------------------------------------------------------------------
1 | {"last_node_id":11,"last_link_id":6,"nodes":[{"id":10,"type":"LoadImage","pos":[-67.18539428710938,435.0294494628906],"size":[315,314],"flags":{},"order":0,"mode":0,"inputs":[],"outputs":[{"name":"IMAGE","type":"IMAGE","links":[5],"slot_index":0},{"name":"MASK","type":"MASK","links":null}],"properties":{"cnr_id":"comfy-core","ver":"0.3.15","Node name for S&R":"LoadImage","enableTabs":false,"tabWidth":65,"tabXOffset":10,"hasSecondTab":false,"secondTabText":"Send Back","secondTabOffset":80,"secondTabWidth":65},"widgets_values":["Fluix_Redux_NeoTokyo_CyberPortrait_00043_.png","image"]},{"id":11,"type":"PreviewImage","pos":[807.515869140625,435.0294494628906],"size":[210,246],"flags":{},"order":2,"mode":0,"inputs":[{"name":"images","type":"IMAGE","link":6}],"outputs":[],"properties":{"cnr_id":"comfy-core","ver":"0.3.15","Node name for S&R":"PreviewImage","enableTabs":false,"tabWidth":65,"tabXOffset":10,"hasSecondTab":false,"secondTabText":"Send Back","secondTabOffset":80,"secondTabWidth":65},"widgets_values":[]},{"id":9,"type":"DepthEstimationNode","pos":[370.5541076660156,435.0294494628906],"size":[315,154],"flags":{},"order":1,"mode":0,"inputs":[{"name":"image","type":"IMAGE","link":5}],"outputs":[{"name":"IMAGE","type":"IMAGE","links":[6],"slot_index":0}],"properties":{"aux_id":"Limbicnation/ComfyUIDepthEstimation","ver":"1fac38f691c69b81dc7b5f5e0c83bf61583220be","Node name for S&R":"DepthEstimationNode","enableTabs":false,"tabWidth":65,"tabXOffset":10,"hasSecondTab":false,"secondTabText":"Send Back","secondTabOffset":80,"secondTabWidth":65,"cnr_id":"comfyuidepthestimation"},"widgets_values":["Depth-Anything-V2-Base",0.5,"3",true,true]}],"links":[[5,10,0,9,0,"IMAGE"],[6,9,0,11,0,"IMAGE"]],"groups":[],"config":{},"extra":{"ds":{"scale":2.593742460100004,"offset":{"0":118.01049041748047,"1":-242.42494201660156}},"ue_links":[]},"version":0.4}
--------------------------------------------------------------------------------
/workflows/Depth_Map_Generator_V1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Limbicnation/ComfyUIDepthEstimation/1c3035351a2269874d6200073b3bb20ac61f2513/workflows/Depth_Map_Generator_V1.png
--------------------------------------------------------------------------------