├── assets ├── sacre_coeur_A.jpg ├── sacre_coeur_B.jpg ├── 0_d_00d1ae6aab6ccd59.jpg ├── 2_a_02a270519bdb90dd.jpg ├── sacre_coeur_A_compare.png └── sacre_coeur_B_compare.png ├── romatch ├── utils │ ├── __init__.py │ ├── kde.py │ ├── local_correlation.py │ ├── transforms.py │ └── utils.py └── models │ └── transformer │ ├── layers │ ├── __init__.py │ ├── layer_scale.py │ ├── drop_path.py │ ├── mlp.py │ ├── swiglu_ffn.py │ ├── dino_head.py │ ├── attention.py │ ├── patch_embed.py │ └── block.py │ ├── __init__.py │ └── dinov2.py ├── LICENSE.txt ├── sky_seg.py ├── vis_feats.py ├── README.md ├── relight.py ├── get_data.py └── warp_utils.py /assets/sacre_coeur_A.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sharpiless/L2M/HEAD/assets/sacre_coeur_A.jpg -------------------------------------------------------------------------------- /assets/sacre_coeur_B.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sharpiless/L2M/HEAD/assets/sacre_coeur_B.jpg -------------------------------------------------------------------------------- /assets/0_d_00d1ae6aab6ccd59.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sharpiless/L2M/HEAD/assets/0_d_00d1ae6aab6ccd59.jpg -------------------------------------------------------------------------------- /assets/2_a_02a270519bdb90dd.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sharpiless/L2M/HEAD/assets/2_a_02a270519bdb90dd.jpg -------------------------------------------------------------------------------- /assets/sacre_coeur_A_compare.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sharpiless/L2M/HEAD/assets/sacre_coeur_A_compare.png -------------------------------------------------------------------------------- /assets/sacre_coeur_B_compare.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sharpiless/L2M/HEAD/assets/sacre_coeur_B_compare.png -------------------------------------------------------------------------------- /romatch/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import ( 2 | pose_auc, 3 | get_pose, 4 | compute_relative_pose, 5 | compute_pose_error, 6 | estimate_pose, 7 | estimate_pose_uncalibrated, 8 | rotate_intrinsic, 9 | get_tuple_transform_ops, 10 | get_depth_tuple_transform_ops, 11 | warp_kpts, 12 | numpy_to_pil, 13 | tensor_to_pil, 14 | recover_pose, 15 | signed_left_to_right_epipolar_distance, 16 | ) 17 | -------------------------------------------------------------------------------- /romatch/utils/kde.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def kde(x, std = 0.1, half = True, down = None): 5 | # use a gaussian kernel to estimate density 6 | if half: 7 | x = x.half() # Do it in half precision TODO: remove hardcoding 8 | if down is not None: 9 | scores = (-torch.cdist(x,x[::down])**2/(2*std**2)).exp() 10 | else: 11 | scores = (-torch.cdist(x,x)**2/(2*std**2)).exp() 12 | density = scores.sum(dim=-1) 13 | return density -------------------------------------------------------------------------------- /romatch/models/transformer/layers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .dino_head import DINOHead 8 | from .mlp import Mlp 9 | from .patch_embed import PatchEmbed 10 | from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused 11 | from .block import NestedTensorBlock 12 | from .attention import MemEffAttention 13 | -------------------------------------------------------------------------------- /romatch/models/transformer/layers/layer_scale.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110 8 | 9 | from typing import Union 10 | 11 | import torch 12 | from torch import Tensor 13 | from torch import nn 14 | 15 | 16 | class LayerScale(nn.Module): 17 | def __init__( 18 | self, 19 | dim: int, 20 | init_values: Union[float, Tensor] = 1e-5, 21 | inplace: bool = False, 22 | ) -> None: 23 | super().__init__() 24 | self.inplace = inplace 25 | self.gamma = nn.Parameter(init_values * torch.ones(dim)) 26 | 27 | def forward(self, x: Tensor) -> Tensor: 28 | return x.mul_(self.gamma) if self.inplace else x * self.gamma 29 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Xuelun Shen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /romatch/models/transformer/layers/drop_path.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # References: 8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py 10 | 11 | 12 | from torch import nn 13 | 14 | 15 | def drop_path(x, drop_prob: float = 0.0, training: bool = False): 16 | if drop_prob == 0.0 or not training: 17 | return x 18 | keep_prob = 1 - drop_prob 19 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 20 | random_tensor = x.new_empty(shape).bernoulli_(keep_prob) 21 | if keep_prob > 0.0: 22 | random_tensor.div_(keep_prob) 23 | output = x * random_tensor 24 | return output 25 | 26 | 27 | class DropPath(nn.Module): 28 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" 29 | 30 | def __init__(self, drop_prob=None): 31 | super(DropPath, self).__init__() 32 | self.drop_prob = drop_prob 33 | 34 | def forward(self, x): 35 | return drop_path(x, self.drop_prob, self.training) 36 | -------------------------------------------------------------------------------- /romatch/models/transformer/layers/mlp.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # References: 8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py 10 | 11 | 12 | from typing import Callable, Optional 13 | 14 | from torch import Tensor, nn 15 | 16 | 17 | class Mlp(nn.Module): 18 | def __init__( 19 | self, 20 | in_features: int, 21 | hidden_features: Optional[int] = None, 22 | out_features: Optional[int] = None, 23 | act_layer: Callable[..., nn.Module] = nn.GELU, 24 | drop: float = 0.0, 25 | bias: bool = True, 26 | ) -> None: 27 | super().__init__() 28 | out_features = out_features or in_features 29 | hidden_features = hidden_features or in_features 30 | self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) 31 | self.act = act_layer() 32 | self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) 33 | self.drop = nn.Dropout(drop) 34 | 35 | def forward(self, x: Tensor) -> Tensor: 36 | x = self.fc1(x) 37 | x = self.act(x) 38 | x = self.drop(x) 39 | x = self.fc2(x) 40 | x = self.drop(x) 41 | return x 42 | -------------------------------------------------------------------------------- /romatch/utils/local_correlation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | def local_correlation( 5 | feature0, 6 | feature1, 7 | local_radius, 8 | padding_mode="zeros", 9 | flow = None, 10 | sample_mode = "bilinear", 11 | ): 12 | r = local_radius 13 | K = (2*r+1)**2 14 | B, c, h, w = feature0.size() 15 | corr = torch.empty((B,K,h,w), device = feature0.device, dtype=feature0.dtype) 16 | if flow is None: 17 | # If flow is None, assume feature0 and feature1 are aligned 18 | coords = torch.meshgrid( 19 | ( 20 | torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=feature0.device), 21 | torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device=feature0.device), 22 | ), 23 | indexing = 'ij' 24 | ) 25 | coords = torch.stack((coords[1], coords[0]), dim=-1)[ 26 | None 27 | ].expand(B, h, w, 2) 28 | else: 29 | coords = flow.permute(0,2,3,1) # If using flow, sample around flow target. 30 | local_window = torch.meshgrid( 31 | ( 32 | torch.linspace(-2*local_radius/h, 2*local_radius/h, 2*r+1, device=feature0.device), 33 | torch.linspace(-2*local_radius/w, 2*local_radius/w, 2*r+1, device=feature0.device), 34 | ), 35 | indexing = 'ij' 36 | ) 37 | local_window = torch.stack((local_window[1], local_window[0]), dim=-1)[ 38 | None 39 | ].expand(1, 2*r+1, 2*r+1, 2).reshape(1, (2*r+1)**2, 2) 40 | for _ in range(B): 41 | with torch.no_grad(): 42 | local_window_coords = (coords[_,:,:,None]+local_window[:,None,None]).reshape(1,h,w*(2*r+1)**2,2) 43 | window_feature = F.grid_sample( 44 | feature1[_:_+1], local_window_coords, padding_mode=padding_mode, align_corners=False, mode = sample_mode, # 45 | ) 46 | window_feature = window_feature.reshape(c,h,w,(2*r+1)**2) 47 | corr[_] = (feature0[_,...,None]/(c**.5)*window_feature).sum(dim=0).permute(2,0,1) 48 | return corr 49 | -------------------------------------------------------------------------------- /romatch/models/transformer/layers/swiglu_ffn.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Callable, Optional 8 | 9 | from torch import Tensor, nn 10 | import torch.nn.functional as F 11 | 12 | 13 | class SwiGLUFFN(nn.Module): 14 | def __init__( 15 | self, 16 | in_features: int, 17 | hidden_features: Optional[int] = None, 18 | out_features: Optional[int] = None, 19 | act_layer: Callable[..., nn.Module] = None, 20 | drop: float = 0.0, 21 | bias: bool = True, 22 | ) -> None: 23 | super().__init__() 24 | out_features = out_features or in_features 25 | hidden_features = hidden_features or in_features 26 | self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) 27 | self.w3 = nn.Linear(hidden_features, out_features, bias=bias) 28 | 29 | def forward(self, x: Tensor) -> Tensor: 30 | x12 = self.w12(x) 31 | x1, x2 = x12.chunk(2, dim=-1) 32 | hidden = F.silu(x1) * x2 33 | return self.w3(hidden) 34 | 35 | 36 | try: 37 | from xformers.ops import SwiGLU 38 | 39 | XFORMERS_AVAILABLE = True 40 | except ImportError: 41 | SwiGLU = SwiGLUFFN 42 | XFORMERS_AVAILABLE = False 43 | 44 | 45 | class SwiGLUFFNFused(SwiGLU): 46 | def __init__( 47 | self, 48 | in_features: int, 49 | hidden_features: Optional[int] = None, 50 | out_features: Optional[int] = None, 51 | act_layer: Callable[..., nn.Module] = None, 52 | drop: float = 0.0, 53 | bias: bool = True, 54 | ) -> None: 55 | out_features = out_features or in_features 56 | hidden_features = hidden_features or in_features 57 | hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 58 | super().__init__( 59 | in_features=in_features, 60 | hidden_features=hidden_features, 61 | out_features=out_features, 62 | bias=bias, 63 | ) 64 | -------------------------------------------------------------------------------- /romatch/models/transformer/layers/dino_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | from torch.nn.init import trunc_normal_ 10 | from torch.nn.utils import weight_norm 11 | 12 | 13 | class DINOHead(nn.Module): 14 | def __init__( 15 | self, 16 | in_dim, 17 | out_dim, 18 | use_bn=False, 19 | nlayers=3, 20 | hidden_dim=2048, 21 | bottleneck_dim=256, 22 | mlp_bias=True, 23 | ): 24 | super().__init__() 25 | nlayers = max(nlayers, 1) 26 | self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias) 27 | self.apply(self._init_weights) 28 | self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False)) 29 | self.last_layer.weight_g.data.fill_(1) 30 | 31 | def _init_weights(self, m): 32 | if isinstance(m, nn.Linear): 33 | trunc_normal_(m.weight, std=0.02) 34 | if isinstance(m, nn.Linear) and m.bias is not None: 35 | nn.init.constant_(m.bias, 0) 36 | 37 | def forward(self, x): 38 | x = self.mlp(x) 39 | eps = 1e-6 if x.dtype == torch.float16 else 1e-12 40 | x = nn.functional.normalize(x, dim=-1, p=2, eps=eps) 41 | x = self.last_layer(x) 42 | return x 43 | 44 | 45 | def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True): 46 | if nlayers == 1: 47 | return nn.Linear(in_dim, bottleneck_dim, bias=bias) 48 | else: 49 | layers = [nn.Linear(in_dim, hidden_dim, bias=bias)] 50 | if use_bn: 51 | layers.append(nn.BatchNorm1d(hidden_dim)) 52 | layers.append(nn.GELU()) 53 | for _ in range(nlayers - 2): 54 | layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias)) 55 | if use_bn: 56 | layers.append(nn.BatchNorm1d(hidden_dim)) 57 | layers.append(nn.GELU()) 58 | layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias)) 59 | return nn.Sequential(*layers) 60 | -------------------------------------------------------------------------------- /romatch/models/transformer/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from romatch.utils.utils import get_grid, get_autocast_params 6 | from .layers.block import Block 7 | from .layers.attention import MemEffAttention 8 | from .dinov2 import vit_large, vit_base 9 | 10 | class TransformerDecoder(nn.Module): 11 | def __init__(self, blocks, hidden_dim, out_dim, is_classifier = False, *args, 12 | amp = False, pos_enc = True, learned_embeddings = False, embedding_dim = None, amp_dtype = torch.float16, **kwargs) -> None: 13 | super().__init__(*args, **kwargs) 14 | self.blocks = blocks 15 | self.to_out = nn.Linear(hidden_dim, out_dim) 16 | self.hidden_dim = hidden_dim 17 | self.out_dim = out_dim 18 | self._scales = [16, 18] 19 | self.is_classifier = is_classifier 20 | self.amp = amp 21 | self.amp_dtype = amp_dtype 22 | self.pos_enc = pos_enc 23 | self.learned_embeddings = learned_embeddings 24 | if self.learned_embeddings: 25 | self.learned_pos_embeddings = nn.Parameter(nn.init.kaiming_normal_(torch.empty((1, hidden_dim, embedding_dim, embedding_dim)))) 26 | 27 | def scales(self): 28 | return self._scales.copy() 29 | 30 | def forward(self, gp_posterior, features, old_stuff, new_scale): 31 | autocast_device, autocast_enabled, autocast_dtype = get_autocast_params(gp_posterior.device, enabled=self.amp, dtype=self.amp_dtype) 32 | with torch.autocast(autocast_device, enabled=autocast_enabled, dtype = autocast_dtype): 33 | B,C,H,W = gp_posterior.shape 34 | x = torch.cat((gp_posterior, features), dim = 1) 35 | B,C,H,W = x.shape 36 | grid = get_grid(B, H, W, x.device).reshape(B,H*W,2) 37 | if self.learned_embeddings: 38 | pos_enc = F.interpolate(self.learned_pos_embeddings, size = (H,W), mode = 'bilinear', align_corners = False).permute(0,2,3,1).reshape(1,H*W,C) 39 | else: 40 | pos_enc = 0 41 | tokens = x.reshape(B,C,H*W).permute(0,2,1) + pos_enc 42 | z = self.blocks(tokens) 43 | out = self.to_out(z) 44 | out = out.permute(0,2,1).reshape(B, self.out_dim, H, W) 45 | warp, certainty = out[:, :-1], out[:, -1:] 46 | return warp, certainty, None 47 | 48 | 49 | -------------------------------------------------------------------------------- /romatch/models/transformer/layers/attention.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # References: 8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py 10 | 11 | import logging 12 | 13 | from torch import Tensor 14 | from torch import nn 15 | 16 | 17 | logger = logging.getLogger("dinov2") 18 | 19 | 20 | try: 21 | from xformers.ops import memory_efficient_attention, unbind, fmha 22 | 23 | XFORMERS_AVAILABLE = True 24 | except ImportError: 25 | # logger.warning("xFormers not available") 26 | XFORMERS_AVAILABLE = False 27 | 28 | 29 | class Attention(nn.Module): 30 | def __init__( 31 | self, 32 | dim: int, 33 | num_heads: int = 8, 34 | qkv_bias: bool = False, 35 | proj_bias: bool = True, 36 | attn_drop: float = 0.0, 37 | proj_drop: float = 0.0, 38 | ) -> None: 39 | super().__init__() 40 | self.num_heads = num_heads 41 | head_dim = dim // num_heads 42 | self.scale = head_dim**-0.5 43 | 44 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 45 | self.attn_drop = nn.Dropout(attn_drop) 46 | self.proj = nn.Linear(dim, dim, bias=proj_bias) 47 | self.proj_drop = nn.Dropout(proj_drop) 48 | 49 | def forward(self, x: Tensor) -> Tensor: 50 | B, N, C = x.shape 51 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 52 | 53 | q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] 54 | attn = q @ k.transpose(-2, -1) 55 | 56 | attn = attn.softmax(dim=-1) 57 | attn = self.attn_drop(attn) 58 | 59 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 60 | x = self.proj(x) 61 | x = self.proj_drop(x) 62 | return x 63 | 64 | 65 | class MemEffAttention(Attention): 66 | def forward(self, x: Tensor, attn_bias=None) -> Tensor: 67 | if not XFORMERS_AVAILABLE: 68 | assert attn_bias is None, "xFormers is required for nested tensors usage" 69 | return super().forward(x) 70 | 71 | B, N, C = x.shape 72 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) 73 | 74 | q, k, v = unbind(qkv, 2) 75 | 76 | x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) 77 | x = x.reshape([B, N, C]) 78 | 79 | x = self.proj(x) 80 | x = self.proj_drop(x) 81 | return x 82 | -------------------------------------------------------------------------------- /romatch/models/transformer/layers/patch_embed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # References: 8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py 10 | 11 | from typing import Callable, Optional, Tuple, Union 12 | 13 | from torch import Tensor 14 | import torch.nn as nn 15 | 16 | 17 | def make_2tuple(x): 18 | if isinstance(x, tuple): 19 | assert len(x) == 2 20 | return x 21 | 22 | assert isinstance(x, int) 23 | return (x, x) 24 | 25 | 26 | class PatchEmbed(nn.Module): 27 | """ 28 | 2D image to patch embedding: (B,C,H,W) -> (B,N,D) 29 | 30 | Args: 31 | img_size: Image size. 32 | patch_size: Patch token size. 33 | in_chans: Number of input image channels. 34 | embed_dim: Number of linear projection output channels. 35 | norm_layer: Normalization layer. 36 | """ 37 | 38 | def __init__( 39 | self, 40 | img_size: Union[int, Tuple[int, int]] = 224, 41 | patch_size: Union[int, Tuple[int, int]] = 16, 42 | in_chans: int = 3, 43 | embed_dim: int = 768, 44 | norm_layer: Optional[Callable] = None, 45 | flatten_embedding: bool = True, 46 | ) -> None: 47 | super().__init__() 48 | 49 | image_HW = make_2tuple(img_size) 50 | patch_HW = make_2tuple(patch_size) 51 | patch_grid_size = ( 52 | image_HW[0] // patch_HW[0], 53 | image_HW[1] // patch_HW[1], 54 | ) 55 | 56 | self.img_size = image_HW 57 | self.patch_size = patch_HW 58 | self.patches_resolution = patch_grid_size 59 | self.num_patches = patch_grid_size[0] * patch_grid_size[1] 60 | 61 | self.in_chans = in_chans 62 | self.embed_dim = embed_dim 63 | 64 | self.flatten_embedding = flatten_embedding 65 | 66 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW) 67 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() 68 | 69 | def forward(self, x: Tensor) -> Tensor: 70 | _, _, H, W = x.shape 71 | patch_H, patch_W = self.patch_size 72 | 73 | assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}" 74 | assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}" 75 | 76 | x = self.proj(x) # B C H W 77 | H, W = x.size(2), x.size(3) 78 | x = x.flatten(2).transpose(1, 2) # B HW C 79 | x = self.norm(x) 80 | if not self.flatten_embedding: 81 | x = x.reshape(-1, H, W, self.embed_dim) # B H W C 82 | return x 83 | 84 | def flops(self) -> float: 85 | Ho, Wo = self.patches_resolution 86 | flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) 87 | if self.norm is not None: 88 | flops += Ho * Wo * self.embed_dim 89 | return flops 90 | -------------------------------------------------------------------------------- /sky_seg.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | import time 4 | import argparse 5 | import cv2 6 | import numpy as np 7 | import onnxruntime 8 | from tqdm import tqdm 9 | import imutils 10 | 11 | def run_inference(onnx_session, input_size, image): 12 | # Pre process:Resize, BGR->RGB, Transpose, PyTorch standardization, float32 cast 13 | temp_image = copy.deepcopy(image) 14 | resize_image = cv2.resize(temp_image, dsize=(input_size[0], input_size[1])) 15 | x = cv2.cvtColor(resize_image, cv2.COLOR_BGR2RGB) 16 | x = np.array(x, dtype=np.float32) 17 | mean = [0.485, 0.456, 0.406] 18 | std = [0.229, 0.224, 0.225] 19 | x = (x / 255 - mean) / std 20 | x = x.transpose(2, 0, 1) 21 | x = x.reshape(-1, 3, input_size[0], input_size[1]).astype('float32') 22 | 23 | # Inference 24 | input_name = onnx_session.get_inputs()[0].name 25 | output_name = onnx_session.get_outputs()[0].name 26 | onnx_result = onnx_session.run([output_name], {input_name: x}) 27 | 28 | # Post process 29 | onnx_result = np.array(onnx_result).squeeze() 30 | min_value = np.min(onnx_result) 31 | max_value = np.max(onnx_result) 32 | onnx_result = (onnx_result - min_value) / (max_value - min_value) 33 | onnx_result *= 255 34 | onnx_result = onnx_result.astype('uint8') 35 | 36 | return onnx_result 37 | 38 | from argparse import ArgumentParser 39 | 40 | if __name__ == "__main__": 41 | from argparse import ArgumentParser 42 | parser = ArgumentParser() 43 | parser.add_argument("--base", type=str) 44 | parser.add_argument("--debug", action='store_true', default=False) 45 | args, _ = parser.parse_known_args() 46 | 47 | onnx_session = onnxruntime.InferenceSession("checkpoints/skyseg.onnx") 48 | base = args.base 49 | img_base = os.path.join(base, "image1") 50 | out_base = os.path.join(base, "debug2") 51 | depth_base = os.path.join(base, "depth1") 52 | 53 | if not os.path.exists(out_base): 54 | os.mkdir(out_base) 55 | 56 | count = 0 57 | 58 | for img in tqdm(os.listdir(img_base)): 59 | 60 | # image = cv2.imread("../zeb/eth3do/playground-DSC_0585.png") 61 | image = cv2.imread(os.path.join(img_base, img)) 62 | H, W = image.shape[:2] 63 | 64 | while(image.shape[0] >= 640 and image.shape[1] >= 640): 65 | image = cv2.pyrDown(image) 66 | result_map = run_inference(onnx_session,[320,320],image) 67 | result_map = imutils.resize(result_map, height=H, width=W) 68 | 69 | # cv2.imwrite(os.path.join(out_base, img), result_map) 70 | depth = np.load(os.path.join(depth_base, img[:-4]+".npy")) 71 | 72 | disp_ = 1 / (depth + 1) 73 | disp2_ = 1 / (depth + 1) 74 | 75 | depth[result_map > 255 * 0.5] = 0.0 76 | np.save(os.path.join(depth_base, img[:-4]+".npy"), depth) 77 | 78 | if args.debug and count < 100: 79 | 80 | disp2_[result_map > 255 * 0.5] = 0.0 81 | 82 | disp_ = ((disp_ - disp_.min()) / (disp_.max() - disp_.min()) * 255).astype(np.uint8) 83 | disp2_ = ((disp2_ - disp2_.min()) / (disp2_.max() - disp2_.min()) * 255).astype(np.uint8) 84 | 85 | cv2.imwrite(os.path.join(out_base, img), np.hstack([disp_, disp2_, result_map])) 86 | 87 | # import IPython 88 | # IPython.embed() 89 | # exit() 90 | count += 1 -------------------------------------------------------------------------------- /romatch/utils/transforms.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | import numpy as np 3 | import torch 4 | import kornia.augmentation as K 5 | from kornia.geometry.transform import warp_perspective 6 | 7 | # Adapted from Kornia 8 | class GeometricSequential: 9 | def __init__(self, *transforms, align_corners=True) -> None: 10 | self.transforms = transforms 11 | self.align_corners = align_corners 12 | 13 | def __call__(self, x, mode="bilinear"): 14 | b, c, h, w = x.shape 15 | M = torch.eye(3, device=x.device)[None].expand(b, 3, 3) 16 | for t in self.transforms: 17 | if np.random.rand() < t.p: 18 | M = M.matmul( 19 | t.compute_transformation(x, t.generate_parameters((b, c, h, w)), None) 20 | ) 21 | return ( 22 | warp_perspective( 23 | x, M, dsize=(h, w), mode=mode, align_corners=self.align_corners 24 | ), 25 | M, 26 | ) 27 | 28 | def apply_transform(self, x, M, mode="bilinear"): 29 | b, c, h, w = x.shape 30 | return warp_perspective( 31 | x, M, dsize=(h, w), align_corners=self.align_corners, mode=mode 32 | ) 33 | 34 | 35 | class RandomPerspective(K.RandomPerspective): 36 | def generate_parameters(self, batch_shape: torch.Size) -> Dict[str, torch.Tensor]: 37 | distortion_scale = torch.as_tensor( 38 | self.distortion_scale, device=self._device, dtype=self._dtype 39 | ) 40 | return self.random_perspective_generator( 41 | batch_shape[0], 42 | batch_shape[-2], 43 | batch_shape[-1], 44 | distortion_scale, 45 | self.same_on_batch, 46 | self.device, 47 | self.dtype, 48 | ) 49 | 50 | def random_perspective_generator( 51 | self, 52 | batch_size: int, 53 | height: int, 54 | width: int, 55 | distortion_scale: torch.Tensor, 56 | same_on_batch: bool = False, 57 | device: torch.device = torch.device("cpu"), 58 | dtype: torch.dtype = torch.float32, 59 | ) -> Dict[str, torch.Tensor]: 60 | r"""Get parameters for ``perspective`` for a random perspective transform. 61 | 62 | Args: 63 | batch_size (int): the tensor batch size. 64 | height (int) : height of the image. 65 | width (int): width of the image. 66 | distortion_scale (torch.Tensor): it controls the degree of distortion and ranges from 0 to 1. 67 | same_on_batch (bool): apply the same transformation across the batch. Default: False. 68 | device (torch.device): the device on which the random numbers will be generated. Default: cpu. 69 | dtype (torch.dtype): the data type of the generated random numbers. Default: float32. 70 | 71 | Returns: 72 | params Dict[str, torch.Tensor]: parameters to be passed for transformation. 73 | - start_points (torch.Tensor): element-wise perspective source areas with a shape of (B, 4, 2). 74 | - end_points (torch.Tensor): element-wise perspective target areas with a shape of (B, 4, 2). 75 | 76 | Note: 77 | The generated random numbers are not reproducible across different devices and dtypes. 78 | """ 79 | if not (distortion_scale.dim() == 0 and 0 <= distortion_scale <= 1): 80 | raise AssertionError( 81 | f"'distortion_scale' must be a scalar within [0, 1]. Got {distortion_scale}." 82 | ) 83 | if not ( 84 | type(height) is int and height > 0 and type(width) is int and width > 0 85 | ): 86 | raise AssertionError( 87 | f"'height' and 'width' must be integers. Got {height}, {width}." 88 | ) 89 | 90 | start_points: torch.Tensor = torch.tensor( 91 | [[[0.0, 0], [width - 1, 0], [width - 1, height - 1], [0, height - 1]]], 92 | device=distortion_scale.device, 93 | dtype=distortion_scale.dtype, 94 | ).expand(batch_size, -1, -1) 95 | 96 | # generate random offset not larger than half of the image 97 | fx = distortion_scale * width / 2 98 | fy = distortion_scale * height / 2 99 | 100 | factor = torch.stack([fx, fy], dim=0).view(-1, 1, 2) 101 | offset = (torch.rand_like(start_points) - 0.5) * 2 102 | end_points = start_points + factor * offset 103 | 104 | return dict(start_points=start_points, end_points=end_points) 105 | 106 | 107 | 108 | class RandomErasing: 109 | def __init__(self, p = 0., scale = 0.) -> None: 110 | self.p = p 111 | self.scale = scale 112 | self.random_eraser = K.RandomErasing(scale = (0.02, scale), p = p) 113 | def __call__(self, image, depth): 114 | if self.p > 0: 115 | image = self.random_eraser(image) 116 | depth = self.random_eraser(depth, params=self.random_eraser._params) 117 | return image, depth 118 | -------------------------------------------------------------------------------- /vis_feats.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import torch 4 | import torchvision.transforms as T 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | from PIL import Image 8 | from sklearn.decomposition import PCA 9 | import argparse 10 | from romatch.models.transformer import vit_base 11 | 12 | device = "cuda" if torch.cuda.is_available() else "cpu" 13 | 14 | # -------------------- 可视化函数 -------------------- 15 | def vis_feat_map_batch(features_list, patch_h, patch_w, resize_hw=(560, 560)): 16 | all_feats = np.concatenate([f.reshape(patch_h * patch_w, -1) for f in features_list], axis=0) 17 | pca = PCA(n_components=3) 18 | pca.fit(all_feats) 19 | 20 | images = [] 21 | for features in features_list: 22 | f = features.reshape(patch_h * patch_w, -1) 23 | pca_feats = pca.transform(f) 24 | pca_feats = (pca_feats - pca_feats.mean(0)) / (pca_feats.std(0) + 1e-5) 25 | pca_feats = np.clip(pca_feats * 0.5 + 0.5, 0, 1) 26 | img = pca_feats.reshape(patch_h, patch_w, 3) 27 | img = (img * 255).astype(np.uint8) 28 | img = Image.fromarray(img).resize(resize_hw, Image.BICUBIC) 29 | images.append(img) 30 | return images 31 | 32 | def save_combined_visualization( 33 | feats_dino, feats_fit3d, feats_L2M, 34 | patch_h, patch_w, base_name, save_dir, original_images 35 | ): 36 | os.makedirs(save_dir, exist_ok=True) 37 | 38 | imgs_dino = vis_feat_map_batch(feats_dino, patch_h, patch_w) 39 | imgs_fit3d = vis_feat_map_batch(feats_fit3d, patch_h, patch_w) 40 | imgs_L2M = vis_feat_map_batch(feats_L2M, patch_h, patch_w) 41 | 42 | # 拼图:每行一个图,共两行四列 43 | fig, axs = plt.subplots(2, 4, figsize=(16, 8)) 44 | titles = ["Original", "DINOv2", "Fit3D", "L2M (Ours)"] 45 | for i in range(2): # row 46 | row_imgs = [original_images[i], imgs_dino[i], imgs_fit3d[i], imgs_L2M[i]] 47 | for j in range(4): 48 | axs[i, j].imshow(row_imgs[j]) 49 | axs[i, j].set_title(titles[j], fontsize=12) 50 | axs[i, j].axis("off") 51 | plt.tight_layout() 52 | plt.savefig(os.path.join(save_dir, f"{base_name}_compare.png")) 53 | plt.close() 54 | 55 | # -------------------- 特征提取函数 -------------------- 56 | def extract_features(model, image_tensor): 57 | with torch.no_grad(): 58 | return model.forward_features(image_tensor)["x_norm_patchtokens"].squeeze(0).cpu().numpy() 59 | 60 | # -------------------- 主脚本 -------------------- 61 | def main(args): 62 | os.makedirs(args.save_dir, exist_ok=True) 63 | 64 | patch_h, patch_w = 37, 37 65 | img_size = patch_h * 14 66 | 67 | transform = T.Compose([ 68 | T.Resize((img_size, img_size)), 69 | T.CenterCrop((img_size, img_size)), 70 | T.ToTensor(), 71 | T.Normalize(mean=(0.485, 0.456, 0.406), 72 | std=(0.229, 0.224, 0.225)), 73 | ]) 74 | 75 | # 初始化模型 76 | vit_kwargs = dict( 77 | img_size=img_size, 78 | patch_size=14, 79 | init_values=1.0, 80 | ffn_layer="mlp", 81 | block_chunks=0 82 | ) 83 | 84 | def load_model(ckpt_path): 85 | model = vit_base(**vit_kwargs).eval().to(device) 86 | raw = torch.load(ckpt_path, map_location="cpu") 87 | if "model" in raw: 88 | raw = raw["model"] 89 | ckpt = {k.replace("model.", ""): v for k, v in raw.items()} 90 | model.load_state_dict(ckpt, strict=False) 91 | return model 92 | 93 | dino = load_model(args.ckpt_dino) 94 | fit3d = load_model(args.ckpt_fit3d) 95 | L2M = load_model(args.ckpt_L2M) 96 | 97 | feats_dino, feats_fit3d, feats_L2M = [], [], [] 98 | original_images = [] 99 | 100 | for img_path in args.img_paths: 101 | img = Image.open(img_path).convert("RGB") 102 | x = transform(img).unsqueeze(0).to(device) 103 | 104 | feats_dino.append(extract_features(dino, x)) 105 | feats_fit3d.append(extract_features(fit3d, x)) 106 | feats_L2M.append(extract_features(L2M, x)) 107 | original_images.append(img) 108 | 109 | base_name = "multi" if len(args.img_paths) > 1 else os.path.splitext(os.path.basename(args.img_paths[0]))[0] 110 | save_combined_visualization( 111 | feats_dino, feats_fit3d, feats_L2M, 112 | patch_h, patch_w, base_name, args.save_dir, original_images 113 | ) 114 | 115 | print(f"Saved 2-row comparison to {os.path.join(args.save_dir, f'{base_name}_compare.png')}") 116 | 117 | if __name__ == "__main__": 118 | parser = argparse.ArgumentParser() 119 | parser.add_argument( 120 | "--img_paths", 121 | nargs="+", 122 | default=[ 123 | "assets/sacre_coeur_A.jpg", 124 | "assets/sacre_coeur_B.jpg" 125 | ], 126 | help="List of image paths" 127 | ) 128 | parser.add_argument( 129 | "--ckpt_fit3d", 130 | default="ckpts/fit3d.pth", 131 | help="Original Fit3D checkpoint" 132 | ) 133 | parser.add_argument( 134 | "--ckpt_L2M", 135 | default="ckpts/l2m_vit_base.pth", 136 | help="L2M Fit3D checkpoint" 137 | ) 138 | parser.add_argument( 139 | "--ckpt_dino", 140 | default="ckpts/dinov2.pth", 141 | help="dino checkpoint" 142 | ) 143 | parser.add_argument( 144 | "--save_dir", 145 | default="outputs_vis_feat", 146 | help="Directory to save visualizations" 147 | ) 148 | args = parser.parse_args() 149 | main(args) 150 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Learning Dense Feature Matching via Lifting Single 2D Image to 3D Space 2 | 3 | ![L2M Logo](https://img.shields.io/badge/L2M-Official%20Implementation-blue) 4 | 5 | Welcome to the **L2M** repository! This is the official implementation of our ICCV'25 paper titled "Learning Dense Feature Matching via Lifting Single 2D Image to 3D Space". 6 | 7 | *Accepted to ICCV 2025 Conference* 8 | 9 | --- 10 | 11 | > 🚨 **Important Notice:** 12 | > This repository is the **official implementation** of the ICCV 2025 paper authored by Sharpiless. 13 | > 14 | > Please be aware that the repository at [https://github.com/chelseaaxy/L2M](https://github.com/chelseaaxy/L2M) is **NOT an official implementation** and is not authorized by the original authors. 15 | > 16 | > Always refer to this repository for the authentic and up-to-date code. 17 | 18 | 19 | ## 🧠 Overview 20 | 21 | **Lift to Match (L2M)** is a two-stage framework for **dense feature matching** that lifts 2D images into 3D space to enhance feature generalization and robustness. Unlike traditional methods that depend on multi-view image pairs, L2M is trained on large-scale, diverse single-view image collections. 22 | 23 | - **Stage 1:** Learn a **3D-aware ViT-based encoder** using multi-view image synthesis and 3D Gaussian feature representation. 24 | - **Stage 2:** Learn a **feature decoder** through novel-view rendering and synthetic data, enabling robust matching across diverse scenarios. 25 | 26 | > 🚧 Code is still under construction. 27 | 28 | --- 29 | 30 | ## 🧪 Feature Visualization 31 | 32 | We compare the 3D-aware ViT encoder from L2M (Stage 1) with other recent methods: 33 | 34 | - **DINOv2**: Learning Robust Visual Features without Supervision 35 | - **FiT3D**: Improving 2D Feature Representations by 3D-Aware Fine-Tuning 36 | - **Ours: L2M Encoder** 37 | 38 | You can download them from the [Releases](https://github.com/Sharpiless/L2M/releases/tag/checkpoints) page. 39 | 40 |
41 | 42 |
43 |
44 | 45 |
46 | 47 |
48 |
49 | 50 | --- 51 | 52 | To get the results, make sure your checkpoints and image files are in the correct paths, then run: 53 | ``` 54 | python vis_feats.py \ 55 | --img_paths assets/sacre_coeur_A.jpg assets/sacre_coeur_B.jpg \ 56 | --ckpt_dino ckpts/dinov2.pth \ 57 | --ckpt_fit3d ckpts/fit3d.pth \ 58 | --ckpt_L2M ckpts/l2m_vit_base.pth \ 59 | --save_dir outputs_vis_feat 60 | ``` 61 | 62 | ## 🏗️ Data Generation 63 | 64 | To enable training from single-view images, we simulate diverse multi-view observations and their corresponding dense correspondence labels in a fully automatic manner. 65 | 66 | #### Stage 2.1: Novel View Synthesis 67 | We lift a single-view image to a coarse 3D structure and then render novel views from different camera poses. These synthesized multi-view images are used to supervise the feature encoder with dense matching consistency. 68 | 69 | Run the following to generate novel-view images with ground-truth dense correspondences: 70 | ``` 71 | python get_data.py \ 72 | --output_path [PATH-to-SAVE] \ 73 | --data_path [PATH-to-IMAGES] \ 74 | --disp_path [PATH-to-MONO-DEPTH] 75 | ``` 76 | 77 | This code provides an example on novel view generation with dense matching ground truth. 78 | 79 | The disp_path should contain grayscale disparity maps predicted by Depth Anything V2 or another monocular depth estimator. 80 | 81 | Below are examples of synthesized novel views with ground-truth dense correspondences, generated in Stage 2.1: 82 | 83 |

