├── assets
├── sacre_coeur_A.jpg
├── sacre_coeur_B.jpg
├── 0_d_00d1ae6aab6ccd59.jpg
├── 2_a_02a270519bdb90dd.jpg
├── sacre_coeur_A_compare.png
└── sacre_coeur_B_compare.png
├── romatch
├── utils
│ ├── __init__.py
│ ├── kde.py
│ ├── local_correlation.py
│ ├── transforms.py
│ └── utils.py
└── models
│ └── transformer
│ ├── layers
│ ├── __init__.py
│ ├── layer_scale.py
│ ├── drop_path.py
│ ├── mlp.py
│ ├── swiglu_ffn.py
│ ├── dino_head.py
│ ├── attention.py
│ ├── patch_embed.py
│ └── block.py
│ ├── __init__.py
│ └── dinov2.py
├── LICENSE.txt
├── sky_seg.py
├── vis_feats.py
├── README.md
├── relight.py
├── get_data.py
└── warp_utils.py
/assets/sacre_coeur_A.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sharpiless/L2M/HEAD/assets/sacre_coeur_A.jpg
--------------------------------------------------------------------------------
/assets/sacre_coeur_B.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sharpiless/L2M/HEAD/assets/sacre_coeur_B.jpg
--------------------------------------------------------------------------------
/assets/0_d_00d1ae6aab6ccd59.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sharpiless/L2M/HEAD/assets/0_d_00d1ae6aab6ccd59.jpg
--------------------------------------------------------------------------------
/assets/2_a_02a270519bdb90dd.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sharpiless/L2M/HEAD/assets/2_a_02a270519bdb90dd.jpg
--------------------------------------------------------------------------------
/assets/sacre_coeur_A_compare.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sharpiless/L2M/HEAD/assets/sacre_coeur_A_compare.png
--------------------------------------------------------------------------------
/assets/sacre_coeur_B_compare.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sharpiless/L2M/HEAD/assets/sacre_coeur_B_compare.png
--------------------------------------------------------------------------------
/romatch/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .utils import (
2 | pose_auc,
3 | get_pose,
4 | compute_relative_pose,
5 | compute_pose_error,
6 | estimate_pose,
7 | estimate_pose_uncalibrated,
8 | rotate_intrinsic,
9 | get_tuple_transform_ops,
10 | get_depth_tuple_transform_ops,
11 | warp_kpts,
12 | numpy_to_pil,
13 | tensor_to_pil,
14 | recover_pose,
15 | signed_left_to_right_epipolar_distance,
16 | )
17 |
--------------------------------------------------------------------------------
/romatch/utils/kde.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def kde(x, std = 0.1, half = True, down = None):
5 | # use a gaussian kernel to estimate density
6 | if half:
7 | x = x.half() # Do it in half precision TODO: remove hardcoding
8 | if down is not None:
9 | scores = (-torch.cdist(x,x[::down])**2/(2*std**2)).exp()
10 | else:
11 | scores = (-torch.cdist(x,x)**2/(2*std**2)).exp()
12 | density = scores.sum(dim=-1)
13 | return density
--------------------------------------------------------------------------------
/romatch/models/transformer/layers/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from .dino_head import DINOHead
8 | from .mlp import Mlp
9 | from .patch_embed import PatchEmbed
10 | from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
11 | from .block import NestedTensorBlock
12 | from .attention import MemEffAttention
13 |
--------------------------------------------------------------------------------
/romatch/models/transformer/layers/layer_scale.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
8 |
9 | from typing import Union
10 |
11 | import torch
12 | from torch import Tensor
13 | from torch import nn
14 |
15 |
16 | class LayerScale(nn.Module):
17 | def __init__(
18 | self,
19 | dim: int,
20 | init_values: Union[float, Tensor] = 1e-5,
21 | inplace: bool = False,
22 | ) -> None:
23 | super().__init__()
24 | self.inplace = inplace
25 | self.gamma = nn.Parameter(init_values * torch.ones(dim))
26 |
27 | def forward(self, x: Tensor) -> Tensor:
28 | return x.mul_(self.gamma) if self.inplace else x * self.gamma
29 |
--------------------------------------------------------------------------------
/LICENSE.txt:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Xuelun Shen
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/romatch/models/transformer/layers/drop_path.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # References:
8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
10 |
11 |
12 | from torch import nn
13 |
14 |
15 | def drop_path(x, drop_prob: float = 0.0, training: bool = False):
16 | if drop_prob == 0.0 or not training:
17 | return x
18 | keep_prob = 1 - drop_prob
19 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
20 | random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
21 | if keep_prob > 0.0:
22 | random_tensor.div_(keep_prob)
23 | output = x * random_tensor
24 | return output
25 |
26 |
27 | class DropPath(nn.Module):
28 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
29 |
30 | def __init__(self, drop_prob=None):
31 | super(DropPath, self).__init__()
32 | self.drop_prob = drop_prob
33 |
34 | def forward(self, x):
35 | return drop_path(x, self.drop_prob, self.training)
36 |
--------------------------------------------------------------------------------
/romatch/models/transformer/layers/mlp.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # References:
8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
10 |
11 |
12 | from typing import Callable, Optional
13 |
14 | from torch import Tensor, nn
15 |
16 |
17 | class Mlp(nn.Module):
18 | def __init__(
19 | self,
20 | in_features: int,
21 | hidden_features: Optional[int] = None,
22 | out_features: Optional[int] = None,
23 | act_layer: Callable[..., nn.Module] = nn.GELU,
24 | drop: float = 0.0,
25 | bias: bool = True,
26 | ) -> None:
27 | super().__init__()
28 | out_features = out_features or in_features
29 | hidden_features = hidden_features or in_features
30 | self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
31 | self.act = act_layer()
32 | self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
33 | self.drop = nn.Dropout(drop)
34 |
35 | def forward(self, x: Tensor) -> Tensor:
36 | x = self.fc1(x)
37 | x = self.act(x)
38 | x = self.drop(x)
39 | x = self.fc2(x)
40 | x = self.drop(x)
41 | return x
42 |
--------------------------------------------------------------------------------
/romatch/utils/local_correlation.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 | def local_correlation(
5 | feature0,
6 | feature1,
7 | local_radius,
8 | padding_mode="zeros",
9 | flow = None,
10 | sample_mode = "bilinear",
11 | ):
12 | r = local_radius
13 | K = (2*r+1)**2
14 | B, c, h, w = feature0.size()
15 | corr = torch.empty((B,K,h,w), device = feature0.device, dtype=feature0.dtype)
16 | if flow is None:
17 | # If flow is None, assume feature0 and feature1 are aligned
18 | coords = torch.meshgrid(
19 | (
20 | torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=feature0.device),
21 | torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device=feature0.device),
22 | ),
23 | indexing = 'ij'
24 | )
25 | coords = torch.stack((coords[1], coords[0]), dim=-1)[
26 | None
27 | ].expand(B, h, w, 2)
28 | else:
29 | coords = flow.permute(0,2,3,1) # If using flow, sample around flow target.
30 | local_window = torch.meshgrid(
31 | (
32 | torch.linspace(-2*local_radius/h, 2*local_radius/h, 2*r+1, device=feature0.device),
33 | torch.linspace(-2*local_radius/w, 2*local_radius/w, 2*r+1, device=feature0.device),
34 | ),
35 | indexing = 'ij'
36 | )
37 | local_window = torch.stack((local_window[1], local_window[0]), dim=-1)[
38 | None
39 | ].expand(1, 2*r+1, 2*r+1, 2).reshape(1, (2*r+1)**2, 2)
40 | for _ in range(B):
41 | with torch.no_grad():
42 | local_window_coords = (coords[_,:,:,None]+local_window[:,None,None]).reshape(1,h,w*(2*r+1)**2,2)
43 | window_feature = F.grid_sample(
44 | feature1[_:_+1], local_window_coords, padding_mode=padding_mode, align_corners=False, mode = sample_mode, #
45 | )
46 | window_feature = window_feature.reshape(c,h,w,(2*r+1)**2)
47 | corr[_] = (feature0[_,...,None]/(c**.5)*window_feature).sum(dim=0).permute(2,0,1)
48 | return corr
49 |
--------------------------------------------------------------------------------
/romatch/models/transformer/layers/swiglu_ffn.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from typing import Callable, Optional
8 |
9 | from torch import Tensor, nn
10 | import torch.nn.functional as F
11 |
12 |
13 | class SwiGLUFFN(nn.Module):
14 | def __init__(
15 | self,
16 | in_features: int,
17 | hidden_features: Optional[int] = None,
18 | out_features: Optional[int] = None,
19 | act_layer: Callable[..., nn.Module] = None,
20 | drop: float = 0.0,
21 | bias: bool = True,
22 | ) -> None:
23 | super().__init__()
24 | out_features = out_features or in_features
25 | hidden_features = hidden_features or in_features
26 | self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
27 | self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
28 |
29 | def forward(self, x: Tensor) -> Tensor:
30 | x12 = self.w12(x)
31 | x1, x2 = x12.chunk(2, dim=-1)
32 | hidden = F.silu(x1) * x2
33 | return self.w3(hidden)
34 |
35 |
36 | try:
37 | from xformers.ops import SwiGLU
38 |
39 | XFORMERS_AVAILABLE = True
40 | except ImportError:
41 | SwiGLU = SwiGLUFFN
42 | XFORMERS_AVAILABLE = False
43 |
44 |
45 | class SwiGLUFFNFused(SwiGLU):
46 | def __init__(
47 | self,
48 | in_features: int,
49 | hidden_features: Optional[int] = None,
50 | out_features: Optional[int] = None,
51 | act_layer: Callable[..., nn.Module] = None,
52 | drop: float = 0.0,
53 | bias: bool = True,
54 | ) -> None:
55 | out_features = out_features or in_features
56 | hidden_features = hidden_features or in_features
57 | hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
58 | super().__init__(
59 | in_features=in_features,
60 | hidden_features=hidden_features,
61 | out_features=out_features,
62 | bias=bias,
63 | )
64 |
--------------------------------------------------------------------------------
/romatch/models/transformer/layers/dino_head.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import torch
8 | import torch.nn as nn
9 | from torch.nn.init import trunc_normal_
10 | from torch.nn.utils import weight_norm
11 |
12 |
13 | class DINOHead(nn.Module):
14 | def __init__(
15 | self,
16 | in_dim,
17 | out_dim,
18 | use_bn=False,
19 | nlayers=3,
20 | hidden_dim=2048,
21 | bottleneck_dim=256,
22 | mlp_bias=True,
23 | ):
24 | super().__init__()
25 | nlayers = max(nlayers, 1)
26 | self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias)
27 | self.apply(self._init_weights)
28 | self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
29 | self.last_layer.weight_g.data.fill_(1)
30 |
31 | def _init_weights(self, m):
32 | if isinstance(m, nn.Linear):
33 | trunc_normal_(m.weight, std=0.02)
34 | if isinstance(m, nn.Linear) and m.bias is not None:
35 | nn.init.constant_(m.bias, 0)
36 |
37 | def forward(self, x):
38 | x = self.mlp(x)
39 | eps = 1e-6 if x.dtype == torch.float16 else 1e-12
40 | x = nn.functional.normalize(x, dim=-1, p=2, eps=eps)
41 | x = self.last_layer(x)
42 | return x
43 |
44 |
45 | def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True):
46 | if nlayers == 1:
47 | return nn.Linear(in_dim, bottleneck_dim, bias=bias)
48 | else:
49 | layers = [nn.Linear(in_dim, hidden_dim, bias=bias)]
50 | if use_bn:
51 | layers.append(nn.BatchNorm1d(hidden_dim))
52 | layers.append(nn.GELU())
53 | for _ in range(nlayers - 2):
54 | layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias))
55 | if use_bn:
56 | layers.append(nn.BatchNorm1d(hidden_dim))
57 | layers.append(nn.GELU())
58 | layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias))
59 | return nn.Sequential(*layers)
60 |
--------------------------------------------------------------------------------
/romatch/models/transformer/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from romatch.utils.utils import get_grid, get_autocast_params
6 | from .layers.block import Block
7 | from .layers.attention import MemEffAttention
8 | from .dinov2 import vit_large, vit_base
9 |
10 | class TransformerDecoder(nn.Module):
11 | def __init__(self, blocks, hidden_dim, out_dim, is_classifier = False, *args,
12 | amp = False, pos_enc = True, learned_embeddings = False, embedding_dim = None, amp_dtype = torch.float16, **kwargs) -> None:
13 | super().__init__(*args, **kwargs)
14 | self.blocks = blocks
15 | self.to_out = nn.Linear(hidden_dim, out_dim)
16 | self.hidden_dim = hidden_dim
17 | self.out_dim = out_dim
18 | self._scales = [16, 18]
19 | self.is_classifier = is_classifier
20 | self.amp = amp
21 | self.amp_dtype = amp_dtype
22 | self.pos_enc = pos_enc
23 | self.learned_embeddings = learned_embeddings
24 | if self.learned_embeddings:
25 | self.learned_pos_embeddings = nn.Parameter(nn.init.kaiming_normal_(torch.empty((1, hidden_dim, embedding_dim, embedding_dim))))
26 |
27 | def scales(self):
28 | return self._scales.copy()
29 |
30 | def forward(self, gp_posterior, features, old_stuff, new_scale):
31 | autocast_device, autocast_enabled, autocast_dtype = get_autocast_params(gp_posterior.device, enabled=self.amp, dtype=self.amp_dtype)
32 | with torch.autocast(autocast_device, enabled=autocast_enabled, dtype = autocast_dtype):
33 | B,C,H,W = gp_posterior.shape
34 | x = torch.cat((gp_posterior, features), dim = 1)
35 | B,C,H,W = x.shape
36 | grid = get_grid(B, H, W, x.device).reshape(B,H*W,2)
37 | if self.learned_embeddings:
38 | pos_enc = F.interpolate(self.learned_pos_embeddings, size = (H,W), mode = 'bilinear', align_corners = False).permute(0,2,3,1).reshape(1,H*W,C)
39 | else:
40 | pos_enc = 0
41 | tokens = x.reshape(B,C,H*W).permute(0,2,1) + pos_enc
42 | z = self.blocks(tokens)
43 | out = self.to_out(z)
44 | out = out.permute(0,2,1).reshape(B, self.out_dim, H, W)
45 | warp, certainty = out[:, :-1], out[:, -1:]
46 | return warp, certainty, None
47 |
48 |
49 |
--------------------------------------------------------------------------------
/romatch/models/transformer/layers/attention.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # References:
8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
10 |
11 | import logging
12 |
13 | from torch import Tensor
14 | from torch import nn
15 |
16 |
17 | logger = logging.getLogger("dinov2")
18 |
19 |
20 | try:
21 | from xformers.ops import memory_efficient_attention, unbind, fmha
22 |
23 | XFORMERS_AVAILABLE = True
24 | except ImportError:
25 | # logger.warning("xFormers not available")
26 | XFORMERS_AVAILABLE = False
27 |
28 |
29 | class Attention(nn.Module):
30 | def __init__(
31 | self,
32 | dim: int,
33 | num_heads: int = 8,
34 | qkv_bias: bool = False,
35 | proj_bias: bool = True,
36 | attn_drop: float = 0.0,
37 | proj_drop: float = 0.0,
38 | ) -> None:
39 | super().__init__()
40 | self.num_heads = num_heads
41 | head_dim = dim // num_heads
42 | self.scale = head_dim**-0.5
43 |
44 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
45 | self.attn_drop = nn.Dropout(attn_drop)
46 | self.proj = nn.Linear(dim, dim, bias=proj_bias)
47 | self.proj_drop = nn.Dropout(proj_drop)
48 |
49 | def forward(self, x: Tensor) -> Tensor:
50 | B, N, C = x.shape
51 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
52 |
53 | q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
54 | attn = q @ k.transpose(-2, -1)
55 |
56 | attn = attn.softmax(dim=-1)
57 | attn = self.attn_drop(attn)
58 |
59 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
60 | x = self.proj(x)
61 | x = self.proj_drop(x)
62 | return x
63 |
64 |
65 | class MemEffAttention(Attention):
66 | def forward(self, x: Tensor, attn_bias=None) -> Tensor:
67 | if not XFORMERS_AVAILABLE:
68 | assert attn_bias is None, "xFormers is required for nested tensors usage"
69 | return super().forward(x)
70 |
71 | B, N, C = x.shape
72 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
73 |
74 | q, k, v = unbind(qkv, 2)
75 |
76 | x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
77 | x = x.reshape([B, N, C])
78 |
79 | x = self.proj(x)
80 | x = self.proj_drop(x)
81 | return x
82 |
--------------------------------------------------------------------------------
/romatch/models/transformer/layers/patch_embed.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # References:
8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
10 |
11 | from typing import Callable, Optional, Tuple, Union
12 |
13 | from torch import Tensor
14 | import torch.nn as nn
15 |
16 |
17 | def make_2tuple(x):
18 | if isinstance(x, tuple):
19 | assert len(x) == 2
20 | return x
21 |
22 | assert isinstance(x, int)
23 | return (x, x)
24 |
25 |
26 | class PatchEmbed(nn.Module):
27 | """
28 | 2D image to patch embedding: (B,C,H,W) -> (B,N,D)
29 |
30 | Args:
31 | img_size: Image size.
32 | patch_size: Patch token size.
33 | in_chans: Number of input image channels.
34 | embed_dim: Number of linear projection output channels.
35 | norm_layer: Normalization layer.
36 | """
37 |
38 | def __init__(
39 | self,
40 | img_size: Union[int, Tuple[int, int]] = 224,
41 | patch_size: Union[int, Tuple[int, int]] = 16,
42 | in_chans: int = 3,
43 | embed_dim: int = 768,
44 | norm_layer: Optional[Callable] = None,
45 | flatten_embedding: bool = True,
46 | ) -> None:
47 | super().__init__()
48 |
49 | image_HW = make_2tuple(img_size)
50 | patch_HW = make_2tuple(patch_size)
51 | patch_grid_size = (
52 | image_HW[0] // patch_HW[0],
53 | image_HW[1] // patch_HW[1],
54 | )
55 |
56 | self.img_size = image_HW
57 | self.patch_size = patch_HW
58 | self.patches_resolution = patch_grid_size
59 | self.num_patches = patch_grid_size[0] * patch_grid_size[1]
60 |
61 | self.in_chans = in_chans
62 | self.embed_dim = embed_dim
63 |
64 | self.flatten_embedding = flatten_embedding
65 |
66 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
67 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
68 |
69 | def forward(self, x: Tensor) -> Tensor:
70 | _, _, H, W = x.shape
71 | patch_H, patch_W = self.patch_size
72 |
73 | assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
74 | assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
75 |
76 | x = self.proj(x) # B C H W
77 | H, W = x.size(2), x.size(3)
78 | x = x.flatten(2).transpose(1, 2) # B HW C
79 | x = self.norm(x)
80 | if not self.flatten_embedding:
81 | x = x.reshape(-1, H, W, self.embed_dim) # B H W C
82 | return x
83 |
84 | def flops(self) -> float:
85 | Ho, Wo = self.patches_resolution
86 | flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
87 | if self.norm is not None:
88 | flops += Ho * Wo * self.embed_dim
89 | return flops
90 |
--------------------------------------------------------------------------------
/sky_seg.py:
--------------------------------------------------------------------------------
1 | import os
2 | import copy
3 | import time
4 | import argparse
5 | import cv2
6 | import numpy as np
7 | import onnxruntime
8 | from tqdm import tqdm
9 | import imutils
10 |
11 | def run_inference(onnx_session, input_size, image):
12 | # Pre process:Resize, BGR->RGB, Transpose, PyTorch standardization, float32 cast
13 | temp_image = copy.deepcopy(image)
14 | resize_image = cv2.resize(temp_image, dsize=(input_size[0], input_size[1]))
15 | x = cv2.cvtColor(resize_image, cv2.COLOR_BGR2RGB)
16 | x = np.array(x, dtype=np.float32)
17 | mean = [0.485, 0.456, 0.406]
18 | std = [0.229, 0.224, 0.225]
19 | x = (x / 255 - mean) / std
20 | x = x.transpose(2, 0, 1)
21 | x = x.reshape(-1, 3, input_size[0], input_size[1]).astype('float32')
22 |
23 | # Inference
24 | input_name = onnx_session.get_inputs()[0].name
25 | output_name = onnx_session.get_outputs()[0].name
26 | onnx_result = onnx_session.run([output_name], {input_name: x})
27 |
28 | # Post process
29 | onnx_result = np.array(onnx_result).squeeze()
30 | min_value = np.min(onnx_result)
31 | max_value = np.max(onnx_result)
32 | onnx_result = (onnx_result - min_value) / (max_value - min_value)
33 | onnx_result *= 255
34 | onnx_result = onnx_result.astype('uint8')
35 |
36 | return onnx_result
37 |
38 | from argparse import ArgumentParser
39 |
40 | if __name__ == "__main__":
41 | from argparse import ArgumentParser
42 | parser = ArgumentParser()
43 | parser.add_argument("--base", type=str)
44 | parser.add_argument("--debug", action='store_true', default=False)
45 | args, _ = parser.parse_known_args()
46 |
47 | onnx_session = onnxruntime.InferenceSession("checkpoints/skyseg.onnx")
48 | base = args.base
49 | img_base = os.path.join(base, "image1")
50 | out_base = os.path.join(base, "debug2")
51 | depth_base = os.path.join(base, "depth1")
52 |
53 | if not os.path.exists(out_base):
54 | os.mkdir(out_base)
55 |
56 | count = 0
57 |
58 | for img in tqdm(os.listdir(img_base)):
59 |
60 | # image = cv2.imread("../zeb/eth3do/playground-DSC_0585.png")
61 | image = cv2.imread(os.path.join(img_base, img))
62 | H, W = image.shape[:2]
63 |
64 | while(image.shape[0] >= 640 and image.shape[1] >= 640):
65 | image = cv2.pyrDown(image)
66 | result_map = run_inference(onnx_session,[320,320],image)
67 | result_map = imutils.resize(result_map, height=H, width=W)
68 |
69 | # cv2.imwrite(os.path.join(out_base, img), result_map)
70 | depth = np.load(os.path.join(depth_base, img[:-4]+".npy"))
71 |
72 | disp_ = 1 / (depth + 1)
73 | disp2_ = 1 / (depth + 1)
74 |
75 | depth[result_map > 255 * 0.5] = 0.0
76 | np.save(os.path.join(depth_base, img[:-4]+".npy"), depth)
77 |
78 | if args.debug and count < 100:
79 |
80 | disp2_[result_map > 255 * 0.5] = 0.0
81 |
82 | disp_ = ((disp_ - disp_.min()) / (disp_.max() - disp_.min()) * 255).astype(np.uint8)
83 | disp2_ = ((disp2_ - disp2_.min()) / (disp2_.max() - disp2_.min()) * 255).astype(np.uint8)
84 |
85 | cv2.imwrite(os.path.join(out_base, img), np.hstack([disp_, disp2_, result_map]))
86 |
87 | # import IPython
88 | # IPython.embed()
89 | # exit()
90 | count += 1
--------------------------------------------------------------------------------
/romatch/utils/transforms.py:
--------------------------------------------------------------------------------
1 | from typing import Dict
2 | import numpy as np
3 | import torch
4 | import kornia.augmentation as K
5 | from kornia.geometry.transform import warp_perspective
6 |
7 | # Adapted from Kornia
8 | class GeometricSequential:
9 | def __init__(self, *transforms, align_corners=True) -> None:
10 | self.transforms = transforms
11 | self.align_corners = align_corners
12 |
13 | def __call__(self, x, mode="bilinear"):
14 | b, c, h, w = x.shape
15 | M = torch.eye(3, device=x.device)[None].expand(b, 3, 3)
16 | for t in self.transforms:
17 | if np.random.rand() < t.p:
18 | M = M.matmul(
19 | t.compute_transformation(x, t.generate_parameters((b, c, h, w)), None)
20 | )
21 | return (
22 | warp_perspective(
23 | x, M, dsize=(h, w), mode=mode, align_corners=self.align_corners
24 | ),
25 | M,
26 | )
27 |
28 | def apply_transform(self, x, M, mode="bilinear"):
29 | b, c, h, w = x.shape
30 | return warp_perspective(
31 | x, M, dsize=(h, w), align_corners=self.align_corners, mode=mode
32 | )
33 |
34 |
35 | class RandomPerspective(K.RandomPerspective):
36 | def generate_parameters(self, batch_shape: torch.Size) -> Dict[str, torch.Tensor]:
37 | distortion_scale = torch.as_tensor(
38 | self.distortion_scale, device=self._device, dtype=self._dtype
39 | )
40 | return self.random_perspective_generator(
41 | batch_shape[0],
42 | batch_shape[-2],
43 | batch_shape[-1],
44 | distortion_scale,
45 | self.same_on_batch,
46 | self.device,
47 | self.dtype,
48 | )
49 |
50 | def random_perspective_generator(
51 | self,
52 | batch_size: int,
53 | height: int,
54 | width: int,
55 | distortion_scale: torch.Tensor,
56 | same_on_batch: bool = False,
57 | device: torch.device = torch.device("cpu"),
58 | dtype: torch.dtype = torch.float32,
59 | ) -> Dict[str, torch.Tensor]:
60 | r"""Get parameters for ``perspective`` for a random perspective transform.
61 |
62 | Args:
63 | batch_size (int): the tensor batch size.
64 | height (int) : height of the image.
65 | width (int): width of the image.
66 | distortion_scale (torch.Tensor): it controls the degree of distortion and ranges from 0 to 1.
67 | same_on_batch (bool): apply the same transformation across the batch. Default: False.
68 | device (torch.device): the device on which the random numbers will be generated. Default: cpu.
69 | dtype (torch.dtype): the data type of the generated random numbers. Default: float32.
70 |
71 | Returns:
72 | params Dict[str, torch.Tensor]: parameters to be passed for transformation.
73 | - start_points (torch.Tensor): element-wise perspective source areas with a shape of (B, 4, 2).
74 | - end_points (torch.Tensor): element-wise perspective target areas with a shape of (B, 4, 2).
75 |
76 | Note:
77 | The generated random numbers are not reproducible across different devices and dtypes.
78 | """
79 | if not (distortion_scale.dim() == 0 and 0 <= distortion_scale <= 1):
80 | raise AssertionError(
81 | f"'distortion_scale' must be a scalar within [0, 1]. Got {distortion_scale}."
82 | )
83 | if not (
84 | type(height) is int and height > 0 and type(width) is int and width > 0
85 | ):
86 | raise AssertionError(
87 | f"'height' and 'width' must be integers. Got {height}, {width}."
88 | )
89 |
90 | start_points: torch.Tensor = torch.tensor(
91 | [[[0.0, 0], [width - 1, 0], [width - 1, height - 1], [0, height - 1]]],
92 | device=distortion_scale.device,
93 | dtype=distortion_scale.dtype,
94 | ).expand(batch_size, -1, -1)
95 |
96 | # generate random offset not larger than half of the image
97 | fx = distortion_scale * width / 2
98 | fy = distortion_scale * height / 2
99 |
100 | factor = torch.stack([fx, fy], dim=0).view(-1, 1, 2)
101 | offset = (torch.rand_like(start_points) - 0.5) * 2
102 | end_points = start_points + factor * offset
103 |
104 | return dict(start_points=start_points, end_points=end_points)
105 |
106 |
107 |
108 | class RandomErasing:
109 | def __init__(self, p = 0., scale = 0.) -> None:
110 | self.p = p
111 | self.scale = scale
112 | self.random_eraser = K.RandomErasing(scale = (0.02, scale), p = p)
113 | def __call__(self, image, depth):
114 | if self.p > 0:
115 | image = self.random_eraser(image)
116 | depth = self.random_eraser(depth, params=self.random_eraser._params)
117 | return image, depth
118 |
--------------------------------------------------------------------------------
/vis_feats.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | import torch
4 | import torchvision.transforms as T
5 | import matplotlib.pyplot as plt
6 | import numpy as np
7 | from PIL import Image
8 | from sklearn.decomposition import PCA
9 | import argparse
10 | from romatch.models.transformer import vit_base
11 |
12 | device = "cuda" if torch.cuda.is_available() else "cpu"
13 |
14 | # -------------------- 可视化函数 --------------------
15 | def vis_feat_map_batch(features_list, patch_h, patch_w, resize_hw=(560, 560)):
16 | all_feats = np.concatenate([f.reshape(patch_h * patch_w, -1) for f in features_list], axis=0)
17 | pca = PCA(n_components=3)
18 | pca.fit(all_feats)
19 |
20 | images = []
21 | for features in features_list:
22 | f = features.reshape(patch_h * patch_w, -1)
23 | pca_feats = pca.transform(f)
24 | pca_feats = (pca_feats - pca_feats.mean(0)) / (pca_feats.std(0) + 1e-5)
25 | pca_feats = np.clip(pca_feats * 0.5 + 0.5, 0, 1)
26 | img = pca_feats.reshape(patch_h, patch_w, 3)
27 | img = (img * 255).astype(np.uint8)
28 | img = Image.fromarray(img).resize(resize_hw, Image.BICUBIC)
29 | images.append(img)
30 | return images
31 |
32 | def save_combined_visualization(
33 | feats_dino, feats_fit3d, feats_L2M,
34 | patch_h, patch_w, base_name, save_dir, original_images
35 | ):
36 | os.makedirs(save_dir, exist_ok=True)
37 |
38 | imgs_dino = vis_feat_map_batch(feats_dino, patch_h, patch_w)
39 | imgs_fit3d = vis_feat_map_batch(feats_fit3d, patch_h, patch_w)
40 | imgs_L2M = vis_feat_map_batch(feats_L2M, patch_h, patch_w)
41 |
42 | # 拼图:每行一个图,共两行四列
43 | fig, axs = plt.subplots(2, 4, figsize=(16, 8))
44 | titles = ["Original", "DINOv2", "Fit3D", "L2M (Ours)"]
45 | for i in range(2): # row
46 | row_imgs = [original_images[i], imgs_dino[i], imgs_fit3d[i], imgs_L2M[i]]
47 | for j in range(4):
48 | axs[i, j].imshow(row_imgs[j])
49 | axs[i, j].set_title(titles[j], fontsize=12)
50 | axs[i, j].axis("off")
51 | plt.tight_layout()
52 | plt.savefig(os.path.join(save_dir, f"{base_name}_compare.png"))
53 | plt.close()
54 |
55 | # -------------------- 特征提取函数 --------------------
56 | def extract_features(model, image_tensor):
57 | with torch.no_grad():
58 | return model.forward_features(image_tensor)["x_norm_patchtokens"].squeeze(0).cpu().numpy()
59 |
60 | # -------------------- 主脚本 --------------------
61 | def main(args):
62 | os.makedirs(args.save_dir, exist_ok=True)
63 |
64 | patch_h, patch_w = 37, 37
65 | img_size = patch_h * 14
66 |
67 | transform = T.Compose([
68 | T.Resize((img_size, img_size)),
69 | T.CenterCrop((img_size, img_size)),
70 | T.ToTensor(),
71 | T.Normalize(mean=(0.485, 0.456, 0.406),
72 | std=(0.229, 0.224, 0.225)),
73 | ])
74 |
75 | # 初始化模型
76 | vit_kwargs = dict(
77 | img_size=img_size,
78 | patch_size=14,
79 | init_values=1.0,
80 | ffn_layer="mlp",
81 | block_chunks=0
82 | )
83 |
84 | def load_model(ckpt_path):
85 | model = vit_base(**vit_kwargs).eval().to(device)
86 | raw = torch.load(ckpt_path, map_location="cpu")
87 | if "model" in raw:
88 | raw = raw["model"]
89 | ckpt = {k.replace("model.", ""): v for k, v in raw.items()}
90 | model.load_state_dict(ckpt, strict=False)
91 | return model
92 |
93 | dino = load_model(args.ckpt_dino)
94 | fit3d = load_model(args.ckpt_fit3d)
95 | L2M = load_model(args.ckpt_L2M)
96 |
97 | feats_dino, feats_fit3d, feats_L2M = [], [], []
98 | original_images = []
99 |
100 | for img_path in args.img_paths:
101 | img = Image.open(img_path).convert("RGB")
102 | x = transform(img).unsqueeze(0).to(device)
103 |
104 | feats_dino.append(extract_features(dino, x))
105 | feats_fit3d.append(extract_features(fit3d, x))
106 | feats_L2M.append(extract_features(L2M, x))
107 | original_images.append(img)
108 |
109 | base_name = "multi" if len(args.img_paths) > 1 else os.path.splitext(os.path.basename(args.img_paths[0]))[0]
110 | save_combined_visualization(
111 | feats_dino, feats_fit3d, feats_L2M,
112 | patch_h, patch_w, base_name, args.save_dir, original_images
113 | )
114 |
115 | print(f"Saved 2-row comparison to {os.path.join(args.save_dir, f'{base_name}_compare.png')}")
116 |
117 | if __name__ == "__main__":
118 | parser = argparse.ArgumentParser()
119 | parser.add_argument(
120 | "--img_paths",
121 | nargs="+",
122 | default=[
123 | "assets/sacre_coeur_A.jpg",
124 | "assets/sacre_coeur_B.jpg"
125 | ],
126 | help="List of image paths"
127 | )
128 | parser.add_argument(
129 | "--ckpt_fit3d",
130 | default="ckpts/fit3d.pth",
131 | help="Original Fit3D checkpoint"
132 | )
133 | parser.add_argument(
134 | "--ckpt_L2M",
135 | default="ckpts/l2m_vit_base.pth",
136 | help="L2M Fit3D checkpoint"
137 | )
138 | parser.add_argument(
139 | "--ckpt_dino",
140 | default="ckpts/dinov2.pth",
141 | help="dino checkpoint"
142 | )
143 | parser.add_argument(
144 | "--save_dir",
145 | default="outputs_vis_feat",
146 | help="Directory to save visualizations"
147 | )
148 | args = parser.parse_args()
149 | main(args)
150 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Learning Dense Feature Matching via Lifting Single 2D Image to 3D Space
2 |
3 | 
4 |
5 | Welcome to the **L2M** repository! This is the official implementation of our ICCV'25 paper titled "Learning Dense Feature Matching via Lifting Single 2D Image to 3D Space".
6 |
7 | *Accepted to ICCV 2025 Conference*
8 |
9 | ---
10 |
11 | > 🚨 **Important Notice:**
12 | > This repository is the **official implementation** of the ICCV 2025 paper authored by Sharpiless.
13 | >
14 | > Please be aware that the repository at [https://github.com/chelseaaxy/L2M](https://github.com/chelseaaxy/L2M) is **NOT an official implementation** and is not authorized by the original authors.
15 | >
16 | > Always refer to this repository for the authentic and up-to-date code.
17 |
18 |
19 | ## 🧠 Overview
20 |
21 | **Lift to Match (L2M)** is a two-stage framework for **dense feature matching** that lifts 2D images into 3D space to enhance feature generalization and robustness. Unlike traditional methods that depend on multi-view image pairs, L2M is trained on large-scale, diverse single-view image collections.
22 |
23 | - **Stage 1:** Learn a **3D-aware ViT-based encoder** using multi-view image synthesis and 3D Gaussian feature representation.
24 | - **Stage 2:** Learn a **feature decoder** through novel-view rendering and synthetic data, enabling robust matching across diverse scenarios.
25 |
26 | > 🚧 Code is still under construction.
27 |
28 | ---
29 |
30 | ## 🧪 Feature Visualization
31 |
32 | We compare the 3D-aware ViT encoder from L2M (Stage 1) with other recent methods:
33 |
34 | - **DINOv2**: Learning Robust Visual Features without Supervision
35 | - **FiT3D**: Improving 2D Feature Representations by 3D-Aware Fine-Tuning
36 | - **Ours: L2M Encoder**
37 |
38 | You can download them from the [Releases](https://github.com/Sharpiless/L2M/releases/tag/checkpoints) page.
39 |
40 |
41 |

42 |
43 |
44 |
45 |
46 |

47 |
48 |
49 |
50 | ---
51 |
52 | To get the results, make sure your checkpoints and image files are in the correct paths, then run:
53 | ```
54 | python vis_feats.py \
55 | --img_paths assets/sacre_coeur_A.jpg assets/sacre_coeur_B.jpg \
56 | --ckpt_dino ckpts/dinov2.pth \
57 | --ckpt_fit3d ckpts/fit3d.pth \
58 | --ckpt_L2M ckpts/l2m_vit_base.pth \
59 | --save_dir outputs_vis_feat
60 | ```
61 |
62 | ## 🏗️ Data Generation
63 |
64 | To enable training from single-view images, we simulate diverse multi-view observations and their corresponding dense correspondence labels in a fully automatic manner.
65 |
66 | #### Stage 2.1: Novel View Synthesis
67 | We lift a single-view image to a coarse 3D structure and then render novel views from different camera poses. These synthesized multi-view images are used to supervise the feature encoder with dense matching consistency.
68 |
69 | Run the following to generate novel-view images with ground-truth dense correspondences:
70 | ```
71 | python get_data.py \
72 | --output_path [PATH-to-SAVE] \
73 | --data_path [PATH-to-IMAGES] \
74 | --disp_path [PATH-to-MONO-DEPTH]
75 | ```
76 |
77 | This code provides an example on novel view generation with dense matching ground truth.
78 |
79 | The disp_path should contain grayscale disparity maps predicted by Depth Anything V2 or another monocular depth estimator.
80 |
81 | Below are examples of synthesized novel views with ground-truth dense correspondences, generated in Stage 2.1:
82 |
83 |
84 |
85 | 
86 |
87 |
88 | These demonstrate both the geometric diversity and high-quality pixel-level correspondence labels used for supervision.
89 |
90 | For novel-view inpainting, we also provide a better inpainting model fine-tuned from Stable-Diffusion-2.0-Inpainting:
91 |
92 | ```
93 | from diffusers import StableDiffusionInpaintPipeline
94 | import torch
95 | from diffusers.utils import load_image, make_image_grid
96 | import PIL
97 |
98 | # 指定模型文件路径
99 | model_path = "Liangyingping/L2M-Inpainting" # 替换为你自己的模型路径
100 |
101 | # 加载模型
102 | pipe = StableDiffusionInpaintPipeline.from_pretrained(
103 | model_path, torch_dtype=torch.float16
104 | )
105 | pipe.to("cuda") # 如果有 GPU,可以将模型加载到 GPU 上
106 |
107 | init_image = load_image("assets/debug_masked_image.png")
108 | mask_image = load_image("assets/debug_mask.png")
109 | W, H = init_image.size
110 |
111 | prompt = "a photo of a person"
112 | image = pipe(
113 | prompt=prompt,
114 | image=init_image,
115 | mask_image=mask_image,
116 | h=512, w=512
117 | ).images[0].resize((W, H))
118 |
119 | print(image.size, init_image.size)
120 |
121 | image2save = make_image_grid([init_image, mask_image, image], rows=1, cols=3)
122 | image2save.save("image2save_ours.png")
123 | ```
124 |
125 | Or you can manually download the model from [hugging-face](https://huggingface.co/Liangyingping/L2M-Inpainting).
126 |
127 |
128 |
129 |
130 |
131 | #### Stage 2.2: Relighting for Appearance Diversity
132 | To improve feature robustness under varying lighting conditions, we apply a physics-inspired relighting pipeline to the synthesized 3D scenes.
133 |
134 | Run the following to generate relit image pairs for training the decoder:
135 | ```
136 | python relight.py
137 | ```
138 | All outputs will be saved under the configured output directory, including original view, novel views, and their camera metrics with dense depth.
139 |
140 |
141 |
142 |
143 | #### Stage 2.3: Sky Masking (Optional)
144 |
145 | If desired, you can run sky_seg.py to mask out sky regions, which are typically textureless and not useful for matching. This can help reduce noise and focus training on geometrically meaningful regions.
146 |
147 | ```
148 | python sky_seg.py
149 | ```
150 |
151 | 
152 |
153 |
154 | ## 🙋♂️ Acknowledgements
155 |
156 | We build upon recent advances in [ROMA](https://github.com/Parskatt/RoMa), [GIM](https://github.com/xuelunshen/gim), and [FiT3D](https://github.com/ywyue/FiT3D).
157 |
--------------------------------------------------------------------------------
/romatch/models/transformer/layers/block.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # References:
8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
10 |
11 | import logging
12 | from typing import Callable, List, Any, Tuple, Dict
13 |
14 | import torch
15 | from torch import nn, Tensor
16 |
17 | from .attention import Attention, MemEffAttention
18 | from .drop_path import DropPath
19 | from .layer_scale import LayerScale
20 | from .mlp import Mlp
21 |
22 |
23 | logger = logging.getLogger("dinov2")
24 |
25 |
26 | try:
27 | from xformers.ops import fmha
28 | from xformers.ops import scaled_index_add, index_select_cat
29 |
30 | XFORMERS_AVAILABLE = True
31 | except ImportError:
32 | # logger.warning("xFormers not available")
33 | XFORMERS_AVAILABLE = False
34 |
35 |
36 | class Block(nn.Module):
37 | def __init__(
38 | self,
39 | dim: int,
40 | num_heads: int,
41 | mlp_ratio: float = 4.0,
42 | qkv_bias: bool = False,
43 | proj_bias: bool = True,
44 | ffn_bias: bool = True,
45 | drop: float = 0.0,
46 | attn_drop: float = 0.0,
47 | init_values=None,
48 | drop_path: float = 0.0,
49 | act_layer: Callable[..., nn.Module] = nn.GELU,
50 | norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
51 | attn_class: Callable[..., nn.Module] = Attention,
52 | ffn_layer: Callable[..., nn.Module] = Mlp,
53 | ) -> None:
54 | super().__init__()
55 | # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
56 | self.norm1 = norm_layer(dim)
57 | self.attn = attn_class(
58 | dim,
59 | num_heads=num_heads,
60 | qkv_bias=qkv_bias,
61 | proj_bias=proj_bias,
62 | attn_drop=attn_drop,
63 | proj_drop=drop,
64 | )
65 | self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
66 | self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
67 |
68 | self.norm2 = norm_layer(dim)
69 | mlp_hidden_dim = int(dim * mlp_ratio)
70 | self.mlp = ffn_layer(
71 | in_features=dim,
72 | hidden_features=mlp_hidden_dim,
73 | act_layer=act_layer,
74 | drop=drop,
75 | bias=ffn_bias,
76 | )
77 | self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
78 | self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
79 |
80 | self.sample_drop_ratio = drop_path
81 |
82 | def forward(self, x: Tensor) -> Tensor:
83 | def attn_residual_func(x: Tensor) -> Tensor:
84 | return self.ls1(self.attn(self.norm1(x)))
85 |
86 | def ffn_residual_func(x: Tensor) -> Tensor:
87 | return self.ls2(self.mlp(self.norm2(x)))
88 |
89 | if self.training and self.sample_drop_ratio > 0.1:
90 | # the overhead is compensated only for a drop path rate larger than 0.1
91 | x = drop_add_residual_stochastic_depth(
92 | x,
93 | residual_func=attn_residual_func,
94 | sample_drop_ratio=self.sample_drop_ratio,
95 | )
96 | x = drop_add_residual_stochastic_depth(
97 | x,
98 | residual_func=ffn_residual_func,
99 | sample_drop_ratio=self.sample_drop_ratio,
100 | )
101 | elif self.training and self.sample_drop_ratio > 0.0:
102 | x = x + self.drop_path1(attn_residual_func(x))
103 | x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
104 | else:
105 | x = x + attn_residual_func(x)
106 | x = x + ffn_residual_func(x)
107 | return x
108 |
109 |
110 | def drop_add_residual_stochastic_depth(
111 | x: Tensor,
112 | residual_func: Callable[[Tensor], Tensor],
113 | sample_drop_ratio: float = 0.0,
114 | ) -> Tensor:
115 | # 1) extract subset using permutation
116 | b, n, d = x.shape
117 | sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
118 | brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
119 | x_subset = x[brange]
120 |
121 | # 2) apply residual_func to get residual
122 | residual = residual_func(x_subset)
123 |
124 | x_flat = x.flatten(1)
125 | residual = residual.flatten(1)
126 |
127 | residual_scale_factor = b / sample_subset_size
128 |
129 | # 3) add the residual
130 | x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
131 | return x_plus_residual.view_as(x)
132 |
133 |
134 | def get_branges_scales(x, sample_drop_ratio=0.0):
135 | b, n, d = x.shape
136 | sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
137 | brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
138 | residual_scale_factor = b / sample_subset_size
139 | return brange, residual_scale_factor
140 |
141 |
142 | def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
143 | if scaling_vector is None:
144 | x_flat = x.flatten(1)
145 | residual = residual.flatten(1)
146 | x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
147 | else:
148 | x_plus_residual = scaled_index_add(
149 | x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
150 | )
151 | return x_plus_residual
152 |
153 |
154 | attn_bias_cache: Dict[Tuple, Any] = {}
155 |
156 |
157 | def get_attn_bias_and_cat(x_list, branges=None):
158 | """
159 | this will perform the index select, cat the tensors, and provide the attn_bias from cache
160 | """
161 | batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
162 | all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
163 | if all_shapes not in attn_bias_cache.keys():
164 | seqlens = []
165 | for b, x in zip(batch_sizes, x_list):
166 | for _ in range(b):
167 | seqlens.append(x.shape[1])
168 | attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
169 | attn_bias._batch_sizes = batch_sizes
170 | attn_bias_cache[all_shapes] = attn_bias
171 |
172 | if branges is not None:
173 | cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
174 | else:
175 | tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
176 | cat_tensors = torch.cat(tensors_bs1, dim=1)
177 |
178 | return attn_bias_cache[all_shapes], cat_tensors
179 |
180 |
181 | def drop_add_residual_stochastic_depth_list(
182 | x_list: List[Tensor],
183 | residual_func: Callable[[Tensor, Any], Tensor],
184 | sample_drop_ratio: float = 0.0,
185 | scaling_vector=None,
186 | ) -> Tensor:
187 | # 1) generate random set of indices for dropping samples in the batch
188 | branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
189 | branges = [s[0] for s in branges_scales]
190 | residual_scale_factors = [s[1] for s in branges_scales]
191 |
192 | # 2) get attention bias and index+concat the tensors
193 | attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
194 |
195 | # 3) apply residual_func to get residual, and split the result
196 | residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
197 |
198 | outputs = []
199 | for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
200 | outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
201 | return outputs
202 |
203 |
204 | class NestedTensorBlock(Block):
205 | def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
206 | """
207 | x_list contains a list of tensors to nest together and run
208 | """
209 | assert isinstance(self.attn, MemEffAttention)
210 |
211 | if self.training and self.sample_drop_ratio > 0.0:
212 |
213 | def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
214 | return self.attn(self.norm1(x), attn_bias=attn_bias)
215 |
216 | def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
217 | return self.mlp(self.norm2(x))
218 |
219 | x_list = drop_add_residual_stochastic_depth_list(
220 | x_list,
221 | residual_func=attn_residual_func,
222 | sample_drop_ratio=self.sample_drop_ratio,
223 | scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
224 | )
225 | x_list = drop_add_residual_stochastic_depth_list(
226 | x_list,
227 | residual_func=ffn_residual_func,
228 | sample_drop_ratio=self.sample_drop_ratio,
229 | scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
230 | )
231 | return x_list
232 | else:
233 |
234 | def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
235 | return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
236 |
237 | def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
238 | return self.ls2(self.mlp(self.norm2(x)))
239 |
240 | attn_bias, x = get_attn_bias_and_cat(x_list)
241 | x = x + attn_residual_func(x, attn_bias=attn_bias)
242 | x = x + ffn_residual_func(x)
243 | return attn_bias.split(x)
244 |
245 | def forward(self, x_or_x_list):
246 | if isinstance(x_or_x_list, Tensor):
247 | return super().forward(x_or_x_list)
248 | elif isinstance(x_or_x_list, list):
249 | assert XFORMERS_AVAILABLE, "Please install xFormers for nested tensors usage"
250 | return self.forward_nested(x_or_x_list)
251 | else:
252 | raise AssertionError
253 |
--------------------------------------------------------------------------------
/relight.py:
--------------------------------------------------------------------------------
1 | import bpy
2 | import numpy as np
3 | import cv2
4 | import math
5 | import os
6 | import imutils
7 | import scipy.ndimage
8 | from depth_anything_v2.dpt import DepthAnythingV2
9 |
10 | # conda activate pgdvs
11 |
12 |
13 | def resize_and_center_crop(image, disparity):
14 | # 获取图像和视差图的尺寸
15 | h, w = image.shape[:2]
16 |
17 | # 计算最短边的尺寸
18 | shortest_edge = min(h, w)
19 |
20 | # 按最短边缩放
21 | if h < w:
22 | new_h = shortest_edge
23 | new_w = int(shortest_edge * (w / h))
24 | else:
25 | new_w = shortest_edge
26 | new_h = int(shortest_edge * (h / w))
27 |
28 | # 缩放图像
29 | image_resized = cv2.resize(image, (new_w, new_h))
30 | disparity_resized = cv2.resize(disparity, (new_w, new_h))
31 |
32 | # 计算裁剪区域,使得图像变为正方形
33 | crop_size = min(image_resized.shape[:2]) # 取缩放后图像的最短边作为裁剪大小
34 | start_x = (new_w - crop_size) // 2
35 | start_y = (new_h - crop_size) // 2
36 |
37 | # 裁剪图像和视差图
38 | image_cropped = image_resized[
39 | start_y : start_y + crop_size, start_x : start_x + crop_size
40 | ]
41 | disparity_cropped = disparity_resized[
42 | start_y : start_y + crop_size, start_x : start_x + crop_size
43 | ]
44 |
45 | return image_cropped, disparity_cropped
46 |
47 |
48 | def create_mesh(vertices, faces, colors, mesh_name="GeneratedMesh"):
49 | # 创建网格和对象
50 | mesh_data = bpy.data.meshes.new(mesh_name)
51 | mesh_object = bpy.data.objects.new(mesh_name, mesh_data)
52 | bpy.context.collection.objects.link(mesh_object)
53 |
54 | # 设置顶点和面
55 | mesh_data.from_pydata(vertices.tolist(), [], faces.tolist())
56 | mesh_data.update()
57 |
58 | # 添加顶点颜色数据
59 | if not mesh_data.vertex_colors:
60 | mesh_data.vertex_colors.new(name="Col")
61 | color_layer = mesh_data.vertex_colors["Col"]
62 |
63 | # 设置每个顶点的颜色
64 | for poly in mesh_data.polygons:
65 | for loop_index in poly.loop_indices:
66 | vertex_index = mesh_data.loops[loop_index].vertex_index
67 | color_layer.data[loop_index].color = (*colors[vertex_index], 1.0) # RGBA
68 |
69 | # 添加材质并启用漫反射着色器
70 | if not mesh_object.data.materials:
71 | material = bpy.data.materials.new("MeshMaterial")
72 | material.use_nodes = True # 使用节点着色
73 | mesh_object.data.materials.append(material)
74 |
75 | # 获取材质的节点树
76 | nodes = material.node_tree.nodes
77 | links = material.node_tree.links
78 |
79 | # 删除默认的 Principled BSDF 节点
80 | for node in nodes:
81 | if node.type == "BSDF_PRINCIPLED":
82 | nodes.remove(node)
83 |
84 | # 添加 Diffuse BSDF 节点
85 | diffuse_bsdf = nodes.new(type="ShaderNodeBsdfDiffuse")
86 |
87 | # 创建顶点颜色节点
88 | vertex_color_node = nodes.new("ShaderNodeAttribute")
89 | vertex_color_node.attribute_name = "Col" # 顶点颜色的名称
90 | links.new(vertex_color_node.outputs[0], diffuse_bsdf.inputs["Color"])
91 |
92 | # 将 Diffuse BSDF 节点连接到材质的 Surface 输入
93 | material_output = nodes.get("Material Output")
94 | links.new(diffuse_bsdf.outputs[0], material_output.inputs["Surface"])
95 |
96 | return mesh_object
97 |
98 |
99 | def generate_faces(image_width, image_height):
100 | indices = np.arange(image_width * image_height).reshape(image_height, image_width)
101 | lower_left = indices[:-1, :-1].ravel()
102 | lower_right = indices[:-1, 1:].ravel()
103 | upper_left = indices[1:, :-1].ravel()
104 | upper_right = indices[1:, 1:].ravel()
105 | faces = np.column_stack(
106 | (
107 | np.column_stack((lower_left, lower_right, upper_left)),
108 | np.column_stack((lower_right, upper_right, upper_left)),
109 | )
110 | ).reshape(-1, 3)
111 | return faces
112 |
113 |
114 | def setup_mesh(image_path, depth_anything_model):
115 | # 加载图像和 disparity
116 | image = cv2.imread(image_path)
117 |
118 | disparity = depth_anything_model.infer_image(image[:, :, ::-1], 518).astype(
119 | np.float32
120 | )
121 |
122 | disparity = (disparity - disparity.min()) / (disparity.max() - disparity.min())
123 |
124 | H, W = image.shape[:2]
125 |
126 | print(disparity.shape, image.shape)
127 |
128 | if H > W:
129 | image = imutils.resize(image, width=512)
130 | disparity = imutils.resize(disparity, width=512)
131 | elif H < W:
132 | image = imutils.resize(image, height=512)
133 | disparity = imutils.resize(disparity, height=512)
134 | else:
135 | image = imutils.resize(image, height=512, width=512)
136 | disparity = imutils.resize(disparity, height=512, width=512)
137 |
138 | print(disparity.shape, image.shape)
139 |
140 | disparity = np.clip(disparity, 0.01, 1)
141 |
142 | disparity = scipy.ndimage.gaussian_filter(disparity, sigma=2)
143 |
144 | print(disparity.shape, image.shape)
145 |
146 | cv2.imwrite("image.png", image)
147 |
148 | image_height, image_width = image.shape[:2]
149 |
150 | # 相机内参
151 | focal = 0.58 * image_width
152 | cx = image_width / 2
153 | cy = image_height / 2
154 | fx = focal
155 | fy = focal
156 |
157 | camera_setup = {
158 | "focal": focal,
159 | "fx": fx,
160 | "fy": fy,
161 | "cx": cx,
162 | "cy": cy,
163 | "image_height": image_height,
164 | "image_width": image_width,
165 | }
166 |
167 | # 计算深度
168 | depth = 1.0 / disparity
169 |
170 | depth = depth * (np.random.random() * 0.6 + 0.7) + np.random.random() * 0.2
171 |
172 | print(depth.shape, image.shape)
173 |
174 | subdivision = 2
175 |
176 | # 生成像素网格
177 | u_coords, v_coords = np.meshgrid(
178 | np.arange(image_width * subdivision), np.arange(image_height * subdivision)
179 | )
180 |
181 | # 将像素坐标转换为相机坐标系
182 | Y = imutils.resize(depth, width=subdivision * 512)
183 | X = (u_coords / subdivision - cx) * Y / fx
184 | Z = (v_coords / subdivision - cy) * Y / fy * -1
185 |
186 | # 合并为点云
187 | points = np.stack((X, Y, Z), axis=-1).reshape(-1, 3)
188 | colors = (
189 | imutils.resize(image, width=subdivision * 512).reshape(-1, 3)[:, ::-1] / 255.0
190 | )
191 |
192 |
193 | # 生成三角形网格
194 | faces = generate_faces(image_width * subdivision, image_height * subdivision)
195 | # 创建 Blender 中的网格
196 | mesh_object = create_mesh(points, faces, colors)
197 |
198 | return mesh_object, camera_setup
199 |
200 |
201 | def setup_gpu_rendering():
202 | bpy.context.scene.render.engine = "BLENDER_EEVEE"
203 | # 获取 Eevee 渲染设置
204 | eevee = bpy.context.scene.eevee
205 |
206 | # 启用一些高级渲染设置
207 | eevee.use_ssr = True # 启用屏幕空间反射
208 |
209 | print("Eevee render settings configured to use GPU.")
210 |
211 |
212 | def setup_scene(camera_setup, mesh_object, output_path):
213 | # 清理现有场景(删除所有物体、灯光、相机等)
214 |
215 | # 确保当前处于对象模式
216 | if bpy.context.object:
217 | bpy.ops.object.mode_set(mode="OBJECT")
218 |
219 | # 将生成的网格添加到场景中
220 | # bpy.context.collection.objects.link(mesh_object)
221 |
222 | for _ in range(np.random.randint(2, 6)):
223 | # 创建并设置光源(这里使用点光源)
224 | light_data = bpy.data.lights.new(name="PointLight", type="POINT")
225 | light_data.energy = 400 + 1600 * np.random.rand() # 光源强度 (1000~3000)
226 | light_data.color = (
227 | np.random.rand() * 0.5 + 0.5 if np.random.rand() > 0.5 else 0.5,
228 | np.random.rand() * 0.5 + 0.5 if np.random.rand() > 0.5 else 0.5,
229 | np.random.rand() * 0.5 + 0.5 if np.random.rand() > 0.5 else 0.5,
230 | ) # 设置光源颜色
231 |
232 | light_obj = bpy.data.objects.new(
233 | name="PointLightObject", object_data=light_data
234 | )
235 | light_obj.location = (
236 | (
237 | np.random.rand() * 3
238 | if np.random.rand() > 0.5
239 | else -1 * np.random.rand() * 2
240 | ), # -4~4
241 | np.random.randint(1, 3), # [1, 2, 3]
242 | (
243 | np.random.rand() * 2
244 | if np.random.rand() > 0.5
245 | else -1 * np.random.rand() * 2
246 | ), # -3~3
247 | ) # 设置光源位置
248 | bpy.context.scene.collection.objects.link(light_obj)
249 | print(light_data.color, light_obj.location)
250 |
251 | # 设置环境光(使用节点设置背景)
252 | bpy.context.scene.world.use_nodes = True
253 | bg_node = bpy.context.scene.world.node_tree.nodes["Background"]
254 | bg_node.inputs[1].default_value = 1.6 + np.random.rand() # 增加环境光亮度
255 |
256 | # 创建相机
257 | camera_data = bpy.data.cameras.new(name="Camera")
258 |
259 | camera_data.lens = camera_setup["focal"] / 10
260 |
261 | camera_obj = bpy.data.objects.new(name="CameraObject", object_data=camera_data)
262 |
263 | # 获取相机的传感器宽度(单位:毫米)
264 | sensor_width = camera_obj.data.sensor_width
265 | print(f"Sensor Width: {sensor_width} mm")
266 |
267 | sensor_fit = camera_obj.data.sensor_fit
268 | print(f"Sensor Fit: {sensor_fit}")
269 |
270 | # 假设图像宽度为像素单位
271 | image_width_pixels = camera_setup["image_width"]
272 |
273 | # 计算焦距(毫米)
274 | focal_mm = (camera_setup["focal"] / image_width_pixels) * sensor_width
275 |
276 | camera_data.lens = focal_mm
277 |
278 | # 相机位置和旋转
279 | camera_obj.location = (0, 0, 0) # 确保相机在模型前方
280 | camera_obj.rotation_euler = (math.radians(90), 0, 0) # 将相机对准模型
281 | bpy.context.scene.collection.objects.link(camera_obj)
282 |
283 | # 设置场景中的相机
284 | bpy.context.scene.camera = camera_obj
285 |
286 | # 设置渲染引擎为 Cycles
287 | setup_gpu_rendering()
288 |
289 | # 渲染设置
290 | bpy.context.scene.render.resolution_x = camera_setup[
291 | "image_width"
292 | ] # 设置渲染图像的分辨率宽度
293 | bpy.context.scene.render.resolution_y = camera_setup[
294 | "image_height"
295 | ] # 设置渲染图像的分辨率高度
296 | bpy.context.scene.render.film_transparent = True # 设置背景透明,方便后期处理
297 |
298 | # 设置采样数(可以调整)
299 | bpy.context.scene.cycles.samples = (
300 | 32 # 这里设置了一个较低的采样数,实际使用时可以调整
301 | )
302 |
303 | # 设置输出图像路径
304 | bpy.context.scene.render.filepath = output_path
305 |
306 | # 渲染当前场景并保存图像
307 | bpy.ops.render.render(write_still=True)
308 |
309 | # 保存渲染结果
310 | bpy.data.images["Render Result"].save_render(filepath=output_path) # 保存渲染结果
311 | print(f"Render saved to {output_path}")
312 |
313 |
314 | if __name__ == "__main__":
315 |
316 | bpy.ops.object.select_all(action="SELECT")
317 | bpy.ops.object.delete(use_global=False)
318 |
319 | import torch
320 |
321 | DEVICE = (
322 | "cuda"
323 | if torch.cuda.is_available()
324 | else "mps" if torch.backends.mps.is_available() else "cpu"
325 | )
326 |
327 | model_configs = {
328 | "vits": {
329 | "encoder": "vits",
330 | "features": 64,
331 | "out_channels": [48, 96, 192, 384],
332 | },
333 | "vitb": {
334 | "encoder": "vitb",
335 | "features": 128,
336 | "out_channels": [96, 192, 384, 768],
337 | },
338 | "vitl": {
339 | "encoder": "vitl",
340 | "features": 256,
341 | "out_channels": [256, 512, 1024, 1024],
342 | },
343 | "vitg": {
344 | "encoder": "vitg",
345 | "features": 384,
346 | "out_channels": [1536, 1536, 1536, 1536],
347 | },
348 | }
349 |
350 | depth_anything = DepthAnythingV2(**model_configs["vitl"]).half()
351 | depth_anything.load_state_dict(
352 | torch.load(f"checkpoints/depth_anything_v2_vitl.pth", map_location="cpu")
353 | )
354 |
355 | depth_anything = depth_anything.to(DEVICE).eval()
356 |
357 | # 输入参数
358 | image_path = "634.jpg"
359 | rendered_image_path = "634_relight.png"
360 |
361 | mesh_object, camera_setup = setup_mesh(image_path, depth_anything)
362 |
363 | setup_scene(camera_setup, mesh_object, rendered_image_path)
364 |
--------------------------------------------------------------------------------
/romatch/models/transformer/dinov2.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # References:
8 | # https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
10 |
11 | from functools import partial
12 | import math
13 | import logging
14 | from typing import Sequence, Tuple, Union, Callable
15 |
16 | import torch
17 | import torch.nn as nn
18 | import torch.utils.checkpoint
19 | from torch.nn.init import trunc_normal_
20 |
21 | from .layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block
22 |
23 |
24 |
25 | def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
26 | if not depth_first and include_root:
27 | fn(module=module, name=name)
28 | for child_name, child_module in module.named_children():
29 | child_name = ".".join((name, child_name)) if name else child_name
30 | named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
31 | if depth_first and include_root:
32 | fn(module=module, name=name)
33 | return module
34 |
35 |
36 | class BlockChunk(nn.ModuleList):
37 | def forward(self, x):
38 | for b in self:
39 | x = b(x)
40 | return x
41 |
42 |
43 | class DinoVisionTransformer(nn.Module):
44 | def __init__(
45 | self,
46 | img_size=224,
47 | patch_size=16,
48 | in_chans=3,
49 | embed_dim=768,
50 | depth=12,
51 | num_heads=12,
52 | mlp_ratio=4.0,
53 | qkv_bias=True,
54 | ffn_bias=True,
55 | proj_bias=True,
56 | drop_path_rate=0.0,
57 | drop_path_uniform=False,
58 | init_values=None, # for layerscale: None or 0 => no layerscale
59 | embed_layer=PatchEmbed,
60 | act_layer=nn.GELU,
61 | block_fn=Block,
62 | ffn_layer="mlp",
63 | block_chunks=1,
64 | ):
65 | """
66 | Args:
67 | img_size (int, tuple): input image size
68 | patch_size (int, tuple): patch size
69 | in_chans (int): number of input channels
70 | embed_dim (int): embedding dimension
71 | depth (int): depth of transformer
72 | num_heads (int): number of attention heads
73 | mlp_ratio (int): ratio of mlp hidden dim to embedding dim
74 | qkv_bias (bool): enable bias for qkv if True
75 | proj_bias (bool): enable bias for proj in attn if True
76 | ffn_bias (bool): enable bias for ffn if True
77 | drop_path_rate (float): stochastic depth rate
78 | drop_path_uniform (bool): apply uniform drop rate across blocks
79 | weight_init (str): weight init scheme
80 | init_values (float): layer-scale init values
81 | embed_layer (nn.Module): patch embedding layer
82 | act_layer (nn.Module): MLP activation layer
83 | block_fn (nn.Module): transformer block class
84 | ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
85 | block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
86 | """
87 | super().__init__()
88 | norm_layer = partial(nn.LayerNorm, eps=1e-6)
89 |
90 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
91 | self.num_tokens = 1
92 | self.n_blocks = depth
93 | self.num_heads = num_heads
94 | self.patch_size = patch_size
95 |
96 | self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
97 | num_patches = self.patch_embed.num_patches
98 |
99 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
100 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
101 |
102 | if drop_path_uniform is True:
103 | dpr = [drop_path_rate] * depth
104 | else:
105 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
106 |
107 | if ffn_layer == "mlp":
108 | ffn_layer = Mlp
109 | elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
110 | ffn_layer = SwiGLUFFNFused
111 | elif ffn_layer == "identity":
112 |
113 | def f(*args, **kwargs):
114 | return nn.Identity()
115 |
116 | ffn_layer = f
117 | else:
118 | raise NotImplementedError
119 |
120 | blocks_list = [
121 | block_fn(
122 | dim=embed_dim,
123 | num_heads=num_heads,
124 | mlp_ratio=mlp_ratio,
125 | qkv_bias=qkv_bias,
126 | proj_bias=proj_bias,
127 | ffn_bias=ffn_bias,
128 | drop_path=dpr[i],
129 | norm_layer=norm_layer,
130 | act_layer=act_layer,
131 | ffn_layer=ffn_layer,
132 | init_values=init_values,
133 | )
134 | for i in range(depth)
135 | ]
136 | if block_chunks > 0:
137 | self.chunked_blocks = True
138 | chunked_blocks = []
139 | chunksize = depth // block_chunks
140 | for i in range(0, depth, chunksize):
141 | # this is to keep the block index consistent if we chunk the block list
142 | chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
143 | self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
144 | else:
145 | self.chunked_blocks = False
146 | self.blocks = nn.ModuleList(blocks_list)
147 |
148 | self.norm = norm_layer(embed_dim)
149 | self.head = nn.Identity()
150 |
151 | self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
152 |
153 | self.init_weights()
154 | for param in self.parameters():
155 | param.requires_grad = False
156 |
157 | @property
158 | def device(self):
159 | return self.cls_token.device
160 |
161 | def init_weights(self):
162 | trunc_normal_(self.pos_embed, std=0.02)
163 | nn.init.normal_(self.cls_token, std=1e-6)
164 | named_apply(init_weights_vit_timm, self)
165 |
166 | def interpolate_pos_encoding(self, x, w, h):
167 | previous_dtype = x.dtype
168 | npatch = x.shape[1] - 1
169 | N = self.pos_embed.shape[1] - 1
170 | if npatch == N and w == h:
171 | return self.pos_embed
172 | pos_embed = self.pos_embed.float()
173 | class_pos_embed = pos_embed[:, 0]
174 | patch_pos_embed = pos_embed[:, 1:]
175 | dim = x.shape[-1]
176 | w0 = w // self.patch_size
177 | h0 = h // self.patch_size
178 | # we add a small number to avoid floating point error in the interpolation
179 | # see discussion at https://github.com/facebookresearch/dino/issues/8
180 | w0, h0 = w0 + 0.1, h0 + 0.1
181 |
182 | patch_pos_embed = nn.functional.interpolate(
183 | patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
184 | scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
185 | mode="bicubic",
186 | )
187 |
188 | assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
189 | patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
190 | return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
191 |
192 | def prepare_tokens_with_masks(self, x, masks=None):
193 | B, nc, w, h = x.shape
194 | x = self.patch_embed(x)
195 | if masks is not None:
196 | x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
197 |
198 | x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
199 | x = x + self.interpolate_pos_encoding(x, w, h)
200 |
201 | return x
202 |
203 | def forward_features_list(self, x_list, masks_list):
204 | x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
205 | for blk in self.blocks:
206 | x = blk(x)
207 |
208 | all_x = x
209 | output = []
210 | for x, masks in zip(all_x, masks_list):
211 | x_norm = self.norm(x)
212 | output.append(
213 | {
214 | "x_norm_clstoken": x_norm[:, 0],
215 | "x_norm_patchtokens": x_norm[:, 1:],
216 | "x_prenorm": x,
217 | "masks": masks,
218 | }
219 | )
220 | return output
221 |
222 | def forward_features(self, x, masks=None):
223 | if isinstance(x, list):
224 | return self.forward_features_list(x, masks)
225 |
226 | x = self.prepare_tokens_with_masks(x, masks)
227 |
228 | for blk in self.blocks:
229 | x = blk(x)
230 |
231 | x_norm = self.norm(x)
232 | return {
233 | "x_norm_clstoken": x_norm[:, 0],
234 | "x_norm_patchtokens": x_norm[:, 1:],
235 | "x_prenorm": x,
236 | "masks": masks,
237 | }
238 |
239 | def _get_intermediate_layers_not_chunked(self, x, n=1):
240 | x = self.prepare_tokens_with_masks(x)
241 | # If n is an int, take the n last blocks. If it's a list, take them
242 | output, total_block_len = [], len(self.blocks)
243 | blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
244 | for i, blk in enumerate(self.blocks):
245 | x = blk(x)
246 | if i in blocks_to_take:
247 | output.append(x)
248 | assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
249 | return output
250 |
251 | def _get_intermediate_layers_chunked(self, x, n=1):
252 | x = self.prepare_tokens_with_masks(x)
253 | output, i, total_block_len = [], 0, len(self.blocks[-1])
254 | # If n is an int, take the n last blocks. If it's a list, take them
255 | blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
256 | for block_chunk in self.blocks:
257 | for blk in block_chunk[i:]: # Passing the nn.Identity()
258 | x = blk(x)
259 | if i in blocks_to_take:
260 | output.append(x)
261 | i += 1
262 | assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
263 | return output
264 |
265 | def get_intermediate_layers(
266 | self,
267 | x: torch.Tensor,
268 | n: Union[int, Sequence] = 1, # Layers or n last layers to take
269 | reshape: bool = False,
270 | return_class_token: bool = False,
271 | norm=True,
272 | ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
273 | if self.chunked_blocks:
274 | outputs = self._get_intermediate_layers_chunked(x, n)
275 | else:
276 | outputs = self._get_intermediate_layers_not_chunked(x, n)
277 | if norm:
278 | outputs = [self.norm(out) for out in outputs]
279 | class_tokens = [out[:, 0] for out in outputs]
280 | outputs = [out[:, 1:] for out in outputs]
281 | if reshape:
282 | B, _, w, h = x.shape
283 | outputs = [
284 | out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
285 | for out in outputs
286 | ]
287 | if return_class_token:
288 | return tuple(zip(outputs, class_tokens))
289 | return tuple(outputs)
290 |
291 | def forward(self, *args, is_training=False, **kwargs):
292 | ret = self.forward_features(*args, **kwargs)
293 | if is_training:
294 | return ret
295 | else:
296 | return self.head(ret["x_norm_clstoken"])
297 |
298 |
299 | def init_weights_vit_timm(module: nn.Module, name: str = ""):
300 | """ViT weight initialization, original timm impl (for reproducibility)"""
301 | if isinstance(module, nn.Linear):
302 | trunc_normal_(module.weight, std=0.02)
303 | if module.bias is not None:
304 | nn.init.zeros_(module.bias)
305 |
306 |
307 | def vit_small(patch_size=16, **kwargs):
308 | model = DinoVisionTransformer(
309 | patch_size=patch_size,
310 | embed_dim=384,
311 | depth=12,
312 | num_heads=6,
313 | mlp_ratio=4,
314 | block_fn=partial(Block, attn_class=MemEffAttention),
315 | **kwargs,
316 | )
317 | return model
318 |
319 |
320 | def vit_base(patch_size=16, **kwargs):
321 | model = DinoVisionTransformer(
322 | patch_size=patch_size,
323 | embed_dim=768,
324 | depth=12,
325 | num_heads=12,
326 | mlp_ratio=4,
327 | block_fn=partial(Block, attn_class=MemEffAttention),
328 | **kwargs,
329 | )
330 | return model
331 |
332 |
333 | def vit_large(patch_size=16, **kwargs):
334 | model = DinoVisionTransformer(
335 | patch_size=patch_size,
336 | embed_dim=1024,
337 | depth=24,
338 | num_heads=16,
339 | mlp_ratio=4,
340 | block_fn=partial(Block, attn_class=MemEffAttention),
341 | **kwargs,
342 | )
343 | return model
344 |
345 |
346 | def vit_giant2(patch_size=16, **kwargs):
347 | """
348 | Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
349 | """
350 | model = DinoVisionTransformer(
351 | patch_size=patch_size,
352 | embed_dim=1536,
353 | depth=40,
354 | num_heads=24,
355 | mlp_ratio=4,
356 | block_fn=partial(Block, attn_class=MemEffAttention),
357 | **kwargs,
358 | )
359 | return model
--------------------------------------------------------------------------------
/get_data.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import sys
3 | import os
4 | import cv2
5 | import glob
6 | import math
7 | import torch
8 | import torch.nn.functional as F
9 | from torch.utils.data.dataset import Dataset
10 | from torch.utils.data.dataloader import DataLoader, default_collate
11 | from torchvision.utils import save_image
12 | import numpy as np
13 | from random import random
14 | from torchvision import transforms
15 | import imutils
16 |
17 | from warp_utils import (
18 | RGBDRenderer,
19 | image_to_tensor,
20 | disparity_to_tensor,
21 | transformation_from_parameters,
22 | )
23 |
24 |
25 | def resize_and_center_crop(image, disparity):
26 | # 获取图像和视差图的尺寸
27 | h, w = image.shape[:2]
28 |
29 | # 计算最短边的尺寸
30 | shortest_edge = min(h, w)
31 |
32 | # 按最短边缩放
33 | if h < w:
34 | new_h = shortest_edge
35 | new_w = int(shortest_edge * (w / h))
36 | else:
37 | new_w = shortest_edge
38 | new_h = int(shortest_edge * (h / w))
39 |
40 | # 缩放图像
41 | image_resized = cv2.resize(image, (new_w, new_h))
42 | disparity_resized = cv2.resize(disparity, (new_w, new_h))
43 |
44 | # 计算裁剪区域,使得图像变为正方形
45 | crop_size = min(image_resized.shape[:2]) # 取缩放后图像的最短边作为裁剪大小
46 | start_x = (new_w - crop_size) // 2
47 | start_y = (new_h - crop_size) // 2
48 |
49 | # 裁剪图像和视差图
50 | image_cropped = image_resized[
51 | start_y : start_y + crop_size, start_x : start_x + crop_size
52 | ]
53 | disparity_cropped = disparity_resized[
54 | start_y : start_y + crop_size, start_x : start_x + crop_size
55 | ]
56 |
57 | return image_cropped, disparity_cropped
58 |
59 |
60 | class WarpBackStage1Dataset(Dataset):
61 | def __init__(
62 | self,
63 | data_root,
64 | disp_root,
65 | width=512,
66 | height=512,
67 | device="cuda", # device of mesh renderer
68 | trans_range={"x": 0.4, "y": 0.4, "z": 0.8, "a": 18, "b": 18, "c": 18},
69 | # trans_range={"x": -1, "y": -1, "z": -1, "a": -1, "b": -1, "c": -1},
70 | ):
71 | self.data_root = data_root
72 | self.disp_root = disp_root
73 |
74 | self.renderer = RGBDRenderer(device)
75 | self.width = width
76 | self.height = height
77 | self.device = device
78 | self.trans_range = trans_range
79 | self.image_path_list = [
80 | os.path.join(data_root, img) for img in os.listdir(data_root)
81 | ]
82 | self.img2tensor = transforms.ToTensor()
83 |
84 | def __len__(self):
85 | return len(self.image_path_list)
86 |
87 | def rand_tensor(self, r, l):
88 | if (
89 | r < 0
90 | ): # we can set a negtive value in self.trans_range to avoid random transformation
91 | return torch.zeros((l, 1, 1))
92 | rand = torch.rand((l, 1, 1))
93 | sign = 2 * (torch.randn_like(rand) > 0).float() - 1
94 | return sign * (r / 2 + r / 2 * rand)
95 |
96 | def get_rand_ext(self, bs):
97 | x, y, z = self.trans_range["x"], self.trans_range["y"], self.trans_range["z"]
98 | a, b, c = self.trans_range["a"], self.trans_range["b"], self.trans_range["c"]
99 | cix = self.rand_tensor(x, bs)
100 | ciy = self.rand_tensor(y, bs)
101 | ciz = self.rand_tensor(z, bs)
102 |
103 | aix = self.rand_tensor(math.pi / a, bs)
104 | aiy = self.rand_tensor(math.pi / b, bs)
105 | aiz = self.rand_tensor(math.pi / c, bs)
106 |
107 | axisangle = torch.cat([aix, aiy, aiz], dim=-1) # [b,1,3]
108 | translation = torch.cat([cix, ciy, ciz], dim=-1)
109 |
110 | cam_ext = transformation_from_parameters(axisangle, translation) # [b,4,4]
111 | cam_ext_inv = torch.inverse(cam_ext) # [b,4,4]
112 |
113 | print(axisangle, translation)
114 |
115 | return cam_ext[:, :-1], cam_ext_inv[:, :-1]
116 |
117 | def __getitem__(self, idx):
118 | image_path = self.image_path_list[idx]
119 | image_name = os.path.splitext(os.path.basename(image_path))[0]
120 |
121 | disp_path = os.path.join(self.disp_root, "%s.npy" % image_name)
122 |
123 | image = cv2.imread(image_path, cv2.IMREAD_COLOR)[:, :, ::-1] # [3,h,w]
124 | disp = np.load(disp_path) # [1,h,w]
125 |
126 | H, W = image.shape[:2]
127 |
128 | if H > W:
129 | image = imutils.resize(image, width=512)
130 | disp = imutils.resize(disp, width=512)
131 | else:
132 | image = imutils.resize(image, height=512)
133 | disp = imutils.resize(disp, height=512)
134 |
135 | image, disp = resize_and_center_crop(image, disp)
136 |
137 | max_d, min_d = disp.max(), disp.min()
138 | disp = (disp - min_d) / (max_d - min_d)
139 |
140 | image = torch.tensor(image).permute(2, 0, 1) / 255
141 | disp = torch.tensor(disp).unsqueeze(0) + 0.001
142 |
143 | self.focal = 0.45 + np.random.random() * 0.3
144 | # set intrinsics
145 | self.K = torch.tensor(
146 | [[self.focal, 0, 0.5], [0, self.focal, 0.5], [0, 0, 1]]
147 | ).to(self.device)
148 |
149 | image = image.to(self.device).unsqueeze(0).float()
150 | disp = disp.to(self.device).unsqueeze(0).float()
151 | rgbd = torch.cat([image, disp], dim=1) # [b,4,h,w]
152 | b = image.shape[0]
153 |
154 | cam_int = self.K.repeat(b, 1, 1) # [b,3,3]
155 |
156 | # warp to a random novel view
157 | mesh = self.renderer.construct_mesh(
158 | rgbd, cam_int, torch.ones_like(disp), normalize_depth=True
159 | )
160 | cam_ext, cam_ext_inv = self.get_rand_ext(b) # [b,3,4]
161 | cam_ext = cam_ext.to(self.device)
162 | cam_ext_inv = cam_ext_inv.to(self.device)
163 |
164 | warp_image, warp_disp, warp_mask, object_mask = self.renderer.render_mesh(
165 | mesh, cam_int, cam_ext
166 | )
167 | warp_mask = (warp_mask < 0.5).float()
168 |
169 | warp_image = torch.clip(warp_image, 0, 1)
170 |
171 | cam_int[0, :2, :] *= 512
172 |
173 | return {
174 | "rgb": image,
175 | "disp": disp,
176 | "warp_mask": warp_mask,
177 | "warp_rgb": warp_image,
178 | "warp_disp": warp_disp,
179 | "image_name": image_name,
180 | "cam_int": cam_int[0],
181 | "cam_ext": cam_ext[0],
182 | }
183 |
184 |
185 | def setup_seed(seed):
186 | torch.manual_seed(seed)
187 | torch.cuda.manual_seed_all(seed)
188 | np.random.seed(seed)
189 | torch.backends.cudnn.deterministic = True
190 |
191 |
192 | if __name__ == "__main__":
193 |
194 | from diffusers import StableDiffusionInpaintPipeline
195 | import torch
196 | import torchvision
197 | import matplotlib.pyplot as plt
198 |
199 | def project_point_to_3d(x, y, depth, K):
200 | """根据相机内参和深度值,计算像素在相机坐标系中的 3D 坐标"""
201 | inv_K = torch.linalg.inv(K)
202 | pixel = torch.tensor([x, y, 1.0]).to(K.device)
203 | normalized_coords = inv_K @ pixel * depth
204 | return normalized_coords
205 |
206 | def transform_to_another_camera(point_3d, T):
207 | """根据变换矩阵将 3D 点从相机 1 转到相机 2"""
208 | point_3d_homogeneous = torch.cat([point_3d, torch.tensor([1.0]).to(T.device)])
209 | transformed_point = T @ point_3d_homogeneous
210 | return transformed_point[:3]
211 |
212 | def project_to_image_plane(point_3d, K):
213 | """将 3D 点投影到图像平面"""
214 | point_2d_homogeneous = K @ point_3d
215 | point_2d = point_2d_homogeneous[:2] / point_2d_homogeneous[2]
216 | return point_2d
217 |
218 | parser = argparse.ArgumentParser(
219 | formatter_class=argparse.ArgumentDefaultsHelpFormatter
220 | )
221 | parser.add_argument(
222 | "--data_path",
223 | type=str,
224 | default="data/davis/raw_images",
225 | )
226 | parser.add_argument(
227 | "--disp_path",
228 | type=str,
229 | default="data/davis/disps",
230 | )
231 | parser.add_argument("--output_path", type=str, default="data/mixed_datasets/l2m_davis")
232 | opt, _ = parser.parse_known_args()
233 |
234 | # 指定模型文件路径
235 | model_path = "stabilityai/stable-diffusion-2-inpainting"
236 |
237 | # 加载模型
238 | pipe = StableDiffusionInpaintPipeline.from_pretrained(
239 | model_path, torch_dtype=torch.float16
240 | )
241 | pipe.to("cuda") # 如果有 GPU,可以将模型加载到 GPU 上
242 | prompt = "a realistic photo"
243 |
244 | setup_seed(0)
245 |
246 | output = opt.output_path
247 | from tqdm import tqdm
248 |
249 | if not os.path.exists(output):
250 | os.makedirs(output, exist_ok=True)
251 | os.mkdir(os.path.join(output, "image1"))
252 | os.mkdir(os.path.join(output, "image2"))
253 | os.mkdir(os.path.join(output, "depth1"))
254 | os.mkdir(os.path.join(output, "depth2"))
255 | os.mkdir(os.path.join(output, "cams"))
256 | os.mkdir(os.path.join(output, "ext"))
257 | os.mkdir(os.path.join(output, "debug"))
258 |
259 | data = WarpBackStage1Dataset(data_root=opt.data_path, disp_root=opt.disp_path)
260 |
261 | for loop in range(2):
262 | for idx in tqdm(range(len(data))):
263 |
264 | batch = data.__getitem__(idx)
265 |
266 | image, disp = batch["rgb"], batch["disp"]
267 | w_image, w_disp = batch["warp_rgb"], batch["warp_disp"]
268 | warp_mask = batch["warp_mask"]
269 |
270 | w_disp = torch.clip(w_disp, 0.01, 100)
271 |
272 | init_image = torchvision.transforms.functional.to_pil_image(w_image[0])
273 | mask_image = torchvision.transforms.functional.to_pil_image(warp_mask[0])
274 | image = torchvision.transforms.functional.to_pil_image(image[0])
275 |
276 | W, H = init_image.size
277 |
278 | inpaint_image = pipe(
279 | prompt=prompt, image=init_image, mask_image=mask_image, h=512, w=512
280 | ).images[0]
281 |
282 | image.save(os.path.join(output, "image1", batch["image_name"] + ".png"))
283 | inpaint_image.save(
284 | os.path.join(output, "image2", batch["image_name"] + ".png")
285 | )
286 |
287 | np.save(
288 | os.path.join(output, "depth1", batch["image_name"] + ".npy"),
289 | 1 / disp.squeeze().cpu().numpy(),
290 | )
291 | np.save(
292 | os.path.join(output, "depth2", batch["image_name"] + ".png"),
293 | 1 / w_disp.squeeze().cpu().numpy(),
294 | )
295 |
296 | w_depth = 1 / (w_disp + 1e-4) * (1 - warp_mask)
297 |
298 | cam_int = batch["cam_int"].cpu().numpy()
299 | cam_ext = batch["cam_ext"].cpu().numpy()
300 |
301 | cam_ext = np.concatenate(
302 | [cam_ext, np.array([[0.0000, 0.0000, 0.0000, 1.0000]])], 0
303 | )
304 |
305 | np.save(os.path.join(output, "cams", batch["image_name"] + ".npy"), cam_int)
306 | np.save(os.path.join(output, "ext", batch["image_name"] + ".png"), cam_ext)
307 |
308 | # 可视化前20个生成的数据
309 | if idx > 19:
310 | continue
311 |
312 | im_A_depth = 1 / disp.squeeze()
313 | im_B_depth = 1 / w_disp.squeeze()
314 |
315 | K1 = batch["cam_int"].float()
316 | K2 = batch["cam_int"].float()
317 | T_1to2 = torch.tensor(cam_ext).cuda().float()
318 |
319 | im_A_cv = np.array(image)
320 | im_B_cv = np.array(inpaint_image)
321 |
322 | # 拼接图像用于显示
323 | im_combined = np.hstack(
324 | (im_A_cv.astype(np.uint8), im_B_cv.astype(np.uint8))
325 | )
326 |
327 | # 选择一部分像素来计算映射关系
328 | matches_A = []
329 | matches_B = []
330 |
331 | # 选择具有非零深度的像素点
332 | for y in range(0, im_A_depth.shape[0], 10): # 步长为10,减少计算量
333 | for x in range(0, im_A_depth.shape[1], 10):
334 | depth_A = im_A_depth[y, x].item()
335 | if depth_A > 0: # 只处理有深度信息的像素
336 | # 计算相机1中的3D坐标
337 | point_3d_A = project_point_to_3d(x, y, depth_A, K1)
338 | # 转换到相机2坐标系
339 | point_3d_B = transform_to_another_camera(
340 | point_3d_A.float(), T_1to2
341 | )
342 | # 投影到相机2的图像平面
343 | point_2d_B = project_to_image_plane(point_3d_B, K2)
344 |
345 | # 将2D匹配点加入列表
346 | matches_A.append((x, y))
347 | matches_B.append((point_2d_B[0].item(), point_2d_B[1].item()))
348 |
349 |
350 | # 转换为 numpy 数组以便绘图
351 | matches_A = np.array(matches_A)
352 | matches_B = np.array(matches_B)
353 |
354 | H, W = im_combined.shape[:2]
355 |
356 | selected_index = np.random.choice(range(len(matches_A)), 20, replace=False)
357 |
358 | # 绘制匹配点及连接线
359 | for i in selected_index:
360 |
361 | # 在合并图像中绘制匹配点
362 | x_A = int(matches_A[i, 0])
363 | y_A = int(matches_A[i, 1])
364 | x_B = int(matches_B[i, 0]) + im_A_cv.shape[1] # 加上图像A的宽度偏移
365 | y_B = int(matches_B[i, 1])
366 |
367 | if y_B > H or x_B > W or y_B < 0:
368 | continue
369 |
370 | # 绘制匹配点
371 | cv2.circle(
372 | im_combined, (x_A, y_A), 5, (0, 255, 0), -1
373 | ) # 图像A上的匹配点
374 | cv2.circle(
375 | im_combined, (x_B, y_B), 5, (0, 0, 255), -1
376 | ) # 图像B上的匹配点
377 |
378 | # 绘制匹配点之间的连线
379 | cv2.line(im_combined, (x_A, y_A), (x_B, y_B), (255, 0, 0), 1)
380 |
381 | # 显示拼接后的图像
382 | plt.figure(figsize=(12, 6))
383 | plt.imshow(im_combined)
384 | plt.title("Image A and B with Matches")
385 | plt.axis("off")
386 | plt.savefig(os.path.join(output, "debug", batch["image_name"] + ".png"))
--------------------------------------------------------------------------------
/warp_utils.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | from PIL import Image
3 | import torch
4 | import torch.nn.functional as F
5 | from torchvision.utils import save_image
6 | from torchvision import transforms
7 | from pytorch3d.renderer.mesh import rasterize_meshes
8 | from pytorch3d.structures import Meshes
9 | from pytorch3d.ops import interpolate_face_attributes
10 | import numpy as np
11 | from functools import reduce
12 |
13 | def vis_depth_discontinuity(
14 | depth, depth_threshold, vis_diff=False, label=False, mask=None
15 | ):
16 | if label == False:
17 | disp = 1.0 / depth
18 | u_diff = (disp[1:, :] - disp[:-1, :])[:-1, 1:-1]
19 | b_diff = (disp[:-1, :] - disp[1:, :])[1:, 1:-1]
20 | l_diff = (disp[:, 1:] - disp[:, :-1])[1:-1, :-1]
21 | r_diff = (disp[:, :-1] - disp[:, 1:])[1:-1, 1:]
22 | if mask is not None:
23 | u_mask = (mask[1:, :] * mask[:-1, :])[:-1, 1:-1]
24 | b_mask = (mask[:-1, :] * mask[1:, :])[1:, 1:-1]
25 | l_mask = (mask[:, 1:] * mask[:, :-1])[1:-1, :-1]
26 | r_mask = (mask[:, :-1] * mask[:, 1:])[1:-1, 1:]
27 | u_diff = u_diff * u_mask
28 | b_diff = b_diff * b_mask
29 | l_diff = l_diff * l_mask
30 | r_diff = r_diff * r_mask
31 | u_over = (np.abs(u_diff) > depth_threshold).astype(np.float32)
32 | b_over = (np.abs(b_diff) > depth_threshold).astype(np.float32)
33 | l_over = (np.abs(l_diff) > depth_threshold).astype(np.float32)
34 | r_over = (np.abs(r_diff) > depth_threshold).astype(np.float32)
35 | else:
36 | disp = depth
37 | u_diff = (disp[1:, :] * disp[:-1, :])[:-1, 1:-1]
38 | b_diff = (disp[:-1, :] * disp[1:, :])[1:, 1:-1]
39 | l_diff = (disp[:, 1:] * disp[:, :-1])[1:-1, :-1]
40 | r_diff = (disp[:, :-1] * disp[:, 1:])[1:-1, 1:]
41 | if mask is not None:
42 | u_mask = (mask[1:, :] * mask[:-1, :])[:-1, 1:-1]
43 | b_mask = (mask[:-1, :] * mask[1:, :])[1:, 1:-1]
44 | l_mask = (mask[:, 1:] * mask[:, :-1])[1:-1, :-1]
45 | r_mask = (mask[:, :-1] * mask[:, 1:])[1:-1, 1:]
46 | u_diff = u_diff * u_mask
47 | b_diff = b_diff * b_mask
48 | l_diff = l_diff * l_mask
49 | r_diff = r_diff * r_mask
50 | u_over = (np.abs(u_diff) > 0).astype(np.float32)
51 | b_over = (np.abs(b_diff) > 0).astype(np.float32)
52 | l_over = (np.abs(l_diff) > 0).astype(np.float32)
53 | r_over = (np.abs(r_diff) > 0).astype(np.float32)
54 | u_over = np.pad(u_over, 1, mode="constant")
55 | b_over = np.pad(b_over, 1, mode="constant")
56 | l_over = np.pad(l_over, 1, mode="constant")
57 | r_over = np.pad(r_over, 1, mode="constant")
58 | u_diff = np.pad(u_diff, 1, mode="constant")
59 | b_diff = np.pad(b_diff, 1, mode="constant")
60 | l_diff = np.pad(l_diff, 1, mode="constant")
61 | r_diff = np.pad(r_diff, 1, mode="constant")
62 |
63 | if vis_diff:
64 | return [u_over, b_over, l_over, r_over], [u_diff, b_diff, l_diff, r_diff]
65 | else:
66 | return [u_over, b_over, l_over, r_over]
67 |
68 | def rolling_window(a, window, strides):
69 | assert (
70 | len(a.shape) == len(window) == len(strides)
71 | ), "'a', 'window', 'strides' dimension mismatch"
72 | shape_fn = lambda i, w, s: (a.shape[i] - w) // s + 1
73 | shape = [shape_fn(i, w, s) for i, (w, s) in enumerate(zip(window, strides))] + list(
74 | window
75 | )
76 |
77 | def acc_shape(i):
78 | if i + 1 >= len(a.shape):
79 | return 1
80 | else:
81 | return reduce(lambda x, y: x * y, a.shape[i + 1 :])
82 |
83 | _strides = [acc_shape(i) * s * a.itemsize for i, s in enumerate(strides)] + list(
84 | a.strides
85 | )
86 |
87 | return np.lib.stride_tricks.as_strided(a, shape=shape, strides=_strides)
88 |
89 | def bilateral_filter(
90 | depth,
91 | sigma_s,
92 | sigma_r,
93 | window_size,
94 | discontinuity_map=None,
95 | HR=False,
96 | mask=None,
97 | ):
98 |
99 | midpt = window_size // 2
100 | ax = np.arange(-midpt, midpt + 1.0)
101 | xx, yy = np.meshgrid(ax, ax)
102 | if discontinuity_map is not None:
103 | spatial_term = np.exp(-(xx ** 2 + yy ** 2) / (2.0 * sigma_s ** 2))
104 |
105 | # padding
106 | depth = depth[1:-1, 1:-1]
107 | depth = np.pad(depth, ((1, 1), (1, 1)), "edge")
108 | pad_depth = np.pad(depth, (midpt, midpt), "edge")
109 | if discontinuity_map is not None:
110 | discontinuity_map = discontinuity_map[1:-1, 1:-1]
111 | discontinuity_map = np.pad(discontinuity_map, ((1, 1), (1, 1)), "edge")
112 | pad_discontinuity_map = np.pad(discontinuity_map, (midpt, midpt), "edge")
113 | pad_discontinuity_hole = 1 - pad_discontinuity_map
114 | # filtering
115 | output = depth.copy()
116 | pad_depth_patches = rolling_window(pad_depth, [window_size, window_size], [1, 1])
117 | if discontinuity_map is not None:
118 | pad_discontinuity_patches = rolling_window(
119 | pad_discontinuity_map, [window_size, window_size], [1, 1]
120 | )
121 | pad_discontinuity_hole_patches = rolling_window(
122 | pad_discontinuity_hole, [window_size, window_size], [1, 1]
123 | )
124 |
125 | if mask is not None:
126 | pad_mask = np.pad(mask, (midpt, midpt), "constant")
127 | pad_mask_patches = rolling_window(pad_mask, [window_size, window_size], [1, 1])
128 | from itertools import product
129 |
130 | if discontinuity_map is not None:
131 | pH, pW = pad_depth_patches.shape[:2]
132 | for pi in range(pH):
133 | for pj in range(pW):
134 | if mask is not None and mask[pi, pj] == 0:
135 | continue
136 | if discontinuity_map is not None:
137 | if bool(pad_discontinuity_patches[pi, pj].any()) is False:
138 | continue
139 | discontinuity_patch = pad_discontinuity_patches[pi, pj]
140 | discontinuity_holes = pad_discontinuity_hole_patches[pi, pj]
141 | depth_patch = pad_depth_patches[pi, pj]
142 | depth_order = depth_patch.ravel().argsort()
143 | patch_midpt = depth_patch[window_size // 2, window_size // 2]
144 | if discontinuity_map is not None:
145 | coef = discontinuity_holes.astype(np.float32)
146 | if mask is not None:
147 | coef = coef * pad_mask_patches[pi, pj]
148 | else:
149 | range_term = np.exp(
150 | -((depth_patch - patch_midpt) ** 2) / (2.0 * sigma_r ** 2)
151 | )
152 | coef = spatial_term * range_term
153 | if coef.max() == 0:
154 | output[pi, pj] = patch_midpt
155 | continue
156 | if discontinuity_map is not None and (coef.max() == 0):
157 | output[pi, pj] = patch_midpt
158 | else:
159 | coef = coef / (coef.sum())
160 | coef_order = coef.ravel()[depth_order]
161 | cum_coef = np.cumsum(coef_order)
162 | ind = np.digitize(0.5, cum_coef)
163 | output[pi, pj] = depth_patch.ravel()[depth_order][ind]
164 | else:
165 | pH, pW = pad_depth_patches.shape[:2]
166 | for pi in range(pH):
167 | for pj in range(pW):
168 | if discontinuity_map is not None:
169 | if (
170 | pad_discontinuity_patches[pi, pj][
171 | window_size // 2, window_size // 2
172 | ]
173 | == 1
174 | ):
175 | continue
176 | discontinuity_patch = pad_discontinuity_patches[pi, pj]
177 | discontinuity_holes = 1.0 - discontinuity_patch
178 | depth_patch = pad_depth_patches[pi, pj]
179 | depth_order = depth_patch.ravel().argsort()
180 | patch_midpt = depth_patch[window_size // 2, window_size // 2]
181 | range_term = np.exp(
182 | -((depth_patch - patch_midpt) ** 2) / (2.0 * sigma_r ** 2)
183 | )
184 | if discontinuity_map is not None:
185 | coef = spatial_term * range_term * discontinuity_holes
186 | else:
187 | coef = spatial_term * range_term
188 | if coef.sum() == 0:
189 | output[pi, pj] = patch_midpt
190 | continue
191 | if discontinuity_map is not None and (coef.sum() == 0):
192 | output[pi, pj] = patch_midpt
193 | else:
194 | coef = coef / (coef.sum())
195 | coef_order = coef.ravel()[depth_order]
196 | cum_coef = np.cumsum(coef_order)
197 | ind = np.digitize(0.5, cum_coef)
198 | output[pi, pj] = depth_patch.ravel()[depth_order][ind]
199 |
200 | return output
201 |
202 | class RGBDRenderer:
203 | def __init__(self, device):
204 | self.device = device
205 | self.eps = 0.1
206 | self.near_z = 1e-2
207 | self.far_z = 1e4
208 |
209 | def render_mesh(self, mesh_dict, cam_int, cam_ext):
210 | vertice = mesh_dict["vertice"] # [b,h*w,3]
211 | faces = mesh_dict["faces"] # [b,nface,3]
212 | attributes = mesh_dict["attributes"] # [b,h*w,4]
213 | h, w = mesh_dict["size"]
214 |
215 | ############
216 | # to NDC space
217 | vertice_homo = self.lift_to_homo(vertice) # [b,h*w,4]
218 | # [b,1,3,4] x [b,h*w,4,1] = [b,h*w,3,1]
219 | vertice_world = torch.matmul(cam_ext.unsqueeze(1), vertice_homo[..., None]).squeeze(-1) # [b,h*w,3]
220 | vertice_depth = vertice_world[..., -1:] # [b,h*w,1]
221 | attributes = torch.cat([attributes, vertice_depth], dim=-1) # [b,h*w,5]
222 | # [b,1,3,3] x [b,h*w,3,1] = [b,h*w,3,1]
223 | vertice_world_homo = self.lift_to_homo(vertice_world)
224 | persp = self.get_perspective_from_intrinsic(cam_int) # [b,4,4]
225 |
226 | # [b,1,4,4] x [b,h*w,4,1] = [b,h*w,4,1]
227 | vertice_ndc = torch.matmul(persp.unsqueeze(1), vertice_world_homo[..., None]).squeeze(-1) # [b,h*w,4]
228 | vertice_ndc = vertice_ndc[..., :-1] / vertice_ndc[..., -1:]
229 | vertice_ndc[..., :-1] *= -1
230 | vertice_ndc[..., 0] *= w / h
231 |
232 | ############
233 | # render
234 | mesh = Meshes(vertice_ndc, faces)
235 | pix_to_face, _, bary_coords, _ = rasterize_meshes(mesh, (h, w), faces_per_pixel=1, blur_radius=1e-6) # [b,h,w,1] [b,h,w,1,3]
236 |
237 | b, nf, _ = faces.size()
238 | faces = faces.reshape(b, nf * 3, 1).repeat(1, 1, 6) # [b,3f,5]
239 | face_attributes = torch.gather(attributes, dim=1, index=faces) # [b,3f,5]
240 | face_attributes = face_attributes.reshape(b * nf, 3, 6)
241 | output = interpolate_face_attributes(pix_to_face, bary_coords, face_attributes)
242 | output = output.squeeze(-2).permute(0, 3, 1, 2)
243 |
244 | render = output[:, :3]
245 | mask = output[:, 3:4]
246 | object_mask = output[:, 4:5]
247 | disparity = torch.reciprocal(output[:, 5:] + 1e-4)
248 |
249 | return render * mask, disparity * mask, mask, object_mask
250 |
251 | def construct_mesh(self, rgbd, cam_int, obj_mask, normalize_depth=False):
252 | b, _, h, w = rgbd.size()
253 |
254 | ############
255 | # get pixel coordinates
256 | pixel_2d = self.get_screen_pixel_coord(h, w) # [1,h,w,2]
257 | pixel_2d_homo = self.lift_to_homo(pixel_2d) # [1,h,w,3]
258 |
259 | ############
260 | # project pixels to 3D space
261 | rgbd = rgbd.permute(0, 2, 3, 1) # [b,h,w,4]
262 | disparity = rgbd[..., -1:] # [b,h,w,1]
263 | depth = torch.reciprocal(disparity + + 1e-4) # [b,h,w,1]
264 | obj_mask = obj_mask.permute(0, 2, 3, 1).to(depth.device)
265 | # In [2]: depth.max()
266 | # Out[2]: 3.0927802771530017
267 |
268 | # In [3]: depth.min()
269 | # Out[3]: 1.466965406649775
270 | cam_int_inv = torch.inverse(cam_int) # [b,3,3]
271 | # [b,1,1,3,3] x [1,h,w,3,1] = [b,h,w,3,1]
272 | pixel_3d = torch.matmul(cam_int_inv[:, None, None, :, :], pixel_2d_homo[..., None]).squeeze(-1) # [b,h,w,3]
273 |
274 | pixel_3d = pixel_3d * depth # [b,h,w,3]
275 | vertice = pixel_3d.reshape(b, h * w, 3) # [b,h*w,3]
276 | ############
277 | # construct faces
278 | faces = self.get_faces(h, w) # [1,nface,3]
279 | faces = faces.repeat(b, 1, 1).long() # [b,nface,3]
280 |
281 | ############
282 | # compute attributes
283 | attr_color = rgbd[..., :-1].reshape(b, h * w, 3) # [b,h*w,3]
284 | attr_object = obj_mask.reshape(b, h * w, 1).to(attr_color.device) # [b,h*w,1]
285 | attr_mask = self.get_visible_mask(disparity, alpha_threshold=0.1).reshape(b, h * w, 1) # [b,h*w,1]
286 | attr = torch.cat([attr_color, attr_mask, attr_object], dim=-1) # [b,h*w,4]
287 | mesh_dict = {
288 | "vertice": vertice,
289 | "faces": faces,
290 | "attributes": attr,
291 | "size": [h, w],
292 | }
293 | return mesh_dict
294 |
295 | def get_screen_pixel_coord(self, h, w):
296 | '''
297 | get normalized pixel coordinates on the screen
298 | x to left, y to down
299 |
300 | e.g.
301 | [0,0][1,0][2,0]
302 | [0,1][1,1][2,1]
303 | output:
304 | pixel_coord: [1,h,w,2]
305 | '''
306 | x = torch.arange(w).to(self.device) # [w]
307 | y = torch.arange(h).to(self.device) # [h]
308 | x = (x + 0.5) / w
309 | y = (y + 0.5) / h
310 | x = x[None, None, ..., None].repeat(1, h, 1, 1) # [1,h,w,1]
311 | y = y[None, ..., None, None].repeat(1, 1, w, 1) # [1,h,w,1]
312 | pixel_coord = torch.cat([x, y], dim=-1) # [1,h,w,2]
313 | return pixel_coord
314 |
315 | def lift_to_homo(self, coord):
316 | '''
317 | return the homo version of coord
318 | input: coord [..., k]
319 | output: homo_coord [...,k+1]
320 | '''
321 | ones = torch.ones_like(coord[..., -1:])
322 | return torch.cat([coord, ones], dim=-1)
323 |
324 | def get_faces(self, h, w):
325 | x = torch.arange(w - 1).to(self.device) # [w-1]
326 | y = torch.arange(h - 1).to(self.device) # [h-1]
327 | x = x[None, None, ..., None].repeat(1, h - 1, 1, 1) # [1,h-1,w-1,1]
328 | y = y[None, ..., None, None].repeat(1, 1, w - 1, 1) # [1,h-1,w-1,1]
329 |
330 | tl = y * w + x
331 | tr = y * w + x + 1
332 | bl = (y + 1) * w + x
333 | br = (y + 1) * w + x + 1
334 |
335 | faces_l = torch.cat([tl, bl, br], dim=-1).reshape(1, -1, 3) # [1,(h-1)(w-1),3]
336 | faces_r = torch.cat([br, tr, tl], dim=-1).reshape(1, -1, 3) # [1,(h-1)(w-1),3]
337 |
338 | return torch.cat([faces_l, faces_r], dim=1) # [1,nface,3]
339 |
340 | def get_visible_mask(self, disparity, beta=10, alpha_threshold=0.3):
341 | b, h, w, _ = disparity.size()
342 | disparity = disparity.reshape(b, 1, h, w) # [b,1,h,w]
343 | kernel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]).unsqueeze(0).unsqueeze(0).float().to(self.device)
344 | kernel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]]).unsqueeze(0).unsqueeze(0).float().to(self.device)
345 | sobel_x = F.conv2d(disparity, kernel_x, padding=(1, 1)) # [b,1,h,w]
346 | sobel_y = F.conv2d(disparity, kernel_y, padding=(1, 1)) # [b,1,h,w]
347 | sobel_mag = torch.sqrt(sobel_x ** 2 + sobel_y ** 2).reshape(b, h, w, 1) # [b,h,w,1]
348 | alpha = torch.exp(-1.0 * beta * sobel_mag) # [b,h,w,1]
349 | vis_mask = torch.greater(alpha, alpha_threshold).float()
350 | return vis_mask
351 |
352 | def get_perspective_from_intrinsic(self, cam_int):
353 | '''
354 | input:
355 | cam_int: [b,3,3]
356 |
357 | output:
358 | persp: [b,4,4]
359 | '''
360 | fx, fy = cam_int[:, 0, 0], cam_int[:, 1, 1] # [b]
361 | cx, cy = cam_int[:, 0, 2], cam_int[:, 1, 2] # [b]
362 |
363 | one = torch.ones_like(cx) # [b]
364 | zero = torch.zeros_like(cx) # [b]
365 |
366 | near_z, far_z = self.near_z * one, self.far_z * one
367 | a = (near_z + far_z) / (far_z - near_z)
368 | b = -2.0 * near_z * far_z / (far_z - near_z)
369 |
370 | matrix = [[2.0 * fx, zero, 2.0 * cx - 1.0, zero],
371 | [zero, 2.0 * fy, 2.0 * cy - 1.0, zero],
372 | [zero, zero, a, b],
373 | [zero, zero, one, zero]]
374 | # -> [[b,4],[b,4],[b,4],[b,4]] -> [b,4,4]
375 | persp = torch.stack([torch.stack(row, dim=-1) for row in matrix], dim=-2) # [b,4,4]
376 | # print(fx, cx, cy, a, b)
377 | return persp
378 |
379 |
380 | #######################
381 | # some helper I/O functions
382 | #######################
383 | def image_to_tensor(img_path, unsqueeze=True):
384 | rgb = transforms.ToTensor()(Image.open(img_path))
385 | if unsqueeze:
386 | rgb = rgb.unsqueeze(0)
387 | return rgb
388 |
389 | def sparse_bilateral_filtering(
390 | depth,
391 | filter_size,
392 | sigma_r=0.5,
393 | sigma_s=4.0,
394 | depth_threshold=0.04,
395 | HR=False,
396 | mask=None,
397 | num_iter=None,
398 | ):
399 |
400 | save_discontinuities = []
401 | vis_depth = depth.copy()
402 | for i in range(num_iter):
403 | u_over, b_over, l_over, r_over = vis_depth_discontinuity(
404 | vis_depth, depth_threshold, mask=mask
405 | )
406 |
407 | discontinuity_map = (u_over + b_over + l_over + r_over).clip(0.0, 1.0)
408 | discontinuity_map[depth == 0] = 1
409 | save_discontinuities.append(discontinuity_map)
410 | if mask is not None:
411 | discontinuity_map[mask == 0] = 0
412 | vis_depth = bilateral_filter(
413 | vis_depth,
414 | sigma_r=sigma_r,
415 | sigma_s=sigma_s,
416 | discontinuity_map=discontinuity_map,
417 | HR=HR,
418 | mask=mask,
419 | window_size=filter_size[i],
420 | )
421 |
422 | return vis_depth
423 |
424 | def disparity_to_tensor(disp_path, unsqueeze=True):
425 | disp = cv2.imread(disp_path, -1) / (2 ** 16 - 1)
426 | disp = sparse_bilateral_filtering(disp + 1e-4, filter_size=[5, 5], num_iter=2)
427 | disp = torch.from_numpy(disp)[None, ...]
428 | if unsqueeze:
429 | disp = disp.unsqueeze(0)
430 | return disp.float()
431 |
432 |
433 | #######################
434 | # some helper geometry functions
435 | # adapt from https://github.com/mattpoggi/depthstillation
436 | #######################
437 | def transformation_from_parameters(axisangle, translation, invert=False):
438 | R = rot_from_axisangle(axisangle)
439 | t = translation.clone()
440 |
441 | if invert:
442 | R = R.transpose(1, 2)
443 | t *= -1
444 |
445 | T = get_translation_matrix(t)
446 |
447 | if invert:
448 | M = torch.matmul(R, T)
449 | else:
450 | M = torch.matmul(T, R)
451 |
452 | return M
453 |
454 |
455 | def get_translation_matrix(translation_vector):
456 | T = torch.zeros(translation_vector.shape[0], 4, 4).to(device=translation_vector.device)
457 | t = translation_vector.contiguous().view(-1, 3, 1)
458 | T[:, 0, 0] = 1
459 | T[:, 1, 1] = 1
460 | T[:, 2, 2] = 1
461 | T[:, 3, 3] = 1
462 | T[:, :3, 3, None] = t
463 | return T
464 |
465 |
466 | def rot_from_axisangle(vec):
467 | angle = torch.norm(vec, 2, 2, True)
468 | axis = vec / (angle + 1e-7)
469 |
470 | ca = torch.cos(angle)
471 | sa = torch.sin(angle)
472 | C = 1 - ca
473 |
474 | x = axis[..., 0].unsqueeze(1)
475 | y = axis[..., 1].unsqueeze(1)
476 | z = axis[..., 2].unsqueeze(1)
477 |
478 | xs = x * sa
479 | ys = y * sa
480 | zs = z * sa
481 | xC = x * C
482 | yC = y * C
483 | zC = z * C
484 | xyC = x * yC
485 | yzC = y * zC
486 | zxC = z * xC
487 |
488 | rot = torch.zeros((vec.shape[0], 4, 4)).to(device=vec.device)
489 |
490 | rot[:, 0, 0] = torch.squeeze(x * xC + ca)
491 | rot[:, 0, 1] = torch.squeeze(xyC - zs)
492 | rot[:, 0, 2] = torch.squeeze(zxC + ys)
493 | rot[:, 1, 0] = torch.squeeze(xyC + zs)
494 | rot[:, 1, 1] = torch.squeeze(y * yC + ca)
495 | rot[:, 1, 2] = torch.squeeze(yzC - xs)
496 | rot[:, 2, 0] = torch.squeeze(zxC - ys)
497 | rot[:, 2, 1] = torch.squeeze(yzC + xs)
498 | rot[:, 2, 2] = torch.squeeze(z * zC + ca)
499 | rot[:, 3, 3] = 1
500 |
501 | return rot
502 |
503 |
--------------------------------------------------------------------------------
/romatch/utils/utils.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | import numpy as np
3 | import cv2
4 | import math
5 | import torch
6 | from torchvision import transforms
7 | from torchvision.transforms.functional import InterpolationMode
8 | import torch.nn.functional as F
9 | from PIL import Image
10 | import kornia
11 |
12 | def recover_pose(E, kpts0, kpts1, K0, K1, mask):
13 | best_num_inliers = 0
14 | K0inv = np.linalg.inv(K0[:2,:2])
15 | K1inv = np.linalg.inv(K1[:2,:2])
16 |
17 | kpts0_n = (K0inv @ (kpts0-K0[None,:2,2]).T).T
18 | kpts1_n = (K1inv @ (kpts1-K1[None,:2,2]).T).T
19 |
20 | for _E in np.split(E, len(E) / 3):
21 | n, R, t, _ = cv2.recoverPose(_E, kpts0_n, kpts1_n, np.eye(3), 1e9, mask=mask)
22 | if n > best_num_inliers:
23 | best_num_inliers = n
24 | ret = (R, t, mask.ravel() > 0)
25 | return ret
26 |
27 |
28 |
29 | # Code taken from https://github.com/PruneTruong/DenseMatching/blob/40c29a6b5c35e86b9509e65ab0cd12553d998e5f/validation/utils_pose_estimation.py
30 | # --- GEOMETRY ---
31 | def estimate_pose(kpts0, kpts1, K0, K1, norm_thresh, conf=0.99999):
32 | if len(kpts0) < 5:
33 | return None
34 | K0inv = np.linalg.inv(K0[:2,:2])
35 | K1inv = np.linalg.inv(K1[:2,:2])
36 |
37 | kpts0 = (K0inv @ (kpts0-K0[None,:2,2]).T).T
38 | kpts1 = (K1inv @ (kpts1-K1[None,:2,2]).T).T
39 | E, mask = cv2.findEssentialMat(
40 | kpts0, kpts1, np.eye(3), threshold=norm_thresh, prob=conf
41 | )
42 |
43 | ret = None
44 | if E is not None:
45 | best_num_inliers = 0
46 |
47 | for _E in np.split(E, len(E) / 3):
48 | n, R, t, _ = cv2.recoverPose(_E, kpts0, kpts1, np.eye(3), 1e9, mask=mask)
49 | if n > best_num_inliers:
50 | best_num_inliers = n
51 | ret = (R, t, mask.ravel() > 0)
52 | return ret
53 |
54 | def estimate_pose_uncalibrated(kpts0, kpts1, K0, K1, norm_thresh, conf=0.99999):
55 | if len(kpts0) < 5:
56 | return None
57 | method = cv2.USAC_ACCURATE
58 | F, mask = cv2.findFundamentalMat(
59 | kpts0, kpts1, ransacReprojThreshold=norm_thresh, confidence=conf, method=method, maxIters=10000
60 | )
61 | E = K1.T@F@K0
62 | ret = None
63 | if E is not None:
64 | best_num_inliers = 0
65 | K0inv = np.linalg.inv(K0[:2,:2])
66 | K1inv = np.linalg.inv(K1[:2,:2])
67 |
68 | kpts0_n = (K0inv @ (kpts0-K0[None,:2,2]).T).T
69 | kpts1_n = (K1inv @ (kpts1-K1[None,:2,2]).T).T
70 |
71 | for _E in np.split(E, len(E) / 3):
72 | n, R, t, _ = cv2.recoverPose(_E, kpts0_n, kpts1_n, np.eye(3), 1e9, mask=mask)
73 | if n > best_num_inliers:
74 | best_num_inliers = n
75 | ret = (R, t, mask.ravel() > 0)
76 | return ret
77 |
78 | def unnormalize_coords(x_n,h,w):
79 | x = torch.stack(
80 | (w * (x_n[..., 0] + 1) / 2, h * (x_n[..., 1] + 1) / 2), dim=-1
81 | ) # [-1+1/h, 1-1/h] -> [0.5, h-0.5]
82 | return x
83 |
84 |
85 | def rotate_intrinsic(K, n):
86 | base_rot = np.array([[0, 1, 0], [-1, 0, 0], [0, 0, 1]])
87 | rot = np.linalg.matrix_power(base_rot, n)
88 | return rot @ K
89 |
90 |
91 | def rotate_pose_inplane(i_T_w, rot):
92 | rotation_matrices = [
93 | np.array(
94 | [
95 | [np.cos(r), -np.sin(r), 0.0, 0.0],
96 | [np.sin(r), np.cos(r), 0.0, 0.0],
97 | [0.0, 0.0, 1.0, 0.0],
98 | [0.0, 0.0, 0.0, 1.0],
99 | ],
100 | dtype=np.float32,
101 | )
102 | for r in [np.deg2rad(d) for d in (0, 270, 180, 90)]
103 | ]
104 | return np.dot(rotation_matrices[rot], i_T_w)
105 |
106 |
107 | def scale_intrinsics(K, scales):
108 | scales = np.diag([1.0 / scales[0], 1.0 / scales[1], 1.0])
109 | return np.dot(scales, K)
110 |
111 |
112 | def to_homogeneous(points):
113 | return np.concatenate([points, np.ones_like(points[:, :1])], axis=-1)
114 |
115 |
116 | def angle_error_mat(R1, R2):
117 | cos = (np.trace(np.dot(R1.T, R2)) - 1) / 2
118 | cos = np.clip(cos, -1.0, 1.0) # numercial errors can make it out of bounds
119 | return np.rad2deg(np.abs(np.arccos(cos)))
120 |
121 |
122 | def angle_error_vec(v1, v2):
123 | n = np.linalg.norm(v1) * np.linalg.norm(v2)
124 | return np.rad2deg(np.arccos(np.clip(np.dot(v1, v2) / n, -1.0, 1.0)))
125 |
126 |
127 | def compute_pose_error(T_0to1, R, t):
128 | R_gt = T_0to1[:3, :3]
129 | t_gt = T_0to1[:3, 3]
130 | error_t = angle_error_vec(t.squeeze(), t_gt)
131 | error_t = np.minimum(error_t, 180 - error_t) # ambiguity of E estimation
132 | error_R = angle_error_mat(R, R_gt)
133 | return error_t, error_R
134 |
135 |
136 | def pose_auc(errors, thresholds):
137 | sort_idx = np.argsort(errors)
138 | errors = np.array(errors.copy())[sort_idx]
139 | recall = (np.arange(len(errors)) + 1) / len(errors)
140 | errors = np.r_[0.0, errors]
141 | recall = np.r_[0.0, recall]
142 | aucs = []
143 | for t in thresholds:
144 | last_index = np.searchsorted(errors, t)
145 | r = np.r_[recall[:last_index], recall[last_index - 1]]
146 | e = np.r_[errors[:last_index], t]
147 | aucs.append(np.trapz(r, x=e) / t)
148 | return aucs
149 |
150 |
151 | # From Patch2Pix https://github.com/GrumpyZhou/patch2pix
152 | def get_depth_tuple_transform_ops_nearest_exact(resize=None):
153 | ops = []
154 | if resize:
155 | ops.append(TupleResizeNearestExact(resize))
156 | return TupleCompose(ops)
157 |
158 | def get_depth_tuple_transform_ops(resize=None, normalize=True, unscale=False):
159 | ops = []
160 | if resize:
161 | ops.append(TupleResize(resize, mode=InterpolationMode.BILINEAR))
162 | return TupleCompose(ops)
163 |
164 |
165 | def get_tuple_transform_ops(resize=None, normalize=True, unscale=False, clahe = False, colorjiggle_params = None):
166 | ops = []
167 | if resize:
168 | ops.append(TupleResize(resize))
169 | ops.append(TupleToTensorScaled())
170 | if normalize:
171 | ops.append(
172 | TupleNormalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
173 | ) # Imagenet mean/std
174 | return TupleCompose(ops)
175 |
176 | class ToTensorScaled(object):
177 | """Convert a RGB PIL Image to a CHW ordered Tensor, scale the range to [0, 1]"""
178 |
179 | def __call__(self, im):
180 | if not isinstance(im, torch.Tensor):
181 | im = np.array(im, dtype=np.float32).transpose((2, 0, 1))
182 | im /= 255.0
183 | return torch.from_numpy(im)
184 | else:
185 | return im
186 |
187 | def __repr__(self):
188 | return "ToTensorScaled(./255)"
189 |
190 |
191 | class TupleToTensorScaled(object):
192 | def __init__(self):
193 | self.to_tensor = ToTensorScaled()
194 |
195 | def __call__(self, im_tuple):
196 | return [self.to_tensor(im) for im in im_tuple]
197 |
198 | def __repr__(self):
199 | return "TupleToTensorScaled(./255)"
200 |
201 |
202 | class ToTensorUnscaled(object):
203 | """Convert a RGB PIL Image to a CHW ordered Tensor"""
204 |
205 | def __call__(self, im):
206 | return torch.from_numpy(np.array(im, dtype=np.float32).transpose((2, 0, 1)))
207 |
208 | def __repr__(self):
209 | return "ToTensorUnscaled()"
210 |
211 |
212 | class TupleToTensorUnscaled(object):
213 | """Convert a RGB PIL Image to a CHW ordered Tensor"""
214 |
215 | def __init__(self):
216 | self.to_tensor = ToTensorUnscaled()
217 |
218 | def __call__(self, im_tuple):
219 | return [self.to_tensor(im) for im in im_tuple]
220 |
221 | def __repr__(self):
222 | return "TupleToTensorUnscaled()"
223 |
224 | class TupleResizeNearestExact:
225 | def __init__(self, size):
226 | self.size = size
227 | def __call__(self, im_tuple):
228 | return [F.interpolate(im, size = self.size, mode = 'nearest-exact') for im in im_tuple]
229 |
230 | def __repr__(self):
231 | return "TupleResizeNearestExact(size={})".format(self.size)
232 |
233 |
234 | class TupleResize(object):
235 | def __init__(self, size, mode=InterpolationMode.BICUBIC):
236 | self.size = size
237 | self.resize = transforms.Resize(size, mode)
238 | def __call__(self, im_tuple):
239 | return [self.resize(im) for im in im_tuple]
240 |
241 | def __repr__(self):
242 | return "TupleResize(size={})".format(self.size)
243 |
244 | class Normalize:
245 | def __call__(self,im):
246 | mean = im.mean(dim=(1,2), keepdims=True)
247 | std = im.std(dim=(1,2), keepdims=True)
248 | return (im-mean)/std
249 |
250 |
251 | class TupleNormalize(object):
252 | def __init__(self, mean, std):
253 | self.mean = mean
254 | self.std = std
255 | self.normalize = transforms.Normalize(mean=mean, std=std)
256 |
257 | def __call__(self, im_tuple):
258 | c,h,w = im_tuple[0].shape
259 | if c > 3:
260 | warnings.warn(f"Number of channels c={c} > 3, assuming first 3 are rgb")
261 | return [self.normalize(im[:3]) for im in im_tuple]
262 |
263 | def __repr__(self):
264 | return "TupleNormalize(mean={}, std={})".format(self.mean, self.std)
265 |
266 |
267 | class TupleCompose(object):
268 | def __init__(self, transforms):
269 | self.transforms = transforms
270 |
271 | def __call__(self, im_tuple):
272 | for t in self.transforms:
273 | im_tuple = t(im_tuple)
274 | return im_tuple
275 |
276 | def __repr__(self):
277 | format_string = self.__class__.__name__ + "("
278 | for t in self.transforms:
279 | format_string += "\n"
280 | format_string += " {0}".format(t)
281 | format_string += "\n)"
282 | return format_string
283 |
284 | @torch.no_grad()
285 | def cls_to_flow(cls, deterministic_sampling = True):
286 | B,C,H,W = cls.shape
287 | device = cls.device
288 | res = round(math.sqrt(C))
289 | G = torch.meshgrid(
290 | *[torch.linspace(-1+1/res, 1-1/res, steps = res, device = device) for _ in range(2)],
291 | indexing = 'ij'
292 | )
293 | G = torch.stack([G[1],G[0]],dim=-1).reshape(C,2)
294 | if deterministic_sampling:
295 | sampled_cls = cls.max(dim=1).indices
296 | else:
297 | sampled_cls = torch.multinomial(cls.permute(0,2,3,1).reshape(B*H*W,C).softmax(dim=-1), 1).reshape(B,H,W)
298 | flow = G[sampled_cls]
299 | return flow
300 |
301 | @torch.no_grad()
302 | def cls_to_flow_refine(cls):
303 | B,C,H,W = cls.shape
304 | device = cls.device
305 | res = round(math.sqrt(C))
306 | G = torch.meshgrid(
307 | *[torch.linspace(-1+1/res, 1-1/res, steps = res, device = device) for _ in range(2)],
308 | indexing = 'ij'
309 | )
310 | G = torch.stack([G[1],G[0]],dim=-1).reshape(C,2)
311 | # FIXME: below softmax line causes mps to bug, don't know why.
312 | if device.type == 'mps':
313 | cls = cls.log_softmax(dim=1).exp()
314 | else:
315 | cls = cls.softmax(dim=1)
316 | mode = cls.max(dim=1).indices
317 |
318 | index = torch.stack((mode-1, mode, mode+1, mode - res, mode + res), dim = 1).clamp(0,C - 1).long()
319 | neighbours = torch.gather(cls, dim = 1, index = index)[...,None]
320 | flow = neighbours[:,0] * G[index[:,0]] + neighbours[:,1] * G[index[:,1]] + neighbours[:,2] * G[index[:,2]] + neighbours[:,3] * G[index[:,3]] + neighbours[:,4] * G[index[:,4]]
321 | tot_prob = neighbours.sum(dim=1)
322 | flow = flow / tot_prob
323 | return flow
324 |
325 |
326 | def get_gt_warp(depth1, depth2, T_1to2, K1, K2, depth_interpolation_mode = 'bilinear', relative_depth_error_threshold = 0.05, H = None, W = None):
327 |
328 | if H is None:
329 | B,H,W = depth1.shape
330 | else:
331 | B = depth1.shape[0]
332 | with torch.no_grad():
333 | x1_n = torch.meshgrid(
334 | *[
335 | torch.linspace(
336 | -1 + 1 / n, 1 - 1 / n, n, device=depth1.device
337 | )
338 | for n in (B, H, W)
339 | ],
340 | indexing = 'ij'
341 | )
342 | x1_n = torch.stack((x1_n[2], x1_n[1]), dim=-1).reshape(B, H * W, 2)
343 | mask, x2 = warp_kpts(
344 | x1_n.double(),
345 | depth1.double(),
346 | depth2.double(),
347 | T_1to2.double(),
348 | K1.double(),
349 | K2.double(),
350 | depth_interpolation_mode = depth_interpolation_mode,
351 | relative_depth_error_threshold = relative_depth_error_threshold,
352 | )
353 | prob = mask.float().reshape(B, H, W)
354 | x2 = x2.reshape(B, H, W, 2)
355 | return x2, prob
356 |
357 | @torch.no_grad()
358 | def warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1, smooth_mask = False, return_relative_depth_error = False, depth_interpolation_mode = "bilinear", relative_depth_error_threshold = 0.05):
359 | """Warp kpts0 from I0 to I1 with depth, K and Rt
360 | Also check covisibility and depth consistency.
361 | Depth is consistent if relative error < 0.2 (hard-coded).
362 | # https://github.com/zju3dv/LoFTR/blob/94e98b695be18acb43d5d3250f52226a8e36f839/src/loftr/utils/geometry.py adapted from here
363 | Args:
364 | kpts0 (torch.Tensor): [N, L, 2] - , should be normalized in (-1,1)
365 | depth0 (torch.Tensor): [N, H, W],
366 | depth1 (torch.Tensor): [N, H, W],
367 | T_0to1 (torch.Tensor): [N, 3, 4],
368 | K0 (torch.Tensor): [N, 3, 3],
369 | K1 (torch.Tensor): [N, 3, 3],
370 | Returns:
371 | calculable_mask (torch.Tensor): [N, L]
372 | warped_keypoints0 (torch.Tensor): [N, L, 2]
373 | """
374 | (
375 | n,
376 | h,
377 | w,
378 | ) = depth0.shape
379 | if depth_interpolation_mode == "combined":
380 | # Inspired by approach in inloc, try to fill holes from bilinear interpolation by nearest neighbour interpolation
381 | if smooth_mask:
382 | raise NotImplementedError("Combined bilinear and NN warp not implemented")
383 | valid_bilinear, warp_bilinear = warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1,
384 | smooth_mask = smooth_mask,
385 | return_relative_depth_error = return_relative_depth_error,
386 | depth_interpolation_mode = "bilinear",
387 | relative_depth_error_threshold = relative_depth_error_threshold)
388 | valid_nearest, warp_nearest = warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1,
389 | smooth_mask = smooth_mask,
390 | return_relative_depth_error = return_relative_depth_error,
391 | depth_interpolation_mode = "nearest-exact",
392 | relative_depth_error_threshold = relative_depth_error_threshold)
393 | nearest_valid_bilinear_invalid = (~valid_bilinear).logical_and(valid_nearest)
394 | warp = warp_bilinear.clone()
395 | warp[nearest_valid_bilinear_invalid] = warp_nearest[nearest_valid_bilinear_invalid]
396 | valid = valid_bilinear | valid_nearest
397 | return valid, warp
398 |
399 |
400 | kpts0_depth = F.grid_sample(depth0[:, None], kpts0[:, :, None], mode = depth_interpolation_mode, align_corners=False)[
401 | :, 0, :, 0
402 | ]
403 | kpts0 = torch.stack(
404 | (w * (kpts0[..., 0] + 1) / 2, h * (kpts0[..., 1] + 1) / 2), dim=-1
405 | ) # [-1+1/h, 1-1/h] -> [0.5, h-0.5]
406 | # Sample depth, get calculable_mask on depth != 0
407 | nonzero_mask = kpts0_depth != 0
408 |
409 | # Unproject
410 | kpts0_h = (
411 | torch.cat([kpts0, torch.ones_like(kpts0[:, :, [0]])], dim=-1)
412 | * kpts0_depth[..., None]
413 | ) # (N, L, 3)
414 | kpts0_n = K0.inverse() @ kpts0_h.transpose(2, 1) # (N, 3, L)
415 | kpts0_cam = kpts0_n
416 |
417 | # Rigid Transform
418 | w_kpts0_cam = T_0to1[:, :3, :3] @ kpts0_cam + T_0to1[:, :3, [3]] # (N, 3, L)
419 | w_kpts0_depth_computed = w_kpts0_cam[:, 2, :]
420 |
421 | # Project
422 | w_kpts0_h = (K1 @ w_kpts0_cam).transpose(2, 1) # (N, L, 3)
423 | w_kpts0 = w_kpts0_h[:, :, :2] / (
424 | w_kpts0_h[:, :, [2]] + 1e-4
425 | ) # (N, L, 2), +1e-4 to avoid zero depth
426 |
427 | # Covisible Check
428 | h, w = depth1.shape[1:3]
429 | covisible_mask = (
430 | (w_kpts0[:, :, 0] > 0)
431 | * (w_kpts0[:, :, 0] < w - 1)
432 | * (w_kpts0[:, :, 1] > 0)
433 | * (w_kpts0[:, :, 1] < h - 1)
434 | )
435 | w_kpts0 = torch.stack(
436 | (2 * w_kpts0[..., 0] / w - 1, 2 * w_kpts0[..., 1] / h - 1), dim=-1
437 | ) # from [0.5,h-0.5] -> [-1+1/h, 1-1/h]
438 | # w_kpts0[~covisible_mask, :] = -5 # xd
439 |
440 | w_kpts0_depth = F.grid_sample(
441 | depth1[:, None], w_kpts0[:, :, None], mode=depth_interpolation_mode, align_corners=False
442 | )[:, 0, :, 0]
443 |
444 | relative_depth_error = (
445 | (w_kpts0_depth - w_kpts0_depth_computed) / w_kpts0_depth
446 | ).abs()
447 | if not smooth_mask:
448 | consistent_mask = relative_depth_error < relative_depth_error_threshold
449 | else:
450 | consistent_mask = (-relative_depth_error/smooth_mask).exp()
451 | valid_mask = nonzero_mask * covisible_mask * consistent_mask
452 | if return_relative_depth_error:
453 | return relative_depth_error, w_kpts0
454 | else:
455 | return valid_mask, w_kpts0
456 |
457 | imagenet_mean = torch.tensor([0.485, 0.456, 0.406])
458 | imagenet_std = torch.tensor([0.229, 0.224, 0.225])
459 |
460 |
461 | def numpy_to_pil(x: np.ndarray):
462 | """
463 | Args:
464 | x: Assumed to be of shape (h,w,c)
465 | """
466 | if isinstance(x, torch.Tensor):
467 | x = x.detach().cpu().numpy()
468 | if x.max() <= 1.01:
469 | x *= 255
470 | x = x.astype(np.uint8)
471 | return Image.fromarray(x)
472 |
473 |
474 | def tensor_to_pil(x, unnormalize=False):
475 | if unnormalize:
476 | x = x * (imagenet_std[:, None, None].to(x.device)) + (imagenet_mean[:, None, None].to(x.device))
477 | x = x.detach().permute(1, 2, 0).cpu().numpy()
478 | x = np.clip(x, 0.0, 1.0)
479 | return numpy_to_pil(x)
480 |
481 |
482 | def to_cuda(batch):
483 | for key, value in batch.items():
484 | if isinstance(value, torch.Tensor):
485 | batch[key] = value.cuda()
486 | return batch
487 |
488 |
489 | def to_cpu(batch):
490 | for key, value in batch.items():
491 | if isinstance(value, torch.Tensor):
492 | batch[key] = value.cpu()
493 | return batch
494 |
495 |
496 | def get_pose(calib):
497 | w, h = np.array(calib["imsize"])[0]
498 | return np.array(calib["K"]), np.array(calib["R"]), np.array(calib["T"]).T, h, w
499 |
500 |
501 | def compute_relative_pose(R1, t1, R2, t2):
502 | rots = R2 @ (R1.T)
503 | trans = -rots @ t1 + t2
504 | return rots, trans
505 |
506 | @torch.no_grad()
507 | def reset_opt(opt):
508 | for group in opt.param_groups:
509 | for p in group['params']:
510 | if p.requires_grad:
511 | state = opt.state[p]
512 | # State initialization
513 |
514 | # Exponential moving average of gradient values
515 | state['exp_avg'] = torch.zeros_like(p)
516 | # Exponential moving average of squared gradient values
517 | state['exp_avg_sq'] = torch.zeros_like(p)
518 | # Exponential moving average of gradient difference
519 | state['exp_avg_diff'] = torch.zeros_like(p)
520 |
521 |
522 | def flow_to_pixel_coords(flow, h1, w1):
523 | flow = (
524 | torch.stack(
525 | (
526 | w1 * (flow[..., 0] + 1) / 2,
527 | h1 * (flow[..., 1] + 1) / 2,
528 | ),
529 | axis=-1,
530 | )
531 | )
532 | return flow
533 |
534 | to_pixel_coords = flow_to_pixel_coords # just an alias
535 |
536 | def flow_to_normalized_coords(flow, h1, w1):
537 | flow = (
538 | torch.stack(
539 | (
540 | 2 * (flow[..., 0]) / w1 - 1,
541 | 2 * (flow[..., 1]) / h1 - 1,
542 | ),
543 | axis=-1,
544 | )
545 | )
546 | return flow
547 |
548 | to_normalized_coords = flow_to_normalized_coords # just an alias
549 |
550 | def warp_to_pixel_coords(warp, h1, w1, h2, w2):
551 | warp1 = warp[..., :2]
552 | warp1 = (
553 | torch.stack(
554 | (
555 | w1 * (warp1[..., 0] + 1) / 2,
556 | h1 * (warp1[..., 1] + 1) / 2,
557 | ),
558 | axis=-1,
559 | )
560 | )
561 | warp2 = warp[..., 2:]
562 | warp2 = (
563 | torch.stack(
564 | (
565 | w2 * (warp2[..., 0] + 1) / 2,
566 | h2 * (warp2[..., 1] + 1) / 2,
567 | ),
568 | axis=-1,
569 | )
570 | )
571 | return torch.cat((warp1,warp2), dim=-1)
572 |
573 |
574 |
575 | def signed_point_line_distance(point, line, eps: float = 1e-9):
576 | r"""Return the distance from points to lines.
577 |
578 | Args:
579 | point: (possibly homogeneous) points :math:`(*, N, 2 or 3)`.
580 | line: lines coefficients :math:`(a, b, c)` with shape :math:`(*, N, 3)`, where :math:`ax + by + c = 0`.
581 | eps: Small constant for safe sqrt.
582 |
583 | Returns:
584 | the computed distance with shape :math:`(*, N)`.
585 | """
586 |
587 | if not point.shape[-1] in (2, 3):
588 | raise ValueError(f"pts must be a (*, 2 or 3) tensor. Got {point.shape}")
589 |
590 | if not line.shape[-1] == 3:
591 | raise ValueError(f"lines must be a (*, 3) tensor. Got {line.shape}")
592 |
593 | numerator = (line[..., 0] * point[..., 0] + line[..., 1] * point[..., 1] + line[..., 2])
594 | denominator = line[..., :2].norm(dim=-1)
595 |
596 | return numerator / (denominator + eps)
597 |
598 |
599 | def signed_left_to_right_epipolar_distance(pts1, pts2, Fm):
600 | r"""Return one-sided epipolar distance for correspondences given the fundamental matrix.
601 |
602 | This method measures the distance from points in the right images to the epilines
603 | of the corresponding points in the left images as they reflect in the right images.
604 |
605 | Args:
606 | pts1: correspondences from the left images with shape
607 | :math:`(*, N, 2 or 3)`. If they are not homogeneous, converted automatically.
608 | pts2: correspondences from the right images with shape
609 | :math:`(*, N, 2 or 3)`. If they are not homogeneous, converted automatically.
610 | Fm: Fundamental matrices with shape :math:`(*, 3, 3)`. Called Fm to
611 | avoid ambiguity with torch.nn.functional.
612 |
613 | Returns:
614 | the computed Symmetrical distance with shape :math:`(*, N)`.
615 | """
616 | import kornia
617 | if (len(Fm.shape) < 3) or not Fm.shape[-2:] == (3, 3):
618 | raise ValueError(f"Fm must be a (*, 3, 3) tensor. Got {Fm.shape}")
619 |
620 | if pts1.shape[-1] == 2:
621 | pts1 = kornia.geometry.convert_points_to_homogeneous(pts1)
622 |
623 | F_t = Fm.transpose(dim0=-2, dim1=-1)
624 | line1_in_2 = pts1 @ F_t
625 |
626 | return signed_point_line_distance(pts2, line1_in_2)
627 |
628 | def get_grid(b, h, w, device):
629 | grid = torch.meshgrid(
630 | *[
631 | torch.linspace(-1 + 1 / n, 1 - 1 / n, n, device=device)
632 | for n in (b, h, w)
633 | ],
634 | indexing = 'ij'
635 | )
636 | grid = torch.stack((grid[2], grid[1]), dim=-1).reshape(b, h, w, 2)
637 | return grid
638 |
639 |
640 | def get_autocast_params(device=None, enabled=False, dtype=None):
641 | if device is None:
642 | autocast_device = "cuda" if torch.cuda.is_available() else "cpu"
643 | else:
644 | #strip :X from device
645 | autocast_device = str(device).split(":")[0]
646 | if 'cuda' in str(device):
647 | out_dtype = dtype
648 | enabled = True
649 | else:
650 | out_dtype = torch.bfloat16
651 | enabled = False
652 | # mps is not supported
653 | autocast_device = "cpu"
654 | return autocast_device, enabled, out_dtype
655 |
656 | def check_not_i16(im):
657 | if im.mode == "I;16":
658 | raise NotImplementedError("Can't handle 16 bit images")
659 |
660 | def check_rgb(im):
661 | if im.mode != "RGB":
662 | raise NotImplementedError("Can't handle non-RGB images")
663 |
--------------------------------------------------------------------------------