├── .gitignore ├── LICENSE ├── README.md ├── assets ├── dog.png └── figures │ ├── badge-website.svg │ ├── dfm-cover.png │ ├── radio.png │ └── sota-comparison.jpg ├── checkpoints └── README.md ├── depthfm ├── __init__.py ├── dfm.py └── unet │ ├── __init__.py │ ├── attention.py │ ├── openaimodel.py │ └── util.py ├── environment.yml ├── inference.ipynb ├── inference.py └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | *__pycache__* 2 | sandbox 3 | *.ckpt 4 | *-depth.png 5 | evaluation -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 CompVis - Computer Vision and Learning LMU Munich 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 |
5 | Ming Gui* · Johannes Schusterbauer* · Ulrich Prestel · Pingchuan Ma 6 |
7 | Dmytro Kotovenko · Olga Grebenkova · Stefan A. Baumann · Vincent Tao Hu · Björn Ommer 8 |
9 |10 | CompVis Group @ LMU Munich 11 |
12 |13 | AAAI 2025 Oral 14 |
15 |* equal contribution
16 | 17 | 18 | 19 | 20 | [](https://depthfm.github.io) 21 | [](https://arxiv.org/abs/2403.13788) 22 | 23 | 24 |  25 | 26 | 27 | ## 📻 Overview 28 | 29 | We present **DepthFM**, a state-of-the-art, versatile, and fast monocular depth estimation model. DepthFM is efficient and can synthesize realistic depth maps within *a single inference* step. Beyond conventional depth estimation tasks, DepthFM also demonstrates state-of-the-art capabilities in downstream tasks such as depth inpainting and depth conditional synthesis. 30 | 31 | With our work we demonstrate the successful transfer of strong image priors from a foundation image synthesis diffusion model (Stable Diffusion v2-1) to a flow matching model. Instead of starting from noise, we directly map from input image to depth map. 32 | 33 | 34 | ## 🛠️ Setup 35 | 36 | This setup was tested with `Ubuntu 22.04.4 LTS`, `CUDA Version: 12.4`, and `Python 3.10.12`. 37 | 38 | First, clone the github repo... 39 | 40 | ```bash 41 | git clone git@github.com:CompVis/depth-fm.git 42 | cd depth-fm 43 | ``` 44 | 45 | Then download the weights via 46 | 47 | ```bash 48 | wget https://ommer-lab.com/files/depthfm/depthfm-v1.ckpt -P checkpoints/ 49 | ``` 50 | 51 | Now you have either the option to setup a virtual environment and install all required packages with `pip` 52 | 53 | ```bash 54 | pip install -r requirements.txt 55 | ``` 56 | 57 | or if you prefer to use `conda` create the conda environment via 58 | 59 | ```bash 60 | conda env create -f environment.yml 61 | ``` 62 | 63 | Now you should be able to listen to DepthFM! 📻 🎶 64 | 65 | 66 | ## 🚀 Usage 67 | 68 | You can either use the notebook `inference.ipynb` or just run the python script `inference.py` as follows 69 | 70 | ```bash 71 | python inference.py \ 72 | --num_steps 2 \ 73 | --ensemble_size 4 \ 74 | --img assets/dog.png \ 75 | --ckpt checkpoints/depthfm-v1.ckpt 76 | ``` 77 | 78 | The argument `--num_steps` allows you to set the number of function evaluations. We find that our model already gives very good results with as few as one or two steps. Ensembling also improves performance, so you can set it via the `--ensemble_size` argument. Currently, the inference code only supports a batch size of one for ensembling. 79 | 80 | ## 📈 Results 81 | 82 | Our quantitative analysis shows that despite being substantially more efficient, our DepthFM performs on-par or even outperforms the current state-of-the-art generative depth estimator Marigold **zero-shot** on a range of benchmark datasets. Below you can find a quantitative comparison of DepthFM against other affine-invariant depth estimators on several benchmarks. 83 | 84 |  85 | 86 | 87 | 88 | ## Trend 89 | 90 | [](https://star-history.com/#CompVis/depth-fm&Date) 91 | 92 | 93 | 94 | 95 | ## 🎓 Citation 96 | 97 | Please cite our paper: 98 | 99 | ```bibtex 100 | @misc{gui2024depthfm, 101 | title={DepthFM: Fast Monocular Depth Estimation with Flow Matching}, 102 | author={Ming Gui and Johannes Schusterbauer and Ulrich Prestel and Pingchuan Ma and Dmytro Kotovenko and Olga Grebenkova and Stefan Andreas Baumann and Vincent Tao Hu and Björn Ommer}, 103 | year={2024}, 104 | eprint={2403.13788}, 105 | archivePrefix={arXiv}, 106 | primaryClass={cs.CV} 107 | } 108 | ``` 109 | -------------------------------------------------------------------------------- /assets/dog.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/depth-fm/da632c8118e41cc64ca191a13314e2db1a7722d9/assets/dog.png -------------------------------------------------------------------------------- /assets/figures/badge-website.svg: -------------------------------------------------------------------------------- 1 | 2 | 130 | -------------------------------------------------------------------------------- /assets/figures/dfm-cover.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/depth-fm/da632c8118e41cc64ca191a13314e2db1a7722d9/assets/figures/dfm-cover.png -------------------------------------------------------------------------------- /assets/figures/radio.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/depth-fm/da632c8118e41cc64ca191a13314e2db1a7722d9/assets/figures/radio.png -------------------------------------------------------------------------------- /assets/figures/sota-comparison.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/depth-fm/da632c8118e41cc64ca191a13314e2db1a7722d9/assets/figures/sota-comparison.jpg -------------------------------------------------------------------------------- /checkpoints/README.md: -------------------------------------------------------------------------------- 1 | Download the weights in this specific folder via 2 | 3 | ```bash 4 | wget https://ommer-lab.com/files/depthfm/depthfm-v1.ckpt 5 | ``` -------------------------------------------------------------------------------- /depthfm/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.append(os.path.dirname(os.path.abspath(__file__))) 4 | from dfm import DepthFM 5 | from unet import UNetModel 6 | -------------------------------------------------------------------------------- /depthfm/dfm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import einops 3 | import numpy as np 4 | import torch.nn as nn 5 | from torch import Tensor 6 | from functools import partial 7 | from torchdiffeq import odeint 8 | 9 | from unet import UNetModel 10 | from diffusers import AutoencoderKL 11 | 12 | 13 | def exists(val): 14 | return val is not None 15 | 16 | 17 | class DepthFM(nn.Module): 18 | def __init__(self, ckpt_path: str): 19 | super().__init__() 20 | vae_id = "runwayml/stable-diffusion-v1-5" 21 | self.vae = AutoencoderKL.from_pretrained(vae_id, subfolder="vae") 22 | self.scale_factor = 0.18215 23 | 24 | # set with checkpoint 25 | ckpt = torch.load(ckpt_path, map_location="cpu") 26 | self.noising_step = ckpt['noising_step'] 27 | self.empty_text_embed = ckpt['empty_text_embedding'] 28 | self.model = UNetModel(**ckpt['ldm_hparams']) 29 | self.model.load_state_dict(ckpt['state_dict']) 30 | 31 | def ode_fn(self, t: Tensor, x: Tensor, **kwargs): 32 | if t.numel() == 1: 33 | t = t.expand(x.size(0)) 34 | return self.model(x=x, t=t, **kwargs) 35 | 36 | def generate(self, z: Tensor, num_steps: int = 4, n_intermediates: int = 0, **kwargs): 37 | """ 38 | ODE solving from z0 (ims) to z1 (depth). 39 | """ 40 | ode_kwargs = dict(method="euler", rtol=1e-5, atol=1e-5, options=dict(step_size=1.0 / num_steps)) 41 | 42 | # t specifies which intermediate times should the solver return 43 | # e.g. t = [0, 0.5, 1] means return the solution at t=0, t=0.5 and t=1 44 | # but it also specifies the number of steps for fixed step size methods 45 | t = torch.linspace(0, 1, n_intermediates + 2, device=z.device, dtype=z.dtype) 46 | # t = torch.tensor([0., 1.], device=z.device, dtype=z.dtype) 47 | 48 | # allow conditioning information for model 49 | ode_fn = partial(self.ode_fn, **kwargs) 50 | 51 | ode_results = odeint(ode_fn, z, t, **ode_kwargs) 52 | 53 | if n_intermediates > 0: 54 | return ode_results 55 | return ode_results[-1] 56 | 57 | def forward(self, ims: Tensor, num_steps: int = 4, ensemble_size: int = 1): 58 | """ 59 | Args: 60 | ims: Tensor of shape (b, 3, h, w) in range [-1, 1] 61 | Returns: 62 | depth: Tensor of shape (b, 1, h, w) in range [0, 1] 63 | """ 64 | if ensemble_size > 1: 65 | assert ims.shape[0] == 1, "Ensemble mode only supported with batch size 1" 66 | ims = ims.repeat(ensemble_size, 1, 1, 1) 67 | 68 | bs, dev = ims.shape[0], ims.device 69 | 70 | ims_z = self.encode(ims, sample_posterior=False) 71 | 72 | conditioning = torch.tensor(self.empty_text_embed).to(dev).repeat(bs, 1, 1) 73 | context = ims_z 74 | 75 | x_source = ims_z 76 | 77 | if self.noising_step > 0: 78 | x_source = q_sample(x_source, self.noising_step) 79 | 80 | # solve ODE 81 | depth_z = self.generate(x_source, num_steps=num_steps, context=context, context_ca=conditioning) 82 | 83 | depth = self.decode(depth_z) 84 | depth = depth.mean(dim=1, keepdim=True) 85 | 86 | if ensemble_size > 1: 87 | depth = depth.mean(dim=0, keepdim=True) 88 | 89 | # normalize depth maps to range [-1, 1] 90 | depth = per_sample_min_max_normalization(depth.exp()) 91 | 92 | return depth 93 | 94 | @torch.no_grad() 95 | def predict_depth(self, ims: Tensor, num_steps: int = 4, ensemble_size: int = 1): 96 | """ Inference method for DepthFM. """ 97 | return self.forward(ims, num_steps, ensemble_size) 98 | 99 | @torch.no_grad() 100 | def encode(self, x: Tensor, sample_posterior: bool = True): 101 | posterior = self.vae.encode(x) 102 | if sample_posterior: 103 | z = posterior.latent_dist.sample() 104 | else: 105 | z = posterior.latent_dist.mode() 106 | # normalize latent code 107 | z = z * self.scale_factor 108 | return z 109 | 110 | @torch.no_grad() 111 | def decode(self, z: Tensor): 112 | z = 1.0 / self.scale_factor * z 113 | return self.vae.decode(z).sample 114 | 115 | 116 | def sigmoid(x): 117 | return 1 / (1 + np.exp(-x)) 118 | 119 | 120 | def cosine_log_snr(t, eps=0.00001): 121 | """ 122 | Returns log Signal-to-Noise ratio for time step t and image size 64 123 | eps: avoid division by zero 124 | """ 125 | return -2 * np.log(np.tan((np.pi * t) / 2) + eps) 126 | 127 | 128 | def cosine_alpha_bar(t): 129 | return sigmoid(cosine_log_snr(t)) 130 | 131 | 132 | def q_sample(x_start: torch.Tensor, t: int, noise: torch.Tensor = None, n_diffusion_timesteps: int = 1000): 133 | """ 134 | Diffuse the data for a given number of diffusion steps. In other 135 | words sample from q(x_t | x_0). 136 | """ 137 | dev = x_start.device 138 | dtype = x_start.dtype 139 | 140 | if noise is None: 141 | noise = torch.randn_like(x_start) 142 | 143 | alpha_bar_t = cosine_alpha_bar(t / n_diffusion_timesteps) 144 | alpha_bar_t = torch.tensor(alpha_bar_t).to(dev).to(dtype) 145 | 146 | return torch.sqrt(alpha_bar_t) * x_start + torch.sqrt(1 - alpha_bar_t) * noise 147 | 148 | 149 | def per_sample_min_max_normalization(x): 150 | """ Normalize each sample in a batch independently 151 | with min-max normalization to [0, 1] """ 152 | bs, *shape = x.shape 153 | x_ = einops.rearrange(x, "b ... -> b (...)") 154 | min_val = einops.reduce(x_, "b ... -> b", "min")[..., None] 155 | max_val = einops.reduce(x_, "b ... -> b", "max")[..., None] 156 | x_ = (x_ - min_val) / (max_val - min_val) 157 | return x_.reshape(bs, *shape) 158 | -------------------------------------------------------------------------------- /depthfm/unet/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.append(os.path.dirname(os.path.abspath(__file__))) 4 | from openaimodel import UNetModel -------------------------------------------------------------------------------- /depthfm/unet/attention.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import nn 4 | from einops import rearrange 5 | from inspect import isfunction 6 | import torch.nn.functional as F 7 | from typing import Optional, Any 8 | 9 | from util import checkpoint 10 | 11 | 12 | try: 13 | import xformers 14 | import xformers.ops 15 | XFORMERS_IS_AVAILBLE = True 16 | except: 17 | print("WARNING: xformers is not available, inference might be slow.") 18 | XFORMERS_IS_AVAILBLE = False 19 | 20 | # CrossAttn precision handling 21 | import os 22 | 23 | _ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32") 24 | 25 | 26 | def exists(val): 27 | return val is not None 28 | 29 | 30 | def uniq(arr): 31 | return {el: True for el in arr}.keys() 32 | 33 | 34 | def default(val, d): 35 | if exists(val): 36 | return val 37 | return d() if isfunction(d) else d 38 | 39 | 40 | def max_neg_value(t): 41 | return -torch.finfo(t.dtype).max 42 | 43 | 44 | def init_(tensor): 45 | dim = tensor.shape[-1] 46 | std = 1 / math.sqrt(dim) 47 | tensor.uniform_(-std, std) 48 | return tensor 49 | 50 | 51 | # feedforward 52 | class GEGLU(nn.Module): 53 | def __init__(self, dim_in, dim_out): 54 | super().__init__() 55 | self.proj = nn.Linear(dim_in, dim_out * 2) 56 | 57 | def forward(self, x): 58 | x, gate = self.proj(x).chunk(2, dim=-1) 59 | return x * F.gelu(gate) 60 | 61 | 62 | class FeedForward(nn.Module): 63 | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0): 64 | super().__init__() 65 | inner_dim = int(dim * mult) 66 | dim_out = default(dim_out, dim) 67 | project_in = ( 68 | nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) 69 | if not glu 70 | else GEGLU(dim, inner_dim) 71 | ) 72 | 73 | self.net = nn.Sequential( 74 | project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out) 75 | ) 76 | 77 | def forward(self, x): 78 | return self.net(x) 79 | 80 | 81 | def zero_module(module): 82 | """ 83 | Zero out the parameters of a module and return it. 84 | """ 85 | for p in module.parameters(): 86 | p.detach().zero_() 87 | return module 88 | 89 | 90 | def Normalize(in_channels): 91 | return torch.nn.GroupNorm( 92 | num_groups=32, num_channels=in_channels, eps=1e-6, affine=True 93 | ) 94 | 95 | 96 | class SpatialSelfAttention(nn.Module): 97 | def __init__(self, in_channels): 98 | super().__init__() 99 | self.in_channels = in_channels 100 | 101 | self.norm = Normalize(in_channels) 102 | self.q = torch.nn.Conv2d( 103 | in_channels, in_channels, kernel_size=1, stride=1, padding=0 104 | ) 105 | self.k = torch.nn.Conv2d( 106 | in_channels, in_channels, kernel_size=1, stride=1, padding=0 107 | ) 108 | self.v = torch.nn.Conv2d( 109 | in_channels, in_channels, kernel_size=1, stride=1, padding=0 110 | ) 111 | self.proj_out = torch.nn.Conv2d( 112 | in_channels, in_channels, kernel_size=1, stride=1, padding=0 113 | ) 114 | 115 | def forward(self, x): 116 | h_ = x 117 | h_ = self.norm(h_) 118 | q = self.q(h_) 119 | k = self.k(h_) 120 | v = self.v(h_) 121 | 122 | # compute attention 123 | b, c, h, w = q.shape 124 | q = rearrange(q, "b c h w -> b (h w) c") 125 | k = rearrange(k, "b c h w -> b c (h w)") 126 | w_ = torch.einsum("bij,bjk->bik", q, k) 127 | 128 | w_ = w_ * (int(c) ** (-0.5)) 129 | w_ = torch.nn.functional.softmax(w_, dim=2) 130 | 131 | # attend to values 132 | v = rearrange(v, "b c h w -> b c (h w)") 133 | w_ = rearrange(w_, "b i j -> b j i") 134 | h_ = torch.einsum("bij,bjk->bik", v, w_) 135 | h_ = rearrange(h_, "b c (h w) -> b c h w", h=h) 136 | h_ = self.proj_out(h_) 137 | 138 | return x + h_ 139 | 140 | 141 | class CrossAttention(nn.Module): 142 | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): 143 | super().__init__() 144 | inner_dim = dim_head * heads 145 | context_dim = default(context_dim, query_dim) 146 | 147 | self.dim_head = dim_head 148 | 149 | self.scale = dim_head**-0.5 150 | self.heads = heads 151 | 152 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False) 153 | self.to_k = nn.Linear(context_dim, inner_dim, bias=False) 154 | self.to_v = nn.Linear(context_dim, inner_dim, bias=False) 155 | 156 | self.to_out = nn.Sequential( 157 | nn.Linear(inner_dim, query_dim), nn.Dropout(dropout) 158 | ) 159 | 160 | def forward(self, x, context=None, mask=None, rescale_attention=True): 161 | 162 | is_self_attention = context is None 163 | 164 | n_tokens = x.shape[1] 165 | 166 | h = self.heads 167 | 168 | q = self.to_q(x) 169 | context = default(context, x) 170 | k = self.to_k(context) 171 | v = self.to_v(context) 172 | 173 | q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v)) 174 | 175 | if rescale_attention: 176 | out = F.scaled_dot_product_attention(q, k, v, scale=(math.log(n_tokens) / math.log(n_tokens*4) / self.dim_head)**0.5 if is_self_attention else None) 177 | else: 178 | out = F.scaled_dot_product_attention(q, k, v) 179 | 180 | out = rearrange(out, "(b h) n d -> b n (h d)", h=h) 181 | return self.to_out(out) 182 | 183 | 184 | class MemoryEfficientCrossAttention(nn.Module): 185 | # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 186 | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): 187 | super().__init__() 188 | # print( 189 | # f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using " 190 | # f"{heads} heads." 191 | # ) 192 | inner_dim = dim_head * heads 193 | context_dim = default(context_dim, query_dim) 194 | 195 | self.heads = heads 196 | self.dim_head = dim_head 197 | 198 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False) 199 | self.to_k = nn.Linear(context_dim, inner_dim, bias=False) 200 | self.to_v = nn.Linear(context_dim, inner_dim, bias=False) 201 | 202 | self.to_out = nn.Sequential( 203 | nn.Linear(inner_dim, query_dim), nn.Dropout(dropout) 204 | ) 205 | self.attention_op: Optional[Any] = None 206 | 207 | def forward(self, x, context=None, mask=None): 208 | q = self.to_q(x) 209 | context = default(context, x) 210 | k = self.to_k(context) 211 | v = self.to_v(context) 212 | 213 | b, _, _ = q.shape 214 | q, k, v = map( 215 | lambda t: t.unsqueeze(3) 216 | .reshape(b, t.shape[1], self.heads, self.dim_head) 217 | .permute(0, 2, 1, 3) 218 | .reshape(b * self.heads, t.shape[1], self.dim_head) 219 | .contiguous(), 220 | (q, k, v), 221 | ) 222 | 223 | # actually compute the attention, what we cannot get enough of 224 | out = xformers.ops.memory_efficient_attention( 225 | q, k, v, attn_bias=None, op=self.attention_op 226 | ) 227 | 228 | if exists(mask): 229 | raise NotImplementedError 230 | out = ( 231 | out.unsqueeze(0) 232 | .reshape(b, self.heads, out.shape[1], self.dim_head) 233 | .permute(0, 2, 1, 3) 234 | .reshape(b, out.shape[1], self.heads * self.dim_head) 235 | ) 236 | return self.to_out(out) 237 | 238 | 239 | class BasicTransformerBlock(nn.Module): 240 | ATTENTION_MODES = { 241 | "softmax": CrossAttention, # vanilla attention 242 | "softmax-xformers": MemoryEfficientCrossAttention, 243 | } 244 | 245 | def __init__( 246 | self, 247 | dim, 248 | n_heads, 249 | d_head, 250 | dropout=0.0, 251 | context_dim=None, 252 | gated_ff=True, 253 | checkpoint=True, 254 | disable_self_attn=False, 255 | ): 256 | super().__init__() 257 | attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax" 258 | assert attn_mode in self.ATTENTION_MODES 259 | attn_cls = self.ATTENTION_MODES[attn_mode] 260 | self.disable_self_attn = disable_self_attn 261 | self.attn1 = attn_cls( 262 | query_dim=dim, 263 | heads=n_heads, 264 | dim_head=d_head, 265 | dropout=dropout, 266 | context_dim=context_dim if self.disable_self_attn else None, 267 | ) # is a self-attention if not self.disable_self_attn 268 | self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) 269 | self.attn2 = attn_cls( 270 | query_dim=dim, 271 | context_dim=context_dim, 272 | heads=n_heads, 273 | dim_head=d_head, 274 | dropout=dropout, 275 | ) # is self-attn if context is none 276 | self.norm1 = nn.LayerNorm(dim) 277 | self.norm2 = nn.LayerNorm(dim) 278 | self.norm3 = nn.LayerNorm(dim) 279 | self.checkpoint = checkpoint 280 | 281 | def forward(self, x, context=None): 282 | return checkpoint( 283 | self._forward, (x, context), self.parameters(), self.checkpoint 284 | ) 285 | 286 | def _forward(self, x, context=None): 287 | x = ( 288 | self.attn1( 289 | self.norm1(x), context=context if self.disable_self_attn else None 290 | ) 291 | + x 292 | ) 293 | x = self.attn2(self.norm2(x), context=context) + x 294 | x = self.ff(self.norm3(x)) + x 295 | return x 296 | 297 | 298 | class SpatialTransformer(nn.Module): 299 | """ 300 | Transformer block for image-like data. 301 | First, project the input (aka embedding) 302 | and reshape to b, t, d. 303 | Then apply standard transformer action. 304 | Finally, reshape to image 305 | NEW: use_linear for more efficiency instead of the 1x1 convs 306 | """ 307 | 308 | def __init__( 309 | self, 310 | in_channels, 311 | n_heads, 312 | d_head, 313 | depth=1, 314 | dropout=0.0, 315 | context_dim=None, 316 | disable_self_attn=False, 317 | use_linear=False, 318 | use_checkpoint=True, 319 | ): 320 | super().__init__() 321 | if exists(context_dim) and not isinstance(context_dim, list): 322 | context_dim = [context_dim] 323 | self.in_channels = in_channels 324 | inner_dim = n_heads * d_head 325 | self.norm = Normalize(in_channels) 326 | if not use_linear: 327 | self.proj_in = nn.Conv2d( 328 | in_channels, inner_dim, kernel_size=1, stride=1, padding=0 329 | ) 330 | else: 331 | self.proj_in = nn.Linear(in_channels, inner_dim) 332 | 333 | self.transformer_blocks = nn.ModuleList( 334 | [ 335 | BasicTransformerBlock( 336 | inner_dim, 337 | n_heads, 338 | d_head, 339 | dropout=dropout, 340 | context_dim=context_dim[d], 341 | disable_self_attn=disable_self_attn, 342 | checkpoint=use_checkpoint, 343 | ) 344 | for d in range(depth) 345 | ] 346 | ) 347 | if not use_linear: 348 | self.proj_out = zero_module( 349 | nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) 350 | ) 351 | else: 352 | self.proj_out = zero_module(nn.Linear(in_channels, inner_dim)) 353 | self.use_linear = use_linear 354 | 355 | def forward(self, x, context=None): 356 | # note: if no context is given, cross-attention defaults to self-attention 357 | if not isinstance(context, list): 358 | context = [context] 359 | b, c, h, w = x.shape 360 | x_in = x 361 | x = self.norm(x) 362 | if not self.use_linear: 363 | x = self.proj_in(x) 364 | x = rearrange(x, "b c h w -> b (h w) c").contiguous() 365 | if self.use_linear: 366 | x = self.proj_in(x) 367 | for i, block in enumerate(self.transformer_blocks): 368 | x = block(x, context=context[i]) 369 | if self.use_linear: 370 | x = self.proj_out(x) 371 | x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous() 372 | if not self.use_linear: 373 | x = self.proj_out(x) 374 | return x + x_in 375 | -------------------------------------------------------------------------------- /depthfm/unet/openaimodel.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import torch as th 4 | import torch.nn as nn 5 | from abc import abstractmethod 6 | import torch.nn.functional as F 7 | 8 | from util import ( 9 | checkpoint, 10 | conv_nd, 11 | linear, 12 | avg_pool_nd, 13 | zero_module, 14 | normalization, 15 | timestep_embedding, 16 | ) 17 | from attention import SpatialTransformer 18 | 19 | 20 | def exists(x): 21 | return x is not None 22 | 23 | # dummy replace 24 | def convert_module_to_f16(x): 25 | pass 26 | 27 | def convert_module_to_f32(x): 28 | pass 29 | 30 | 31 | ## go 32 | class AttentionPool2d(nn.Module): 33 | """ 34 | Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py 35 | """ 36 | 37 | def __init__( 38 | self, 39 | spacial_dim: int, 40 | embed_dim: int, 41 | num_heads_channels: int, 42 | output_dim: int = None, 43 | ): 44 | super().__init__() 45 | self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5) 46 | self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) 47 | self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) 48 | self.num_heads = embed_dim // num_heads_channels 49 | self.attention = QKVAttention(self.num_heads) 50 | 51 | def forward(self, x): 52 | b, c, *_spatial = x.shape 53 | x = x.reshape(b, c, -1) # NC(HW) 54 | x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1) 55 | x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1) 56 | x = self.qkv_proj(x) 57 | x = self.attention(x) 58 | x = self.c_proj(x) 59 | return x[:, :, 0] 60 | 61 | 62 | class TimestepBlock(nn.Module): 63 | """ 64 | Any module where forward() takes timestep embeddings as a second argument. 65 | """ 66 | 67 | @abstractmethod 68 | def forward(self, x, emb): 69 | """ 70 | Apply the module to `x` given `emb` timestep embeddings. 71 | """ 72 | 73 | 74 | class TimestepEmbedSequential(nn.Sequential, TimestepBlock): 75 | """ 76 | A sequential module that passes timestep embeddings to the children that 77 | support it as an extra input. 78 | """ 79 | 80 | def forward(self, x, emb, context=None): 81 | for layer in self: 82 | if isinstance(layer, TimestepBlock): 83 | x = layer(x, emb) 84 | elif isinstance(layer, SpatialTransformer): 85 | x = layer(x, context) 86 | else: 87 | x = layer(x) 88 | return x 89 | 90 | 91 | class Upsample(nn.Module): 92 | """ 93 | An upsampling layer with an optional convolution. 94 | :param channels: channels in the inputs and outputs. 95 | :param use_conv: a bool determining if a convolution is applied. 96 | :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then 97 | upsampling occurs in the inner-two dimensions. 98 | """ 99 | 100 | def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): 101 | super().__init__() 102 | self.channels = channels 103 | self.out_channels = out_channels or channels 104 | self.use_conv = use_conv 105 | self.dims = dims 106 | if use_conv: 107 | self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding) 108 | 109 | def forward(self, x): 110 | assert x.shape[1] == self.channels 111 | if self.dims == 3: 112 | x = F.interpolate( 113 | x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" 114 | ) 115 | else: 116 | x = F.interpolate(x, scale_factor=2, mode="nearest") 117 | if self.use_conv: 118 | x = self.conv(x) 119 | return x 120 | 121 | class TransposedUpsample(nn.Module): 122 | 'Learned 2x upsampling without padding' 123 | def __init__(self, channels, out_channels=None, ks=5): 124 | super().__init__() 125 | self.channels = channels 126 | self.out_channels = out_channels or channels 127 | 128 | self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2) 129 | 130 | def forward(self,x): 131 | return self.up(x) 132 | 133 | 134 | class Downsample(nn.Module): 135 | """ 136 | A downsampling layer with an optional convolution. 137 | :param channels: channels in the inputs and outputs. 138 | :param use_conv: a bool determining if a convolution is applied. 139 | :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then 140 | downsampling occurs in the inner-two dimensions. 141 | """ 142 | 143 | def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1): 144 | super().__init__() 145 | self.channels = channels 146 | self.out_channels = out_channels or channels 147 | self.use_conv = use_conv 148 | self.dims = dims 149 | stride = 2 if dims != 3 else (1, 2, 2) 150 | if use_conv: 151 | self.op = conv_nd( 152 | dims, self.channels, self.out_channels, 3, stride=stride, padding=padding 153 | ) 154 | else: 155 | assert self.channels == self.out_channels 156 | self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) 157 | 158 | def forward(self, x): 159 | assert x.shape[1] == self.channels 160 | return self.op(x) 161 | 162 | 163 | class ResBlock(TimestepBlock): 164 | """ 165 | A residual block that can optionally change the number of channels. 166 | :param channels: the number of input channels. 167 | :param emb_channels: the number of timestep embedding channels. 168 | :param dropout: the rate of dropout. 169 | :param out_channels: if specified, the number of out channels. 170 | :param use_conv: if True and out_channels is specified, use a spatial 171 | convolution instead of a smaller 1x1 convolution to change the 172 | channels in the skip connection. 173 | :param dims: determines if the signal is 1D, 2D, or 3D. 174 | :param use_checkpoint: if True, use gradient checkpointing on this module. 175 | :param up: if True, use this block for upsampling. 176 | :param down: if True, use this block for downsampling. 177 | """ 178 | 179 | def __init__( 180 | self, 181 | channels, 182 | emb_channels, 183 | dropout, 184 | out_channels=None, 185 | use_conv=False, 186 | use_scale_shift_norm=False, 187 | dims=2, 188 | use_checkpoint=False, 189 | up=False, 190 | down=False, 191 | ): 192 | super().__init__() 193 | self.channels = channels 194 | self.emb_channels = emb_channels 195 | self.dropout = dropout 196 | self.out_channels = out_channels or channels 197 | self.use_conv = use_conv 198 | self.use_checkpoint = use_checkpoint 199 | self.use_scale_shift_norm = use_scale_shift_norm 200 | 201 | self.in_layers = nn.Sequential( 202 | normalization(channels), 203 | nn.SiLU(), 204 | conv_nd(dims, channels, self.out_channels, 3, padding=1), 205 | ) 206 | 207 | self.updown = up or down 208 | 209 | if up: 210 | self.h_upd = Upsample(channels, False, dims) 211 | self.x_upd = Upsample(channels, False, dims) 212 | elif down: 213 | self.h_upd = Downsample(channels, False, dims) 214 | self.x_upd = Downsample(channels, False, dims) 215 | else: 216 | self.h_upd = self.x_upd = nn.Identity() 217 | 218 | self.emb_layers = nn.Sequential( 219 | nn.SiLU(), 220 | linear( 221 | emb_channels, 222 | 2 * self.out_channels if use_scale_shift_norm else self.out_channels, 223 | ), 224 | ) 225 | self.out_layers = nn.Sequential( 226 | normalization(self.out_channels), 227 | nn.SiLU(), 228 | nn.Dropout(p=dropout), 229 | zero_module( 230 | conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) 231 | ), 232 | ) 233 | 234 | if self.out_channels == channels: 235 | self.skip_connection = nn.Identity() 236 | elif use_conv: 237 | self.skip_connection = conv_nd( 238 | dims, channels, self.out_channels, 3, padding=1 239 | ) 240 | else: 241 | self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) 242 | 243 | def forward(self, x, emb): 244 | """ 245 | Apply the block to a Tensor, conditioned on a timestep embedding. 246 | :param x: an [N x C x ...] Tensor of features. 247 | :param emb: an [N x emb_channels] Tensor of timestep embeddings. 248 | :return: an [N x C x ...] Tensor of outputs. 249 | """ 250 | return checkpoint( 251 | self._forward, (x, emb), self.parameters(), self.use_checkpoint 252 | ) 253 | 254 | 255 | def _forward(self, x, emb): 256 | if self.updown: 257 | in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] 258 | h = in_rest(x) 259 | h = self.h_upd(h) 260 | x = self.x_upd(x) 261 | h = in_conv(h) 262 | else: 263 | h = self.in_layers(x) 264 | emb_out = self.emb_layers(emb).type(h.dtype) 265 | while len(emb_out.shape) < len(h.shape): 266 | emb_out = emb_out[..., None] 267 | if self.use_scale_shift_norm: 268 | out_norm, out_rest = self.out_layers[0], self.out_layers[1:] 269 | scale, shift = th.chunk(emb_out, 2, dim=1) 270 | h = out_norm(h) * (1 + scale) + shift 271 | h = out_rest(h) 272 | else: 273 | h = h + emb_out 274 | h = self.out_layers(h) 275 | return self.skip_connection(x) + h 276 | 277 | 278 | class AttentionBlock(nn.Module): 279 | """ 280 | An attention block that allows spatial positions to attend to each other. 281 | Originally ported from here, but adapted to the N-d case. 282 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. 283 | """ 284 | 285 | def __init__( 286 | self, 287 | channels, 288 | num_heads=1, 289 | num_head_channels=-1, 290 | use_checkpoint=False, 291 | use_new_attention_order=False, 292 | ): 293 | super().__init__() 294 | self.channels = channels 295 | if num_head_channels == -1: 296 | self.num_heads = num_heads 297 | else: 298 | assert ( 299 | channels % num_head_channels == 0 300 | ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" 301 | self.num_heads = channels // num_head_channels 302 | self.use_checkpoint = use_checkpoint 303 | self.norm = normalization(channels) 304 | self.qkv = conv_nd(1, channels, channels * 3, 1) 305 | if use_new_attention_order: 306 | # split qkv before split heads 307 | self.attention = QKVAttention(self.num_heads) 308 | else: 309 | # split heads before split qkv 310 | self.attention = QKVAttentionLegacy(self.num_heads) 311 | 312 | self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) 313 | 314 | def forward(self, x): 315 | return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!! 316 | #return pt_checkpoint(self._forward, x) # pytorch 317 | 318 | def _forward(self, x): 319 | b, c, *spatial = x.shape 320 | x = x.reshape(b, c, -1) 321 | qkv = self.qkv(self.norm(x)) 322 | h = self.attention(qkv) 323 | h = self.proj_out(h) 324 | return (x + h).reshape(b, c, *spatial) 325 | 326 | 327 | def count_flops_attn(model, _x, y): 328 | """ 329 | A counter for the `thop` package to count the operations in an 330 | attention operation. 331 | Meant to be used like: 332 | macs, params = thop.profile( 333 | model, 334 | inputs=(inputs, timestamps), 335 | custom_ops={QKVAttention: QKVAttention.count_flops}, 336 | ) 337 | """ 338 | b, c, *spatial = y[0].shape 339 | num_spatial = int(np.prod(spatial)) 340 | # We perform two matmuls with the same number of ops. 341 | # The first computes the weight matrix, the second computes 342 | # the combination of the value vectors. 343 | matmul_ops = 2 * b * (num_spatial ** 2) * c 344 | model.total_ops += th.DoubleTensor([matmul_ops]) 345 | 346 | 347 | class QKVAttentionLegacy(nn.Module): 348 | """ 349 | A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping 350 | """ 351 | 352 | def __init__(self, n_heads): 353 | super().__init__() 354 | self.n_heads = n_heads 355 | 356 | def forward(self, qkv): 357 | """ 358 | Apply QKV attention. 359 | :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. 360 | :return: an [N x (H * C) x T] tensor after attention. 361 | """ 362 | bs, width, length = qkv.shape 363 | assert width % (3 * self.n_heads) == 0 364 | ch = width // (3 * self.n_heads) 365 | q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) 366 | scale = 1 / math.sqrt(math.sqrt(ch)) 367 | weight = th.einsum( 368 | "bct,bcs->bts", q * scale, k * scale 369 | ) # More stable with f16 than dividing afterwards 370 | weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) 371 | a = th.einsum("bts,bcs->bct", weight, v) 372 | return a.reshape(bs, -1, length) 373 | 374 | @staticmethod 375 | def count_flops(model, _x, y): 376 | return count_flops_attn(model, _x, y) 377 | 378 | 379 | class QKVAttention(nn.Module): 380 | """ 381 | A module which performs QKV attention and splits in a different order. 382 | """ 383 | 384 | def __init__(self, n_heads): 385 | super().__init__() 386 | self.n_heads = n_heads 387 | 388 | def forward(self, qkv): 389 | """ 390 | Apply QKV attention. 391 | :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. 392 | :return: an [N x (H * C) x T] tensor after attention. 393 | """ 394 | bs, width, length = qkv.shape 395 | assert width % (3 * self.n_heads) == 0 396 | ch = width // (3 * self.n_heads) 397 | q, k, v = qkv.chunk(3, dim=1) 398 | scale = 1 / math.sqrt(math.sqrt(ch)) 399 | weight = th.einsum( 400 | "bct,bcs->bts", 401 | (q * scale).view(bs * self.n_heads, ch, length), 402 | (k * scale).view(bs * self.n_heads, ch, length), 403 | ) # More stable with f16 than dividing afterwards 404 | weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) 405 | a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length)) 406 | return a.reshape(bs, -1, length) 407 | 408 | @staticmethod 409 | def count_flops(model, _x, y): 410 | return count_flops_attn(model, _x, y) 411 | 412 | 413 | class Timestep(nn.Module): 414 | def __init__(self, dim): 415 | super().__init__() 416 | self.dim = dim 417 | 418 | def forward(self, t): 419 | return timestep_embedding(t, self.dim) 420 | 421 | 422 | class UNetModel(nn.Module): 423 | """ 424 | The full UNet model with attention and timestep embedding. 425 | :param in_channels: channels in the input Tensor. 426 | :param model_channels: base channel count for the model. 427 | :param out_channels: channels in the output Tensor. 428 | :param num_res_blocks: number of residual blocks per downsample. 429 | :param attention_resolutions: a collection of downsample rates at which 430 | attention will take place. May be a set, list, or tuple. 431 | For example, if this contains 4, then at 4x downsampling, attention 432 | will be used. 433 | :param dropout: the dropout probability. 434 | :param channel_mult: channel multiplier for each level of the UNet. 435 | :param conv_resample: if True, use learned convolutions for upsampling and 436 | downsampling. 437 | :param dims: determines if the signal is 1D, 2D, or 3D. 438 | :param num_classes: if specified (as an int), then this model will be 439 | class-conditional with `num_classes` classes. 440 | :param use_checkpoint: use gradient checkpointing to reduce memory usage. 441 | :param num_heads: the number of attention heads in each attention layer. 442 | :param num_heads_channels: if specified, ignore num_heads and instead use 443 | a fixed channel width per attention head. 444 | :param num_heads_upsample: works with num_heads to set a different number 445 | of heads for upsampling. Deprecated. 446 | :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. 447 | :param resblock_updown: use residual blocks for up/downsampling. 448 | :param use_new_attention_order: use a different attention pattern for potentially 449 | increased efficiency. 450 | """ 451 | 452 | def __init__( 453 | self, 454 | image_size, 455 | in_channels, 456 | model_channels, 457 | out_channels, 458 | num_res_blocks, 459 | attention_resolutions, 460 | dropout=0, 461 | channel_mult=(1, 2, 4, 8), 462 | conv_resample=True, 463 | dims=2, 464 | num_classes=None, 465 | use_checkpoint=False, 466 | use_fp16=False, 467 | use_bf16=False, 468 | num_heads=-1, 469 | num_head_channels=-1, 470 | num_heads_upsample=-1, 471 | use_scale_shift_norm=False, 472 | resblock_updown=False, 473 | use_new_attention_order=False, 474 | use_spatial_transformer=False, # custom transformer support 475 | transformer_depth=1, # custom transformer support 476 | context_dim=None, # custom transformer support 477 | n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model 478 | legacy=True, 479 | disable_self_attentions=None, 480 | num_attention_blocks=None, 481 | disable_middle_self_attn=False, 482 | use_linear_in_transformer=False, 483 | adm_in_channels=None, 484 | load_from_ckpt=None, 485 | ): 486 | super().__init__() 487 | if use_spatial_transformer: 488 | assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...' 489 | 490 | if context_dim is not None: 491 | assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...' 492 | from omegaconf.listconfig import ListConfig 493 | if type(context_dim) == ListConfig: 494 | context_dim = list(context_dim) 495 | 496 | if num_heads_upsample == -1: 497 | num_heads_upsample = num_heads 498 | 499 | if num_heads == -1: 500 | assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set' 501 | 502 | if num_head_channels == -1: 503 | assert num_heads != -1, 'Either num_heads or num_head_channels has to be set' 504 | 505 | self.image_size = image_size 506 | self.in_channels = in_channels 507 | self.model_channels = model_channels 508 | self.out_channels = out_channels 509 | if isinstance(num_res_blocks, int): 510 | self.num_res_blocks = len(channel_mult) * [num_res_blocks] 511 | else: 512 | if len(num_res_blocks) != len(channel_mult): 513 | raise ValueError("provide num_res_blocks either as an int (globally constant) or " 514 | "as a list/tuple (per-level) with the same length as channel_mult") 515 | self.num_res_blocks = num_res_blocks 516 | if disable_self_attentions is not None: 517 | # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not 518 | assert len(disable_self_attentions) == len(channel_mult) 519 | if num_attention_blocks is not None: 520 | assert len(num_attention_blocks) == len(self.num_res_blocks) 521 | assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks)))) 522 | print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. " 523 | f"This option has LESS priority than attention_resolutions {attention_resolutions}, " 524 | f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, " 525 | f"attention will still not be set.") 526 | 527 | self.attention_resolutions = attention_resolutions 528 | self.dropout = dropout 529 | self.channel_mult = channel_mult 530 | self.conv_resample = conv_resample 531 | self.num_classes = num_classes 532 | self.use_checkpoint = use_checkpoint 533 | self.dtype = th.float16 if use_fp16 else th.float32 534 | self.dtype = th.bfloat16 if use_bf16 else self.dtype 535 | self.num_heads = num_heads 536 | self.num_head_channels = num_head_channels 537 | self.num_heads_upsample = num_heads_upsample 538 | self.predict_codebook_ids = n_embed is not None 539 | 540 | time_embed_dim = model_channels * 4 541 | self.time_embed = nn.Sequential( 542 | linear(model_channels, time_embed_dim), 543 | nn.SiLU(), 544 | linear(time_embed_dim, time_embed_dim), 545 | ) 546 | 547 | if self.num_classes is not None: 548 | if isinstance(self.num_classes, int): 549 | self.label_emb = nn.Embedding(num_classes, time_embed_dim) 550 | elif self.num_classes == "continuous": 551 | print("setting up linear c_adm embedding layer") 552 | self.label_emb = nn.Linear(1, time_embed_dim) 553 | elif self.num_classes == "sequential": 554 | assert adm_in_channels is not None 555 | self.label_emb = nn.Sequential( 556 | nn.Sequential( 557 | linear(adm_in_channels, time_embed_dim), 558 | nn.SiLU(), 559 | linear(time_embed_dim, time_embed_dim), 560 | ) 561 | ) 562 | else: 563 | raise ValueError() 564 | 565 | self.input_blocks = nn.ModuleList( 566 | [ 567 | TimestepEmbedSequential( 568 | conv_nd(dims, in_channels, model_channels, 3, padding=1) 569 | ) 570 | ] 571 | ) 572 | self._feature_size = model_channels 573 | input_block_chans = [model_channels] 574 | ch = model_channels 575 | ds = 1 576 | for level, mult in enumerate(channel_mult): 577 | for nr in range(self.num_res_blocks[level]): 578 | layers = [ 579 | ResBlock( 580 | ch, 581 | time_embed_dim, 582 | dropout, 583 | out_channels=mult * model_channels, 584 | dims=dims, 585 | use_checkpoint=use_checkpoint, 586 | use_scale_shift_norm=use_scale_shift_norm, 587 | ) 588 | ] 589 | ch = mult * model_channels 590 | if ds in attention_resolutions: 591 | if num_head_channels == -1: 592 | dim_head = ch // num_heads 593 | else: 594 | num_heads = ch // num_head_channels 595 | dim_head = num_head_channels 596 | if legacy: 597 | #num_heads = 1 598 | dim_head = ch // num_heads if use_spatial_transformer else num_head_channels 599 | if exists(disable_self_attentions): 600 | disabled_sa = disable_self_attentions[level] 601 | else: 602 | disabled_sa = False 603 | 604 | if not exists(num_attention_blocks) or nr < num_attention_blocks[level]: 605 | layers.append( 606 | AttentionBlock( 607 | ch, 608 | use_checkpoint=use_checkpoint, 609 | num_heads=num_heads, 610 | num_head_channels=dim_head, 611 | use_new_attention_order=use_new_attention_order, 612 | ) if not use_spatial_transformer else SpatialTransformer( 613 | ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, 614 | disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, 615 | use_checkpoint=use_checkpoint 616 | ) 617 | ) 618 | self.input_blocks.append(TimestepEmbedSequential(*layers)) 619 | self._feature_size += ch 620 | input_block_chans.append(ch) 621 | if level != len(channel_mult) - 1: 622 | out_ch = ch 623 | self.input_blocks.append( 624 | TimestepEmbedSequential( 625 | ResBlock( 626 | ch, 627 | time_embed_dim, 628 | dropout, 629 | out_channels=out_ch, 630 | dims=dims, 631 | use_checkpoint=use_checkpoint, 632 | use_scale_shift_norm=use_scale_shift_norm, 633 | down=True, 634 | ) 635 | if resblock_updown 636 | else Downsample( 637 | ch, conv_resample, dims=dims, out_channels=out_ch 638 | ) 639 | ) 640 | ) 641 | ch = out_ch 642 | input_block_chans.append(ch) 643 | ds *= 2 644 | self._feature_size += ch 645 | 646 | if num_head_channels == -1: 647 | dim_head = ch // num_heads 648 | else: 649 | num_heads = ch // num_head_channels 650 | dim_head = num_head_channels 651 | if legacy: 652 | #num_heads = 1 653 | dim_head = ch // num_heads if use_spatial_transformer else num_head_channels 654 | self.middle_block = TimestepEmbedSequential( 655 | ResBlock( 656 | ch, 657 | time_embed_dim, 658 | dropout, 659 | dims=dims, 660 | use_checkpoint=use_checkpoint, 661 | use_scale_shift_norm=use_scale_shift_norm, 662 | ), 663 | AttentionBlock( 664 | ch, 665 | use_checkpoint=use_checkpoint, 666 | num_heads=num_heads, 667 | num_head_channels=dim_head, 668 | use_new_attention_order=use_new_attention_order, 669 | ) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn 670 | ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, 671 | disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer, 672 | use_checkpoint=use_checkpoint 673 | ), 674 | ResBlock( 675 | ch, 676 | time_embed_dim, 677 | dropout, 678 | dims=dims, 679 | use_checkpoint=use_checkpoint, 680 | use_scale_shift_norm=use_scale_shift_norm, 681 | ), 682 | ) 683 | self._feature_size += ch 684 | 685 | self.output_blocks = nn.ModuleList([]) 686 | for level, mult in list(enumerate(channel_mult))[::-1]: 687 | for i in range(self.num_res_blocks[level] + 1): 688 | ich = input_block_chans.pop() 689 | layers = [ 690 | ResBlock( 691 | ch + ich, 692 | time_embed_dim, 693 | dropout, 694 | out_channels=model_channels * mult, 695 | dims=dims, 696 | use_checkpoint=use_checkpoint, 697 | use_scale_shift_norm=use_scale_shift_norm, 698 | ) 699 | ] 700 | ch = model_channels * mult 701 | if ds in attention_resolutions: 702 | if num_head_channels == -1: 703 | dim_head = ch // num_heads 704 | else: 705 | num_heads = ch // num_head_channels 706 | dim_head = num_head_channels 707 | if legacy: 708 | #num_heads = 1 709 | dim_head = ch // num_heads if use_spatial_transformer else num_head_channels 710 | if exists(disable_self_attentions): 711 | disabled_sa = disable_self_attentions[level] 712 | else: 713 | disabled_sa = False 714 | 715 | if not exists(num_attention_blocks) or i < num_attention_blocks[level]: 716 | layers.append( 717 | AttentionBlock( 718 | ch, 719 | use_checkpoint=use_checkpoint, 720 | num_heads=num_heads_upsample, 721 | num_head_channels=dim_head, 722 | use_new_attention_order=use_new_attention_order, 723 | ) if not use_spatial_transformer else SpatialTransformer( 724 | ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, 725 | disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, 726 | use_checkpoint=use_checkpoint 727 | ) 728 | ) 729 | if level and i == self.num_res_blocks[level]: 730 | out_ch = ch 731 | layers.append( 732 | ResBlock( 733 | ch, 734 | time_embed_dim, 735 | dropout, 736 | out_channels=out_ch, 737 | dims=dims, 738 | use_checkpoint=use_checkpoint, 739 | use_scale_shift_norm=use_scale_shift_norm, 740 | up=True, 741 | ) 742 | if resblock_updown 743 | else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) 744 | ) 745 | ds //= 2 746 | self.output_blocks.append(TimestepEmbedSequential(*layers)) 747 | self._feature_size += ch 748 | 749 | self.out = nn.Sequential( 750 | normalization(ch), 751 | nn.SiLU(), 752 | zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), 753 | ) 754 | if self.predict_codebook_ids: 755 | self.id_predictor = nn.Sequential( 756 | normalization(ch), 757 | conv_nd(dims, model_channels, n_embed, 1), 758 | #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits 759 | ) 760 | 761 | if load_from_ckpt is not None: 762 | self.load_from_ckpt(load_from_ckpt) 763 | 764 | def load_from_ckpt(self, ckpt_path): 765 | input_ch = self.state_dict()["input_blocks.0.0.weight"].shape[1] 766 | assert input_ch >= 4 and input_ch // 4 * 4 == input_ch, "Input channels must be at a multiplier 4 to load from SD ckpt" 767 | output_ch = self.state_dict()["out.2.weight"].shape[0] 768 | assert output_ch >= 4 and output_ch // 4 * 4 == output_ch, "Output channels must be at a multiplier 4 to load from SD ckpt" 769 | sd = th.load(ckpt_path) 770 | sd_ = {} 771 | for k,v in sd["state_dict"].items(): 772 | if k.startswith("model.diffusion_model"): 773 | sd_[k.replace("model.diffusion_model.", "")] = v 774 | 775 | if input_ch > 4: 776 | # Scaling for input channels so that the gradients are not too large 777 | scale = input_ch // 4 778 | sd_["input_blocks.0.0.weight"] = sd_["input_blocks.0.0.weight"] / scale 779 | sd_["input_blocks.0.0.weight"] = sd_["input_blocks.0.0.weight"].repeat(1, scale, 1, 1) 780 | 781 | if output_ch > 4: 782 | # No scaling for output channels 783 | scale = output_ch // 4 784 | sd_["out.2.weight"] = sd_["out.2.weight"].repeat(scale, 1, 1, 1) 785 | sd_["out.2.bias"] = sd_["out.2.bias"].repeat(scale) 786 | 787 | missing, unexpected = self.load_state_dict(sd_, strict=False) 788 | 789 | if len(missing) > 0: 790 | print(f"Load model weights - missing keys: {len(missing)}") 791 | print(missing) 792 | if len(unexpected) > 0: 793 | print(f"Load model weights - unexpected keys: {len(unexpected)}") 794 | print(unexpected) 795 | 796 | 797 | def convert_to_fp16(self): 798 | """ 799 | Convert the torso of the model to float16. 800 | """ 801 | self.input_blocks.apply(convert_module_to_f16) 802 | self.middle_block.apply(convert_module_to_f16) 803 | self.output_blocks.apply(convert_module_to_f16) 804 | 805 | def convert_to_fp32(self): 806 | """ 807 | Convert the torso of the model to float32. 808 | """ 809 | self.input_blocks.apply(convert_module_to_f32) 810 | self.middle_block.apply(convert_module_to_f32) 811 | self.output_blocks.apply(convert_module_to_f32) 812 | 813 | def forward(self, x, t=None, context=None, context_ca=None, y=None,**kwargs): 814 | """ 815 | Apply the model to an input batch. 816 | :param x: an [N x C x ...] Tensor of inputs. 817 | :param t: a 1-D batch of timesteps. 818 | :param context: conditioning plugged in via crossattn 819 | :param y: an [N] Tensor of labels, if class-conditional. 820 | :return: an [N x C x ...] Tensor of outputs. 821 | """ 822 | assert (y is not None) == ( 823 | self.num_classes is not None 824 | ), "must specify y if and only if the model is class-conditional" 825 | hs = [] 826 | t_emb = timestep_embedding(t, self.model_channels, repeat_only=False) 827 | emb = self.time_embed(t_emb) 828 | 829 | if self.num_classes is not None: 830 | assert y.shape[0] == x.shape[0] 831 | emb = emb + self.label_emb(y) 832 | 833 | h = x.type(self.dtype) 834 | if context is not None: 835 | h = th.cat([h, context], dim=1) 836 | for module in self.input_blocks: 837 | h = module(h, emb, context_ca) 838 | hs.append(h) 839 | h = self.middle_block(h, emb, context_ca) 840 | for module in self.output_blocks: 841 | h = th.cat([h, hs.pop()], dim=1) 842 | h = module(h, emb, context_ca) 843 | h = h.type(x.dtype) 844 | if self.predict_codebook_ids: 845 | return self.id_predictor(h) 846 | else: 847 | return self.out(h) 848 | 849 | def get_midblock_features(self, x, t=None, context=None, context_ca=None, y=None, **kwargs): 850 | """ 851 | Apply the model to an input batch and return the features from the middle block. 852 | :param x: an [N x C x ...] Tensor of inputs. 853 | :param t: a 1-D batch of timesteps. 854 | :param context: conditioning plugged in via crossattn 855 | :param y: an [N] Tensor of labels, if class-conditional 856 | """ 857 | assert (y is not None) == ( 858 | self.num_classes is not None 859 | ), "must specify y if and only if the model is class-conditional" 860 | hs = [] 861 | t_emb = timestep_embedding(t, self.model_channels, repeat_only=False) 862 | emb = self.time_embed(t_emb) 863 | 864 | if self.num_classes is not None: 865 | assert y.shape[0] == x.shape[0] 866 | emb = emb + self.label_emb(y) 867 | 868 | h = x.type(self.dtype) 869 | if context is not None: 870 | h = th.cat([h, context], dim=1) 871 | for module in self.input_blocks: 872 | h = module(h, emb, context_ca) 873 | hs.append(h) 874 | h = self.middle_block(h, emb, context_ca) 875 | return h 876 | 877 | if __name__ == "__main__": 878 | unet = UNetModel( 879 | image_size=32, 880 | in_channels=8, 881 | model_channels=320, 882 | out_channels=4, 883 | num_res_blocks=2, 884 | attention_resolutions=(4,2,1), 885 | dropout=0.0, 886 | channel_mult=(1, 2, 4, 4), 887 | num_heads=8, 888 | use_spatial_transformer=True, 889 | context_dim=768, 890 | transformer_depth=1, 891 | legacy=False, 892 | load_from_ckpt="/export/scratch/ra97ram/checkpoints/sd/v1-5-pruned.ckpt" 893 | ) 894 | print(f"UNetModel has {sum(p.numel() for p in unet.parameters())} parameters") 895 | -------------------------------------------------------------------------------- /depthfm/unet/util.py: -------------------------------------------------------------------------------- 1 | # adopted from 2 | # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 3 | # and 4 | # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py 5 | # and 6 | # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py 7 | # 8 | # thanks! 9 | 10 | 11 | import os 12 | import math 13 | import torch 14 | import torch.nn as nn 15 | import numpy as np 16 | from einops import repeat 17 | 18 | 19 | def extract_into_tensor(a, t, x_shape): 20 | b, *_ = t.shape 21 | out = a.gather(-1, t) 22 | return out.reshape(b, *((1,) * (len(x_shape) - 1))) 23 | 24 | 25 | def checkpoint(func, inputs, params, flag): 26 | """ 27 | Evaluate a function without caching intermediate activations, allowing for 28 | reduced memory at the expense of extra compute in the backward pass. 29 | :param func: the function to evaluate. 30 | :param inputs: the argument sequence to pass to `func`. 31 | :param params: a sequence of parameters `func` depends on but does not 32 | explicitly take as arguments. 33 | :param flag: if False, disable gradient checkpointing. 34 | """ 35 | if flag: 36 | args = tuple(inputs) + tuple(params) 37 | return CheckpointFunction.apply(func, len(inputs), *args) 38 | else: 39 | return func(*inputs) 40 | 41 | 42 | class CheckpointFunction(torch.autograd.Function): 43 | @staticmethod 44 | def forward(ctx, run_function, length, *args): 45 | ctx.run_function = run_function 46 | ctx.input_tensors = list(args[:length]) 47 | ctx.input_params = list(args[length:]) 48 | ctx.gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(), 49 | "dtype": torch.get_autocast_gpu_dtype(), 50 | "cache_enabled": torch.is_autocast_cache_enabled()} 51 | with torch.no_grad(): 52 | output_tensors = ctx.run_function(*ctx.input_tensors) 53 | return output_tensors 54 | 55 | @staticmethod 56 | def backward(ctx, *output_grads): 57 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 58 | with torch.enable_grad(), \ 59 | torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs): 60 | # Fixes a bug where the first op in run_function modifies the 61 | # Tensor storage in place, which is not allowed for detach()'d 62 | # Tensors. 63 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 64 | output_tensors = ctx.run_function(*shallow_copies) 65 | input_grads = torch.autograd.grad( 66 | output_tensors, 67 | ctx.input_tensors + ctx.input_params, 68 | output_grads, 69 | allow_unused=True, 70 | ) 71 | del ctx.input_tensors 72 | del ctx.input_params 73 | del output_tensors 74 | return (None, None) + input_grads 75 | 76 | 77 | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): 78 | """ 79 | Create sinusoidal timestep embeddings. 80 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 81 | These may be fractional. 82 | :param dim: the dimension of the output. 83 | :param max_period: controls the minimum frequency of the embeddings. 84 | :return: an [N x dim] Tensor of positional embeddings. 85 | """ 86 | if not repeat_only: 87 | half = dim // 2 88 | freqs = torch.exp( 89 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 90 | ).to(device=timesteps.device) 91 | args = timesteps[:, None].float() * freqs[None] 92 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 93 | if dim % 2: 94 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 95 | else: 96 | embedding = repeat(timesteps, 'b -> b d', d=dim) 97 | return embedding 98 | 99 | 100 | def zero_module(module): 101 | """ 102 | Zero out the parameters of a module and return it. 103 | """ 104 | for p in module.parameters(): 105 | p.detach().zero_() 106 | return module 107 | 108 | 109 | def scale_module(module, scale): 110 | """ 111 | Scale the parameters of a module and return it. 112 | """ 113 | for p in module.parameters(): 114 | p.detach().mul_(scale) 115 | return module 116 | 117 | 118 | def mean_flat(tensor): 119 | """ 120 | Take the mean over all non-batch dimensions. 121 | """ 122 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 123 | 124 | 125 | def normalization(channels): 126 | """ 127 | Make a standard normalization layer. 128 | :param channels: number of input channels. 129 | :return: an nn.Module for normalization. 130 | """ 131 | return GroupNorm32(32, channels) 132 | 133 | 134 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 135 | class SiLU(nn.Module): 136 | def forward(self, x): 137 | return x * torch.sigmoid(x) 138 | 139 | 140 | class GroupNorm32(nn.GroupNorm): 141 | def forward(self, x): 142 | return super().forward(x.float()).type(x.dtype) 143 | 144 | 145 | def conv_nd(dims, *args, **kwargs): 146 | """ 147 | Create a 1D, 2D, or 3D convolution module. 148 | """ 149 | if dims == 1: 150 | return nn.Conv1d(*args, **kwargs) 151 | elif dims == 2: 152 | return nn.Conv2d(*args, **kwargs) 153 | elif dims == 3: 154 | return nn.Conv3d(*args, **kwargs) 155 | raise ValueError(f"unsupported dimensions: {dims}") 156 | 157 | 158 | def linear(*args, **kwargs): 159 | """ 160 | Create a linear module. 161 | """ 162 | return nn.Linear(*args, **kwargs) 163 | 164 | 165 | def avg_pool_nd(dims, *args, **kwargs): 166 | """ 167 | Create a 1D, 2D, or 3D average pooling module. 168 | """ 169 | if dims == 1: 170 | return nn.AvgPool1d(*args, **kwargs) 171 | elif dims == 2: 172 | return nn.AvgPool2d(*args, **kwargs) 173 | elif dims == 3: 174 | return nn.AvgPool3d(*args, **kwargs) 175 | raise ValueError(f"unsupported dimensions: {dims}") 176 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: dfm 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - conda-forge 6 | - defaults 7 | dependencies: 8 | - omegaconf>=2.3.0 9 | - accelerate>=0.27.2 10 | - diffusers=0.24.0 11 | - matplotlib>=3.8.1 12 | - python=3.11.5 13 | - pytorch=2.1.0 14 | - pytorch-cuda=12.1 15 | - pip=23.3 16 | - pip: 17 | - transformers==4.35.0 18 | - einops==0.7.0 19 | - torchdiffeq==0.2.3 20 | - xformers==0.0.22.post7 -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import einops 4 | import argparse 5 | import numpy as np 6 | from PIL import Image 7 | from PIL.Image import Resampling 8 | from depthfm import DepthFM 9 | import matplotlib.pyplot as plt 10 | 11 | def get_dtype_from_str(dtype_str): 12 | return {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[dtype_str] 13 | 14 | def resize_max_res( 15 | img: Image.Image, max_edge_resolution: int, resample_method=Resampling.BILINEAR 16 | ) -> Image.Image: 17 | """ 18 | Resize image to limit maximum edge length while keeping aspect ratio. 19 | 20 | Args: 21 | img (`Image.Image`): 22 | Image to be resized. 23 | max_edge_resolution (`int`): 24 | Maximum edge length (pixel). 25 | resample_method (`PIL.Image.Resampling`): 26 | Resampling method used to resize images. 27 | 28 | Returns: 29 | `Image.Image`: Resized image. 30 | """ 31 | original_width, original_height = img.size 32 | downscale_factor = min( max_edge_resolution / original_width, max_edge_resolution / original_height) 33 | 34 | new_width = int(original_width * downscale_factor) 35 | new_height = int(original_height * downscale_factor) 36 | 37 | new_width = round(new_width / 64) * 64 38 | new_height = round(new_height / 64) * 64 39 | 40 | print(f"Resizing image from {original_width}x{original_height} to {new_width}x{new_height}") 41 | 42 | resized_img = img.resize((new_width, new_height), resample=resample_method) 43 | return resized_img, (original_width, original_height) 44 | 45 | def load_im(fp, processing_res=-1): 46 | assert os.path.exists(fp), f"File not found: {fp}" 47 | im = Image.open(fp).convert('RGB') 48 | if processing_res < 0: 49 | processing_res = max(im.size) 50 | im, orig_res = resize_max_res(im, processing_res) 51 | x = np.array(im) 52 | x = einops.rearrange(x, 'h w c -> c h w') 53 | x = x / 127.5 - 1 54 | x = torch.tensor(x, dtype=torch.float32)[None] 55 | return x, orig_res 56 | 57 | 58 | def main(args): 59 | print(f"{'Input':<10}: {args.img}") 60 | print(f"{'Steps':<10}: {args.num_steps}") 61 | print(f"{'Ensemble':<10}: {args.ensemble_size}") 62 | 63 | # Load the model 64 | model = DepthFM(args.ckpt) 65 | model.cuda(args.device).eval() 66 | 67 | # Load an image 68 | im, orig_res = load_im(args.img, args.processing_res) 69 | im = im.cuda(args.device) 70 | 71 | # Generate depth 72 | dtype = get_dtype_from_str(args.dtype) 73 | model.model.dtype = dtype 74 | with torch.autocast(device_type="cuda", dtype=dtype): 75 | depth = model.predict_depth(im, num_steps=args.num_steps, ensemble_size=args.ensemble_size) 76 | depth = depth.squeeze(0).squeeze(0).cpu().numpy() # (h, w) in [0, 1] 77 | 78 | # Convert depth to [0, 255] range 79 | if args.no_color: 80 | depth = (depth * 255).astype(np.uint8) 81 | else: 82 | depth = plt.get_cmap('magma')(depth, bytes=True)[..., :3] 83 | 84 | # Save the depth map 85 | depth_fp = args.img + '_depth.png' 86 | depth_img = Image.fromarray(depth) 87 | if depth_img.size != orig_res: 88 | depth_img = depth_img.resize(orig_res, Resampling.BILINEAR) 89 | depth_img.save(depth_fp) 90 | print(f"==> Saved depth map to {depth_fp}") 91 | 92 | 93 | if __name__ == "__main__": 94 | parser = argparse.ArgumentParser("DepthFM Inference") 95 | parser.add_argument("--img", type=str, default="assets/dog.png", 96 | help="Path to the input image") 97 | parser.add_argument("--ckpt", type=str, default="checkpoints/depthfm-v1.ckpt", 98 | help="Path to the model checkpoint") 99 | parser.add_argument("--num_steps", type=int, default=2, 100 | help="Number of steps for ODE solver") 101 | parser.add_argument("--ensemble_size", type=int, default=4, 102 | help="Number of ensemble members") 103 | parser.add_argument("--no_color", action="store_true", 104 | help="If set, the depth map will be grayscale") 105 | parser.add_argument("--device", type=int, default=0, 106 | help="GPU to use") 107 | parser.add_argument("--processing_res", type=int, default=-1, 108 | help="Longer edge of the image will be resized to this resolution. -1 to disable resizing.") 109 | parser.add_argument("--dtype", type=str, choices=["fp32", "bf16", "fp16"], default="fp16", 110 | help="Run with specific precision. Speeds up inference with subtle loss") 111 | args = parser.parse_args() 112 | 113 | main(args) 114 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.26.0 2 | einops 3 | omegaconf 4 | matplotlib 5 | accelerate>=0.22.0 6 | torch==2.1.0 7 | torchdiffeq>=0.2.3 8 | diffusers==0.26.3 9 | huggingface_hub==0.25.0 10 | xformers --------------------------------------------------------------------------------