84 | 85 | ![test_000002809](https://github.com/user-attachments/assets/a9c62860-b153-40ab-95cb-fa14cb59490c) 86 | 87 | 88 | These demonstrate both the geometric diversity and high-quality pixel-level correspondence labels used for supervision. 89 | 90 | For novel-view inpainting, we also provide a better inpainting model fine-tuned from Stable-Diffusion-2.0-Inpainting: 91 | 92 | ``` 93 | from diffusers import StableDiffusionInpaintPipeline 94 | import torch 95 | from diffusers.utils import load_image, make_image_grid 96 | import PIL 97 | 98 | # 指定模型文件路径 99 | model_path = "Liangyingping/L2M-Inpainting" # 替换为你自己的模型路径 100 | 101 | # 加载模型 102 | pipe = StableDiffusionInpaintPipeline.from_pretrained( 103 | model_path, torch_dtype=torch.float16 104 | ) 105 | pipe.to("cuda") # 如果有 GPU,可以将模型加载到 GPU 上 106 | 107 | init_image = load_image("assets/debug_masked_image.png") 108 | mask_image = load_image("assets/debug_mask.png") 109 | W, H = init_image.size 110 | 111 | prompt = "a photo of a person" 112 | image = pipe( 113 | prompt=prompt, 114 | image=init_image, 115 | mask_image=mask_image, 116 | h=512, w=512 117 | ).images[0].resize((W, H)) 118 | 119 | print(image.size, init_image.size) 120 | 121 | image2save = make_image_grid([init_image, mask_image, image], rows=1, cols=3) 122 | image2save.save("image2save_ours.png") 123 | ``` 124 | 125 | Or you can manually download the model from [hugging-face](https://huggingface.co/Liangyingping/L2M-Inpainting). 126 | 127 | novel-view-sup 128 | 129 | novel-view-mpi 130 | 131 | #### Stage 2.2: Relighting for Appearance Diversity 132 | To improve feature robustness under varying lighting conditions, we apply a physics-inspired relighting pipeline to the synthesized 3D scenes. 133 | 134 | Run the following to generate relit image pairs for training the decoder: 135 | ``` 136 | python relight.py 137 | ``` 138 | All outputs will be saved under the configured output directory, including original view, novel views, and their camera metrics with dense depth. 139 | 140 | demo-data 141 | 142 | 143 | #### Stage 2.3: Sky Masking (Optional) 144 | 145 | If desired, you can run sky_seg.py to mask out sky regions, which are typically textureless and not useful for matching. This can help reduce noise and focus training on geometrically meaningful regions. 146 | 147 | ``` 148 | python sky_seg.py 149 | ``` 150 | 151 | ![ADE_train_00000971](https://github.com/user-attachments/assets/ef34c52c-7bff-4be6-94dd-9aeedeef0f60) 152 | 153 | 154 | ## 🙋‍♂️ Acknowledgements 155 | 156 | We build upon recent advances in [ROMA](https://github.com/Parskatt/RoMa), [GIM](https://github.com/xuelunshen/gim), and [FiT3D](https://github.com/ywyue/FiT3D). 157 | -------------------------------------------------------------------------------- /romatch/models/transformer/layers/block.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # References: 8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py 10 | 11 | import logging 12 | from typing import Callable, List, Any, Tuple, Dict 13 | 14 | import torch 15 | from torch import nn, Tensor 16 | 17 | from .attention import Attention, MemEffAttention 18 | from .drop_path import DropPath 19 | from .layer_scale import LayerScale 20 | from .mlp import Mlp 21 | 22 | 23 | logger = logging.getLogger("dinov2") 24 | 25 | 26 | try: 27 | from xformers.ops import fmha 28 | from xformers.ops import scaled_index_add, index_select_cat 29 | 30 | XFORMERS_AVAILABLE = True 31 | except ImportError: 32 | # logger.warning("xFormers not available") 33 | XFORMERS_AVAILABLE = False 34 | 35 | 36 | class Block(nn.Module): 37 | def __init__( 38 | self, 39 | dim: int, 40 | num_heads: int, 41 | mlp_ratio: float = 4.0, 42 | qkv_bias: bool = False, 43 | proj_bias: bool = True, 44 | ffn_bias: bool = True, 45 | drop: float = 0.0, 46 | attn_drop: float = 0.0, 47 | init_values=None, 48 | drop_path: float = 0.0, 49 | act_layer: Callable[..., nn.Module] = nn.GELU, 50 | norm_layer: Callable[..., nn.Module] = nn.LayerNorm, 51 | attn_class: Callable[..., nn.Module] = Attention, 52 | ffn_layer: Callable[..., nn.Module] = Mlp, 53 | ) -> None: 54 | super().__init__() 55 | # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}") 56 | self.norm1 = norm_layer(dim) 57 | self.attn = attn_class( 58 | dim, 59 | num_heads=num_heads, 60 | qkv_bias=qkv_bias, 61 | proj_bias=proj_bias, 62 | attn_drop=attn_drop, 63 | proj_drop=drop, 64 | ) 65 | self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() 66 | self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() 67 | 68 | self.norm2 = norm_layer(dim) 69 | mlp_hidden_dim = int(dim * mlp_ratio) 70 | self.mlp = ffn_layer( 71 | in_features=dim, 72 | hidden_features=mlp_hidden_dim, 73 | act_layer=act_layer, 74 | drop=drop, 75 | bias=ffn_bias, 76 | ) 77 | self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() 78 | self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() 79 | 80 | self.sample_drop_ratio = drop_path 81 | 82 | def forward(self, x: Tensor) -> Tensor: 83 | def attn_residual_func(x: Tensor) -> Tensor: 84 | return self.ls1(self.attn(self.norm1(x))) 85 | 86 | def ffn_residual_func(x: Tensor) -> Tensor: 87 | return self.ls2(self.mlp(self.norm2(x))) 88 | 89 | if self.training and self.sample_drop_ratio > 0.1: 90 | # the overhead is compensated only for a drop path rate larger than 0.1 91 | x = drop_add_residual_stochastic_depth( 92 | x, 93 | residual_func=attn_residual_func, 94 | sample_drop_ratio=self.sample_drop_ratio, 95 | ) 96 | x = drop_add_residual_stochastic_depth( 97 | x, 98 | residual_func=ffn_residual_func, 99 | sample_drop_ratio=self.sample_drop_ratio, 100 | ) 101 | elif self.training and self.sample_drop_ratio > 0.0: 102 | x = x + self.drop_path1(attn_residual_func(x)) 103 | x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2 104 | else: 105 | x = x + attn_residual_func(x) 106 | x = x + ffn_residual_func(x) 107 | return x 108 | 109 | 110 | def drop_add_residual_stochastic_depth( 111 | x: Tensor, 112 | residual_func: Callable[[Tensor], Tensor], 113 | sample_drop_ratio: float = 0.0, 114 | ) -> Tensor: 115 | # 1) extract subset using permutation 116 | b, n, d = x.shape 117 | sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) 118 | brange = (torch.randperm(b, device=x.device))[:sample_subset_size] 119 | x_subset = x[brange] 120 | 121 | # 2) apply residual_func to get residual 122 | residual = residual_func(x_subset) 123 | 124 | x_flat = x.flatten(1) 125 | residual = residual.flatten(1) 126 | 127 | residual_scale_factor = b / sample_subset_size 128 | 129 | # 3) add the residual 130 | x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) 131 | return x_plus_residual.view_as(x) 132 | 133 | 134 | def get_branges_scales(x, sample_drop_ratio=0.0): 135 | b, n, d = x.shape 136 | sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) 137 | brange = (torch.randperm(b, device=x.device))[:sample_subset_size] 138 | residual_scale_factor = b / sample_subset_size 139 | return brange, residual_scale_factor 140 | 141 | 142 | def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None): 143 | if scaling_vector is None: 144 | x_flat = x.flatten(1) 145 | residual = residual.flatten(1) 146 | x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) 147 | else: 148 | x_plus_residual = scaled_index_add( 149 | x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor 150 | ) 151 | return x_plus_residual 152 | 153 | 154 | attn_bias_cache: Dict[Tuple, Any] = {} 155 | 156 | 157 | def get_attn_bias_and_cat(x_list, branges=None): 158 | """ 159 | this will perform the index select, cat the tensors, and provide the attn_bias from cache 160 | """ 161 | batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list] 162 | all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list)) 163 | if all_shapes not in attn_bias_cache.keys(): 164 | seqlens = [] 165 | for b, x in zip(batch_sizes, x_list): 166 | for _ in range(b): 167 | seqlens.append(x.shape[1]) 168 | attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens) 169 | attn_bias._batch_sizes = batch_sizes 170 | attn_bias_cache[all_shapes] = attn_bias 171 | 172 | if branges is not None: 173 | cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1]) 174 | else: 175 | tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list) 176 | cat_tensors = torch.cat(tensors_bs1, dim=1) 177 | 178 | return attn_bias_cache[all_shapes], cat_tensors 179 | 180 | 181 | def drop_add_residual_stochastic_depth_list( 182 | x_list: List[Tensor], 183 | residual_func: Callable[[Tensor, Any], Tensor], 184 | sample_drop_ratio: float = 0.0, 185 | scaling_vector=None, 186 | ) -> Tensor: 187 | # 1) generate random set of indices for dropping samples in the batch 188 | branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list] 189 | branges = [s[0] for s in branges_scales] 190 | residual_scale_factors = [s[1] for s in branges_scales] 191 | 192 | # 2) get attention bias and index+concat the tensors 193 | attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges) 194 | 195 | # 3) apply residual_func to get residual, and split the result 196 | residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore 197 | 198 | outputs = [] 199 | for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors): 200 | outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x)) 201 | return outputs 202 | 203 | 204 | class NestedTensorBlock(Block): 205 | def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]: 206 | """ 207 | x_list contains a list of tensors to nest together and run 208 | """ 209 | assert isinstance(self.attn, MemEffAttention) 210 | 211 | if self.training and self.sample_drop_ratio > 0.0: 212 | 213 | def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: 214 | return self.attn(self.norm1(x), attn_bias=attn_bias) 215 | 216 | def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: 217 | return self.mlp(self.norm2(x)) 218 | 219 | x_list = drop_add_residual_stochastic_depth_list( 220 | x_list, 221 | residual_func=attn_residual_func, 222 | sample_drop_ratio=self.sample_drop_ratio, 223 | scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None, 224 | ) 225 | x_list = drop_add_residual_stochastic_depth_list( 226 | x_list, 227 | residual_func=ffn_residual_func, 228 | sample_drop_ratio=self.sample_drop_ratio, 229 | scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None, 230 | ) 231 | return x_list 232 | else: 233 | 234 | def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: 235 | return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias)) 236 | 237 | def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: 238 | return self.ls2(self.mlp(self.norm2(x))) 239 | 240 | attn_bias, x = get_attn_bias_and_cat(x_list) 241 | x = x + attn_residual_func(x, attn_bias=attn_bias) 242 | x = x + ffn_residual_func(x) 243 | return attn_bias.split(x) 244 | 245 | def forward(self, x_or_x_list): 246 | if isinstance(x_or_x_list, Tensor): 247 | return super().forward(x_or_x_list) 248 | elif isinstance(x_or_x_list, list): 249 | assert XFORMERS_AVAILABLE, "Please install xFormers for nested tensors usage" 250 | return self.forward_nested(x_or_x_list) 251 | else: 252 | raise AssertionError 253 | -------------------------------------------------------------------------------- /relight.py: -------------------------------------------------------------------------------- 1 | import bpy 2 | import numpy as np 3 | import cv2 4 | import math 5 | import os 6 | import imutils 7 | import scipy.ndimage 8 | from depth_anything_v2.dpt import DepthAnythingV2 9 | 10 | # conda activate pgdvs 11 | 12 | 13 | def resize_and_center_crop(image, disparity): 14 | # 获取图像和视差图的尺寸 15 | h, w = image.shape[:2] 16 | 17 | # 计算最短边的尺寸 18 | shortest_edge = min(h, w) 19 | 20 | # 按最短边缩放 21 | if h < w: 22 | new_h = shortest_edge 23 | new_w = int(shortest_edge * (w / h)) 24 | else: 25 | new_w = shortest_edge 26 | new_h = int(shortest_edge * (h / w)) 27 | 28 | # 缩放图像 29 | image_resized = cv2.resize(image, (new_w, new_h)) 30 | disparity_resized = cv2.resize(disparity, (new_w, new_h)) 31 | 32 | # 计算裁剪区域,使得图像变为正方形 33 | crop_size = min(image_resized.shape[:2]) # 取缩放后图像的最短边作为裁剪大小 34 | start_x = (new_w - crop_size) // 2 35 | start_y = (new_h - crop_size) // 2 36 | 37 | # 裁剪图像和视差图 38 | image_cropped = image_resized[ 39 | start_y : start_y + crop_size, start_x : start_x + crop_size 40 | ] 41 | disparity_cropped = disparity_resized[ 42 | start_y : start_y + crop_size, start_x : start_x + crop_size 43 | ] 44 | 45 | return image_cropped, disparity_cropped 46 | 47 | 48 | def create_mesh(vertices, faces, colors, mesh_name="GeneratedMesh"): 49 | # 创建网格和对象 50 | mesh_data = bpy.data.meshes.new(mesh_name) 51 | mesh_object = bpy.data.objects.new(mesh_name, mesh_data) 52 | bpy.context.collection.objects.link(mesh_object) 53 | 54 | # 设置顶点和面 55 | mesh_data.from_pydata(vertices.tolist(), [], faces.tolist()) 56 | mesh_data.update() 57 | 58 | # 添加顶点颜色数据 59 | if not mesh_data.vertex_colors: 60 | mesh_data.vertex_colors.new(name="Col") 61 | color_layer = mesh_data.vertex_colors["Col"] 62 | 63 | # 设置每个顶点的颜色 64 | for poly in mesh_data.polygons: 65 | for loop_index in poly.loop_indices: 66 | vertex_index = mesh_data.loops[loop_index].vertex_index 67 | color_layer.data[loop_index].color = (*colors[vertex_index], 1.0) # RGBA 68 | 69 | # 添加材质并启用漫反射着色器 70 | if not mesh_object.data.materials: 71 | material = bpy.data.materials.new("MeshMaterial") 72 | material.use_nodes = True # 使用节点着色 73 | mesh_object.data.materials.append(material) 74 | 75 | # 获取材质的节点树 76 | nodes = material.node_tree.nodes 77 | links = material.node_tree.links 78 | 79 | # 删除默认的 Principled BSDF 节点 80 | for node in nodes: 81 | if node.type == "BSDF_PRINCIPLED": 82 | nodes.remove(node) 83 | 84 | # 添加 Diffuse BSDF 节点 85 | diffuse_bsdf = nodes.new(type="ShaderNodeBsdfDiffuse") 86 | 87 | # 创建顶点颜色节点 88 | vertex_color_node = nodes.new("ShaderNodeAttribute") 89 | vertex_color_node.attribute_name = "Col" # 顶点颜色的名称 90 | links.new(vertex_color_node.outputs[0], diffuse_bsdf.inputs["Color"]) 91 | 92 | # 将 Diffuse BSDF 节点连接到材质的 Surface 输入 93 | material_output = nodes.get("Material Output") 94 | links.new(diffuse_bsdf.outputs[0], material_output.inputs["Surface"]) 95 | 96 | return mesh_object 97 | 98 | 99 | def generate_faces(image_width, image_height): 100 | indices = np.arange(image_width * image_height).reshape(image_height, image_width) 101 | lower_left = indices[:-1, :-1].ravel() 102 | lower_right = indices[:-1, 1:].ravel() 103 | upper_left = indices[1:, :-1].ravel() 104 | upper_right = indices[1:, 1:].ravel() 105 | faces = np.column_stack( 106 | ( 107 | np.column_stack((lower_left, lower_right, upper_left)), 108 | np.column_stack((lower_right, upper_right, upper_left)), 109 | ) 110 | ).reshape(-1, 3) 111 | return faces 112 | 113 | 114 | def setup_mesh(image_path, depth_anything_model): 115 | # 加载图像和 disparity 116 | image = cv2.imread(image_path) 117 | 118 | disparity = depth_anything_model.infer_image(image[:, :, ::-1], 518).astype( 119 | np.float32 120 | ) 121 | 122 | disparity = (disparity - disparity.min()) / (disparity.max() - disparity.min()) 123 | 124 | H, W = image.shape[:2] 125 | 126 | print(disparity.shape, image.shape) 127 | 128 | if H > W: 129 | image = imutils.resize(image, width=512) 130 | disparity = imutils.resize(disparity, width=512) 131 | elif H < W: 132 | image = imutils.resize(image, height=512) 133 | disparity = imutils.resize(disparity, height=512) 134 | else: 135 | image = imutils.resize(image, height=512, width=512) 136 | disparity = imutils.resize(disparity, height=512, width=512) 137 | 138 | print(disparity.shape, image.shape) 139 | 140 | disparity = np.clip(disparity, 0.01, 1) 141 | 142 | disparity = scipy.ndimage.gaussian_filter(disparity, sigma=2) 143 | 144 | print(disparity.shape, image.shape) 145 | 146 | cv2.imwrite("image.png", image) 147 | 148 | image_height, image_width = image.shape[:2] 149 | 150 | # 相机内参 151 | focal = 0.58 * image_width 152 | cx = image_width / 2 153 | cy = image_height / 2 154 | fx = focal 155 | fy = focal 156 | 157 | camera_setup = { 158 | "focal": focal, 159 | "fx": fx, 160 | "fy": fy, 161 | "cx": cx, 162 | "cy": cy, 163 | "image_height": image_height, 164 | "image_width": image_width, 165 | } 166 | 167 | # 计算深度 168 | depth = 1.0 / disparity 169 | 170 | depth = depth * (np.random.random() * 0.6 + 0.7) + np.random.random() * 0.2 171 | 172 | print(depth.shape, image.shape) 173 | 174 | subdivision = 2 175 | 176 | # 生成像素网格 177 | u_coords, v_coords = np.meshgrid( 178 | np.arange(image_width * subdivision), np.arange(image_height * subdivision) 179 | ) 180 | 181 | # 将像素坐标转换为相机坐标系 182 | Y = imutils.resize(depth, width=subdivision * 512) 183 | X = (u_coords / subdivision - cx) * Y / fx 184 | Z = (v_coords / subdivision - cy) * Y / fy * -1 185 | 186 | # 合并为点云 187 | points = np.stack((X, Y, Z), axis=-1).reshape(-1, 3) 188 | colors = ( 189 | imutils.resize(image, width=subdivision * 512).reshape(-1, 3)[:, ::-1] / 255.0 190 | ) 191 | 192 | 193 | # 生成三角形网格 194 | faces = generate_faces(image_width * subdivision, image_height * subdivision) 195 | # 创建 Blender 中的网格 196 | mesh_object = create_mesh(points, faces, colors) 197 | 198 | return mesh_object, camera_setup 199 | 200 | 201 | def setup_gpu_rendering(): 202 | bpy.context.scene.render.engine = "BLENDER_EEVEE" 203 | # 获取 Eevee 渲染设置 204 | eevee = bpy.context.scene.eevee 205 | 206 | # 启用一些高级渲染设置 207 | eevee.use_ssr = True # 启用屏幕空间反射 208 | 209 | print("Eevee render settings configured to use GPU.") 210 | 211 | 212 | def setup_scene(camera_setup, mesh_object, output_path): 213 | # 清理现有场景(删除所有物体、灯光、相机等) 214 | 215 | # 确保当前处于对象模式 216 | if bpy.context.object: 217 | bpy.ops.object.mode_set(mode="OBJECT") 218 | 219 | # 将生成的网格添加到场景中 220 | # bpy.context.collection.objects.link(mesh_object) 221 | 222 | for _ in range(np.random.randint(2, 6)): 223 | # 创建并设置光源(这里使用点光源) 224 | light_data = bpy.data.lights.new(name="PointLight", type="POINT") 225 | light_data.energy = 400 + 1600 * np.random.rand() # 光源强度 (1000~3000) 226 | light_data.color = ( 227 | np.random.rand() * 0.5 + 0.5 if np.random.rand() > 0.5 else 0.5, 228 | np.random.rand() * 0.5 + 0.5 if np.random.rand() > 0.5 else 0.5, 229 | np.random.rand() * 0.5 + 0.5 if np.random.rand() > 0.5 else 0.5, 230 | ) # 设置光源颜色 231 | 232 | light_obj = bpy.data.objects.new( 233 | name="PointLightObject", object_data=light_data 234 | ) 235 | light_obj.location = ( 236 | ( 237 | np.random.rand() * 3 238 | if np.random.rand() > 0.5 239 | else -1 * np.random.rand() * 2 240 | ), # -4~4 241 | np.random.randint(1, 3), # [1, 2, 3] 242 | ( 243 | np.random.rand() * 2 244 | if np.random.rand() > 0.5 245 | else -1 * np.random.rand() * 2 246 | ), # -3~3 247 | ) # 设置光源位置 248 | bpy.context.scene.collection.objects.link(light_obj) 249 | print(light_data.color, light_obj.location) 250 | 251 | # 设置环境光(使用节点设置背景) 252 | bpy.context.scene.world.use_nodes = True 253 | bg_node = bpy.context.scene.world.node_tree.nodes["Background"] 254 | bg_node.inputs[1].default_value = 1.6 + np.random.rand() # 增加环境光亮度 255 | 256 | # 创建相机 257 | camera_data = bpy.data.cameras.new(name="Camera") 258 | 259 | camera_data.lens = camera_setup["focal"] / 10 260 | 261 | camera_obj = bpy.data.objects.new(name="CameraObject", object_data=camera_data) 262 | 263 | # 获取相机的传感器宽度(单位:毫米) 264 | sensor_width = camera_obj.data.sensor_width 265 | print(f"Sensor Width: {sensor_width} mm") 266 | 267 | sensor_fit = camera_obj.data.sensor_fit 268 | print(f"Sensor Fit: {sensor_fit}") 269 | 270 | # 假设图像宽度为像素单位 271 | image_width_pixels = camera_setup["image_width"] 272 | 273 | # 计算焦距(毫米) 274 | focal_mm = (camera_setup["focal"] / image_width_pixels) * sensor_width 275 | 276 | camera_data.lens = focal_mm 277 | 278 | # 相机位置和旋转 279 | camera_obj.location = (0, 0, 0) # 确保相机在模型前方 280 | camera_obj.rotation_euler = (math.radians(90), 0, 0) # 将相机对准模型 281 | bpy.context.scene.collection.objects.link(camera_obj) 282 | 283 | # 设置场景中的相机 284 | bpy.context.scene.camera = camera_obj 285 | 286 | # 设置渲染引擎为 Cycles 287 | setup_gpu_rendering() 288 | 289 | # 渲染设置 290 | bpy.context.scene.render.resolution_x = camera_setup[ 291 | "image_width" 292 | ] # 设置渲染图像的分辨率宽度 293 | bpy.context.scene.render.resolution_y = camera_setup[ 294 | "image_height" 295 | ] # 设置渲染图像的分辨率高度 296 | bpy.context.scene.render.film_transparent = True # 设置背景透明,方便后期处理 297 | 298 | # 设置采样数(可以调整) 299 | bpy.context.scene.cycles.samples = ( 300 | 32 # 这里设置了一个较低的采样数,实际使用时可以调整 301 | ) 302 | 303 | # 设置输出图像路径 304 | bpy.context.scene.render.filepath = output_path 305 | 306 | # 渲染当前场景并保存图像 307 | bpy.ops.render.render(write_still=True) 308 | 309 | # 保存渲染结果 310 | bpy.data.images["Render Result"].save_render(filepath=output_path) # 保存渲染结果 311 | print(f"Render saved to {output_path}") 312 | 313 | 314 | if __name__ == "__main__": 315 | 316 | bpy.ops.object.select_all(action="SELECT") 317 | bpy.ops.object.delete(use_global=False) 318 | 319 | import torch 320 | 321 | DEVICE = ( 322 | "cuda" 323 | if torch.cuda.is_available() 324 | else "mps" if torch.backends.mps.is_available() else "cpu" 325 | ) 326 | 327 | model_configs = { 328 | "vits": { 329 | "encoder": "vits", 330 | "features": 64, 331 | "out_channels": [48, 96, 192, 384], 332 | }, 333 | "vitb": { 334 | "encoder": "vitb", 335 | "features": 128, 336 | "out_channels": [96, 192, 384, 768], 337 | }, 338 | "vitl": { 339 | "encoder": "vitl", 340 | "features": 256, 341 | "out_channels": [256, 512, 1024, 1024], 342 | }, 343 | "vitg": { 344 | "encoder": "vitg", 345 | "features": 384, 346 | "out_channels": [1536, 1536, 1536, 1536], 347 | }, 348 | } 349 | 350 | depth_anything = DepthAnythingV2(**model_configs["vitl"]).half() 351 | depth_anything.load_state_dict( 352 | torch.load(f"checkpoints/depth_anything_v2_vitl.pth", map_location="cpu") 353 | ) 354 | 355 | depth_anything = depth_anything.to(DEVICE).eval() 356 | 357 | # 输入参数 358 | image_path = "634.jpg" 359 | rendered_image_path = "634_relight.png" 360 | 361 | mesh_object, camera_setup = setup_mesh(image_path, depth_anything) 362 | 363 | setup_scene(camera_setup, mesh_object, rendered_image_path) 364 | -------------------------------------------------------------------------------- /romatch/models/transformer/dinov2.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # References: 8 | # https://github.com/facebookresearch/dino/blob/main/vision_transformer.py 9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py 10 | 11 | from functools import partial 12 | import math 13 | import logging 14 | from typing import Sequence, Tuple, Union, Callable 15 | 16 | import torch 17 | import torch.nn as nn 18 | import torch.utils.checkpoint 19 | from torch.nn.init import trunc_normal_ 20 | 21 | from .layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block 22 | 23 | 24 | 25 | def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module: 26 | if not depth_first and include_root: 27 | fn(module=module, name=name) 28 | for child_name, child_module in module.named_children(): 29 | child_name = ".".join((name, child_name)) if name else child_name 30 | named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True) 31 | if depth_first and include_root: 32 | fn(module=module, name=name) 33 | return module 34 | 35 | 36 | class BlockChunk(nn.ModuleList): 37 | def forward(self, x): 38 | for b in self: 39 | x = b(x) 40 | return x 41 | 42 | 43 | class DinoVisionTransformer(nn.Module): 44 | def __init__( 45 | self, 46 | img_size=224, 47 | patch_size=16, 48 | in_chans=3, 49 | embed_dim=768, 50 | depth=12, 51 | num_heads=12, 52 | mlp_ratio=4.0, 53 | qkv_bias=True, 54 | ffn_bias=True, 55 | proj_bias=True, 56 | drop_path_rate=0.0, 57 | drop_path_uniform=False, 58 | init_values=None, # for layerscale: None or 0 => no layerscale 59 | embed_layer=PatchEmbed, 60 | act_layer=nn.GELU, 61 | block_fn=Block, 62 | ffn_layer="mlp", 63 | block_chunks=1, 64 | ): 65 | """ 66 | Args: 67 | img_size (int, tuple): input image size 68 | patch_size (int, tuple): patch size 69 | in_chans (int): number of input channels 70 | embed_dim (int): embedding dimension 71 | depth (int): depth of transformer 72 | num_heads (int): number of attention heads 73 | mlp_ratio (int): ratio of mlp hidden dim to embedding dim 74 | qkv_bias (bool): enable bias for qkv if True 75 | proj_bias (bool): enable bias for proj in attn if True 76 | ffn_bias (bool): enable bias for ffn if True 77 | drop_path_rate (float): stochastic depth rate 78 | drop_path_uniform (bool): apply uniform drop rate across blocks 79 | weight_init (str): weight init scheme 80 | init_values (float): layer-scale init values 81 | embed_layer (nn.Module): patch embedding layer 82 | act_layer (nn.Module): MLP activation layer 83 | block_fn (nn.Module): transformer block class 84 | ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity" 85 | block_chunks: (int) split block sequence into block_chunks units for FSDP wrap 86 | """ 87 | super().__init__() 88 | norm_layer = partial(nn.LayerNorm, eps=1e-6) 89 | 90 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 91 | self.num_tokens = 1 92 | self.n_blocks = depth 93 | self.num_heads = num_heads 94 | self.patch_size = patch_size 95 | 96 | self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 97 | num_patches = self.patch_embed.num_patches 98 | 99 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 100 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) 101 | 102 | if drop_path_uniform is True: 103 | dpr = [drop_path_rate] * depth 104 | else: 105 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 106 | 107 | if ffn_layer == "mlp": 108 | ffn_layer = Mlp 109 | elif ffn_layer == "swiglufused" or ffn_layer == "swiglu": 110 | ffn_layer = SwiGLUFFNFused 111 | elif ffn_layer == "identity": 112 | 113 | def f(*args, **kwargs): 114 | return nn.Identity() 115 | 116 | ffn_layer = f 117 | else: 118 | raise NotImplementedError 119 | 120 | blocks_list = [ 121 | block_fn( 122 | dim=embed_dim, 123 | num_heads=num_heads, 124 | mlp_ratio=mlp_ratio, 125 | qkv_bias=qkv_bias, 126 | proj_bias=proj_bias, 127 | ffn_bias=ffn_bias, 128 | drop_path=dpr[i], 129 | norm_layer=norm_layer, 130 | act_layer=act_layer, 131 | ffn_layer=ffn_layer, 132 | init_values=init_values, 133 | ) 134 | for i in range(depth) 135 | ] 136 | if block_chunks > 0: 137 | self.chunked_blocks = True 138 | chunked_blocks = [] 139 | chunksize = depth // block_chunks 140 | for i in range(0, depth, chunksize): 141 | # this is to keep the block index consistent if we chunk the block list 142 | chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize]) 143 | self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks]) 144 | else: 145 | self.chunked_blocks = False 146 | self.blocks = nn.ModuleList(blocks_list) 147 | 148 | self.norm = norm_layer(embed_dim) 149 | self.head = nn.Identity() 150 | 151 | self.mask_token = nn.Parameter(torch.zeros(1, embed_dim)) 152 | 153 | self.init_weights() 154 | for param in self.parameters(): 155 | param.requires_grad = False 156 | 157 | @property 158 | def device(self): 159 | return self.cls_token.device 160 | 161 | def init_weights(self): 162 | trunc_normal_(self.pos_embed, std=0.02) 163 | nn.init.normal_(self.cls_token, std=1e-6) 164 | named_apply(init_weights_vit_timm, self) 165 | 166 | def interpolate_pos_encoding(self, x, w, h): 167 | previous_dtype = x.dtype 168 | npatch = x.shape[1] - 1 169 | N = self.pos_embed.shape[1] - 1 170 | if npatch == N and w == h: 171 | return self.pos_embed 172 | pos_embed = self.pos_embed.float() 173 | class_pos_embed = pos_embed[:, 0] 174 | patch_pos_embed = pos_embed[:, 1:] 175 | dim = x.shape[-1] 176 | w0 = w // self.patch_size 177 | h0 = h // self.patch_size 178 | # we add a small number to avoid floating point error in the interpolation 179 | # see discussion at https://github.com/facebookresearch/dino/issues/8 180 | w0, h0 = w0 + 0.1, h0 + 0.1 181 | 182 | patch_pos_embed = nn.functional.interpolate( 183 | patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), 184 | scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), 185 | mode="bicubic", 186 | ) 187 | 188 | assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1] 189 | patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) 190 | return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype) 191 | 192 | def prepare_tokens_with_masks(self, x, masks=None): 193 | B, nc, w, h = x.shape 194 | x = self.patch_embed(x) 195 | if masks is not None: 196 | x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x) 197 | 198 | x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) 199 | x = x + self.interpolate_pos_encoding(x, w, h) 200 | 201 | return x 202 | 203 | def forward_features_list(self, x_list, masks_list): 204 | x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)] 205 | for blk in self.blocks: 206 | x = blk(x) 207 | 208 | all_x = x 209 | output = [] 210 | for x, masks in zip(all_x, masks_list): 211 | x_norm = self.norm(x) 212 | output.append( 213 | { 214 | "x_norm_clstoken": x_norm[:, 0], 215 | "x_norm_patchtokens": x_norm[:, 1:], 216 | "x_prenorm": x, 217 | "masks": masks, 218 | } 219 | ) 220 | return output 221 | 222 | def forward_features(self, x, masks=None): 223 | if isinstance(x, list): 224 | return self.forward_features_list(x, masks) 225 | 226 | x = self.prepare_tokens_with_masks(x, masks) 227 | 228 | for blk in self.blocks: 229 | x = blk(x) 230 | 231 | x_norm = self.norm(x) 232 | return { 233 | "x_norm_clstoken": x_norm[:, 0], 234 | "x_norm_patchtokens": x_norm[:, 1:], 235 | "x_prenorm": x, 236 | "masks": masks, 237 | } 238 | 239 | def _get_intermediate_layers_not_chunked(self, x, n=1): 240 | x = self.prepare_tokens_with_masks(x) 241 | # If n is an int, take the n last blocks. If it's a list, take them 242 | output, total_block_len = [], len(self.blocks) 243 | blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n 244 | for i, blk in enumerate(self.blocks): 245 | x = blk(x) 246 | if i in blocks_to_take: 247 | output.append(x) 248 | assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" 249 | return output 250 | 251 | def _get_intermediate_layers_chunked(self, x, n=1): 252 | x = self.prepare_tokens_with_masks(x) 253 | output, i, total_block_len = [], 0, len(self.blocks[-1]) 254 | # If n is an int, take the n last blocks. If it's a list, take them 255 | blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n 256 | for block_chunk in self.blocks: 257 | for blk in block_chunk[i:]: # Passing the nn.Identity() 258 | x = blk(x) 259 | if i in blocks_to_take: 260 | output.append(x) 261 | i += 1 262 | assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" 263 | return output 264 | 265 | def get_intermediate_layers( 266 | self, 267 | x: torch.Tensor, 268 | n: Union[int, Sequence] = 1, # Layers or n last layers to take 269 | reshape: bool = False, 270 | return_class_token: bool = False, 271 | norm=True, 272 | ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]: 273 | if self.chunked_blocks: 274 | outputs = self._get_intermediate_layers_chunked(x, n) 275 | else: 276 | outputs = self._get_intermediate_layers_not_chunked(x, n) 277 | if norm: 278 | outputs = [self.norm(out) for out in outputs] 279 | class_tokens = [out[:, 0] for out in outputs] 280 | outputs = [out[:, 1:] for out in outputs] 281 | if reshape: 282 | B, _, w, h = x.shape 283 | outputs = [ 284 | out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous() 285 | for out in outputs 286 | ] 287 | if return_class_token: 288 | return tuple(zip(outputs, class_tokens)) 289 | return tuple(outputs) 290 | 291 | def forward(self, *args, is_training=False, **kwargs): 292 | ret = self.forward_features(*args, **kwargs) 293 | if is_training: 294 | return ret 295 | else: 296 | return self.head(ret["x_norm_clstoken"]) 297 | 298 | 299 | def init_weights_vit_timm(module: nn.Module, name: str = ""): 300 | """ViT weight initialization, original timm impl (for reproducibility)""" 301 | if isinstance(module, nn.Linear): 302 | trunc_normal_(module.weight, std=0.02) 303 | if module.bias is not None: 304 | nn.init.zeros_(module.bias) 305 | 306 | 307 | def vit_small(patch_size=16, **kwargs): 308 | model = DinoVisionTransformer( 309 | patch_size=patch_size, 310 | embed_dim=384, 311 | depth=12, 312 | num_heads=6, 313 | mlp_ratio=4, 314 | block_fn=partial(Block, attn_class=MemEffAttention), 315 | **kwargs, 316 | ) 317 | return model 318 | 319 | 320 | def vit_base(patch_size=16, **kwargs): 321 | model = DinoVisionTransformer( 322 | patch_size=patch_size, 323 | embed_dim=768, 324 | depth=12, 325 | num_heads=12, 326 | mlp_ratio=4, 327 | block_fn=partial(Block, attn_class=MemEffAttention), 328 | **kwargs, 329 | ) 330 | return model 331 | 332 | 333 | def vit_large(patch_size=16, **kwargs): 334 | model = DinoVisionTransformer( 335 | patch_size=patch_size, 336 | embed_dim=1024, 337 | depth=24, 338 | num_heads=16, 339 | mlp_ratio=4, 340 | block_fn=partial(Block, attn_class=MemEffAttention), 341 | **kwargs, 342 | ) 343 | return model 344 | 345 | 346 | def vit_giant2(patch_size=16, **kwargs): 347 | """ 348 | Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64 349 | """ 350 | model = DinoVisionTransformer( 351 | patch_size=patch_size, 352 | embed_dim=1536, 353 | depth=40, 354 | num_heads=24, 355 | mlp_ratio=4, 356 | block_fn=partial(Block, attn_class=MemEffAttention), 357 | **kwargs, 358 | ) 359 | return model -------------------------------------------------------------------------------- /get_data.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import sys 3 | import os 4 | import cv2 5 | import glob 6 | import math 7 | import torch 8 | import torch.nn.functional as F 9 | from torch.utils.data.dataset import Dataset 10 | from torch.utils.data.dataloader import DataLoader, default_collate 11 | from torchvision.utils import save_image 12 | import numpy as np 13 | from random import random 14 | from torchvision import transforms 15 | import imutils 16 | 17 | from warp_utils import ( 18 | RGBDRenderer, 19 | image_to_tensor, 20 | disparity_to_tensor, 21 | transformation_from_parameters, 22 | ) 23 | 24 | 25 | def resize_and_center_crop(image, disparity): 26 | # 获取图像和视差图的尺寸 27 | h, w = image.shape[:2] 28 | 29 | # 计算最短边的尺寸 30 | shortest_edge = min(h, w) 31 | 32 | # 按最短边缩放 33 | if h < w: 34 | new_h = shortest_edge 35 | new_w = int(shortest_edge * (w / h)) 36 | else: 37 | new_w = shortest_edge 38 | new_h = int(shortest_edge * (h / w)) 39 | 40 | # 缩放图像 41 | image_resized = cv2.resize(image, (new_w, new_h)) 42 | disparity_resized = cv2.resize(disparity, (new_w, new_h)) 43 | 44 | # 计算裁剪区域,使得图像变为正方形 45 | crop_size = min(image_resized.shape[:2]) # 取缩放后图像的最短边作为裁剪大小 46 | start_x = (new_w - crop_size) // 2 47 | start_y = (new_h - crop_size) // 2 48 | 49 | # 裁剪图像和视差图 50 | image_cropped = image_resized[ 51 | start_y : start_y + crop_size, start_x : start_x + crop_size 52 | ] 53 | disparity_cropped = disparity_resized[ 54 | start_y : start_y + crop_size, start_x : start_x + crop_size 55 | ] 56 | 57 | return image_cropped, disparity_cropped 58 | 59 | 60 | class WarpBackStage1Dataset(Dataset): 61 | def __init__( 62 | self, 63 | data_root, 64 | disp_root, 65 | width=512, 66 | height=512, 67 | device="cuda", # device of mesh renderer 68 | trans_range={"x": 0.4, "y": 0.4, "z": 0.8, "a": 18, "b": 18, "c": 18}, 69 | # trans_range={"x": -1, "y": -1, "z": -1, "a": -1, "b": -1, "c": -1}, 70 | ): 71 | self.data_root = data_root 72 | self.disp_root = disp_root 73 | 74 | self.renderer = RGBDRenderer(device) 75 | self.width = width 76 | self.height = height 77 | self.device = device 78 | self.trans_range = trans_range 79 | self.image_path_list = [ 80 | os.path.join(data_root, img) for img in os.listdir(data_root) 81 | ] 82 | self.img2tensor = transforms.ToTensor() 83 | 84 | def __len__(self): 85 | return len(self.image_path_list) 86 | 87 | def rand_tensor(self, r, l): 88 | if ( 89 | r < 0 90 | ): # we can set a negtive value in self.trans_range to avoid random transformation 91 | return torch.zeros((l, 1, 1)) 92 | rand = torch.rand((l, 1, 1)) 93 | sign = 2 * (torch.randn_like(rand) > 0).float() - 1 94 | return sign * (r / 2 + r / 2 * rand) 95 | 96 | def get_rand_ext(self, bs): 97 | x, y, z = self.trans_range["x"], self.trans_range["y"], self.trans_range["z"] 98 | a, b, c = self.trans_range["a"], self.trans_range["b"], self.trans_range["c"] 99 | cix = self.rand_tensor(x, bs) 100 | ciy = self.rand_tensor(y, bs) 101 | ciz = self.rand_tensor(z, bs) 102 | 103 | aix = self.rand_tensor(math.pi / a, bs) 104 | aiy = self.rand_tensor(math.pi / b, bs) 105 | aiz = self.rand_tensor(math.pi / c, bs) 106 | 107 | axisangle = torch.cat([aix, aiy, aiz], dim=-1) # [b,1,3] 108 | translation = torch.cat([cix, ciy, ciz], dim=-1) 109 | 110 | cam_ext = transformation_from_parameters(axisangle, translation) # [b,4,4] 111 | cam_ext_inv = torch.inverse(cam_ext) # [b,4,4] 112 | 113 | print(axisangle, translation) 114 | 115 | return cam_ext[:, :-1], cam_ext_inv[:, :-1] 116 | 117 | def __getitem__(self, idx): 118 | image_path = self.image_path_list[idx] 119 | image_name = os.path.splitext(os.path.basename(image_path))[0] 120 | 121 | disp_path = os.path.join(self.disp_root, "%s.npy" % image_name) 122 | 123 | image = cv2.imread(image_path, cv2.IMREAD_COLOR)[:, :, ::-1] # [3,h,w] 124 | disp = np.load(disp_path) # [1,h,w] 125 | 126 | H, W = image.shape[:2] 127 | 128 | if H > W: 129 | image = imutils.resize(image, width=512) 130 | disp = imutils.resize(disp, width=512) 131 | else: 132 | image = imutils.resize(image, height=512) 133 | disp = imutils.resize(disp, height=512) 134 | 135 | image, disp = resize_and_center_crop(image, disp) 136 | 137 | max_d, min_d = disp.max(), disp.min() 138 | disp = (disp - min_d) / (max_d - min_d) 139 | 140 | image = torch.tensor(image).permute(2, 0, 1) / 255 141 | disp = torch.tensor(disp).unsqueeze(0) + 0.001 142 | 143 | self.focal = 0.45 + np.random.random() * 0.3 144 | # set intrinsics 145 | self.K = torch.tensor( 146 | [[self.focal, 0, 0.5], [0, self.focal, 0.5], [0, 0, 1]] 147 | ).to(self.device) 148 | 149 | image = image.to(self.device).unsqueeze(0).float() 150 | disp = disp.to(self.device).unsqueeze(0).float() 151 | rgbd = torch.cat([image, disp], dim=1) # [b,4,h,w] 152 | b = image.shape[0] 153 | 154 | cam_int = self.K.repeat(b, 1, 1) # [b,3,3] 155 | 156 | # warp to a random novel view 157 | mesh = self.renderer.construct_mesh( 158 | rgbd, cam_int, torch.ones_like(disp), normalize_depth=True 159 | ) 160 | cam_ext, cam_ext_inv = self.get_rand_ext(b) # [b,3,4] 161 | cam_ext = cam_ext.to(self.device) 162 | cam_ext_inv = cam_ext_inv.to(self.device) 163 | 164 | warp_image, warp_disp, warp_mask, object_mask = self.renderer.render_mesh( 165 | mesh, cam_int, cam_ext 166 | ) 167 | warp_mask = (warp_mask < 0.5).float() 168 | 169 | warp_image = torch.clip(warp_image, 0, 1) 170 | 171 | cam_int[0, :2, :] *= 512 172 | 173 | return { 174 | "rgb": image, 175 | "disp": disp, 176 | "warp_mask": warp_mask, 177 | "warp_rgb": warp_image, 178 | "warp_disp": warp_disp, 179 | "image_name": image_name, 180 | "cam_int": cam_int[0], 181 | "cam_ext": cam_ext[0], 182 | } 183 | 184 | 185 | def setup_seed(seed): 186 | torch.manual_seed(seed) 187 | torch.cuda.manual_seed_all(seed) 188 | np.random.seed(seed) 189 | torch.backends.cudnn.deterministic = True 190 | 191 | 192 | if __name__ == "__main__": 193 | 194 | from diffusers import StableDiffusionInpaintPipeline 195 | import torch 196 | import torchvision 197 | import matplotlib.pyplot as plt 198 | 199 | def project_point_to_3d(x, y, depth, K): 200 | """根据相机内参和深度值,计算像素在相机坐标系中的 3D 坐标""" 201 | inv_K = torch.linalg.inv(K) 202 | pixel = torch.tensor([x, y, 1.0]).to(K.device) 203 | normalized_coords = inv_K @ pixel * depth 204 | return normalized_coords 205 | 206 | def transform_to_another_camera(point_3d, T): 207 | """根据变换矩阵将 3D 点从相机 1 转到相机 2""" 208 | point_3d_homogeneous = torch.cat([point_3d, torch.tensor([1.0]).to(T.device)]) 209 | transformed_point = T @ point_3d_homogeneous 210 | return transformed_point[:3] 211 | 212 | def project_to_image_plane(point_3d, K): 213 | """将 3D 点投影到图像平面""" 214 | point_2d_homogeneous = K @ point_3d 215 | point_2d = point_2d_homogeneous[:2] / point_2d_homogeneous[2] 216 | return point_2d 217 | 218 | parser = argparse.ArgumentParser( 219 | formatter_class=argparse.ArgumentDefaultsHelpFormatter 220 | ) 221 | parser.add_argument( 222 | "--data_path", 223 | type=str, 224 | default="data/davis/raw_images", 225 | ) 226 | parser.add_argument( 227 | "--disp_path", 228 | type=str, 229 | default="data/davis/disps", 230 | ) 231 | parser.add_argument("--output_path", type=str, default="data/mixed_datasets/l2m_davis") 232 | opt, _ = parser.parse_known_args() 233 | 234 | # 指定模型文件路径 235 | model_path = "stabilityai/stable-diffusion-2-inpainting" 236 | 237 | # 加载模型 238 | pipe = StableDiffusionInpaintPipeline.from_pretrained( 239 | model_path, torch_dtype=torch.float16 240 | ) 241 | pipe.to("cuda") # 如果有 GPU,可以将模型加载到 GPU 上 242 | prompt = "a realistic photo" 243 | 244 | setup_seed(0) 245 | 246 | output = opt.output_path 247 | from tqdm import tqdm 248 | 249 | if not os.path.exists(output): 250 | os.makedirs(output, exist_ok=True) 251 | os.mkdir(os.path.join(output, "image1")) 252 | os.mkdir(os.path.join(output, "image2")) 253 | os.mkdir(os.path.join(output, "depth1")) 254 | os.mkdir(os.path.join(output, "depth2")) 255 | os.mkdir(os.path.join(output, "cams")) 256 | os.mkdir(os.path.join(output, "ext")) 257 | os.mkdir(os.path.join(output, "debug")) 258 | 259 | data = WarpBackStage1Dataset(data_root=opt.data_path, disp_root=opt.disp_path) 260 | 261 | for loop in range(2): 262 | for idx in tqdm(range(len(data))): 263 | 264 | batch = data.__getitem__(idx) 265 | 266 | image, disp = batch["rgb"], batch["disp"] 267 | w_image, w_disp = batch["warp_rgb"], batch["warp_disp"] 268 | warp_mask = batch["warp_mask"] 269 | 270 | w_disp = torch.clip(w_disp, 0.01, 100) 271 | 272 | init_image = torchvision.transforms.functional.to_pil_image(w_image[0]) 273 | mask_image = torchvision.transforms.functional.to_pil_image(warp_mask[0]) 274 | image = torchvision.transforms.functional.to_pil_image(image[0]) 275 | 276 | W, H = init_image.size 277 | 278 | inpaint_image = pipe( 279 | prompt=prompt, image=init_image, mask_image=mask_image, h=512, w=512 280 | ).images[0] 281 | 282 | image.save(os.path.join(output, "image1", batch["image_name"] + ".png")) 283 | inpaint_image.save( 284 | os.path.join(output, "image2", batch["image_name"] + ".png") 285 | ) 286 | 287 | np.save( 288 | os.path.join(output, "depth1", batch["image_name"] + ".npy"), 289 | 1 / disp.squeeze().cpu().numpy(), 290 | ) 291 | np.save( 292 | os.path.join(output, "depth2", batch["image_name"] + ".png"), 293 | 1 / w_disp.squeeze().cpu().numpy(), 294 | ) 295 | 296 | w_depth = 1 / (w_disp + 1e-4) * (1 - warp_mask) 297 | 298 | cam_int = batch["cam_int"].cpu().numpy() 299 | cam_ext = batch["cam_ext"].cpu().numpy() 300 | 301 | cam_ext = np.concatenate( 302 | [cam_ext, np.array([[0.0000, 0.0000, 0.0000, 1.0000]])], 0 303 | ) 304 | 305 | np.save(os.path.join(output, "cams", batch["image_name"] + ".npy"), cam_int) 306 | np.save(os.path.join(output, "ext", batch["image_name"] + ".png"), cam_ext) 307 | 308 | # 可视化前20个生成的数据 309 | if idx > 19: 310 | continue 311 | 312 | im_A_depth = 1 / disp.squeeze() 313 | im_B_depth = 1 / w_disp.squeeze() 314 | 315 | K1 = batch["cam_int"].float() 316 | K2 = batch["cam_int"].float() 317 | T_1to2 = torch.tensor(cam_ext).cuda().float() 318 | 319 | im_A_cv = np.array(image) 320 | im_B_cv = np.array(inpaint_image) 321 | 322 | # 拼接图像用于显示 323 | im_combined = np.hstack( 324 | (im_A_cv.astype(np.uint8), im_B_cv.astype(np.uint8)) 325 | ) 326 | 327 | # 选择一部分像素来计算映射关系 328 | matches_A = [] 329 | matches_B = [] 330 | 331 | # 选择具有非零深度的像素点 332 | for y in range(0, im_A_depth.shape[0], 10): # 步长为10,减少计算量 333 | for x in range(0, im_A_depth.shape[1], 10): 334 | depth_A = im_A_depth[y, x].item() 335 | if depth_A > 0: # 只处理有深度信息的像素 336 | # 计算相机1中的3D坐标 337 | point_3d_A = project_point_to_3d(x, y, depth_A, K1) 338 | # 转换到相机2坐标系 339 | point_3d_B = transform_to_another_camera( 340 | point_3d_A.float(), T_1to2 341 | ) 342 | # 投影到相机2的图像平面 343 | point_2d_B = project_to_image_plane(point_3d_B, K2) 344 | 345 | # 将2D匹配点加入列表 346 | matches_A.append((x, y)) 347 | matches_B.append((point_2d_B[0].item(), point_2d_B[1].item())) 348 | 349 | 350 | # 转换为 numpy 数组以便绘图 351 | matches_A = np.array(matches_A) 352 | matches_B = np.array(matches_B) 353 | 354 | H, W = im_combined.shape[:2] 355 | 356 | selected_index = np.random.choice(range(len(matches_A)), 20, replace=False) 357 | 358 | # 绘制匹配点及连接线 359 | for i in selected_index: 360 | 361 | # 在合并图像中绘制匹配点 362 | x_A = int(matches_A[i, 0]) 363 | y_A = int(matches_A[i, 1]) 364 | x_B = int(matches_B[i, 0]) + im_A_cv.shape[1] # 加上图像A的宽度偏移 365 | y_B = int(matches_B[i, 1]) 366 | 367 | if y_B > H or x_B > W or y_B < 0: 368 | continue 369 | 370 | # 绘制匹配点 371 | cv2.circle( 372 | im_combined, (x_A, y_A), 5, (0, 255, 0), -1 373 | ) # 图像A上的匹配点 374 | cv2.circle( 375 | im_combined, (x_B, y_B), 5, (0, 0, 255), -1 376 | ) # 图像B上的匹配点 377 | 378 | # 绘制匹配点之间的连线 379 | cv2.line(im_combined, (x_A, y_A), (x_B, y_B), (255, 0, 0), 1) 380 | 381 | # 显示拼接后的图像 382 | plt.figure(figsize=(12, 6)) 383 | plt.imshow(im_combined) 384 | plt.title("Image A and B with Matches") 385 | plt.axis("off") 386 | plt.savefig(os.path.join(output, "debug", batch["image_name"] + ".png")) -------------------------------------------------------------------------------- /warp_utils.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | from PIL import Image 3 | import torch 4 | import torch.nn.functional as F 5 | from torchvision.utils import save_image 6 | from torchvision import transforms 7 | from pytorch3d.renderer.mesh import rasterize_meshes 8 | from pytorch3d.structures import Meshes 9 | from pytorch3d.ops import interpolate_face_attributes 10 | import numpy as np 11 | from functools import reduce 12 | 13 | def vis_depth_discontinuity( 14 | depth, depth_threshold, vis_diff=False, label=False, mask=None 15 | ): 16 | if label == False: 17 | disp = 1.0 / depth 18 | u_diff = (disp[1:, :] - disp[:-1, :])[:-1, 1:-1] 19 | b_diff = (disp[:-1, :] - disp[1:, :])[1:, 1:-1] 20 | l_diff = (disp[:, 1:] - disp[:, :-1])[1:-1, :-1] 21 | r_diff = (disp[:, :-1] - disp[:, 1:])[1:-1, 1:] 22 | if mask is not None: 23 | u_mask = (mask[1:, :] * mask[:-1, :])[:-1, 1:-1] 24 | b_mask = (mask[:-1, :] * mask[1:, :])[1:, 1:-1] 25 | l_mask = (mask[:, 1:] * mask[:, :-1])[1:-1, :-1] 26 | r_mask = (mask[:, :-1] * mask[:, 1:])[1:-1, 1:] 27 | u_diff = u_diff * u_mask 28 | b_diff = b_diff * b_mask 29 | l_diff = l_diff * l_mask 30 | r_diff = r_diff * r_mask 31 | u_over = (np.abs(u_diff) > depth_threshold).astype(np.float32) 32 | b_over = (np.abs(b_diff) > depth_threshold).astype(np.float32) 33 | l_over = (np.abs(l_diff) > depth_threshold).astype(np.float32) 34 | r_over = (np.abs(r_diff) > depth_threshold).astype(np.float32) 35 | else: 36 | disp = depth 37 | u_diff = (disp[1:, :] * disp[:-1, :])[:-1, 1:-1] 38 | b_diff = (disp[:-1, :] * disp[1:, :])[1:, 1:-1] 39 | l_diff = (disp[:, 1:] * disp[:, :-1])[1:-1, :-1] 40 | r_diff = (disp[:, :-1] * disp[:, 1:])[1:-1, 1:] 41 | if mask is not None: 42 | u_mask = (mask[1:, :] * mask[:-1, :])[:-1, 1:-1] 43 | b_mask = (mask[:-1, :] * mask[1:, :])[1:, 1:-1] 44 | l_mask = (mask[:, 1:] * mask[:, :-1])[1:-1, :-1] 45 | r_mask = (mask[:, :-1] * mask[:, 1:])[1:-1, 1:] 46 | u_diff = u_diff * u_mask 47 | b_diff = b_diff * b_mask 48 | l_diff = l_diff * l_mask 49 | r_diff = r_diff * r_mask 50 | u_over = (np.abs(u_diff) > 0).astype(np.float32) 51 | b_over = (np.abs(b_diff) > 0).astype(np.float32) 52 | l_over = (np.abs(l_diff) > 0).astype(np.float32) 53 | r_over = (np.abs(r_diff) > 0).astype(np.float32) 54 | u_over = np.pad(u_over, 1, mode="constant") 55 | b_over = np.pad(b_over, 1, mode="constant") 56 | l_over = np.pad(l_over, 1, mode="constant") 57 | r_over = np.pad(r_over, 1, mode="constant") 58 | u_diff = np.pad(u_diff, 1, mode="constant") 59 | b_diff = np.pad(b_diff, 1, mode="constant") 60 | l_diff = np.pad(l_diff, 1, mode="constant") 61 | r_diff = np.pad(r_diff, 1, mode="constant") 62 | 63 | if vis_diff: 64 | return [u_over, b_over, l_over, r_over], [u_diff, b_diff, l_diff, r_diff] 65 | else: 66 | return [u_over, b_over, l_over, r_over] 67 | 68 | def rolling_window(a, window, strides): 69 | assert ( 70 | len(a.shape) == len(window) == len(strides) 71 | ), "'a', 'window', 'strides' dimension mismatch" 72 | shape_fn = lambda i, w, s: (a.shape[i] - w) // s + 1 73 | shape = [shape_fn(i, w, s) for i, (w, s) in enumerate(zip(window, strides))] + list( 74 | window 75 | ) 76 | 77 | def acc_shape(i): 78 | if i + 1 >= len(a.shape): 79 | return 1 80 | else: 81 | return reduce(lambda x, y: x * y, a.shape[i + 1 :]) 82 | 83 | _strides = [acc_shape(i) * s * a.itemsize for i, s in enumerate(strides)] + list( 84 | a.strides 85 | ) 86 | 87 | return np.lib.stride_tricks.as_strided(a, shape=shape, strides=_strides) 88 | 89 | def bilateral_filter( 90 | depth, 91 | sigma_s, 92 | sigma_r, 93 | window_size, 94 | discontinuity_map=None, 95 | HR=False, 96 | mask=None, 97 | ): 98 | 99 | midpt = window_size // 2 100 | ax = np.arange(-midpt, midpt + 1.0) 101 | xx, yy = np.meshgrid(ax, ax) 102 | if discontinuity_map is not None: 103 | spatial_term = np.exp(-(xx ** 2 + yy ** 2) / (2.0 * sigma_s ** 2)) 104 | 105 | # padding 106 | depth = depth[1:-1, 1:-1] 107 | depth = np.pad(depth, ((1, 1), (1, 1)), "edge") 108 | pad_depth = np.pad(depth, (midpt, midpt), "edge") 109 | if discontinuity_map is not None: 110 | discontinuity_map = discontinuity_map[1:-1, 1:-1] 111 | discontinuity_map = np.pad(discontinuity_map, ((1, 1), (1, 1)), "edge") 112 | pad_discontinuity_map = np.pad(discontinuity_map, (midpt, midpt), "edge") 113 | pad_discontinuity_hole = 1 - pad_discontinuity_map 114 | # filtering 115 | output = depth.copy() 116 | pad_depth_patches = rolling_window(pad_depth, [window_size, window_size], [1, 1]) 117 | if discontinuity_map is not None: 118 | pad_discontinuity_patches = rolling_window( 119 | pad_discontinuity_map, [window_size, window_size], [1, 1] 120 | ) 121 | pad_discontinuity_hole_patches = rolling_window( 122 | pad_discontinuity_hole, [window_size, window_size], [1, 1] 123 | ) 124 | 125 | if mask is not None: 126 | pad_mask = np.pad(mask, (midpt, midpt), "constant") 127 | pad_mask_patches = rolling_window(pad_mask, [window_size, window_size], [1, 1]) 128 | from itertools import product 129 | 130 | if discontinuity_map is not None: 131 | pH, pW = pad_depth_patches.shape[:2] 132 | for pi in range(pH): 133 | for pj in range(pW): 134 | if mask is not None and mask[pi, pj] == 0: 135 | continue 136 | if discontinuity_map is not None: 137 | if bool(pad_discontinuity_patches[pi, pj].any()) is False: 138 | continue 139 | discontinuity_patch = pad_discontinuity_patches[pi, pj] 140 | discontinuity_holes = pad_discontinuity_hole_patches[pi, pj] 141 | depth_patch = pad_depth_patches[pi, pj] 142 | depth_order = depth_patch.ravel().argsort() 143 | patch_midpt = depth_patch[window_size // 2, window_size // 2] 144 | if discontinuity_map is not None: 145 | coef = discontinuity_holes.astype(np.float32) 146 | if mask is not None: 147 | coef = coef * pad_mask_patches[pi, pj] 148 | else: 149 | range_term = np.exp( 150 | -((depth_patch - patch_midpt) ** 2) / (2.0 * sigma_r ** 2) 151 | ) 152 | coef = spatial_term * range_term 153 | if coef.max() == 0: 154 | output[pi, pj] = patch_midpt 155 | continue 156 | if discontinuity_map is not None and (coef.max() == 0): 157 | output[pi, pj] = patch_midpt 158 | else: 159 | coef = coef / (coef.sum()) 160 | coef_order = coef.ravel()[depth_order] 161 | cum_coef = np.cumsum(coef_order) 162 | ind = np.digitize(0.5, cum_coef) 163 | output[pi, pj] = depth_patch.ravel()[depth_order][ind] 164 | else: 165 | pH, pW = pad_depth_patches.shape[:2] 166 | for pi in range(pH): 167 | for pj in range(pW): 168 | if discontinuity_map is not None: 169 | if ( 170 | pad_discontinuity_patches[pi, pj][ 171 | window_size // 2, window_size // 2 172 | ] 173 | == 1 174 | ): 175 | continue 176 | discontinuity_patch = pad_discontinuity_patches[pi, pj] 177 | discontinuity_holes = 1.0 - discontinuity_patch 178 | depth_patch = pad_depth_patches[pi, pj] 179 | depth_order = depth_patch.ravel().argsort() 180 | patch_midpt = depth_patch[window_size // 2, window_size // 2] 181 | range_term = np.exp( 182 | -((depth_patch - patch_midpt) ** 2) / (2.0 * sigma_r ** 2) 183 | ) 184 | if discontinuity_map is not None: 185 | coef = spatial_term * range_term * discontinuity_holes 186 | else: 187 | coef = spatial_term * range_term 188 | if coef.sum() == 0: 189 | output[pi, pj] = patch_midpt 190 | continue 191 | if discontinuity_map is not None and (coef.sum() == 0): 192 | output[pi, pj] = patch_midpt 193 | else: 194 | coef = coef / (coef.sum()) 195 | coef_order = coef.ravel()[depth_order] 196 | cum_coef = np.cumsum(coef_order) 197 | ind = np.digitize(0.5, cum_coef) 198 | output[pi, pj] = depth_patch.ravel()[depth_order][ind] 199 | 200 | return output 201 | 202 | class RGBDRenderer: 203 | def __init__(self, device): 204 | self.device = device 205 | self.eps = 0.1 206 | self.near_z = 1e-2 207 | self.far_z = 1e4 208 | 209 | def render_mesh(self, mesh_dict, cam_int, cam_ext): 210 | vertice = mesh_dict["vertice"] # [b,h*w,3] 211 | faces = mesh_dict["faces"] # [b,nface,3] 212 | attributes = mesh_dict["attributes"] # [b,h*w,4] 213 | h, w = mesh_dict["size"] 214 | 215 | ############ 216 | # to NDC space 217 | vertice_homo = self.lift_to_homo(vertice) # [b,h*w,4] 218 | # [b,1,3,4] x [b,h*w,4,1] = [b,h*w,3,1] 219 | vertice_world = torch.matmul(cam_ext.unsqueeze(1), vertice_homo[..., None]).squeeze(-1) # [b,h*w,3] 220 | vertice_depth = vertice_world[..., -1:] # [b,h*w,1] 221 | attributes = torch.cat([attributes, vertice_depth], dim=-1) # [b,h*w,5] 222 | # [b,1,3,3] x [b,h*w,3,1] = [b,h*w,3,1] 223 | vertice_world_homo = self.lift_to_homo(vertice_world) 224 | persp = self.get_perspective_from_intrinsic(cam_int) # [b,4,4] 225 | 226 | # [b,1,4,4] x [b,h*w,4,1] = [b,h*w,4,1] 227 | vertice_ndc = torch.matmul(persp.unsqueeze(1), vertice_world_homo[..., None]).squeeze(-1) # [b,h*w,4] 228 | vertice_ndc = vertice_ndc[..., :-1] / vertice_ndc[..., -1:] 229 | vertice_ndc[..., :-1] *= -1 230 | vertice_ndc[..., 0] *= w / h 231 | 232 | ############ 233 | # render 234 | mesh = Meshes(vertice_ndc, faces) 235 | pix_to_face, _, bary_coords, _ = rasterize_meshes(mesh, (h, w), faces_per_pixel=1, blur_radius=1e-6) # [b,h,w,1] [b,h,w,1,3] 236 | 237 | b, nf, _ = faces.size() 238 | faces = faces.reshape(b, nf * 3, 1).repeat(1, 1, 6) # [b,3f,5] 239 | face_attributes = torch.gather(attributes, dim=1, index=faces) # [b,3f,5] 240 | face_attributes = face_attributes.reshape(b * nf, 3, 6) 241 | output = interpolate_face_attributes(pix_to_face, bary_coords, face_attributes) 242 | output = output.squeeze(-2).permute(0, 3, 1, 2) 243 | 244 | render = output[:, :3] 245 | mask = output[:, 3:4] 246 | object_mask = output[:, 4:5] 247 | disparity = torch.reciprocal(output[:, 5:] + 1e-4) 248 | 249 | return render * mask, disparity * mask, mask, object_mask 250 | 251 | def construct_mesh(self, rgbd, cam_int, obj_mask, normalize_depth=False): 252 | b, _, h, w = rgbd.size() 253 | 254 | ############ 255 | # get pixel coordinates 256 | pixel_2d = self.get_screen_pixel_coord(h, w) # [1,h,w,2] 257 | pixel_2d_homo = self.lift_to_homo(pixel_2d) # [1,h,w,3] 258 | 259 | ############ 260 | # project pixels to 3D space 261 | rgbd = rgbd.permute(0, 2, 3, 1) # [b,h,w,4] 262 | disparity = rgbd[..., -1:] # [b,h,w,1] 263 | depth = torch.reciprocal(disparity + + 1e-4) # [b,h,w,1] 264 | obj_mask = obj_mask.permute(0, 2, 3, 1).to(depth.device) 265 | # In [2]: depth.max() 266 | # Out[2]: 3.0927802771530017 267 | 268 | # In [3]: depth.min() 269 | # Out[3]: 1.466965406649775 270 | cam_int_inv = torch.inverse(cam_int) # [b,3,3] 271 | # [b,1,1,3,3] x [1,h,w,3,1] = [b,h,w,3,1] 272 | pixel_3d = torch.matmul(cam_int_inv[:, None, None, :, :], pixel_2d_homo[..., None]).squeeze(-1) # [b,h,w,3] 273 | 274 | pixel_3d = pixel_3d * depth # [b,h,w,3] 275 | vertice = pixel_3d.reshape(b, h * w, 3) # [b,h*w,3] 276 | ############ 277 | # construct faces 278 | faces = self.get_faces(h, w) # [1,nface,3] 279 | faces = faces.repeat(b, 1, 1).long() # [b,nface,3] 280 | 281 | ############ 282 | # compute attributes 283 | attr_color = rgbd[..., :-1].reshape(b, h * w, 3) # [b,h*w,3] 284 | attr_object = obj_mask.reshape(b, h * w, 1).to(attr_color.device) # [b,h*w,1] 285 | attr_mask = self.get_visible_mask(disparity, alpha_threshold=0.1).reshape(b, h * w, 1) # [b,h*w,1] 286 | attr = torch.cat([attr_color, attr_mask, attr_object], dim=-1) # [b,h*w,4] 287 | mesh_dict = { 288 | "vertice": vertice, 289 | "faces": faces, 290 | "attributes": attr, 291 | "size": [h, w], 292 | } 293 | return mesh_dict 294 | 295 | def get_screen_pixel_coord(self, h, w): 296 | ''' 297 | get normalized pixel coordinates on the screen 298 | x to left, y to down 299 | 300 | e.g. 301 | [0,0][1,0][2,0] 302 | [0,1][1,1][2,1] 303 | output: 304 | pixel_coord: [1,h,w,2] 305 | ''' 306 | x = torch.arange(w).to(self.device) # [w] 307 | y = torch.arange(h).to(self.device) # [h] 308 | x = (x + 0.5) / w 309 | y = (y + 0.5) / h 310 | x = x[None, None, ..., None].repeat(1, h, 1, 1) # [1,h,w,1] 311 | y = y[None, ..., None, None].repeat(1, 1, w, 1) # [1,h,w,1] 312 | pixel_coord = torch.cat([x, y], dim=-1) # [1,h,w,2] 313 | return pixel_coord 314 | 315 | def lift_to_homo(self, coord): 316 | ''' 317 | return the homo version of coord 318 | input: coord [..., k] 319 | output: homo_coord [...,k+1] 320 | ''' 321 | ones = torch.ones_like(coord[..., -1:]) 322 | return torch.cat([coord, ones], dim=-1) 323 | 324 | def get_faces(self, h, w): 325 | x = torch.arange(w - 1).to(self.device) # [w-1] 326 | y = torch.arange(h - 1).to(self.device) # [h-1] 327 | x = x[None, None, ..., None].repeat(1, h - 1, 1, 1) # [1,h-1,w-1,1] 328 | y = y[None, ..., None, None].repeat(1, 1, w - 1, 1) # [1,h-1,w-1,1] 329 | 330 | tl = y * w + x 331 | tr = y * w + x + 1 332 | bl = (y + 1) * w + x 333 | br = (y + 1) * w + x + 1 334 | 335 | faces_l = torch.cat([tl, bl, br], dim=-1).reshape(1, -1, 3) # [1,(h-1)(w-1),3] 336 | faces_r = torch.cat([br, tr, tl], dim=-1).reshape(1, -1, 3) # [1,(h-1)(w-1),3] 337 | 338 | return torch.cat([faces_l, faces_r], dim=1) # [1,nface,3] 339 | 340 | def get_visible_mask(self, disparity, beta=10, alpha_threshold=0.3): 341 | b, h, w, _ = disparity.size() 342 | disparity = disparity.reshape(b, 1, h, w) # [b,1,h,w] 343 | kernel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]).unsqueeze(0).unsqueeze(0).float().to(self.device) 344 | kernel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]]).unsqueeze(0).unsqueeze(0).float().to(self.device) 345 | sobel_x = F.conv2d(disparity, kernel_x, padding=(1, 1)) # [b,1,h,w] 346 | sobel_y = F.conv2d(disparity, kernel_y, padding=(1, 1)) # [b,1,h,w] 347 | sobel_mag = torch.sqrt(sobel_x ** 2 + sobel_y ** 2).reshape(b, h, w, 1) # [b,h,w,1] 348 | alpha = torch.exp(-1.0 * beta * sobel_mag) # [b,h,w,1] 349 | vis_mask = torch.greater(alpha, alpha_threshold).float() 350 | return vis_mask 351 | 352 | def get_perspective_from_intrinsic(self, cam_int): 353 | ''' 354 | input: 355 | cam_int: [b,3,3] 356 | 357 | output: 358 | persp: [b,4,4] 359 | ''' 360 | fx, fy = cam_int[:, 0, 0], cam_int[:, 1, 1] # [b] 361 | cx, cy = cam_int[:, 0, 2], cam_int[:, 1, 2] # [b] 362 | 363 | one = torch.ones_like(cx) # [b] 364 | zero = torch.zeros_like(cx) # [b] 365 | 366 | near_z, far_z = self.near_z * one, self.far_z * one 367 | a = (near_z + far_z) / (far_z - near_z) 368 | b = -2.0 * near_z * far_z / (far_z - near_z) 369 | 370 | matrix = [[2.0 * fx, zero, 2.0 * cx - 1.0, zero], 371 | [zero, 2.0 * fy, 2.0 * cy - 1.0, zero], 372 | [zero, zero, a, b], 373 | [zero, zero, one, zero]] 374 | # -> [[b,4],[b,4],[b,4],[b,4]] -> [b,4,4] 375 | persp = torch.stack([torch.stack(row, dim=-1) for row in matrix], dim=-2) # [b,4,4] 376 | # print(fx, cx, cy, a, b) 377 | return persp 378 | 379 | 380 | ####################### 381 | # some helper I/O functions 382 | ####################### 383 | def image_to_tensor(img_path, unsqueeze=True): 384 | rgb = transforms.ToTensor()(Image.open(img_path)) 385 | if unsqueeze: 386 | rgb = rgb.unsqueeze(0) 387 | return rgb 388 | 389 | def sparse_bilateral_filtering( 390 | depth, 391 | filter_size, 392 | sigma_r=0.5, 393 | sigma_s=4.0, 394 | depth_threshold=0.04, 395 | HR=False, 396 | mask=None, 397 | num_iter=None, 398 | ): 399 | 400 | save_discontinuities = [] 401 | vis_depth = depth.copy() 402 | for i in range(num_iter): 403 | u_over, b_over, l_over, r_over = vis_depth_discontinuity( 404 | vis_depth, depth_threshold, mask=mask 405 | ) 406 | 407 | discontinuity_map = (u_over + b_over + l_over + r_over).clip(0.0, 1.0) 408 | discontinuity_map[depth == 0] = 1 409 | save_discontinuities.append(discontinuity_map) 410 | if mask is not None: 411 | discontinuity_map[mask == 0] = 0 412 | vis_depth = bilateral_filter( 413 | vis_depth, 414 | sigma_r=sigma_r, 415 | sigma_s=sigma_s, 416 | discontinuity_map=discontinuity_map, 417 | HR=HR, 418 | mask=mask, 419 | window_size=filter_size[i], 420 | ) 421 | 422 | return vis_depth 423 | 424 | def disparity_to_tensor(disp_path, unsqueeze=True): 425 | disp = cv2.imread(disp_path, -1) / (2 ** 16 - 1) 426 | disp = sparse_bilateral_filtering(disp + 1e-4, filter_size=[5, 5], num_iter=2) 427 | disp = torch.from_numpy(disp)[None, ...] 428 | if unsqueeze: 429 | disp = disp.unsqueeze(0) 430 | return disp.float() 431 | 432 | 433 | ####################### 434 | # some helper geometry functions 435 | # adapt from https://github.com/mattpoggi/depthstillation 436 | ####################### 437 | def transformation_from_parameters(axisangle, translation, invert=False): 438 | R = rot_from_axisangle(axisangle) 439 | t = translation.clone() 440 | 441 | if invert: 442 | R = R.transpose(1, 2) 443 | t *= -1 444 | 445 | T = get_translation_matrix(t) 446 | 447 | if invert: 448 | M = torch.matmul(R, T) 449 | else: 450 | M = torch.matmul(T, R) 451 | 452 | return M 453 | 454 | 455 | def get_translation_matrix(translation_vector): 456 | T = torch.zeros(translation_vector.shape[0], 4, 4).to(device=translation_vector.device) 457 | t = translation_vector.contiguous().view(-1, 3, 1) 458 | T[:, 0, 0] = 1 459 | T[:, 1, 1] = 1 460 | T[:, 2, 2] = 1 461 | T[:, 3, 3] = 1 462 | T[:, :3, 3, None] = t 463 | return T 464 | 465 | 466 | def rot_from_axisangle(vec): 467 | angle = torch.norm(vec, 2, 2, True) 468 | axis = vec / (angle + 1e-7) 469 | 470 | ca = torch.cos(angle) 471 | sa = torch.sin(angle) 472 | C = 1 - ca 473 | 474 | x = axis[..., 0].unsqueeze(1) 475 | y = axis[..., 1].unsqueeze(1) 476 | z = axis[..., 2].unsqueeze(1) 477 | 478 | xs = x * sa 479 | ys = y * sa 480 | zs = z * sa 481 | xC = x * C 482 | yC = y * C 483 | zC = z * C 484 | xyC = x * yC 485 | yzC = y * zC 486 | zxC = z * xC 487 | 488 | rot = torch.zeros((vec.shape[0], 4, 4)).to(device=vec.device) 489 | 490 | rot[:, 0, 0] = torch.squeeze(x * xC + ca) 491 | rot[:, 0, 1] = torch.squeeze(xyC - zs) 492 | rot[:, 0, 2] = torch.squeeze(zxC + ys) 493 | rot[:, 1, 0] = torch.squeeze(xyC + zs) 494 | rot[:, 1, 1] = torch.squeeze(y * yC + ca) 495 | rot[:, 1, 2] = torch.squeeze(yzC - xs) 496 | rot[:, 2, 0] = torch.squeeze(zxC - ys) 497 | rot[:, 2, 1] = torch.squeeze(yzC + xs) 498 | rot[:, 2, 2] = torch.squeeze(z * zC + ca) 499 | rot[:, 3, 3] = 1 500 | 501 | return rot 502 | 503 | -------------------------------------------------------------------------------- /romatch/utils/utils.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | import numpy as np 3 | import cv2 4 | import math 5 | import torch 6 | from torchvision import transforms 7 | from torchvision.transforms.functional import InterpolationMode 8 | import torch.nn.functional as F 9 | from PIL import Image 10 | import kornia 11 | 12 | def recover_pose(E, kpts0, kpts1, K0, K1, mask): 13 | best_num_inliers = 0 14 | K0inv = np.linalg.inv(K0[:2,:2]) 15 | K1inv = np.linalg.inv(K1[:2,:2]) 16 | 17 | kpts0_n = (K0inv @ (kpts0-K0[None,:2,2]).T).T 18 | kpts1_n = (K1inv @ (kpts1-K1[None,:2,2]).T).T 19 | 20 | for _E in np.split(E, len(E) / 3): 21 | n, R, t, _ = cv2.recoverPose(_E, kpts0_n, kpts1_n, np.eye(3), 1e9, mask=mask) 22 | if n > best_num_inliers: 23 | best_num_inliers = n 24 | ret = (R, t, mask.ravel() > 0) 25 | return ret 26 | 27 | 28 | 29 | # Code taken from https://github.com/PruneTruong/DenseMatching/blob/40c29a6b5c35e86b9509e65ab0cd12553d998e5f/validation/utils_pose_estimation.py 30 | # --- GEOMETRY --- 31 | def estimate_pose(kpts0, kpts1, K0, K1, norm_thresh, conf=0.99999): 32 | if len(kpts0) < 5: 33 | return None 34 | K0inv = np.linalg.inv(K0[:2,:2]) 35 | K1inv = np.linalg.inv(K1[:2,:2]) 36 | 37 | kpts0 = (K0inv @ (kpts0-K0[None,:2,2]).T).T 38 | kpts1 = (K1inv @ (kpts1-K1[None,:2,2]).T).T 39 | E, mask = cv2.findEssentialMat( 40 | kpts0, kpts1, np.eye(3), threshold=norm_thresh, prob=conf 41 | ) 42 | 43 | ret = None 44 | if E is not None: 45 | best_num_inliers = 0 46 | 47 | for _E in np.split(E, len(E) / 3): 48 | n, R, t, _ = cv2.recoverPose(_E, kpts0, kpts1, np.eye(3), 1e9, mask=mask) 49 | if n > best_num_inliers: 50 | best_num_inliers = n 51 | ret = (R, t, mask.ravel() > 0) 52 | return ret 53 | 54 | def estimate_pose_uncalibrated(kpts0, kpts1, K0, K1, norm_thresh, conf=0.99999): 55 | if len(kpts0) < 5: 56 | return None 57 | method = cv2.USAC_ACCURATE 58 | F, mask = cv2.findFundamentalMat( 59 | kpts0, kpts1, ransacReprojThreshold=norm_thresh, confidence=conf, method=method, maxIters=10000 60 | ) 61 | E = K1.T@F@K0 62 | ret = None 63 | if E is not None: 64 | best_num_inliers = 0 65 | K0inv = np.linalg.inv(K0[:2,:2]) 66 | K1inv = np.linalg.inv(K1[:2,:2]) 67 | 68 | kpts0_n = (K0inv @ (kpts0-K0[None,:2,2]).T).T 69 | kpts1_n = (K1inv @ (kpts1-K1[None,:2,2]).T).T 70 | 71 | for _E in np.split(E, len(E) / 3): 72 | n, R, t, _ = cv2.recoverPose(_E, kpts0_n, kpts1_n, np.eye(3), 1e9, mask=mask) 73 | if n > best_num_inliers: 74 | best_num_inliers = n 75 | ret = (R, t, mask.ravel() > 0) 76 | return ret 77 | 78 | def unnormalize_coords(x_n,h,w): 79 | x = torch.stack( 80 | (w * (x_n[..., 0] + 1) / 2, h * (x_n[..., 1] + 1) / 2), dim=-1 81 | ) # [-1+1/h, 1-1/h] -> [0.5, h-0.5] 82 | return x 83 | 84 | 85 | def rotate_intrinsic(K, n): 86 | base_rot = np.array([[0, 1, 0], [-1, 0, 0], [0, 0, 1]]) 87 | rot = np.linalg.matrix_power(base_rot, n) 88 | return rot @ K 89 | 90 | 91 | def rotate_pose_inplane(i_T_w, rot): 92 | rotation_matrices = [ 93 | np.array( 94 | [ 95 | [np.cos(r), -np.sin(r), 0.0, 0.0], 96 | [np.sin(r), np.cos(r), 0.0, 0.0], 97 | [0.0, 0.0, 1.0, 0.0], 98 | [0.0, 0.0, 0.0, 1.0], 99 | ], 100 | dtype=np.float32, 101 | ) 102 | for r in [np.deg2rad(d) for d in (0, 270, 180, 90)] 103 | ] 104 | return np.dot(rotation_matrices[rot], i_T_w) 105 | 106 | 107 | def scale_intrinsics(K, scales): 108 | scales = np.diag([1.0 / scales[0], 1.0 / scales[1], 1.0]) 109 | return np.dot(scales, K) 110 | 111 | 112 | def to_homogeneous(points): 113 | return np.concatenate([points, np.ones_like(points[:, :1])], axis=-1) 114 | 115 | 116 | def angle_error_mat(R1, R2): 117 | cos = (np.trace(np.dot(R1.T, R2)) - 1) / 2 118 | cos = np.clip(cos, -1.0, 1.0) # numercial errors can make it out of bounds 119 | return np.rad2deg(np.abs(np.arccos(cos))) 120 | 121 | 122 | def angle_error_vec(v1, v2): 123 | n = np.linalg.norm(v1) * np.linalg.norm(v2) 124 | return np.rad2deg(np.arccos(np.clip(np.dot(v1, v2) / n, -1.0, 1.0))) 125 | 126 | 127 | def compute_pose_error(T_0to1, R, t): 128 | R_gt = T_0to1[:3, :3] 129 | t_gt = T_0to1[:3, 3] 130 | error_t = angle_error_vec(t.squeeze(), t_gt) 131 | error_t = np.minimum(error_t, 180 - error_t) # ambiguity of E estimation 132 | error_R = angle_error_mat(R, R_gt) 133 | return error_t, error_R 134 | 135 | 136 | def pose_auc(errors, thresholds): 137 | sort_idx = np.argsort(errors) 138 | errors = np.array(errors.copy())[sort_idx] 139 | recall = (np.arange(len(errors)) + 1) / len(errors) 140 | errors = np.r_[0.0, errors] 141 | recall = np.r_[0.0, recall] 142 | aucs = [] 143 | for t in thresholds: 144 | last_index = np.searchsorted(errors, t) 145 | r = np.r_[recall[:last_index], recall[last_index - 1]] 146 | e = np.r_[errors[:last_index], t] 147 | aucs.append(np.trapz(r, x=e) / t) 148 | return aucs 149 | 150 | 151 | # From Patch2Pix https://github.com/GrumpyZhou/patch2pix 152 | def get_depth_tuple_transform_ops_nearest_exact(resize=None): 153 | ops = [] 154 | if resize: 155 | ops.append(TupleResizeNearestExact(resize)) 156 | return TupleCompose(ops) 157 | 158 | def get_depth_tuple_transform_ops(resize=None, normalize=True, unscale=False): 159 | ops = [] 160 | if resize: 161 | ops.append(TupleResize(resize, mode=InterpolationMode.BILINEAR)) 162 | return TupleCompose(ops) 163 | 164 | 165 | def get_tuple_transform_ops(resize=None, normalize=True, unscale=False, clahe = False, colorjiggle_params = None): 166 | ops = [] 167 | if resize: 168 | ops.append(TupleResize(resize)) 169 | ops.append(TupleToTensorScaled()) 170 | if normalize: 171 | ops.append( 172 | TupleNormalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 173 | ) # Imagenet mean/std 174 | return TupleCompose(ops) 175 | 176 | class ToTensorScaled(object): 177 | """Convert a RGB PIL Image to a CHW ordered Tensor, scale the range to [0, 1]""" 178 | 179 | def __call__(self, im): 180 | if not isinstance(im, torch.Tensor): 181 | im = np.array(im, dtype=np.float32).transpose((2, 0, 1)) 182 | im /= 255.0 183 | return torch.from_numpy(im) 184 | else: 185 | return im 186 | 187 | def __repr__(self): 188 | return "ToTensorScaled(./255)" 189 | 190 | 191 | class TupleToTensorScaled(object): 192 | def __init__(self): 193 | self.to_tensor = ToTensorScaled() 194 | 195 | def __call__(self, im_tuple): 196 | return [self.to_tensor(im) for im in im_tuple] 197 | 198 | def __repr__(self): 199 | return "TupleToTensorScaled(./255)" 200 | 201 | 202 | class ToTensorUnscaled(object): 203 | """Convert a RGB PIL Image to a CHW ordered Tensor""" 204 | 205 | def __call__(self, im): 206 | return torch.from_numpy(np.array(im, dtype=np.float32).transpose((2, 0, 1))) 207 | 208 | def __repr__(self): 209 | return "ToTensorUnscaled()" 210 | 211 | 212 | class TupleToTensorUnscaled(object): 213 | """Convert a RGB PIL Image to a CHW ordered Tensor""" 214 | 215 | def __init__(self): 216 | self.to_tensor = ToTensorUnscaled() 217 | 218 | def __call__(self, im_tuple): 219 | return [self.to_tensor(im) for im in im_tuple] 220 | 221 | def __repr__(self): 222 | return "TupleToTensorUnscaled()" 223 | 224 | class TupleResizeNearestExact: 225 | def __init__(self, size): 226 | self.size = size 227 | def __call__(self, im_tuple): 228 | return [F.interpolate(im, size = self.size, mode = 'nearest-exact') for im in im_tuple] 229 | 230 | def __repr__(self): 231 | return "TupleResizeNearestExact(size={})".format(self.size) 232 | 233 | 234 | class TupleResize(object): 235 | def __init__(self, size, mode=InterpolationMode.BICUBIC): 236 | self.size = size 237 | self.resize = transforms.Resize(size, mode) 238 | def __call__(self, im_tuple): 239 | return [self.resize(im) for im in im_tuple] 240 | 241 | def __repr__(self): 242 | return "TupleResize(size={})".format(self.size) 243 | 244 | class Normalize: 245 | def __call__(self,im): 246 | mean = im.mean(dim=(1,2), keepdims=True) 247 | std = im.std(dim=(1,2), keepdims=True) 248 | return (im-mean)/std 249 | 250 | 251 | class TupleNormalize(object): 252 | def __init__(self, mean, std): 253 | self.mean = mean 254 | self.std = std 255 | self.normalize = transforms.Normalize(mean=mean, std=std) 256 | 257 | def __call__(self, im_tuple): 258 | c,h,w = im_tuple[0].shape 259 | if c > 3: 260 | warnings.warn(f"Number of channels c={c} > 3, assuming first 3 are rgb") 261 | return [self.normalize(im[:3]) for im in im_tuple] 262 | 263 | def __repr__(self): 264 | return "TupleNormalize(mean={}, std={})".format(self.mean, self.std) 265 | 266 | 267 | class TupleCompose(object): 268 | def __init__(self, transforms): 269 | self.transforms = transforms 270 | 271 | def __call__(self, im_tuple): 272 | for t in self.transforms: 273 | im_tuple = t(im_tuple) 274 | return im_tuple 275 | 276 | def __repr__(self): 277 | format_string = self.__class__.__name__ + "(" 278 | for t in self.transforms: 279 | format_string += "\n" 280 | format_string += " {0}".format(t) 281 | format_string += "\n)" 282 | return format_string 283 | 284 | @torch.no_grad() 285 | def cls_to_flow(cls, deterministic_sampling = True): 286 | B,C,H,W = cls.shape 287 | device = cls.device 288 | res = round(math.sqrt(C)) 289 | G = torch.meshgrid( 290 | *[torch.linspace(-1+1/res, 1-1/res, steps = res, device = device) for _ in range(2)], 291 | indexing = 'ij' 292 | ) 293 | G = torch.stack([G[1],G[0]],dim=-1).reshape(C,2) 294 | if deterministic_sampling: 295 | sampled_cls = cls.max(dim=1).indices 296 | else: 297 | sampled_cls = torch.multinomial(cls.permute(0,2,3,1).reshape(B*H*W,C).softmax(dim=-1), 1).reshape(B,H,W) 298 | flow = G[sampled_cls] 299 | return flow 300 | 301 | @torch.no_grad() 302 | def cls_to_flow_refine(cls): 303 | B,C,H,W = cls.shape 304 | device = cls.device 305 | res = round(math.sqrt(C)) 306 | G = torch.meshgrid( 307 | *[torch.linspace(-1+1/res, 1-1/res, steps = res, device = device) for _ in range(2)], 308 | indexing = 'ij' 309 | ) 310 | G = torch.stack([G[1],G[0]],dim=-1).reshape(C,2) 311 | # FIXME: below softmax line causes mps to bug, don't know why. 312 | if device.type == 'mps': 313 | cls = cls.log_softmax(dim=1).exp() 314 | else: 315 | cls = cls.softmax(dim=1) 316 | mode = cls.max(dim=1).indices 317 | 318 | index = torch.stack((mode-1, mode, mode+1, mode - res, mode + res), dim = 1).clamp(0,C - 1).long() 319 | neighbours = torch.gather(cls, dim = 1, index = index)[...,None] 320 | flow = neighbours[:,0] * G[index[:,0]] + neighbours[:,1] * G[index[:,1]] + neighbours[:,2] * G[index[:,2]] + neighbours[:,3] * G[index[:,3]] + neighbours[:,4] * G[index[:,4]] 321 | tot_prob = neighbours.sum(dim=1) 322 | flow = flow / tot_prob 323 | return flow 324 | 325 | 326 | def get_gt_warp(depth1, depth2, T_1to2, K1, K2, depth_interpolation_mode = 'bilinear', relative_depth_error_threshold = 0.05, H = None, W = None): 327 | 328 | if H is None: 329 | B,H,W = depth1.shape 330 | else: 331 | B = depth1.shape[0] 332 | with torch.no_grad(): 333 | x1_n = torch.meshgrid( 334 | *[ 335 | torch.linspace( 336 | -1 + 1 / n, 1 - 1 / n, n, device=depth1.device 337 | ) 338 | for n in (B, H, W) 339 | ], 340 | indexing = 'ij' 341 | ) 342 | x1_n = torch.stack((x1_n[2], x1_n[1]), dim=-1).reshape(B, H * W, 2) 343 | mask, x2 = warp_kpts( 344 | x1_n.double(), 345 | depth1.double(), 346 | depth2.double(), 347 | T_1to2.double(), 348 | K1.double(), 349 | K2.double(), 350 | depth_interpolation_mode = depth_interpolation_mode, 351 | relative_depth_error_threshold = relative_depth_error_threshold, 352 | ) 353 | prob = mask.float().reshape(B, H, W) 354 | x2 = x2.reshape(B, H, W, 2) 355 | return x2, prob 356 | 357 | @torch.no_grad() 358 | def warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1, smooth_mask = False, return_relative_depth_error = False, depth_interpolation_mode = "bilinear", relative_depth_error_threshold = 0.05): 359 | """Warp kpts0 from I0 to I1 with depth, K and Rt 360 | Also check covisibility and depth consistency. 361 | Depth is consistent if relative error < 0.2 (hard-coded). 362 | # https://github.com/zju3dv/LoFTR/blob/94e98b695be18acb43d5d3250f52226a8e36f839/src/loftr/utils/geometry.py adapted from here 363 | Args: 364 | kpts0 (torch.Tensor): [N, L, 2] - , should be normalized in (-1,1) 365 | depth0 (torch.Tensor): [N, H, W], 366 | depth1 (torch.Tensor): [N, H, W], 367 | T_0to1 (torch.Tensor): [N, 3, 4], 368 | K0 (torch.Tensor): [N, 3, 3], 369 | K1 (torch.Tensor): [N, 3, 3], 370 | Returns: 371 | calculable_mask (torch.Tensor): [N, L] 372 | warped_keypoints0 (torch.Tensor): [N, L, 2] 373 | """ 374 | ( 375 | n, 376 | h, 377 | w, 378 | ) = depth0.shape 379 | if depth_interpolation_mode == "combined": 380 | # Inspired by approach in inloc, try to fill holes from bilinear interpolation by nearest neighbour interpolation 381 | if smooth_mask: 382 | raise NotImplementedError("Combined bilinear and NN warp not implemented") 383 | valid_bilinear, warp_bilinear = warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1, 384 | smooth_mask = smooth_mask, 385 | return_relative_depth_error = return_relative_depth_error, 386 | depth_interpolation_mode = "bilinear", 387 | relative_depth_error_threshold = relative_depth_error_threshold) 388 | valid_nearest, warp_nearest = warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1, 389 | smooth_mask = smooth_mask, 390 | return_relative_depth_error = return_relative_depth_error, 391 | depth_interpolation_mode = "nearest-exact", 392 | relative_depth_error_threshold = relative_depth_error_threshold) 393 | nearest_valid_bilinear_invalid = (~valid_bilinear).logical_and(valid_nearest) 394 | warp = warp_bilinear.clone() 395 | warp[nearest_valid_bilinear_invalid] = warp_nearest[nearest_valid_bilinear_invalid] 396 | valid = valid_bilinear | valid_nearest 397 | return valid, warp 398 | 399 | 400 | kpts0_depth = F.grid_sample(depth0[:, None], kpts0[:, :, None], mode = depth_interpolation_mode, align_corners=False)[ 401 | :, 0, :, 0 402 | ] 403 | kpts0 = torch.stack( 404 | (w * (kpts0[..., 0] + 1) / 2, h * (kpts0[..., 1] + 1) / 2), dim=-1 405 | ) # [-1+1/h, 1-1/h] -> [0.5, h-0.5] 406 | # Sample depth, get calculable_mask on depth != 0 407 | nonzero_mask = kpts0_depth != 0 408 | 409 | # Unproject 410 | kpts0_h = ( 411 | torch.cat([kpts0, torch.ones_like(kpts0[:, :, [0]])], dim=-1) 412 | * kpts0_depth[..., None] 413 | ) # (N, L, 3) 414 | kpts0_n = K0.inverse() @ kpts0_h.transpose(2, 1) # (N, 3, L) 415 | kpts0_cam = kpts0_n 416 | 417 | # Rigid Transform 418 | w_kpts0_cam = T_0to1[:, :3, :3] @ kpts0_cam + T_0to1[:, :3, [3]] # (N, 3, L) 419 | w_kpts0_depth_computed = w_kpts0_cam[:, 2, :] 420 | 421 | # Project 422 | w_kpts0_h = (K1 @ w_kpts0_cam).transpose(2, 1) # (N, L, 3) 423 | w_kpts0 = w_kpts0_h[:, :, :2] / ( 424 | w_kpts0_h[:, :, [2]] + 1e-4 425 | ) # (N, L, 2), +1e-4 to avoid zero depth 426 | 427 | # Covisible Check 428 | h, w = depth1.shape[1:3] 429 | covisible_mask = ( 430 | (w_kpts0[:, :, 0] > 0) 431 | * (w_kpts0[:, :, 0] < w - 1) 432 | * (w_kpts0[:, :, 1] > 0) 433 | * (w_kpts0[:, :, 1] < h - 1) 434 | ) 435 | w_kpts0 = torch.stack( 436 | (2 * w_kpts0[..., 0] / w - 1, 2 * w_kpts0[..., 1] / h - 1), dim=-1 437 | ) # from [0.5,h-0.5] -> [-1+1/h, 1-1/h] 438 | # w_kpts0[~covisible_mask, :] = -5 # xd 439 | 440 | w_kpts0_depth = F.grid_sample( 441 | depth1[:, None], w_kpts0[:, :, None], mode=depth_interpolation_mode, align_corners=False 442 | )[:, 0, :, 0] 443 | 444 | relative_depth_error = ( 445 | (w_kpts0_depth - w_kpts0_depth_computed) / w_kpts0_depth 446 | ).abs() 447 | if not smooth_mask: 448 | consistent_mask = relative_depth_error < relative_depth_error_threshold 449 | else: 450 | consistent_mask = (-relative_depth_error/smooth_mask).exp() 451 | valid_mask = nonzero_mask * covisible_mask * consistent_mask 452 | if return_relative_depth_error: 453 | return relative_depth_error, w_kpts0 454 | else: 455 | return valid_mask, w_kpts0 456 | 457 | imagenet_mean = torch.tensor([0.485, 0.456, 0.406]) 458 | imagenet_std = torch.tensor([0.229, 0.224, 0.225]) 459 | 460 | 461 | def numpy_to_pil(x: np.ndarray): 462 | """ 463 | Args: 464 | x: Assumed to be of shape (h,w,c) 465 | """ 466 | if isinstance(x, torch.Tensor): 467 | x = x.detach().cpu().numpy() 468 | if x.max() <= 1.01: 469 | x *= 255 470 | x = x.astype(np.uint8) 471 | return Image.fromarray(x) 472 | 473 | 474 | def tensor_to_pil(x, unnormalize=False): 475 | if unnormalize: 476 | x = x * (imagenet_std[:, None, None].to(x.device)) + (imagenet_mean[:, None, None].to(x.device)) 477 | x = x.detach().permute(1, 2, 0).cpu().numpy() 478 | x = np.clip(x, 0.0, 1.0) 479 | return numpy_to_pil(x) 480 | 481 | 482 | def to_cuda(batch): 483 | for key, value in batch.items(): 484 | if isinstance(value, torch.Tensor): 485 | batch[key] = value.cuda() 486 | return batch 487 | 488 | 489 | def to_cpu(batch): 490 | for key, value in batch.items(): 491 | if isinstance(value, torch.Tensor): 492 | batch[key] = value.cpu() 493 | return batch 494 | 495 | 496 | def get_pose(calib): 497 | w, h = np.array(calib["imsize"])[0] 498 | return np.array(calib["K"]), np.array(calib["R"]), np.array(calib["T"]).T, h, w 499 | 500 | 501 | def compute_relative_pose(R1, t1, R2, t2): 502 | rots = R2 @ (R1.T) 503 | trans = -rots @ t1 + t2 504 | return rots, trans 505 | 506 | @torch.no_grad() 507 | def reset_opt(opt): 508 | for group in opt.param_groups: 509 | for p in group['params']: 510 | if p.requires_grad: 511 | state = opt.state[p] 512 | # State initialization 513 | 514 | # Exponential moving average of gradient values 515 | state['exp_avg'] = torch.zeros_like(p) 516 | # Exponential moving average of squared gradient values 517 | state['exp_avg_sq'] = torch.zeros_like(p) 518 | # Exponential moving average of gradient difference 519 | state['exp_avg_diff'] = torch.zeros_like(p) 520 | 521 | 522 | def flow_to_pixel_coords(flow, h1, w1): 523 | flow = ( 524 | torch.stack( 525 | ( 526 | w1 * (flow[..., 0] + 1) / 2, 527 | h1 * (flow[..., 1] + 1) / 2, 528 | ), 529 | axis=-1, 530 | ) 531 | ) 532 | return flow 533 | 534 | to_pixel_coords = flow_to_pixel_coords # just an alias 535 | 536 | def flow_to_normalized_coords(flow, h1, w1): 537 | flow = ( 538 | torch.stack( 539 | ( 540 | 2 * (flow[..., 0]) / w1 - 1, 541 | 2 * (flow[..., 1]) / h1 - 1, 542 | ), 543 | axis=-1, 544 | ) 545 | ) 546 | return flow 547 | 548 | to_normalized_coords = flow_to_normalized_coords # just an alias 549 | 550 | def warp_to_pixel_coords(warp, h1, w1, h2, w2): 551 | warp1 = warp[..., :2] 552 | warp1 = ( 553 | torch.stack( 554 | ( 555 | w1 * (warp1[..., 0] + 1) / 2, 556 | h1 * (warp1[..., 1] + 1) / 2, 557 | ), 558 | axis=-1, 559 | ) 560 | ) 561 | warp2 = warp[..., 2:] 562 | warp2 = ( 563 | torch.stack( 564 | ( 565 | w2 * (warp2[..., 0] + 1) / 2, 566 | h2 * (warp2[..., 1] + 1) / 2, 567 | ), 568 | axis=-1, 569 | ) 570 | ) 571 | return torch.cat((warp1,warp2), dim=-1) 572 | 573 | 574 | 575 | def signed_point_line_distance(point, line, eps: float = 1e-9): 576 | r"""Return the distance from points to lines. 577 | 578 | Args: 579 | point: (possibly homogeneous) points :math:`(*, N, 2 or 3)`. 580 | line: lines coefficients :math:`(a, b, c)` with shape :math:`(*, N, 3)`, where :math:`ax + by + c = 0`. 581 | eps: Small constant for safe sqrt. 582 | 583 | Returns: 584 | the computed distance with shape :math:`(*, N)`. 585 | """ 586 | 587 | if not point.shape[-1] in (2, 3): 588 | raise ValueError(f"pts must be a (*, 2 or 3) tensor. Got {point.shape}") 589 | 590 | if not line.shape[-1] == 3: 591 | raise ValueError(f"lines must be a (*, 3) tensor. Got {line.shape}") 592 | 593 | numerator = (line[..., 0] * point[..., 0] + line[..., 1] * point[..., 1] + line[..., 2]) 594 | denominator = line[..., :2].norm(dim=-1) 595 | 596 | return numerator / (denominator + eps) 597 | 598 | 599 | def signed_left_to_right_epipolar_distance(pts1, pts2, Fm): 600 | r"""Return one-sided epipolar distance for correspondences given the fundamental matrix. 601 | 602 | This method measures the distance from points in the right images to the epilines 603 | of the corresponding points in the left images as they reflect in the right images. 604 | 605 | Args: 606 | pts1: correspondences from the left images with shape 607 | :math:`(*, N, 2 or 3)`. If they are not homogeneous, converted automatically. 608 | pts2: correspondences from the right images with shape 609 | :math:`(*, N, 2 or 3)`. If they are not homogeneous, converted automatically. 610 | Fm: Fundamental matrices with shape :math:`(*, 3, 3)`. Called Fm to 611 | avoid ambiguity with torch.nn.functional. 612 | 613 | Returns: 614 | the computed Symmetrical distance with shape :math:`(*, N)`. 615 | """ 616 | import kornia 617 | if (len(Fm.shape) < 3) or not Fm.shape[-2:] == (3, 3): 618 | raise ValueError(f"Fm must be a (*, 3, 3) tensor. Got {Fm.shape}") 619 | 620 | if pts1.shape[-1] == 2: 621 | pts1 = kornia.geometry.convert_points_to_homogeneous(pts1) 622 | 623 | F_t = Fm.transpose(dim0=-2, dim1=-1) 624 | line1_in_2 = pts1 @ F_t 625 | 626 | return signed_point_line_distance(pts2, line1_in_2) 627 | 628 | def get_grid(b, h, w, device): 629 | grid = torch.meshgrid( 630 | *[ 631 | torch.linspace(-1 + 1 / n, 1 - 1 / n, n, device=device) 632 | for n in (b, h, w) 633 | ], 634 | indexing = 'ij' 635 | ) 636 | grid = torch.stack((grid[2], grid[1]), dim=-1).reshape(b, h, w, 2) 637 | return grid 638 | 639 | 640 | def get_autocast_params(device=None, enabled=False, dtype=None): 641 | if device is None: 642 | autocast_device = "cuda" if torch.cuda.is_available() else "cpu" 643 | else: 644 | #strip :X from device 645 | autocast_device = str(device).split(":")[0] 646 | if 'cuda' in str(device): 647 | out_dtype = dtype 648 | enabled = True 649 | else: 650 | out_dtype = torch.bfloat16 651 | enabled = False 652 | # mps is not supported 653 | autocast_device = "cpu" 654 | return autocast_device, enabled, out_dtype 655 | 656 | def check_not_i16(im): 657 | if im.mode == "I;16": 658 | raise NotImplementedError("Can't handle 16 bit images") 659 | 660 | def check_rgb(im): 661 | if im.mode != "RGB": 662 | raise NotImplementedError("Can't handle non-RGB images") 663 | --------------------------------------------------------------------------------