├── .gitignore
├── CONTRIBUTING.md
├── LICENSE.txt
├── README.md
├── assets
└── teaser.jpg
├── requirements.txt
├── restyle_image.py
├── restyle_scene.py
├── scene_transfer
├── __init__.py
├── attention_utils.py
├── config.py
├── ddpm_inversion.py
├── depth_estimator.py
├── image_utils.py
├── latent_utils.py
├── model_utils.py
├── sd15_transfer.py
├── sdxl_refiner.py
└── semantic_matching.py
├── scene_transfer_model.py
├── scripts
├── download_data.py
├── download_data.sh
└── download_weights.sh
├── third_party
├── depth_anything_v2
│ ├── dinov2.py
│ ├── dinov2_layers
│ │ ├── __init__.py
│ │ ├── attention.py
│ │ ├── block.py
│ │ ├── drop_path.py
│ │ ├── layer_scale.py
│ │ ├── mlp.py
│ │ ├── patch_embed.py
│ │ └── swiglu_ffn.py
│ ├── dpt.py
│ └── util
│ │ ├── blocks.py
│ │ └── transform.py
└── dust3r
│ ├── LICENSE
│ ├── croco
│ ├── LICENSE
│ ├── NOTICE
│ ├── README.MD
│ ├── assets
│ │ ├── Chateau1.png
│ │ ├── Chateau2.png
│ │ └── arch.jpg
│ ├── croco-stereo-flow-demo.ipynb
│ ├── datasets
│ │ ├── __init__.py
│ │ ├── crops
│ │ │ ├── README.MD
│ │ │ └── extract_crops_from_images.py
│ │ ├── habitat_sim
│ │ │ ├── README.MD
│ │ │ ├── __init__.py
│ │ │ ├── generate_from_metadata.py
│ │ │ ├── generate_from_metadata_files.py
│ │ │ ├── generate_multiview_images.py
│ │ │ ├── multiview_habitat_sim_generator.py
│ │ │ ├── pack_metadata_files.py
│ │ │ └── paths.py
│ │ ├── pairs_dataset.py
│ │ └── transforms.py
│ ├── demo.py
│ ├── interactive_demo.ipynb
│ ├── models
│ │ ├── blocks.py
│ │ ├── criterion.py
│ │ ├── croco.py
│ │ ├── croco_downstream.py
│ │ ├── curope
│ │ │ ├── __init__.py
│ │ │ ├── build
│ │ │ │ └── temp.linux-x86_64-cpython-38
│ │ │ │ │ └── build.ninja
│ │ │ ├── curope.cpp
│ │ │ ├── curope2d.py
│ │ │ ├── kernels.cu
│ │ │ └── setup.py
│ │ ├── dpt_block.py
│ │ ├── head_downstream.py
│ │ ├── masking.py
│ │ └── pos_embed.py
│ ├── pretrain.py
│ ├── stereoflow
│ │ ├── README.MD
│ │ ├── augmentor.py
│ │ ├── criterion.py
│ │ ├── datasets_flow.py
│ │ ├── datasets_stereo.py
│ │ ├── download_model.sh
│ │ ├── engine.py
│ │ ├── test.py
│ │ └── train.py
│ └── utils
│ │ └── misc.py
│ ├── datasets_preprocess
│ ├── path_to_root.py
│ └── preprocess_co3d.py
│ └── dust3r
│ ├── __init__.py
│ ├── cloud_opt
│ ├── __init__.py
│ ├── base_opt.py
│ ├── commons.py
│ ├── init_im_poses.py
│ ├── optimizer.py
│ └── pair_viewer.py
│ ├── datasets
│ ├── __init__.py
│ ├── base
│ │ ├── __init__.py
│ │ ├── base_stereo_view_dataset.py
│ │ ├── batched_sampler.py
│ │ └── easy_dataset.py
│ ├── co3d.py
│ └── utils
│ │ ├── __init__.py
│ │ ├── cropping.py
│ │ └── transforms.py
│ ├── heads
│ ├── __init__.py
│ ├── dpt_head.py
│ ├── linear_head.py
│ └── postprocess.py
│ ├── image_pairs.py
│ ├── inference.py
│ ├── losses.py
│ ├── model.py
│ ├── optim_factory.py
│ ├── patch_embed.py
│ ├── post_process.py
│ ├── utils
│ ├── __init__.py
│ ├── device.py
│ ├── geometry.py
│ ├── image.py
│ ├── misc.py
│ └── path_to_croco.py
│ └── viz.py
├── utils
├── adain.py
├── logging.py
└── proj_utils.py
└── viewformer
├── UNet2DConditionalModel.py
├── __init__.py
├── image_utils.py
├── sdxl.py
├── stylelifter.py
└── viewtransfer_pipeline.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
2 | output
3 | checkpoints
4 | data
5 | demo
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to contribute
2 |
3 | We'd love to accept your patches and contributions to this project.
4 |
5 | ## Before you begin
6 |
7 | ### Sign our Contributor License Agreement
8 |
9 | Contributions to this project must be accompanied by a
10 | [Contributor License Agreement](https://cla.developers.google.com/about) (CLA).
11 | You (or your employer) retain the copyright to your contribution; this simply
12 | gives us permission to use and redistribute your contributions as part of the
13 | project.
14 |
15 | If you or your current employer have already signed the Google CLA (even if it
16 | was for a different project), you probably don't need to do it again.
17 |
18 | Visit to see your current agreements or to
19 | sign a new one.
20 |
21 | ### Review our community guidelines
22 |
23 | This project follows
24 | [Google's Open Source Community Guidelines](https://opensource.google/conduct/).
25 |
26 | ## Contribution process
27 |
28 | ### Code reviews
29 |
30 | All submissions, including submissions by project members, require review. We
31 | use GitHub pull requests for this purpose. Consult
32 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
33 | information on using pull requests.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # 🎨 ReStyle3D: Scene-Level Appearance Transfer with Semantic Correspondences
2 |
3 | ### ACM SIGGRAPH 2025
4 |
5 | [](https://restyle3d.github.io/) [](https://arxiv.org/abs/2502.10377) [](https://huggingface.co/gradient-spaces/ReStyle3D) [](https://www.apache.org/licenses/LICENSE-2.0)
6 |
7 | Official implementation of the paper titled "Scene-level Appearance Transfer with Semantic Correspondences".
8 |
9 | [Liyuan Zhu](https://www.zhuliyuan.net/)1,
10 | [Shengqu Cai](https://primecai.github.io/)1,\*,
11 | [Shengyu Huang](https://shengyuh.github.io/)2,\*,
12 | [Gordon Wetzstein](https://stanford.edu/~gordonwz/)1,
13 | [Naji Khosravan](https://www.najikhosravan.com/)3,
14 | [Iro Armeni](https://ir0.github.io/)1
15 |
16 |
17 |
18 | 1Stanford University, 2NVIDIA Research, 3Zillow Group | \* denotes equal contribution
19 |
20 |
21 | ```bibtex
22 | @inproceedings{zhu2025_restyle3d,
23 | author = {Liyuan Zhu and Shengqu Cai and Shengyu Huang and Gordon Wetzstein and Naji Khosravan and Iro Armeni},
24 | title = {Scene-level Appearance Transfer with Semantic Correspondences},
25 | booktitle = {ACM SIGGRAPH 2025 Conference Papers},
26 | year = {2025},
27 | }
28 | ```
29 |
30 | We introduce ReStyle3D, a novel framework for scene-level appearance
31 | transfer from a single style image to a real-world scene represented by
32 | multiple views. This method combines explicit semantic correspondences
33 | with multi-view consistency to achieve precise and coherent stylization.
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 | ## 🛠️ Setup
42 | ### ✅ Tested Environments
43 | - Ubuntu 22.04 LTS, Python 3.10.15, CUDA 12.2, GeForce RTX 4090/3090
44 |
45 | - CentOS Linux 7, Python 3.12.1, CUDA 12.4, NVIDIA A100
46 |
47 | ### 📦 Repository
48 | ```
49 | git clone git@github.com:GradientSpaces/ReStyle3D.git
50 | cd ReStyle3D
51 | ```
52 |
53 | ### 💻 Installation
54 | ```
55 | conda create -n restyle3d python=3.10
56 | conda activate restyle3d
57 | pip install -r requirements.txt
58 | ```
59 |
60 | ### 📦 Pretrained Checkpoints
61 | Download the pretrained models by running:
62 | ```
63 | bash scripts/download_weights.sh
64 | ```
65 |
66 |
67 | ## 🚀 Usage
68 |
69 | We download our dataset:
70 | ```
71 | bash scripts/download_data.sh
72 | ```
73 |
74 | ### 🎮 Demo (Single-view)
75 | We include 3 demo images to run semantic appearance transfer:
76 | ```
77 | python restyle_image.py
78 | ```
79 |
80 |
81 |
82 | ### 🎨 Stylizing Multi-view Scenes
83 | To run on a single scene and style:
84 | ```
85 | python restyle_scene.py \
86 | --scene_path demo/scene_transfer/bedroom/ \
87 | --scene_type bedroom \
88 | --style_path demo/design_styles/bedroom/pexels-itsterrymag-2631746
89 | ```
90 |
91 | ### 📂 Dataset: SceneTransfer
92 | We organize the data into two components:
93 |
94 | 1. Interior Scenes:
95 | Multi-view real-world scans with aligned images, depth, and semantic segmentations.
96 | ```
97 | 📁 data/
98 | └── interiors/
99 | ├── bedroom/
100 | │ ├── 0/
101 | │ │ ├── images/ # multi-view RGB images
102 | │ │ ├── depth/ # depth maps
103 | │ │ └── seg_dict/ # semantic segmentation dictionaries
104 | │ └── 1/
105 | │ └── ...
106 | ├── living_room/
107 | └── kitchen/
108 | ```
109 | 2. Design Styles:
110 | Style examplars with precomputed semantic segmentation.
111 | ```
112 | 📁 data/
113 | └── design_styles/
114 | ├── bedroom/
115 | │ └── pexels-itsterrymag-2631746/
116 | │ ├── image.jpg # style reference image
117 | │ ├── seg_dict.pth # semantic segmentation dictionary
118 | │ └── seg.png # segmentation visualization
119 | ├── living_room/
120 | └── kitchen/
121 | ```
122 |
123 |
124 |
125 |
126 |
127 | ## 🚧 TODO
128 | - [ ] Release full dataset
129 | - [ ] Release evaluation code
130 | - [ ] Customize dataset
131 |
132 |
133 | ## 🙏 Acknowledgement
134 | Our codebase is built on top of the following works:
135 | - [Cross-image-attention](https://github.com/garibida/cross-image-attention)
136 | - [ODISE](https://github.com/NVlabs/ODISE)
137 | - [ViewCrafter](https://github.com/Drexubery/ViewCrafter)
138 | - [GenWarp](https://github.com/sony/genwarp)
139 | - [DUSt3R](https://github.com/naver/dust3r)
140 |
141 | We appreciate the open-source efforts from the authors.
142 |
143 | ## 📫 Contact
144 | If you encounter any issues or have questions, feel free to reach out: [Liyuan Zhu](liyzhu@stanford.edu).
145 |
146 |
147 |
148 |
--------------------------------------------------------------------------------
/assets/teaser.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GradientSpaces/ReStyle3D/5a6440eeee53783530dbbda7c26898b8b4c351da/assets/teaser.jpg
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | # Use PyTorch CUDA wheels in addition to PyPI
2 | --extra-index-url https://download.pytorch.org/whl/cu121
3 |
4 | torch==2.5.0
5 | torchvision==0.20.0
6 | torchaudio==2.5.0
7 |
8 | # Everything below comes from the default PyPI index
9 | diffusers==0.31.0
10 | xformers==0.0.28.post2
11 | transformers==4.43.2
12 | accelerate==1.0.1
13 | einops
14 | roma
15 | open3d
16 | scikit-learn
17 | pyrallis
18 | jaxtyping
19 | opencv-python
20 | matplotlib
21 | huggingface_hub[cli]
22 | git+https://github.com/pesser/splatting
23 |
--------------------------------------------------------------------------------
/restyle_scene.py:
--------------------------------------------------------------------------------
1 | import os
2 | from pathlib import Path
3 | import argparse
4 |
5 | from restyle_image import generate_single_view_stylized
6 | from viewformer.stylelifter import StyleLifter
7 | from utils.logging import logger
8 |
9 | def restyle_scene(scene_path: str, style_path: str, scene_type: str, output_root: str, downsample: int):
10 | scene_path = Path(scene_path)
11 | style_path = Path(style_path)
12 | scene_id = scene_path.parts[-1]
13 | style_id = style_path.parts[-1]
14 |
15 | # Get structure image
16 | image_files = sorted((scene_path / "images").glob("*.jpg"))
17 | if not image_files:
18 | logger.error(f"No images found in {scene_path / 'images'}")
19 | return
20 | struct_img = image_files[0]
21 | frame_name = struct_img.stem
22 |
23 | struct_seg = scene_path / "seg_dict" / f"{frame_name}.pth"
24 |
25 | # Look for style image.* (jpg or png)
26 | image_candidates = list(style_path.glob("image.*"))
27 | if not image_candidates:
28 | logger.error(f"No image.* found in {style_path}")
29 | return
30 | style_img = image_candidates[0]
31 |
32 | style_seg = style_path / "seg_dict.pth"
33 |
34 | # Check required files
35 | for p in [struct_img, struct_seg, style_img, style_seg]:
36 | if not p.exists():
37 | logger.error(f"Missing required input: {p}")
38 | return
39 |
40 | # Step 1: generate single-view stylized image
41 | logger.info(f"Generating single-view stylization: {style_id} → {scene_path}")
42 | stylized_2d_output = Path("output/2d_results") / f"{scene_id}_style_{style_id}"
43 | generate_single_view_stylized(
44 | struct_img_path=struct_img,
45 | style_img_path=style_img,
46 | struct_seg_dict=struct_seg,
47 | style_seg_dict=style_seg,
48 | output_path=stylized_2d_output / "intermediate",
49 | scene_type=scene_type,
50 | )
51 |
52 | # Step 2: multi-view lifting
53 | logger.info(f"Starting multi-view style lifting...")
54 | stylelifter = StyleLifter(ckpt_path="checkpoints")
55 | output_3d_path = Path(output_root) / f"{scene_type}_{scene_id}" / style_id
56 |
57 | stylelifter(
58 | src_scene=str(scene_path),
59 | stylized_path=stylized_2d_output / "stylized.png",
60 | output_path=output_3d_path,
61 | downsample=downsample
62 | )
63 |
64 | logger.info(f"✅ Scene stylization complete for {scene_type}/{scene_id} using {style_id}.")
65 | logger.info(f"🖼️ Results saved to: {output_3d_path.resolve()}")
66 |
67 |
68 |
69 | if __name__ == "__main__":
70 | parser = argparse.ArgumentParser(description="ReStyle3D: Scene Stylization Pipeline")
71 | parser.add_argument("--scene_path", type=str, required=True, help="Path to scene directory (e.g., data/interiors/bedroom/0)")
72 | parser.add_argument("--style_path", type=str, required=True, help="Path to style folder (e.g., data/design_styles_v2/bedroom/pexels-xxx)")
73 | parser.add_argument("--scene_type", type=str, required=True, help="Scene type (e.g., bedroom, kitchen, living_room)")
74 | parser.add_argument("--output_root", type=str, default="output/demo_restyle3d", help="Root path to save results")
75 | parser.add_argument("--downsample", type=int, default=4, help="Downsampling stride for multi-view processing (default: 4)")
76 |
77 | args = parser.parse_args()
78 | restyle_scene(args.scene_path, args.style_path, args.scene_type, args.output_root, args.downsample)
79 |
80 |
--------------------------------------------------------------------------------
/scene_transfer/__init__.py:
--------------------------------------------------------------------------------
1 | OUT_INDEX = 0
2 | STYLE_INDEX = 1
3 | STRUCT_INDEX = 2
4 |
--------------------------------------------------------------------------------
/scene_transfer/attention_utils.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from scene_transfer import OUT_INDEX
4 |
5 | def should_mix_keys_and_values(model, hidden_states: torch.Tensor) -> bool:
6 | """ Verify whether we should perform the mixing in the current timestep. """
7 | is_in_32_timestep_range = (
8 | model.config.cross_attn_32_range.start <= model.step < model.config.cross_attn_32_range.end
9 | )
10 | is_in_64_timestep_range = (
11 | model.config.cross_attn_64_range.start <= model.step < model.config.cross_attn_64_range.end
12 | )
13 | is_hidden_states_32_square = (hidden_states.shape[1] == 32 ** 2)
14 | is_hidden_states_64_square = (hidden_states.shape[1] == 64 ** 2)
15 | should_mix = (is_in_32_timestep_range and is_hidden_states_32_square) or \
16 | (is_in_64_timestep_range and is_hidden_states_64_square)
17 | return should_mix
18 |
19 |
20 | def compute_scaled_dot_product_attention(Q, K, V, edit_map=False, is_cross=False, contrast_strength=1.0, masks=None):
21 | """ Compute the scale dot product attention, potentially with our contrasting operation. """
22 | cost_volume = Q @ K.transpose(-2, -1) / math.sqrt(Q.size(-1))
23 | # attn_weight = torch.softmax((Q @ K.transpose(-2, -1) / math.sqrt(Q.size(-1))), dim=-1)
24 | if masks is not None:
25 | mask_64, mask_32 = masks
26 | if (Q.shape[-2] == 32 ** 2):
27 | cost_volume[OUT_INDEX] = cost_volume[OUT_INDEX] * mask_32
28 |
29 | if (Q.shape[-2] == 64 ** 2):
30 | cost_volume[OUT_INDEX] = cost_volume[OUT_INDEX] * mask_64
31 |
32 | attn_weight = torch.softmax(cost_volume, dim=-1)
33 | if edit_map and not is_cross:
34 | attn_weight[OUT_INDEX] = torch.stack([
35 | torch.clip(enhance_tensor(attn_weight[OUT_INDEX][head_idx], contrast_factor=contrast_strength),
36 | min=0.0, max=1.0)
37 | for head_idx in range(attn_weight.shape[1])
38 | ])
39 | return attn_weight @ V, attn_weight
40 |
41 | def enhance_tensor(tensor: torch.Tensor, contrast_factor: float = 1.67) -> torch.Tensor:
42 | """ Compute the attention map contrasting. """
43 | adjusted_tensor = (tensor - tensor.mean(dim=-1)) * contrast_factor + tensor.mean(dim=-1)
44 | return adjusted_tensor
--------------------------------------------------------------------------------
/scene_transfer/config.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from pathlib import Path
3 | from typing import NamedTuple, Optional
4 |
5 |
6 | class Range(NamedTuple):
7 | start: int
8 | end: int
9 |
10 |
11 | @dataclass
12 | class RunConfig:
13 | # Appearance image path
14 | app_image_path: Path
15 | # Struct image path
16 | struct_image_path: Path
17 | # Domain name (e.g., buildings, animals)
18 | domain_name: Optional[str] = None
19 | # Output path
20 | output_path: Path = Path('./output/test')
21 | # Random seed
22 | seed: int = 42
23 | # Input prompt for inversion (will use domain name as default)
24 | prompt: Optional[str] = None
25 | # Number of timesteps
26 | num_timesteps: int = 120
27 | # Whether to use a binary mask for performing AdaIN
28 | use_masked_adain: bool = False
29 | # Timesteps to apply cross-attention on 64x64 layers
30 | cross_attn_64_range: Range = Range(start=10, end=70)
31 | # Timesteps to apply cross-attention on 32x32 layers
32 | cross_attn_32_range: Range = Range(start=10, end=70)
33 | # Timesteps to apply AdaIn
34 | adain_range: Range = Range(start=20, end=100)
35 | # Swap guidance scale
36 | swap_guidance_scale: float = 2.0
37 | # Attention contrasting strength
38 | contrast_strength: float = 1.67
39 | # Object nouns to use for self-segmentation (will use the domain name as default)
40 | object_noun: Optional[str] = None
41 | # Whether to load previously saved inverted latent codes
42 | load_latents: bool = True
43 | # Number of steps to skip in the denoising process (used value from original edit-friendly DDPM paper)
44 | skip_steps: int = 32
45 | # ControlNet guidance scale
46 | controlnet_guidance: float = 1.0
47 | # Predict depth
48 | pred_depth: bool = True
49 | # Appearance image depth path
50 | app_depth_path: Path = None
51 | # Struct image depth path
52 | struct_depth_path: Path = None
53 |
54 | def config_exp(self):
55 | self.output_path = self.output_path
56 | self.output_path.mkdir(parents=True, exist_ok=True)
57 |
58 | # Handle the domain name, prompt, and object nouns used for masking, etc.
59 | if self.use_masked_adain and self.domain_name is None:
60 | raise ValueError("Must provide --domain_name and --prompt when using masked AdaIN")
61 | if not self.use_masked_adain and self.domain_name is None:
62 | self.domain_name = "object"
63 | if self.prompt is None:
64 | self.prompt = f"A photo of a {self.domain_name}"
65 | if self.object_noun is None:
66 | self.object_noun = self.domain_name
67 |
68 | # Define the paths to store the inverted latents to
69 | self.latents_path = Path(self.output_path) / "latents"
70 | self.latents_path.mkdir(parents=True, exist_ok=True)
71 | self.app_latent_save_path = self.latents_path / f"{self.app_image_path.stem}.pt"
72 | self.struct_latent_save_path = self.latents_path / f"{self.struct_image_path.stem}.pt"
73 |
74 | if self.pred_depth:
75 | self.app_depth_path = self.output_path / "app_depth.png"
76 | self.struct_depth_path = self.output_path / "struct_depth.png"
--------------------------------------------------------------------------------
/scene_transfer/depth_estimator.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import cv2
3 | import diffusers
4 |
5 |
6 | from scene_transfer.config import RunConfig
7 | from third_party.depth_anything_v2.dpt import DepthAnythingV2
8 |
9 |
10 | DEVICE = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
11 |
12 | def get_DepthAnyThing_model(encoder='vitl'):
13 | model_configs = {
14 | 'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
15 | 'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
16 | 'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
17 | 'vitg': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]}
18 | }
19 | model = DepthAnythingV2(**model_configs[encoder])
20 | model.load_state_dict(torch.load(f'checkpoints/depth_anything_v2_{encoder}.pth', map_location='cpu'))
21 | model = model.to(DEVICE).eval()
22 | return model
23 |
24 | def normalize_depthmap(depthmap):
25 | min_, max_ = depthmap.min(), depthmap.max()
26 | depthmap = (depthmap - min_) / (max_ - min_) * 255
27 | return depthmap
28 |
29 | def get_depthmaps(cfg: RunConfig, model='Depth-Anything'):
30 | if model == "Depth-Anything":
31 | model = get_DepthAnyThing_model()
32 | app_image = cv2.imread(cfg.app_image_path)
33 | struct_image = cv2.imread(cfg.struct_image_path)
34 |
35 | app_depth = model.infer_image(app_image)
36 | struct_depth = model.infer_image(struct_image)
37 |
38 | elif model == "Marigold":
39 | pipe = diffusers.MarigoldDepthPipeline.from_pretrained(
40 | "prs-eth/marigold-depth-lcm-v1-0", variant="fp16", torch_dtype=torch.float16
41 | ).to(DEVICE)
42 | app_image = diffusers.utils.load_image(str(cfg.app_image_path))
43 | struct_image = diffusers.utils.load_image(str(cfg.app_image_path))
44 | app_depth = pipe(app_image)[0].squeeze()
45 | struct_depth = pipe(struct_image)[0].squeeze()
46 | else:
47 | raise NotImplementedError("unknown depth estimator!")
48 |
49 | app_depth = normalize_depthmap(app_depth)
50 | struct_depth = normalize_depthmap(struct_depth)
51 |
52 | return app_depth, struct_depth
53 |
54 | def get_depthmap(image, model='Depth-Anything', normalize=True):
55 | if model == "Depth-Anything":
56 | model = get_DepthAnyThing_model()
57 |
58 | depth = model.infer_image(image)
59 |
60 | else:
61 | raise NotImplementedError("unknown depth estimator!")
62 |
63 | if normalize:
64 | depth = normalize_depthmap(depth)
65 |
66 | return depth
--------------------------------------------------------------------------------
/scene_transfer/image_utils.py:
--------------------------------------------------------------------------------
1 | import pathlib
2 | from typing import Optional, Tuple
3 |
4 | import numpy as np
5 | from PIL import Image
6 |
7 | from scene_transfer.config import RunConfig
8 |
9 |
10 | def load_images(cfg: RunConfig, save_path: Optional[pathlib.Path] = None) -> Tuple[Image.Image, Image.Image]:
11 | image_style = load_size(cfg.app_image_path)
12 | image_struct = load_size(cfg.struct_image_path)
13 |
14 | if save_path is not None:
15 | Image.fromarray(image_style).save(save_path / f"in_style.png")
16 | Image.fromarray(image_struct).save(save_path / f"in_struct.png")
17 | return image_style, image_struct
18 |
19 |
20 | def load_size(image_path: pathlib.Path,
21 | left: int = 0,
22 | right: int = 0,
23 | top: int = 0,
24 | bottom: int = 0,
25 | size: int = 512,
26 | resize: bool = True) -> Image.Image:
27 | if isinstance(image_path, (str, pathlib.Path)):
28 | image = np.array(Image.open(str(image_path)).convert('RGB'))
29 | else:
30 | image = image_path
31 |
32 | if resize:
33 | # Resize the image
34 | resized_image = Image.fromarray(image).resize((size, size))
35 |
36 | # Convert back to numpy array
37 | resized_array = np.array(resized_image)
38 | return resized_array
39 |
40 | h, w, _ = image.shape
41 |
42 | left = min(left, w - 1)
43 | right = min(right, w - left - 1)
44 | top = min(top, h - left - 1)
45 | bottom = min(bottom, h - top - 1)
46 | image = image[top:h - bottom, left:w - right]
47 |
48 | h, w, c = image.shape
49 |
50 | if h < w:
51 | offset = (w - h) // 2
52 | image = image[:, offset:offset + h]
53 | elif w < h:
54 | offset = (h - w) // 2
55 | image = image[offset:offset + w]
56 |
57 | image = np.array(Image.fromarray(image).resize((size, size)))
58 | return image
59 |
60 |
61 | def save_generated_masks(model, cfg: RunConfig):
62 | tensor2im(model.image_app_mask_32).save(cfg.output_path / f"mask_style_32.png")
63 | tensor2im(model.image_struct_mask_32).save(cfg.output_path / f"mask_struct_32.png")
64 | tensor2im(model.image_app_mask_64).save(cfg.output_path / f"mask_style_64.png")
65 | tensor2im(model.image_struct_mask_64).save(cfg.output_path / f"mask_struct_64.png")
66 |
67 |
68 | def tensor2im(x) -> Image.Image:
69 | return Image.fromarray(x.cpu().numpy().astype(np.uint8) * 255)
--------------------------------------------------------------------------------
/scene_transfer/latent_utils.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from typing import Tuple, Union
3 |
4 | import numpy as np
5 | import torch
6 | from PIL import Image
7 |
8 | from scene_transfer_model import SceneTransfer
9 | from scene_transfer.config import RunConfig
10 | from scene_transfer import image_utils
11 | from scene_transfer.ddpm_inversion import invert
12 | from scene_transfer.depth_estimator import get_depthmaps
13 | from utils.logging import logger
14 |
15 | def load_latents_or_invert_images(model: Union[SceneTransfer], cfg: RunConfig):
16 | if cfg.load_latents and cfg.app_latent_save_path.exists() and cfg.struct_latent_save_path.exists():
17 | logger.info("Loading existing latents...")
18 | latents_app, latents_struct = load_latents(cfg.app_latent_save_path, cfg.struct_latent_save_path)
19 | noise_app, noise_struct = load_noise(cfg.app_latent_save_path, cfg.struct_latent_save_path)
20 | else:
21 | logger.info("Inverting images...")
22 | app_image, struct_image = image_utils.load_images(cfg=cfg, save_path=cfg.output_path)
23 | # Load depth images
24 | if cfg.pred_depth:
25 | depth_app, depth_struct = get_depthmaps(cfg)
26 | depth_app, depth_struct = Image.fromarray(depth_app).convert("RGB"), Image.fromarray(depth_struct).convert("RGB")
27 | depth_app.save(cfg.app_depth_path)
28 | depth_struct.save(cfg.struct_depth_path)
29 | else:
30 | depth_app = Image.open(cfg.app_depth_path).convert("RGB")
31 | depth_struct = Image.open(cfg.struct_depth_path).convert("RGB")
32 |
33 | # Ensure depth images are the same size as the input images
34 | depth_app = depth_app.resize(app_image.shape[:2])
35 | depth_struct = depth_struct.resize(struct_image.shape[:2])
36 |
37 | # Normalize depth images to [0, 1]
38 | depth_app = np.array(depth_app).astype(np.float32) / 255.0
39 | depth_struct = np.array(depth_struct).astype(np.float32) / 255.0
40 |
41 | # Convert back to PIL Image
42 | depth_app = Image.fromarray((depth_app * 255).astype(np.uint8))
43 | depth_struct = Image.fromarray((depth_struct * 255).astype(np.uint8))
44 |
45 | model.enable_edit = False # Deactivate the cross-image attention layers
46 | latents_app, latents_struct, noise_app, noise_struct = invert_images(app_image=app_image,
47 | struct_image=struct_image,
48 | sd_model=model.pipe,
49 | depth_app=depth_app,
50 | depth_struct=depth_struct,
51 | cfg=cfg)
52 | model.enable_edit = True
53 | return latents_app, latents_struct, noise_app, noise_struct
54 |
55 |
56 | def load_latents(app_latent_save_path: Path, struct_latent_save_path: Path) -> Tuple[torch.Tensor, torch.Tensor]:
57 | latents_app = torch.load(app_latent_save_path, weights_only=True)
58 | latents_struct = torch.load(struct_latent_save_path, weights_only=True)
59 | if type(latents_struct) == list:
60 | latents_app = [l.to("cuda") for l in latents_app]
61 | latents_struct = [l.to("cuda") for l in latents_struct]
62 | else:
63 | latents_app = latents_app.to("cuda")
64 | latents_struct = latents_struct.to("cuda")
65 | return latents_app, latents_struct
66 |
67 |
68 | def load_noise(app_latent_save_path: Path, struct_latent_save_path: Path) -> Tuple[torch.Tensor, torch.Tensor]:
69 | latents_app = torch.load(app_latent_save_path.parent / (app_latent_save_path.stem + "_ddpm_noise.pt"))
70 | latents_struct = torch.load(struct_latent_save_path.parent / (struct_latent_save_path.stem + "_ddpm_noise.pt"))
71 | latents_app = latents_app.to("cuda")
72 | latents_struct = latents_struct.to("cuda")
73 | return latents_app, latents_struct
74 |
75 |
76 | def invert_images(sd_model: Union[SceneTransfer], app_image: Image.Image, struct_image: Image.Image, depth_app: Image.Image, depth_struct: Image.Image, cfg: RunConfig):
77 | input_app = torch.from_numpy(np.array(app_image)).float() / 127.5 - 1.0
78 | input_struct = torch.from_numpy(np.array(struct_image)).float() / 127.5 - 1.0
79 |
80 | zs_app, latents_app = invert(x0=input_app.permute(2, 0, 1).unsqueeze(0).to('cuda'),
81 | pipe=sd_model,
82 | prompt_src=cfg.prompt,
83 | num_diffusion_steps=cfg.num_timesteps,
84 | cfg_scale_src=3.5,
85 | depth=depth_app)
86 |
87 | zs_struct, latents_struct = invert(x0=input_struct.permute(2, 0, 1).unsqueeze(0).to('cuda'),
88 | pipe=sd_model,
89 | prompt_src=cfg.prompt,
90 | num_diffusion_steps=cfg.num_timesteps,
91 | cfg_scale_src=3.5,
92 | depth=depth_struct)
93 |
94 | # Save the inverted latents and noises
95 | torch.save(latents_app, cfg.latents_path / f"{cfg.app_image_path.stem}.pt")
96 | torch.save(latents_struct, cfg.latents_path / f"{cfg.struct_image_path.stem}.pt")
97 | torch.save(zs_app, cfg.latents_path / f"{cfg.app_image_path.stem}_ddpm_noise.pt")
98 | torch.save(zs_struct, cfg.latents_path / f"{cfg.struct_image_path.stem}_ddpm_noise.pt")
99 | return latents_app, latents_struct, zs_app, zs_struct
100 |
101 |
102 | def get_init_latents_and_noises(model: Union[SceneTransfer], cfg: RunConfig) -> Tuple[torch.Tensor, torch.Tensor]:
103 | # If we stored all the latents along the diffusion process, select the desired one based on the skip_steps
104 | if model.latents_struct.dim() == 4 and model.latents_app.dim() == 4 and model.latents_app.shape[0] > 1:
105 | model.latents_struct = model.latents_struct[cfg.skip_steps]
106 | model.latents_app = model.latents_app[cfg.skip_steps]
107 | init_latents = torch.stack([model.latents_struct, model.latents_app, model.latents_struct])
108 | init_zs = [model.zs_struct[cfg.skip_steps:], model.zs_app[cfg.skip_steps:], model.zs_struct[cfg.skip_steps:]]
109 | return init_latents, init_zs
110 |
--------------------------------------------------------------------------------
/scene_transfer/model_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from diffusers import DDIMScheduler, ControlNetModel
3 | from diffusers import EulerAncestralDiscreteScheduler, AutoencoderKL
4 | from typing import Optional
5 | from scene_transfer.sd15_transfer import SemanticAttentionSD15
6 | from scene_transfer.sdxl_refiner import StableDiffusionXLControlNetPipeline
7 | from utils.logging import logger
8 |
9 | def get_scene_transfer_sd15() -> SemanticAttentionSD15:
10 | logger.info("Loading SD1.5...")
11 | device = torch.device(f'cuda') if torch.cuda.is_available() else torch.device('cpu')
12 | pipe = SemanticAttentionSD15.from_pretrained("runwayml/stable-diffusion-v1-5",
13 | controlnet=ControlNetModel.from_pretrained('lllyasviel/sd-controlnet-depth'),
14 | safety_checker=None).to(device)
15 | # pipe.unet = FreeUUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet").to(device)
16 | pipe.unet.enable_freeu(s1=0.9, s2=0.2, b1=1.5, b2=1.6)
17 | pipe.scheduler = DDIMScheduler.from_config("runwayml/stable-diffusion-v1-5", subfolder="scheduler")
18 | return pipe
19 |
20 | def get_refining_pipe(precision : torch.dtype = torch.float16) -> StableDiffusionXLControlNetPipeline:
21 | logger.info("Loading SDXL...")
22 | # Setup device
23 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
24 |
25 | controlnet = ControlNetModel.from_pretrained(
26 | "diffusers/controlnet-depth-sdxl-1.0",
27 | variant="fp16",
28 | use_safetensors=True,
29 | torch_dtype=precision,
30 | )
31 | controlnet.enable_xformers_memory_efficient_attention()
32 | vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=precision)
33 | pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
34 | "stabilityai/stable-diffusion-xl-base-1.0",
35 | controlnet=controlnet,
36 | vae=vae,
37 | variant="fp16",
38 | use_safetensors=True,
39 | torch_dtype=precision,
40 | )
41 | pipe.to(device)
42 | pipe.enable_model_cpu_offload()
43 | pipe.enable_xformers_memory_efficient_attention()
44 |
45 | return pipe
--------------------------------------------------------------------------------
/scene_transfer/semantic_matching.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from utils.logging import logger
3 |
4 | def get_mask_for_label(label_name, objects_list, mask_tensor):
5 | """
6 | Get the mask for a given label from a 2D mask tensor, including all instances of the label.
7 |
8 | Args:
9 | label_name (str): The label to find the mask for
10 | objects_list (list): List of dictionaries containing object information
11 | mask_tensor (torch.Tensor): 2D PyTorch tensor where pixel values correspond to object IDs
12 |
13 | Returns:
14 | torch.Tensor: Binary mask for the given label, including all instances
15 | """
16 | # Find all objects with the given label
17 | target_objects = [obj for obj in objects_list if obj['label'] == label_name]
18 |
19 | if not target_objects:
20 | raise ValueError(f"No object found with label '{label_name}'")
21 |
22 | # Get the IDs of all target objects
23 | target_ids = [obj['id'] for obj in target_objects]
24 |
25 | # Create a binary mask where any of the target objects are True and everything else is False
26 | binary_mask = torch.zeros_like(mask_tensor, dtype=torch.bool)
27 | for target_id in target_ids:
28 | binary_mask |= (mask_tensor == target_id)
29 |
30 | return binary_mask
31 |
32 |
33 | def match_semantic_labels(src_dict, tgt_dict):
34 | """
35 | Match semantic labels between source and target images in seg_dict.
36 |
37 | Args:
38 | seg_dict (dict): Dictionary containing segmentation predictions for source and target images.
39 |
40 | Returns:
41 | list: List of tuples containing matched labels and their similarity scores.
42 | """
43 | # Extract labels from source and target predictions
44 | source_labels = [obj['label'] for obj in src_dict['pred_tgt'][1]]
45 | target_labels = [obj['label'] for obj in tgt_dict['pred_tgt'][1]]
46 | # Find the intersection between source and target labels
47 | common_labels = list(set(source_labels) & set(target_labels))
48 |
49 | # Initialize a list to store matched labels and their masks
50 | matched_labels = []
51 |
52 | # Iterate through common labels
53 | for label in common_labels:
54 | # Get masks for the current label in both source and target
55 | source_mask = get_mask_for_label(label, src_dict['pred_tgt'][1], src_dict['pred_tgt'][0])
56 | target_mask = get_mask_for_label(label, tgt_dict['pred_tgt'][1], tgt_dict['pred_tgt'][0])
57 |
58 | # get the area of the source and target masks
59 | source_mask_area = source_mask.sum()
60 | target_mask_area = target_mask.sum()
61 | # get the area of the original images for both source and target
62 | source_img_area = src_dict['pred_tgt'][0].shape[-1] * src_dict['pred_tgt'][0].shape[-2]
63 | target_img_area = tgt_dict['pred_tgt'][0].shape[-1] * tgt_dict['pred_tgt'][0].shape[-2]
64 | # get the ratio of the source and target masks to the original images
65 | source_mask_ratio = source_mask_area / source_img_area
66 | target_mask_ratio = target_mask_area / target_img_area
67 |
68 | # Skip labels with mask ratios below 1%
69 | if source_mask_ratio < 0.01 or target_mask_ratio < 0.01:
70 | logger.info(f"[Semantic Matching] Skipping {label} mask due to small mask ratio")
71 | continue
72 |
73 | matched_labels.append((label, source_mask, target_mask))
74 |
75 | return matched_labels
76 |
77 | def merge_similar_labels(seg_dict, labels=['wall', 'floor']):
78 | """
79 | Merge similar labels in the objects list.
80 |
81 | Args:
82 | objects_list (list): List of dictionaries containing object information
83 |
84 | Returns:
85 | list: List of dictionaries containing merged object information
86 | """
87 | for obj in seg_dict['pred_src'][1]:
88 | for label in labels:
89 | if label in obj['label']:
90 | obj['label'] = label
91 | seg_dict['pred_src'][1][seg_dict['pred_src'][1].index(obj)]['label'] = label
92 |
93 | for obj in seg_dict['pred_tgt'][1]:
94 | for label in labels:
95 | if label in obj['label']:
96 | obj['label'] = label
97 | seg_dict['pred_tgt'][1][seg_dict['pred_tgt'][1].index(obj)]['label'] = label
98 |
99 | return seg_dict
100 |
101 |
--------------------------------------------------------------------------------
/scripts/download_data.py:
--------------------------------------------------------------------------------
1 | from huggingface_hub import hf_hub_download
2 |
3 | hf_hub_download(
4 | repo_id="gradient-spaces/SceneTransfer",
5 | filename="demo.zip",
6 | repo_type="dataset",
7 | local_dir=".", # Downloads to current directory
8 | )
9 |
10 |
11 | # Download SceneTransfer.zip
12 | hf_hub_download(
13 | repo_id="gradient-spaces/SceneTransfer",
14 | filename="SceneTransfer.zip",
15 | repo_type="dataset",
16 | local_dir=".", # Downloads to current directory
17 | )
--------------------------------------------------------------------------------
/scripts/download_data.sh:
--------------------------------------------------------------------------------
1 | python scripts/download_data.py
2 | unzip demo.zip
3 | rm -rf demo.zip
4 |
5 | unzip SceneTransfer.zip
6 | rm -rf SceneTransfer.zip
--------------------------------------------------------------------------------
/scripts/download_weights.sh:
--------------------------------------------------------------------------------
1 | mkdir checkpoints && cd checkpoints
2 | wget https://download.europe.naverlabs.com/ComputerVision/DUSt3R/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth
3 | wget https://huggingface.co/depth-anything/Depth-Anything-V2-Large/resolve/main/depth_anything_v2_vitl.pth
4 | cd ..
5 |
6 |
7 |
--------------------------------------------------------------------------------
/third_party/depth_anything_v2/dinov2_layers/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from .mlp import Mlp
8 | from .patch_embed import PatchEmbed
9 | from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
10 | from .block import NestedTensorBlock
11 | from .attention import MemEffAttention
12 |
--------------------------------------------------------------------------------
/third_party/depth_anything_v2/dinov2_layers/attention.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # References:
8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
10 |
11 | import logging
12 |
13 | from torch import Tensor
14 | from torch import nn
15 |
16 |
17 | logger = logging.getLogger("dinov2")
18 |
19 |
20 | try:
21 | from xformers.ops import memory_efficient_attention, unbind, fmha
22 |
23 | XFORMERS_AVAILABLE = True
24 | except ImportError:
25 | logger.warning("xFormers not available")
26 | XFORMERS_AVAILABLE = False
27 |
28 |
29 | class Attention(nn.Module):
30 | def __init__(
31 | self,
32 | dim: int,
33 | num_heads: int = 8,
34 | qkv_bias: bool = False,
35 | proj_bias: bool = True,
36 | attn_drop: float = 0.0,
37 | proj_drop: float = 0.0,
38 | ) -> None:
39 | super().__init__()
40 | self.num_heads = num_heads
41 | head_dim = dim // num_heads
42 | self.scale = head_dim**-0.5
43 |
44 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
45 | self.attn_drop = nn.Dropout(attn_drop)
46 | self.proj = nn.Linear(dim, dim, bias=proj_bias)
47 | self.proj_drop = nn.Dropout(proj_drop)
48 |
49 | def forward(self, x: Tensor) -> Tensor:
50 | B, N, C = x.shape
51 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
52 |
53 | q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
54 | attn = q @ k.transpose(-2, -1)
55 |
56 | attn = attn.softmax(dim=-1)
57 | attn = self.attn_drop(attn)
58 |
59 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
60 | x = self.proj(x)
61 | x = self.proj_drop(x)
62 | return x
63 |
64 |
65 | class MemEffAttention(Attention):
66 | def forward(self, x: Tensor, attn_bias=None) -> Tensor:
67 | if not XFORMERS_AVAILABLE:
68 | assert attn_bias is None, "xFormers is required for nested tensors usage"
69 | return super().forward(x)
70 |
71 | B, N, C = x.shape
72 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
73 |
74 | q, k, v = unbind(qkv, 2)
75 |
76 | x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
77 | x = x.reshape([B, N, C])
78 |
79 | x = self.proj(x)
80 | x = self.proj_drop(x)
81 | return x
82 |
83 |
--------------------------------------------------------------------------------
/third_party/depth_anything_v2/dinov2_layers/drop_path.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # References:
8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
10 |
11 |
12 | from torch import nn
13 |
14 |
15 | def drop_path(x, drop_prob: float = 0.0, training: bool = False):
16 | if drop_prob == 0.0 or not training:
17 | return x
18 | keep_prob = 1 - drop_prob
19 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
20 | random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
21 | if keep_prob > 0.0:
22 | random_tensor.div_(keep_prob)
23 | output = x * random_tensor
24 | return output
25 |
26 |
27 | class DropPath(nn.Module):
28 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
29 |
30 | def __init__(self, drop_prob=None):
31 | super(DropPath, self).__init__()
32 | self.drop_prob = drop_prob
33 |
34 | def forward(self, x):
35 | return drop_path(x, self.drop_prob, self.training)
36 |
--------------------------------------------------------------------------------
/third_party/depth_anything_v2/dinov2_layers/layer_scale.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
8 |
9 | from typing import Union
10 |
11 | import torch
12 | from torch import Tensor
13 | from torch import nn
14 |
15 |
16 | class LayerScale(nn.Module):
17 | def __init__(
18 | self,
19 | dim: int,
20 | init_values: Union[float, Tensor] = 1e-5,
21 | inplace: bool = False,
22 | ) -> None:
23 | super().__init__()
24 | self.inplace = inplace
25 | self.gamma = nn.Parameter(init_values * torch.ones(dim))
26 |
27 | def forward(self, x: Tensor) -> Tensor:
28 | return x.mul_(self.gamma) if self.inplace else x * self.gamma
29 |
--------------------------------------------------------------------------------
/third_party/depth_anything_v2/dinov2_layers/mlp.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # References:
8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
10 |
11 |
12 | from typing import Callable, Optional
13 |
14 | from torch import Tensor, nn
15 |
16 |
17 | class Mlp(nn.Module):
18 | def __init__(
19 | self,
20 | in_features: int,
21 | hidden_features: Optional[int] = None,
22 | out_features: Optional[int] = None,
23 | act_layer: Callable[..., nn.Module] = nn.GELU,
24 | drop: float = 0.0,
25 | bias: bool = True,
26 | ) -> None:
27 | super().__init__()
28 | out_features = out_features or in_features
29 | hidden_features = hidden_features or in_features
30 | self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
31 | self.act = act_layer()
32 | self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
33 | self.drop = nn.Dropout(drop)
34 |
35 | def forward(self, x: Tensor) -> Tensor:
36 | x = self.fc1(x)
37 | x = self.act(x)
38 | x = self.drop(x)
39 | x = self.fc2(x)
40 | x = self.drop(x)
41 | return x
42 |
--------------------------------------------------------------------------------
/third_party/depth_anything_v2/dinov2_layers/patch_embed.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # References:
8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
10 |
11 | from typing import Callable, Optional, Tuple, Union
12 |
13 | from torch import Tensor
14 | import torch.nn as nn
15 |
16 |
17 | def make_2tuple(x):
18 | if isinstance(x, tuple):
19 | assert len(x) == 2
20 | return x
21 |
22 | assert isinstance(x, int)
23 | return (x, x)
24 |
25 |
26 | class PatchEmbed(nn.Module):
27 | """
28 | 2D image to patch embedding: (B,C,H,W) -> (B,N,D)
29 |
30 | Args:
31 | img_size: Image size.
32 | patch_size: Patch token size.
33 | in_chans: Number of input image channels.
34 | embed_dim: Number of linear projection output channels.
35 | norm_layer: Normalization layer.
36 | """
37 |
38 | def __init__(
39 | self,
40 | img_size: Union[int, Tuple[int, int]] = 224,
41 | patch_size: Union[int, Tuple[int, int]] = 16,
42 | in_chans: int = 3,
43 | embed_dim: int = 768,
44 | norm_layer: Optional[Callable] = None,
45 | flatten_embedding: bool = True,
46 | ) -> None:
47 | super().__init__()
48 |
49 | image_HW = make_2tuple(img_size)
50 | patch_HW = make_2tuple(patch_size)
51 | patch_grid_size = (
52 | image_HW[0] // patch_HW[0],
53 | image_HW[1] // patch_HW[1],
54 | )
55 |
56 | self.img_size = image_HW
57 | self.patch_size = patch_HW
58 | self.patches_resolution = patch_grid_size
59 | self.num_patches = patch_grid_size[0] * patch_grid_size[1]
60 |
61 | self.in_chans = in_chans
62 | self.embed_dim = embed_dim
63 |
64 | self.flatten_embedding = flatten_embedding
65 |
66 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
67 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
68 |
69 | def forward(self, x: Tensor) -> Tensor:
70 | _, _, H, W = x.shape
71 | patch_H, patch_W = self.patch_size
72 |
73 | assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
74 | assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
75 |
76 | x = self.proj(x) # B C H W
77 | H, W = x.size(2), x.size(3)
78 | x = x.flatten(2).transpose(1, 2) # B HW C
79 | x = self.norm(x)
80 | if not self.flatten_embedding:
81 | x = x.reshape(-1, H, W, self.embed_dim) # B H W C
82 | return x
83 |
84 | def flops(self) -> float:
85 | Ho, Wo = self.patches_resolution
86 | flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
87 | if self.norm is not None:
88 | flops += Ho * Wo * self.embed_dim
89 | return flops
90 |
--------------------------------------------------------------------------------
/third_party/depth_anything_v2/dinov2_layers/swiglu_ffn.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from typing import Callable, Optional
8 |
9 | from torch import Tensor, nn
10 | import torch.nn.functional as F
11 |
12 |
13 | class SwiGLUFFN(nn.Module):
14 | def __init__(
15 | self,
16 | in_features: int,
17 | hidden_features: Optional[int] = None,
18 | out_features: Optional[int] = None,
19 | act_layer: Callable[..., nn.Module] = None,
20 | drop: float = 0.0,
21 | bias: bool = True,
22 | ) -> None:
23 | super().__init__()
24 | out_features = out_features or in_features
25 | hidden_features = hidden_features or in_features
26 | self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
27 | self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
28 |
29 | def forward(self, x: Tensor) -> Tensor:
30 | x12 = self.w12(x)
31 | x1, x2 = x12.chunk(2, dim=-1)
32 | hidden = F.silu(x1) * x2
33 | return self.w3(hidden)
34 |
35 |
36 | try:
37 | from xformers.ops import SwiGLU
38 |
39 | XFORMERS_AVAILABLE = True
40 | except ImportError:
41 | SwiGLU = SwiGLUFFN
42 | XFORMERS_AVAILABLE = False
43 |
44 |
45 | class SwiGLUFFNFused(SwiGLU):
46 | def __init__(
47 | self,
48 | in_features: int,
49 | hidden_features: Optional[int] = None,
50 | out_features: Optional[int] = None,
51 | act_layer: Callable[..., nn.Module] = None,
52 | drop: float = 0.0,
53 | bias: bool = True,
54 | ) -> None:
55 | out_features = out_features or in_features
56 | hidden_features = hidden_features or in_features
57 | hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
58 | super().__init__(
59 | in_features=in_features,
60 | hidden_features=hidden_features,
61 | out_features=out_features,
62 | bias=bias,
63 | )
64 |
--------------------------------------------------------------------------------
/third_party/depth_anything_v2/util/blocks.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 |
4 | def _make_scratch(in_shape, out_shape, groups=1, expand=False):
5 | scratch = nn.Module()
6 |
7 | out_shape1 = out_shape
8 | out_shape2 = out_shape
9 | out_shape3 = out_shape
10 | if len(in_shape) >= 4:
11 | out_shape4 = out_shape
12 |
13 | if expand:
14 | out_shape1 = out_shape
15 | out_shape2 = out_shape * 2
16 | out_shape3 = out_shape * 4
17 | if len(in_shape) >= 4:
18 | out_shape4 = out_shape * 8
19 |
20 | scratch.layer1_rn = nn.Conv2d(in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups)
21 | scratch.layer2_rn = nn.Conv2d(in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups)
22 | scratch.layer3_rn = nn.Conv2d(in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups)
23 | if len(in_shape) >= 4:
24 | scratch.layer4_rn = nn.Conv2d(in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups)
25 |
26 | return scratch
27 |
28 |
29 | class ResidualConvUnit(nn.Module):
30 | """Residual convolution module.
31 | """
32 |
33 | def __init__(self, features, activation, bn):
34 | """Init.
35 |
36 | Args:
37 | features (int): number of features
38 | """
39 | super().__init__()
40 |
41 | self.bn = bn
42 |
43 | self.groups=1
44 |
45 | self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
46 |
47 | self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
48 |
49 | if self.bn == True:
50 | self.bn1 = nn.BatchNorm2d(features)
51 | self.bn2 = nn.BatchNorm2d(features)
52 |
53 | self.activation = activation
54 |
55 | self.skip_add = nn.quantized.FloatFunctional()
56 |
57 | def forward(self, x):
58 | """Forward pass.
59 |
60 | Args:
61 | x (tensor): input
62 |
63 | Returns:
64 | tensor: output
65 | """
66 |
67 | out = self.activation(x)
68 | out = self.conv1(out)
69 | if self.bn == True:
70 | out = self.bn1(out)
71 |
72 | out = self.activation(out)
73 | out = self.conv2(out)
74 | if self.bn == True:
75 | out = self.bn2(out)
76 |
77 | if self.groups > 1:
78 | out = self.conv_merge(out)
79 |
80 | return self.skip_add.add(out, x)
81 |
82 |
83 | class FeatureFusionBlock(nn.Module):
84 | """Feature fusion block.
85 | """
86 |
87 | def __init__(
88 | self,
89 | features,
90 | activation,
91 | deconv=False,
92 | bn=False,
93 | expand=False,
94 | align_corners=True,
95 | size=None
96 | ):
97 | """Init.
98 |
99 | Args:
100 | features (int): number of features
101 | """
102 | super(FeatureFusionBlock, self).__init__()
103 |
104 | self.deconv = deconv
105 | self.align_corners = align_corners
106 |
107 | self.groups=1
108 |
109 | self.expand = expand
110 | out_features = features
111 | if self.expand == True:
112 | out_features = features // 2
113 |
114 | self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
115 |
116 | self.resConfUnit1 = ResidualConvUnit(features, activation, bn)
117 | self.resConfUnit2 = ResidualConvUnit(features, activation, bn)
118 |
119 | self.skip_add = nn.quantized.FloatFunctional()
120 |
121 | self.size=size
122 |
123 | def forward(self, *xs, size=None):
124 | """Forward pass.
125 |
126 | Returns:
127 | tensor: output
128 | """
129 | output = xs[0]
130 |
131 | if len(xs) == 2:
132 | res = self.resConfUnit1(xs[1])
133 | output = self.skip_add.add(output, res)
134 |
135 | output = self.resConfUnit2(output)
136 |
137 | if (size is None) and (self.size is None):
138 | modifier = {"scale_factor": 2}
139 | elif size is None:
140 | modifier = {"size": self.size}
141 | else:
142 | modifier = {"size": size}
143 |
144 | output = nn.functional.interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners)
145 |
146 | output = self.out_conv(output)
147 |
148 | return output
149 |
--------------------------------------------------------------------------------
/third_party/depth_anything_v2/util/transform.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import cv2
3 |
4 |
5 | class Resize(object):
6 | """Resize sample to given size (width, height).
7 | """
8 |
9 | def __init__(
10 | self,
11 | width,
12 | height,
13 | resize_target=True,
14 | keep_aspect_ratio=False,
15 | ensure_multiple_of=1,
16 | resize_method="lower_bound",
17 | image_interpolation_method=cv2.INTER_AREA,
18 | ):
19 | """Init.
20 |
21 | Args:
22 | width (int): desired output width
23 | height (int): desired output height
24 | resize_target (bool, optional):
25 | True: Resize the full sample (image, mask, target).
26 | False: Resize image only.
27 | Defaults to True.
28 | keep_aspect_ratio (bool, optional):
29 | True: Keep the aspect ratio of the input sample.
30 | Output sample might not have the given width and height, and
31 | resize behaviour depends on the parameter 'resize_method'.
32 | Defaults to False.
33 | ensure_multiple_of (int, optional):
34 | Output width and height is constrained to be multiple of this parameter.
35 | Defaults to 1.
36 | resize_method (str, optional):
37 | "lower_bound": Output will be at least as large as the given size.
38 | "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
39 | "minimal": Scale as least as possible. (Output size might be smaller than given size.)
40 | Defaults to "lower_bound".
41 | """
42 | self.__width = width
43 | self.__height = height
44 |
45 | self.__resize_target = resize_target
46 | self.__keep_aspect_ratio = keep_aspect_ratio
47 | self.__multiple_of = ensure_multiple_of
48 | self.__resize_method = resize_method
49 | self.__image_interpolation_method = image_interpolation_method
50 |
51 | def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
52 | y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
53 |
54 | if max_val is not None and y > max_val:
55 | y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
56 |
57 | if y < min_val:
58 | y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
59 |
60 | return y
61 |
62 | def get_size(self, width, height):
63 | # determine new height and width
64 | scale_height = self.__height / height
65 | scale_width = self.__width / width
66 |
67 | if self.__keep_aspect_ratio:
68 | if self.__resize_method == "lower_bound":
69 | # scale such that output size is lower bound
70 | if scale_width > scale_height:
71 | # fit width
72 | scale_height = scale_width
73 | else:
74 | # fit height
75 | scale_width = scale_height
76 | elif self.__resize_method == "upper_bound":
77 | # scale such that output size is upper bound
78 | if scale_width < scale_height:
79 | # fit width
80 | scale_height = scale_width
81 | else:
82 | # fit height
83 | scale_width = scale_height
84 | elif self.__resize_method == "minimal":
85 | # scale as least as possbile
86 | if abs(1 - scale_width) < abs(1 - scale_height):
87 | # fit width
88 | scale_height = scale_width
89 | else:
90 | # fit height
91 | scale_width = scale_height
92 | else:
93 | raise ValueError(f"resize_method {self.__resize_method} not implemented")
94 |
95 | if self.__resize_method == "lower_bound":
96 | new_height = self.constrain_to_multiple_of(scale_height * height, min_val=self.__height)
97 | new_width = self.constrain_to_multiple_of(scale_width * width, min_val=self.__width)
98 | elif self.__resize_method == "upper_bound":
99 | new_height = self.constrain_to_multiple_of(scale_height * height, max_val=self.__height)
100 | new_width = self.constrain_to_multiple_of(scale_width * width, max_val=self.__width)
101 | elif self.__resize_method == "minimal":
102 | new_height = self.constrain_to_multiple_of(scale_height * height)
103 | new_width = self.constrain_to_multiple_of(scale_width * width)
104 | else:
105 | raise ValueError(f"resize_method {self.__resize_method} not implemented")
106 |
107 | return (new_width, new_height)
108 |
109 | def __call__(self, sample):
110 | width, height = self.get_size(sample["image"].shape[1], sample["image"].shape[0])
111 |
112 | # resize sample
113 | sample["image"] = cv2.resize(sample["image"], (width, height), interpolation=self.__image_interpolation_method)
114 |
115 | if self.__resize_target:
116 | if "depth" in sample:
117 | sample["depth"] = cv2.resize(sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST)
118 |
119 | if "mask" in sample:
120 | sample["mask"] = cv2.resize(sample["mask"].astype(np.float32), (width, height), interpolation=cv2.INTER_NEAREST)
121 |
122 | return sample
123 |
124 |
125 | class NormalizeImage(object):
126 | """Normlize image by given mean and std.
127 | """
128 |
129 | def __init__(self, mean, std):
130 | self.__mean = mean
131 | self.__std = std
132 |
133 | def __call__(self, sample):
134 | sample["image"] = (sample["image"] - self.__mean) / self.__std
135 |
136 | return sample
137 |
138 |
139 | class PrepareForNet(object):
140 | """Prepare sample for usage as network input.
141 | """
142 |
143 | def __init__(self):
144 | pass
145 |
146 | def __call__(self, sample):
147 | image = np.transpose(sample["image"], (2, 0, 1))
148 | sample["image"] = np.ascontiguousarray(image).astype(np.float32)
149 |
150 | if "depth" in sample:
151 | depth = sample["depth"].astype(np.float32)
152 | sample["depth"] = np.ascontiguousarray(depth)
153 |
154 | if "mask" in sample:
155 | sample["mask"] = sample["mask"].astype(np.float32)
156 | sample["mask"] = np.ascontiguousarray(sample["mask"])
157 |
158 | return sample
--------------------------------------------------------------------------------
/third_party/dust3r/LICENSE:
--------------------------------------------------------------------------------
1 | DUSt3R, Copyright (c) 2024-present Naver Corporation, is licensed under the Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license.
2 |
3 | A summary of the CC BY-NC-SA 4.0 license is located here:
4 | https://creativecommons.org/licenses/by-nc-sa/4.0/
5 |
6 | The CC BY-NC-SA 4.0 license is located here:
7 | https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
8 |
--------------------------------------------------------------------------------
/third_party/dust3r/croco/LICENSE:
--------------------------------------------------------------------------------
1 | CroCo, Copyright (c) 2022-present Naver Corporation, is licensed under the Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license.
2 |
3 | A summary of the CC BY-NC-SA 4.0 license is located here:
4 | https://creativecommons.org/licenses/by-nc-sa/4.0/
5 |
6 | The CC BY-NC-SA 4.0 license is located here:
7 | https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
8 |
9 |
10 | SEE NOTICE BELOW WITH RESPECT TO THE FILE: models/pos_embed.py, models/blocks.py
11 |
12 | ***************************
13 |
14 | NOTICE WITH RESPECT TO THE FILE: models/pos_embed.py
15 |
16 | This software is being redistributed in a modifiled form. The original form is available here:
17 |
18 | https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
19 |
20 | This software in this file incorporates parts of the following software available here:
21 |
22 | Transformer: https://github.com/tensorflow/models/blob/master/official/legacy/transformer/model_utils.py
23 | available under the following license: https://github.com/tensorflow/models/blob/master/LICENSE
24 |
25 | MoCo v3: https://github.com/facebookresearch/moco-v3
26 | available under the following license: https://github.com/facebookresearch/moco-v3/blob/main/LICENSE
27 |
28 | DeiT: https://github.com/facebookresearch/deit
29 | available under the following license: https://github.com/facebookresearch/deit/blob/main/LICENSE
30 |
31 |
32 | ORIGINAL COPYRIGHT NOTICE AND PERMISSION NOTICE AVAILABLE HERE IS REPRODUCE BELOW:
33 |
34 | https://github.com/facebookresearch/mae/blob/main/LICENSE
35 |
36 | Attribution-NonCommercial 4.0 International
37 |
38 | ***************************
39 |
40 | NOTICE WITH RESPECT TO THE FILE: models/blocks.py
41 |
42 | This software is being redistributed in a modifiled form. The original form is available here:
43 |
44 | https://github.com/rwightman/pytorch-image-models
45 |
46 | ORIGINAL COPYRIGHT NOTICE AND PERMISSION NOTICE AVAILABLE HERE IS REPRODUCE BELOW:
47 |
48 | https://github.com/rwightman/pytorch-image-models/blob/master/LICENSE
49 |
50 | Apache License
51 | Version 2.0, January 2004
52 | http://www.apache.org/licenses/
--------------------------------------------------------------------------------
/third_party/dust3r/croco/NOTICE:
--------------------------------------------------------------------------------
1 | CroCo
2 | Copyright 2022-present NAVER Corp.
3 |
4 | This project contains subcomponents with separate copyright notices and license terms.
5 | Your use of the source code for these subcomponents is subject to the terms and conditions of the following licenses.
6 |
7 | ====
8 |
9 | facebookresearch/mae
10 | https://github.com/facebookresearch/mae
11 |
12 | Attribution-NonCommercial 4.0 International
13 |
14 | ====
15 |
16 | rwightman/pytorch-image-models
17 | https://github.com/rwightman/pytorch-image-models
18 |
19 | Apache License
20 | Version 2.0, January 2004
21 | http://www.apache.org/licenses/
--------------------------------------------------------------------------------
/third_party/dust3r/croco/assets/Chateau1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GradientSpaces/ReStyle3D/5a6440eeee53783530dbbda7c26898b8b4c351da/third_party/dust3r/croco/assets/Chateau1.png
--------------------------------------------------------------------------------
/third_party/dust3r/croco/assets/Chateau2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GradientSpaces/ReStyle3D/5a6440eeee53783530dbbda7c26898b8b4c351da/third_party/dust3r/croco/assets/Chateau2.png
--------------------------------------------------------------------------------
/third_party/dust3r/croco/assets/arch.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GradientSpaces/ReStyle3D/5a6440eeee53783530dbbda7c26898b8b4c351da/third_party/dust3r/croco/assets/arch.jpg
--------------------------------------------------------------------------------
/third_party/dust3r/croco/croco-stereo-flow-demo.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "9bca0f41",
6 | "metadata": {},
7 | "source": [
8 | "# Simple inference example with CroCo-Stereo or CroCo-Flow"
9 | ]
10 | },
11 | {
12 | "cell_type": "code",
13 | "execution_count": null,
14 | "id": "80653ef7",
15 | "metadata": {},
16 | "outputs": [],
17 | "source": [
18 | "# Copyright (C) 2022-present Naver Corporation. All rights reserved.\n",
19 | "# Licensed under CC BY-NC-SA 4.0 (non-commercial use only)."
20 | ]
21 | },
22 | {
23 | "cell_type": "markdown",
24 | "id": "4f033862",
25 | "metadata": {},
26 | "source": [
27 | "First download the model(s) of your choice by running\n",
28 | "```\n",
29 | "bash stereoflow/download_model.sh crocostereo.pth\n",
30 | "bash stereoflow/download_model.sh crocoflow.pth\n",
31 | "```"
32 | ]
33 | },
34 | {
35 | "cell_type": "code",
36 | "execution_count": null,
37 | "id": "1fb2e392",
38 | "metadata": {},
39 | "outputs": [],
40 | "source": [
41 | "import torch\n",
42 | "use_gpu = torch.cuda.is_available() and torch.cuda.device_count()>0\n",
43 | "device = torch.device('cuda:0' if use_gpu else 'cpu')\n",
44 | "import matplotlib.pylab as plt"
45 | ]
46 | },
47 | {
48 | "cell_type": "code",
49 | "execution_count": null,
50 | "id": "e0e25d77",
51 | "metadata": {},
52 | "outputs": [],
53 | "source": [
54 | "from stereoflow.test import _load_model_and_criterion\n",
55 | "from stereoflow.engine import tiled_pred\n",
56 | "from stereoflow.datasets_stereo import img_to_tensor, vis_disparity\n",
57 | "from stereoflow.datasets_flow import flowToColor\n",
58 | "tile_overlap=0.7 # recommended value, higher value can be slightly better but slower"
59 | ]
60 | },
61 | {
62 | "cell_type": "markdown",
63 | "id": "86a921f5",
64 | "metadata": {},
65 | "source": [
66 | "### CroCo-Stereo example"
67 | ]
68 | },
69 | {
70 | "cell_type": "code",
71 | "execution_count": null,
72 | "id": "64e483cb",
73 | "metadata": {},
74 | "outputs": [],
75 | "source": [
76 | "image1 = np.asarray(Image.open(''))\n",
77 | "image2 = np.asarray(Image.open(''))"
78 | ]
79 | },
80 | {
81 | "cell_type": "code",
82 | "execution_count": null,
83 | "id": "f0d04303",
84 | "metadata": {},
85 | "outputs": [],
86 | "source": [
87 | "model, _, cropsize, with_conf, task, tile_conf_mode = _load_model_and_criterion('stereoflow_models/crocostereo.pth', None, device)\n"
88 | ]
89 | },
90 | {
91 | "cell_type": "code",
92 | "execution_count": null,
93 | "id": "47dc14b5",
94 | "metadata": {},
95 | "outputs": [],
96 | "source": [
97 | "im1 = img_to_tensor(image1).to(device).unsqueeze(0)\n",
98 | "im2 = img_to_tensor(image2).to(device).unsqueeze(0)\n",
99 | "with torch.inference_mode():\n",
100 | " pred, _, _ = tiled_pred(model, None, im1, im2, None, conf_mode=tile_conf_mode, overlap=tile_overlap, crop=cropsize, with_conf=with_conf, return_time=False)\n",
101 | "pred = pred.squeeze(0).squeeze(0).cpu().numpy()"
102 | ]
103 | },
104 | {
105 | "cell_type": "code",
106 | "execution_count": null,
107 | "id": "583b9f16",
108 | "metadata": {},
109 | "outputs": [],
110 | "source": [
111 | "plt.imshow(vis_disparity(pred))\n",
112 | "plt.axis('off')"
113 | ]
114 | },
115 | {
116 | "cell_type": "markdown",
117 | "id": "d2df5d70",
118 | "metadata": {},
119 | "source": [
120 | "### CroCo-Flow example"
121 | ]
122 | },
123 | {
124 | "cell_type": "code",
125 | "execution_count": null,
126 | "id": "9ee257a7",
127 | "metadata": {},
128 | "outputs": [],
129 | "source": [
130 | "image1 = np.asarray(Image.open(''))\n",
131 | "image2 = np.asarray(Image.open(''))"
132 | ]
133 | },
134 | {
135 | "cell_type": "code",
136 | "execution_count": null,
137 | "id": "d5edccf0",
138 | "metadata": {},
139 | "outputs": [],
140 | "source": [
141 | "model, _, cropsize, with_conf, task, tile_conf_mode = _load_model_and_criterion('stereoflow_models/crocoflow.pth', None, device)\n"
142 | ]
143 | },
144 | {
145 | "cell_type": "code",
146 | "execution_count": null,
147 | "id": "b19692c3",
148 | "metadata": {},
149 | "outputs": [],
150 | "source": [
151 | "im1 = img_to_tensor(image1).to(device).unsqueeze(0)\n",
152 | "im2 = img_to_tensor(image2).to(device).unsqueeze(0)\n",
153 | "with torch.inference_mode():\n",
154 | " pred, _, _ = tiled_pred(model, None, im1, im2, None, conf_mode=tile_conf_mode, overlap=tile_overlap, crop=cropsize, with_conf=with_conf, return_time=False)\n",
155 | "pred = pred.squeeze(0).permute(1,2,0).cpu().numpy()"
156 | ]
157 | },
158 | {
159 | "cell_type": "code",
160 | "execution_count": null,
161 | "id": "26f79db3",
162 | "metadata": {},
163 | "outputs": [],
164 | "source": [
165 | "plt.imshow(flowToColor(pred))\n",
166 | "plt.axis('off')"
167 | ]
168 | }
169 | ],
170 | "metadata": {
171 | "kernelspec": {
172 | "display_name": "Python 3 (ipykernel)",
173 | "language": "python",
174 | "name": "python3"
175 | },
176 | "language_info": {
177 | "codemirror_mode": {
178 | "name": "ipython",
179 | "version": 3
180 | },
181 | "file_extension": ".py",
182 | "mimetype": "text/x-python",
183 | "name": "python",
184 | "nbconvert_exporter": "python",
185 | "pygments_lexer": "ipython3",
186 | "version": "3.9.7"
187 | }
188 | },
189 | "nbformat": 4,
190 | "nbformat_minor": 5
191 | }
192 |
--------------------------------------------------------------------------------
/third_party/dust3r/croco/datasets/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GradientSpaces/ReStyle3D/5a6440eeee53783530dbbda7c26898b8b4c351da/third_party/dust3r/croco/datasets/__init__.py
--------------------------------------------------------------------------------
/third_party/dust3r/croco/datasets/crops/README.MD:
--------------------------------------------------------------------------------
1 | ## Generation of crops from the real datasets
2 |
3 | The instructions below allow to generate the crops used for pre-training CroCo v2 from the following real-world datasets: ARKitScenes, MegaDepth, 3DStreetView and IndoorVL.
4 |
5 | ### Download the metadata of the crops to generate
6 |
7 | First, download the metadata and put them in `./data/`:
8 | ```
9 | mkdir -p data
10 | cd data/
11 | wget https://download.europe.naverlabs.com/ComputerVision/CroCo/data/crop_metadata.zip
12 | unzip crop_metadata.zip
13 | rm crop_metadata.zip
14 | cd ..
15 | ```
16 |
17 | ### Prepare the original datasets
18 |
19 | Second, download the original datasets in `./data/original_datasets/`.
20 | ```
21 | mkdir -p data/original_datasets
22 | ```
23 |
24 | ##### ARKitScenes
25 |
26 | Download the `raw` dataset from https://github.com/apple/ARKitScenes/blob/main/DATA.md and put it in `./data/original_datasets/ARKitScenes/`.
27 | The resulting file structure should be like:
28 | ```
29 | ./data/original_datasets/ARKitScenes/
30 | └───Training
31 | └───40753679
32 | │ │ ultrawide
33 | │ │ ...
34 | └───40753686
35 | │
36 | ...
37 | ```
38 |
39 | ##### MegaDepth
40 |
41 | Download `MegaDepth v1 Dataset` from https://www.cs.cornell.edu/projects/megadepth/ and put it in `./data/original_datasets/MegaDepth/`.
42 | The resulting file structure should be like:
43 |
44 | ```
45 | ./data/original_datasets/MegaDepth/
46 | └───0000
47 | │ └───images
48 | │ │ │ 1000557903_87fa96b8a4_o.jpg
49 | │ │ └ ...
50 | │ └─── ...
51 | └───0001
52 | │ │
53 | │ └ ...
54 | └─── ...
55 | ```
56 |
57 | ##### 3DStreetView
58 |
59 | Download `3D_Street_View` dataset from https://github.com/amir32002/3D_Street_View and put it in `./data/original_datasets/3DStreetView/`.
60 | The resulting file structure should be like:
61 |
62 | ```
63 | ./data/original_datasets/3DStreetView/
64 | └───dataset_aligned
65 | │ └───0002
66 | │ │ │ 0000002_0000001_0000002_0000001.jpg
67 | │ │ └ ...
68 | │ └─── ...
69 | └───dataset_unaligned
70 | │ └───0003
71 | │ │ │ 0000003_0000001_0000002_0000001.jpg
72 | │ │ └ ...
73 | │ └─── ...
74 | ```
75 |
76 | ##### IndoorVL
77 |
78 | Download the `IndoorVL` datasets using [Kapture](https://github.com/naver/kapture).
79 |
80 | ```
81 | pip install kapture
82 | mkdir -p ./data/original_datasets/IndoorVL
83 | cd ./data/original_datasets/IndoorVL
84 | kapture_download_dataset.py update
85 | kapture_download_dataset.py install "HyundaiDepartmentStore_*"
86 | kapture_download_dataset.py install "GangnamStation_*"
87 | cd -
88 | ```
89 |
90 | ### Extract the crops
91 |
92 | Now, extract the crops for each of the dataset:
93 | ```
94 | for dataset in ARKitScenes MegaDepth 3DStreetView IndoorVL;
95 | do
96 | python3 datasets/crops/extract_crops_from_images.py --crops ./data/crop_metadata/${dataset}/crops_release.txt --root-dir ./data/original_datasets/${dataset}/ --output-dir ./data/${dataset}_crops/ --imsize 256 --nthread 8 --max-subdir-levels 5 --ideal-number-pairs-in-dir 500;
97 | done
98 | ```
99 |
100 | ##### Note for IndoorVL
101 |
102 | Due to some legal issues, we can only release 144,228 pairs out of the 1,593,689 pairs used in the paper.
103 | To account for it in terms of number of pre-training iterations, the pre-training command in this repository uses 125 training epochs including 12 warm-up epochs and learning rate cosine schedule of 250, instead of 100, 10 and 200 respectively.
104 | The impact on the performance is negligible.
105 |
--------------------------------------------------------------------------------
/third_party/dust3r/croco/datasets/crops/extract_crops_from_images.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3 | #
4 | # --------------------------------------------------------
5 | # Extracting crops for pre-training
6 | # --------------------------------------------------------
7 |
8 | import os
9 | import argparse
10 | from tqdm import tqdm
11 | from PIL import Image
12 | import functools
13 | from multiprocessing import Pool
14 | import math
15 |
16 |
17 | def arg_parser():
18 | parser = argparse.ArgumentParser('Generate cropped image pairs from image crop list')
19 |
20 | parser.add_argument('--crops', type=str, required=True, help='crop file')
21 | parser.add_argument('--root-dir', type=str, required=True, help='root directory')
22 | parser.add_argument('--output-dir', type=str, required=True, help='output directory')
23 | parser.add_argument('--imsize', type=int, default=256, help='size of the crops')
24 | parser.add_argument('--nthread', type=int, required=True, help='number of simultaneous threads')
25 | parser.add_argument('--max-subdir-levels', type=int, default=5, help='maximum number of subdirectories')
26 | parser.add_argument('--ideal-number-pairs-in-dir', type=int, default=500, help='number of pairs stored in a dir')
27 | return parser
28 |
29 |
30 | def main(args):
31 | listing_path = os.path.join(args.output_dir, 'listing.txt')
32 |
33 | print(f'Loading list of crops ... ({args.nthread} threads)')
34 | crops, num_crops_to_generate = load_crop_file(args.crops)
35 |
36 | print(f'Preparing jobs ({len(crops)} candidate image pairs)...')
37 | num_levels = min(math.ceil(math.log(num_crops_to_generate, args.ideal_number_pairs_in_dir)), args.max_subdir_levels)
38 | num_pairs_in_dir = math.ceil(num_crops_to_generate ** (1/num_levels))
39 |
40 | jobs = prepare_jobs(crops, num_levels, num_pairs_in_dir)
41 | del crops
42 |
43 | os.makedirs(args.output_dir, exist_ok=True)
44 | mmap = Pool(args.nthread).imap_unordered if args.nthread > 1 else map
45 | call = functools.partial(save_image_crops, args)
46 |
47 | print(f"Generating cropped images to {args.output_dir} ...")
48 | with open(listing_path, 'w') as listing:
49 | listing.write('# pair_path\n')
50 | for results in tqdm(mmap(call, jobs), total=len(jobs)):
51 | for path in results:
52 | listing.write(f'{path}\n')
53 | print('Finished writing listing to', listing_path)
54 |
55 |
56 | def load_crop_file(path):
57 | data = open(path).read().splitlines()
58 | pairs = []
59 | num_crops_to_generate = 0
60 | for line in tqdm(data):
61 | if line.startswith('#'):
62 | continue
63 | line = line.split(', ')
64 | if len(line) < 8:
65 | img1, img2, rotation = line
66 | pairs.append((img1, img2, int(rotation), []))
67 | else:
68 | l1, r1, t1, b1, l2, r2, t2, b2 = map(int, line)
69 | rect1, rect2 = (l1, t1, r1, b1), (l2, t2, r2, b2)
70 | pairs[-1][-1].append((rect1, rect2))
71 | num_crops_to_generate += 1
72 | return pairs, num_crops_to_generate
73 |
74 |
75 | def prepare_jobs(pairs, num_levels, num_pairs_in_dir):
76 | jobs = []
77 | powers = [num_pairs_in_dir**level for level in reversed(range(num_levels))]
78 |
79 | def get_path(idx):
80 | idx_array = []
81 | d = idx
82 | for level in range(num_levels - 1):
83 | idx_array.append(idx // powers[level])
84 | idx = idx % powers[level]
85 | idx_array.append(d)
86 | return '/'.join(map(lambda x: hex(x)[2:], idx_array))
87 |
88 | idx = 0
89 | for pair_data in tqdm(pairs):
90 | img1, img2, rotation, crops = pair_data
91 | if -60 <= rotation and rotation <= 60:
92 | rotation = 0 # most likely not a true rotation
93 | paths = [get_path(idx + k) for k in range(len(crops))]
94 | idx += len(crops)
95 | jobs.append(((img1, img2), rotation, crops, paths))
96 | return jobs
97 |
98 |
99 | def load_image(path):
100 | try:
101 | return Image.open(path).convert('RGB')
102 | except Exception as e:
103 | print('skipping', path, e)
104 | raise OSError()
105 |
106 |
107 | def save_image_crops(args, data):
108 | # load images
109 | img_pair, rot, crops, paths = data
110 | try:
111 | img1, img2 = [load_image(os.path.join(args.root_dir, impath)) for impath in img_pair]
112 | except OSError as e:
113 | return []
114 |
115 | def area(sz):
116 | return sz[0] * sz[1]
117 |
118 | tgt_size = (args.imsize, args.imsize)
119 |
120 | def prepare_crop(img, rect, rot=0):
121 | # actual crop
122 | img = img.crop(rect)
123 |
124 | # resize to desired size
125 | interp = Image.Resampling.LANCZOS if area(img.size) > 4*area(tgt_size) else Image.Resampling.BICUBIC
126 | img = img.resize(tgt_size, resample=interp)
127 |
128 | # rotate the image
129 | rot90 = (round(rot/90) % 4) * 90
130 | if rot90 == 90:
131 | img = img.transpose(Image.Transpose.ROTATE_90)
132 | elif rot90 == 180:
133 | img = img.transpose(Image.Transpose.ROTATE_180)
134 | elif rot90 == 270:
135 | img = img.transpose(Image.Transpose.ROTATE_270)
136 | return img
137 |
138 | results = []
139 | for (rect1, rect2), path in zip(crops, paths):
140 | crop1 = prepare_crop(img1, rect1)
141 | crop2 = prepare_crop(img2, rect2, rot)
142 |
143 | fullpath1 = os.path.join(args.output_dir, path+'_1.jpg')
144 | fullpath2 = os.path.join(args.output_dir, path+'_2.jpg')
145 | os.makedirs(os.path.dirname(fullpath1), exist_ok=True)
146 |
147 | assert not os.path.isfile(fullpath1), fullpath1
148 | assert not os.path.isfile(fullpath2), fullpath2
149 | crop1.save(fullpath1)
150 | crop2.save(fullpath2)
151 | results.append(path)
152 |
153 | return results
154 |
155 |
156 | if __name__ == '__main__':
157 | args = arg_parser().parse_args()
158 | main(args)
159 |
160 |
--------------------------------------------------------------------------------
/third_party/dust3r/croco/datasets/habitat_sim/README.MD:
--------------------------------------------------------------------------------
1 | ## Generation of synthetic image pairs using Habitat-Sim
2 |
3 | These instructions allow to generate pre-training pairs from the Habitat simulator.
4 | As we did not save metadata of the pairs used in the original paper, they are not strictly the same, but these data use the same setting and are equivalent.
5 |
6 | ### Download Habitat-Sim scenes
7 | Download Habitat-Sim scenes:
8 | - Download links can be found here: https://github.com/facebookresearch/habitat-sim/blob/main/DATASETS.md
9 | - We used scenes from the HM3D, habitat-test-scenes, Replica, ReplicaCad and ScanNet datasets.
10 | - Please put the scenes under `./data/habitat-sim-data/scene_datasets/` following the structure below, or update manually paths in `paths.py`.
11 | ```
12 | ./data/
13 | └──habitat-sim-data/
14 | └──scene_datasets/
15 | ├──hm3d/
16 | ├──gibson/
17 | ├──habitat-test-scenes/
18 | ├──replica_cad_baked_lighting/
19 | ├──replica_cad/
20 | ├──ReplicaDataset/
21 | └──scannet/
22 | ```
23 |
24 | ### Image pairs generation
25 | We provide metadata to generate reproducible images pairs for pretraining and validation.
26 | Experiments described in the paper used similar data, but whose generation was not reproducible at the time.
27 |
28 | Specifications:
29 | - 256x256 resolution images, with 60 degrees field of view .
30 | - Up to 1000 image pairs per scene.
31 | - Number of scenes considered/number of images pairs per dataset:
32 | - Scannet: 1097 scenes / 985 209 pairs
33 | - HM3D:
34 | - hm3d/train: 800 / 800k pairs
35 | - hm3d/val: 100 scenes / 100k pairs
36 | - hm3d/minival: 10 scenes / 10k pairs
37 | - habitat-test-scenes: 3 scenes / 3k pairs
38 | - replica_cad_baked_lighting: 13 scenes / 13k pairs
39 |
40 | - Scenes from hm3d/val and hm3d/minival pairs were not used for the pre-training but kept for validation purposes.
41 |
42 | Download metadata and extract it:
43 | ```bash
44 | mkdir -p data/habitat_release_metadata/
45 | cd data/habitat_release_metadata/
46 | wget https://download.europe.naverlabs.com/ComputerVision/CroCo/data/habitat_release_metadata/multiview_habitat_metadata.tar.gz
47 | tar -xvf multiview_habitat_metadata.tar.gz
48 | cd ../..
49 | # Location of the metadata
50 | METADATA_DIR="./data/habitat_release_metadata/multiview_habitat_metadata"
51 | ```
52 |
53 | Generate image pairs from metadata:
54 | - The following command will print a list of commandlines to generate image pairs for each scene:
55 | ```bash
56 | # Target output directory
57 | PAIRS_DATASET_DIR="./data/habitat_release/"
58 | python datasets/habitat_sim/generate_from_metadata_files.py --input_dir=$METADATA_DIR --output_dir=$PAIRS_DATASET_DIR
59 | ```
60 | - One can launch multiple of such commands in parallel e.g. using GNU Parallel:
61 | ```bash
62 | python datasets/habitat_sim/generate_from_metadata_files.py --input_dir=$METADATA_DIR --output_dir=$PAIRS_DATASET_DIR | parallel -j 16
63 | ```
64 |
65 | ## Metadata generation
66 |
67 | Image pairs were randomly sampled using the following commands, whose outputs contain randomness and are thus not exactly reproducible:
68 | ```bash
69 | # Print commandlines to generate image pairs from the different scenes available.
70 | PAIRS_DATASET_DIR=MY_CUSTOM_PATH
71 | python datasets/habitat_sim/generate_multiview_images.py --list_commands --output_dir=$PAIRS_DATASET_DIR
72 |
73 | # Once a dataset is generated, pack metadata files for reproducibility.
74 | METADATA_DIR=MY_CUSTON_PATH
75 | python datasets/habitat_sim/pack_metadata_files.py $PAIRS_DATASET_DIR $METADATA_DIR
76 | ```
77 |
--------------------------------------------------------------------------------
/third_party/dust3r/croco/datasets/habitat_sim/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GradientSpaces/ReStyle3D/5a6440eeee53783530dbbda7c26898b8b4c351da/third_party/dust3r/croco/datasets/habitat_sim/__init__.py
--------------------------------------------------------------------------------
/third_party/dust3r/croco/datasets/habitat_sim/generate_from_metadata.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3 |
4 | """
5 | Script to generate image pairs for a given scene reproducing poses provided in a metadata file.
6 | """
7 | import os
8 | from datasets.habitat_sim.multiview_habitat_sim_generator import MultiviewHabitatSimGenerator
9 | from datasets.habitat_sim.paths import SCENES_DATASET
10 | import argparse
11 | import quaternion
12 | import PIL.Image
13 | import cv2
14 | import json
15 | from tqdm import tqdm
16 |
17 | def generate_multiview_images_from_metadata(metadata_filename,
18 | output_dir,
19 | overload_params = dict(),
20 | scene_datasets_paths=None,
21 | exist_ok=False):
22 | """
23 | Generate images from a metadata file for reproducibility purposes.
24 | """
25 | # Reorder paths by decreasing label length, to avoid collisions when testing if a string by such label
26 | if scene_datasets_paths is not None:
27 | scene_datasets_paths = dict(sorted(scene_datasets_paths.items(), key= lambda x: len(x[0]), reverse=True))
28 |
29 | with open(metadata_filename, 'r') as f:
30 | input_metadata = json.load(f)
31 | metadata = dict()
32 | for key, value in input_metadata.items():
33 | # Optionally replace some paths
34 | if key in ("scene_dataset_config_file", "scene", "navmesh") and value != "":
35 | if scene_datasets_paths is not None:
36 | for dataset_label, dataset_path in scene_datasets_paths.items():
37 | if value.startswith(dataset_label):
38 | value = os.path.normpath(os.path.join(dataset_path, os.path.relpath(value, dataset_label)))
39 | break
40 | metadata[key] = value
41 |
42 | # Overload some parameters
43 | for key, value in overload_params.items():
44 | metadata[key] = value
45 |
46 | generation_entries = dict([(key, value) for key, value in metadata.items() if not (key in ('multiviews', 'output_dir', 'generate_depth'))])
47 | generate_depth = metadata["generate_depth"]
48 |
49 | os.makedirs(output_dir, exist_ok=exist_ok)
50 |
51 | generator = MultiviewHabitatSimGenerator(**generation_entries)
52 |
53 | # Generate views
54 | for idx_label, data in tqdm(metadata['multiviews'].items()):
55 | positions = data["positions"]
56 | orientations = data["orientations"]
57 | n = len(positions)
58 | for oidx in range(n):
59 | observation = generator.render_viewpoint(positions[oidx], quaternion.from_float_array(orientations[oidx]))
60 | observation_label = f"{oidx + 1}" # Leonid is indexing starting from 1
61 | # Color image saved using PIL
62 | img = PIL.Image.fromarray(observation['color'][:,:,:3])
63 | filename = os.path.join(output_dir, f"{idx_label}_{observation_label}.jpeg")
64 | img.save(filename)
65 | if generate_depth:
66 | # Depth image as EXR file
67 | filename = os.path.join(output_dir, f"{idx_label}_{observation_label}_depth.exr")
68 | cv2.imwrite(filename, observation['depth'], [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_HALF])
69 | # Camera parameters
70 | camera_params = dict([(key, observation[key].tolist()) for key in ("camera_intrinsics", "R_cam2world", "t_cam2world")])
71 | filename = os.path.join(output_dir, f"{idx_label}_{observation_label}_camera_params.json")
72 | with open(filename, "w") as f:
73 | json.dump(camera_params, f)
74 | # Save metadata
75 | with open(os.path.join(output_dir, "metadata.json"), "w") as f:
76 | json.dump(metadata, f)
77 |
78 | generator.close()
79 |
80 | if __name__ == "__main__":
81 | parser = argparse.ArgumentParser()
82 | parser.add_argument("--metadata_filename", required=True)
83 | parser.add_argument("--output_dir", required=True)
84 | args = parser.parse_args()
85 |
86 | generate_multiview_images_from_metadata(metadata_filename=args.metadata_filename,
87 | output_dir=args.output_dir,
88 | scene_datasets_paths=SCENES_DATASET,
89 | overload_params=dict(),
90 | exist_ok=True)
91 |
92 |
--------------------------------------------------------------------------------
/third_party/dust3r/croco/datasets/habitat_sim/generate_from_metadata_files.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3 |
4 | """
5 | Script generating commandlines to generate image pairs from metadata files.
6 | """
7 | import os
8 | import glob
9 | from tqdm import tqdm
10 | import argparse
11 |
12 | if __name__ == "__main__":
13 | parser = argparse.ArgumentParser()
14 | parser.add_argument("--input_dir", required=True)
15 | parser.add_argument("--output_dir", required=True)
16 | parser.add_argument("--prefix", default="", help="Commanline prefix, useful e.g. to setup environment.")
17 | args = parser.parse_args()
18 |
19 | input_metadata_filenames = glob.iglob(f"{args.input_dir}/**/metadata.json", recursive=True)
20 |
21 | for metadata_filename in tqdm(input_metadata_filenames):
22 | output_dir = os.path.join(args.output_dir, os.path.relpath(os.path.dirname(metadata_filename), args.input_dir))
23 | # Do not process the scene if the metadata file already exists
24 | if os.path.exists(os.path.join(output_dir, "metadata.json")):
25 | continue
26 | commandline = f"{args.prefix}python datasets/habitat_sim/generate_from_metadata.py --metadata_filename={metadata_filename} --output_dir={output_dir}"
27 | print(commandline)
28 |
--------------------------------------------------------------------------------
/third_party/dust3r/croco/datasets/habitat_sim/pack_metadata_files.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3 | """
4 | Utility script to pack metadata files of the dataset in order to be able to re-generate it elsewhere.
5 | """
6 | import os
7 | import glob
8 | from tqdm import tqdm
9 | import shutil
10 | import json
11 | from datasets.habitat_sim.paths import *
12 | import argparse
13 | import collections
14 |
15 | if __name__ == "__main__":
16 | parser = argparse.ArgumentParser()
17 | parser.add_argument("input_dir")
18 | parser.add_argument("output_dir")
19 | args = parser.parse_args()
20 |
21 | input_dirname = args.input_dir
22 | output_dirname = args.output_dir
23 |
24 | input_metadata_filenames = glob.iglob(f"{input_dirname}/**/metadata.json", recursive=True)
25 |
26 | images_count = collections.defaultdict(lambda : 0)
27 |
28 | os.makedirs(output_dirname)
29 | for input_filename in tqdm(input_metadata_filenames):
30 | # Ignore empty files
31 | with open(input_filename, "r") as f:
32 | original_metadata = json.load(f)
33 | if "multiviews" not in original_metadata or len(original_metadata["multiviews"]) == 0:
34 | print("No views in", input_filename)
35 | continue
36 |
37 | relpath = os.path.relpath(input_filename, input_dirname)
38 | print(relpath)
39 |
40 | # Copy metadata, while replacing scene paths by generic keys depending on the dataset, for portability.
41 | # Data paths are sorted by decreasing length to avoid potential bugs due to paths starting by the same string pattern.
42 | scenes_dataset_paths = dict(sorted(SCENES_DATASET.items(), key=lambda x: len(x[1]), reverse=True))
43 | metadata = dict()
44 | for key, value in original_metadata.items():
45 | if key in ("scene_dataset_config_file", "scene", "navmesh") and value != "":
46 | known_path = False
47 | for dataset, dataset_path in scenes_dataset_paths.items():
48 | if value.startswith(dataset_path):
49 | value = os.path.join(dataset, os.path.relpath(value, dataset_path))
50 | known_path = True
51 | break
52 | if not known_path:
53 | raise KeyError("Unknown path:" + value)
54 | metadata[key] = value
55 |
56 | # Compile some general statistics while packing data
57 | scene_split = metadata["scene"].split("/")
58 | upper_level = "/".join(scene_split[:2]) if scene_split[0] == "hm3d" else scene_split[0]
59 | images_count[upper_level] += len(metadata["multiviews"])
60 |
61 | output_filename = os.path.join(output_dirname, relpath)
62 | os.makedirs(os.path.dirname(output_filename), exist_ok=True)
63 | with open(output_filename, "w") as f:
64 | json.dump(metadata, f)
65 |
66 | # Print statistics
67 | print("Images count:")
68 | for upper_level, count in images_count.items():
69 | print(f"- {upper_level}: {count}")
--------------------------------------------------------------------------------
/third_party/dust3r/croco/datasets/habitat_sim/paths.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3 |
4 | """
5 | Paths to Habitat-Sim scenes
6 | """
7 |
8 | import os
9 | import json
10 | import collections
11 | from tqdm import tqdm
12 |
13 |
14 | # Hardcoded path to the different scene datasets
15 | SCENES_DATASET = {
16 | "hm3d": "./data/habitat-sim-data/scene_datasets/hm3d/",
17 | "gibson": "./data/habitat-sim-data/scene_datasets/gibson/",
18 | "habitat-test-scenes": "./data/habitat-sim/scene_datasets/habitat-test-scenes/",
19 | "replica_cad_baked_lighting": "./data/habitat-sim/scene_datasets/replica_cad_baked_lighting/",
20 | "replica_cad": "./data/habitat-sim/scene_datasets/replica_cad/",
21 | "replica": "./data/habitat-sim/scene_datasets/ReplicaDataset/",
22 | "scannet": "./data/habitat-sim/scene_datasets/scannet/"
23 | }
24 |
25 | SceneData = collections.namedtuple("SceneData", ["scene_dataset_config_file", "scene", "navmesh", "output_dir"])
26 |
27 | def list_replicacad_scenes(base_output_dir, base_path=SCENES_DATASET["replica_cad"]):
28 | scene_dataset_config_file = os.path.join(base_path, "replicaCAD.scene_dataset_config.json")
29 | scenes = [f"apt_{i}" for i in range(6)] + ["empty_stage"]
30 | navmeshes = [f"navmeshes/apt_{i}_static_furniture.navmesh" for i in range(6)] + ["empty_stage.navmesh"]
31 | scenes_data = []
32 | for idx in range(len(scenes)):
33 | output_dir = os.path.join(base_output_dir, "ReplicaCAD", scenes[idx])
34 | # Add scene
35 | data = SceneData(scene_dataset_config_file=scene_dataset_config_file,
36 | scene = scenes[idx] + ".scene_instance.json",
37 | navmesh = os.path.join(base_path, navmeshes[idx]),
38 | output_dir = output_dir)
39 | scenes_data.append(data)
40 | return scenes_data
41 |
42 | def list_replica_cad_baked_lighting_scenes(base_output_dir, base_path=SCENES_DATASET["replica_cad_baked_lighting"]):
43 | scene_dataset_config_file = os.path.join(base_path, "replicaCAD_baked.scene_dataset_config.json")
44 | scenes = sum([[f"Baked_sc{i}_staging_{j:02}" for i in range(5)] for j in range(21)], [])
45 | navmeshes = ""#[f"navmeshes/apt_{i}_static_furniture.navmesh" for i in range(6)] + ["empty_stage.navmesh"]
46 | scenes_data = []
47 | for idx in range(len(scenes)):
48 | output_dir = os.path.join(base_output_dir, "replica_cad_baked_lighting", scenes[idx])
49 | data = SceneData(scene_dataset_config_file=scene_dataset_config_file,
50 | scene = scenes[idx],
51 | navmesh = "",
52 | output_dir = output_dir)
53 | scenes_data.append(data)
54 | return scenes_data
55 |
56 | def list_replica_scenes(base_output_dir, base_path):
57 | scenes_data = []
58 | for scene_id in os.listdir(base_path):
59 | scene = os.path.join(base_path, scene_id, "mesh.ply")
60 | navmesh = os.path.join(base_path, scene_id, "habitat/mesh_preseg_semantic.navmesh") # Not sure if I should use it
61 | scene_dataset_config_file = ""
62 | output_dir = os.path.join(base_output_dir, scene_id)
63 | # Add scene only if it does not exist already, or if exist_ok
64 | data = SceneData(scene_dataset_config_file = scene_dataset_config_file,
65 | scene = scene,
66 | navmesh = navmesh,
67 | output_dir = output_dir)
68 | scenes_data.append(data)
69 | return scenes_data
70 |
71 |
72 | def list_scenes(base_output_dir, base_path):
73 | """
74 | Generic method iterating through a base_path folder to find scenes.
75 | """
76 | scenes_data = []
77 | for root, dirs, files in os.walk(base_path, followlinks=True):
78 | folder_scenes_data = []
79 | for file in files:
80 | name, ext = os.path.splitext(file)
81 | if ext == ".glb":
82 | scene = os.path.join(root, name + ".glb")
83 | navmesh = os.path.join(root, name + ".navmesh")
84 | if not os.path.exists(navmesh):
85 | navmesh = ""
86 | relpath = os.path.relpath(root, base_path)
87 | output_dir = os.path.abspath(os.path.join(base_output_dir, relpath, name))
88 | data = SceneData(scene_dataset_config_file="",
89 | scene = scene,
90 | navmesh = navmesh,
91 | output_dir = output_dir)
92 | folder_scenes_data.append(data)
93 |
94 | # Specific check for HM3D:
95 | # When two meshesxxxx.basis.glb and xxxx.glb are present, use the 'basis' version.
96 | basis_scenes = [data.scene[:-len(".basis.glb")] for data in folder_scenes_data if data.scene.endswith(".basis.glb")]
97 | if len(basis_scenes) != 0:
98 | folder_scenes_data = [data for data in folder_scenes_data if not (data.scene[:-len(".glb")] in basis_scenes)]
99 |
100 | scenes_data.extend(folder_scenes_data)
101 | return scenes_data
102 |
103 | def list_scenes_available(base_output_dir, scenes_dataset_paths=SCENES_DATASET):
104 | scenes_data = []
105 |
106 | # HM3D
107 | for split in ("minival", "train", "val", "examples"):
108 | scenes_data += list_scenes(base_output_dir=os.path.join(base_output_dir, f"hm3d/{split}/"),
109 | base_path=f"{scenes_dataset_paths['hm3d']}/{split}")
110 |
111 | # Gibson
112 | scenes_data += list_scenes(base_output_dir=os.path.join(base_output_dir, "gibson"),
113 | base_path=scenes_dataset_paths["gibson"])
114 |
115 | # Habitat test scenes (just a few)
116 | scenes_data += list_scenes(base_output_dir=os.path.join(base_output_dir, "habitat-test-scenes"),
117 | base_path=scenes_dataset_paths["habitat-test-scenes"])
118 |
119 | # ReplicaCAD (baked lightning)
120 | scenes_data += list_replica_cad_baked_lighting_scenes(base_output_dir=base_output_dir)
121 |
122 | # ScanNet
123 | scenes_data += list_scenes(base_output_dir=os.path.join(base_output_dir, "scannet"),
124 | base_path=scenes_dataset_paths["scannet"])
125 |
126 | # Replica
127 | list_replica_scenes(base_output_dir=os.path.join(base_output_dir, "replica"),
128 | base_path=scenes_dataset_paths["replica"])
129 | return scenes_data
130 |
--------------------------------------------------------------------------------
/third_party/dust3r/croco/datasets/pairs_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3 |
4 | import os
5 | from torch.utils.data import Dataset
6 | from PIL import Image
7 |
8 | from datasets.transforms import get_pair_transforms
9 |
10 | def load_image(impath):
11 | return Image.open(impath)
12 |
13 | def load_pairs_from_cache_file(fname, root=''):
14 | assert os.path.isfile(fname), "cannot parse pairs from {:s}, file does not exist".format(fname)
15 | with open(fname, 'r') as fid:
16 | lines = fid.read().strip().splitlines()
17 | pairs = [ (os.path.join(root,l.split()[0]), os.path.join(root,l.split()[1])) for l in lines]
18 | return pairs
19 |
20 | def load_pairs_from_list_file(fname, root=''):
21 | assert os.path.isfile(fname), "cannot parse pairs from {:s}, file does not exist".format(fname)
22 | with open(fname, 'r') as fid:
23 | lines = fid.read().strip().splitlines()
24 | pairs = [ (os.path.join(root,l+'_1.jpg'), os.path.join(root,l+'_2.jpg')) for l in lines if not l.startswith('#')]
25 | return pairs
26 |
27 |
28 | def write_cache_file(fname, pairs, root=''):
29 | if len(root)>0:
30 | if not root.endswith('/'): root+='/'
31 | assert os.path.isdir(root)
32 | s = ''
33 | for im1, im2 in pairs:
34 | if len(root)>0:
35 | assert im1.startswith(root), im1
36 | assert im2.startswith(root), im2
37 | s += '{:s} {:s}\n'.format(im1[len(root):], im2[len(root):])
38 | with open(fname, 'w') as fid:
39 | fid.write(s[:-1])
40 |
41 | def parse_and_cache_all_pairs(dname, data_dir='./data/'):
42 | if dname=='habitat_release':
43 | dirname = os.path.join(data_dir, 'habitat_release')
44 | assert os.path.isdir(dirname), "cannot find folder for habitat_release pairs: "+dirname
45 | cache_file = os.path.join(dirname, 'pairs.txt')
46 | assert not os.path.isfile(cache_file), "cache file already exists: "+cache_file
47 |
48 | print('Parsing pairs for dataset: '+dname)
49 | pairs = []
50 | for root, dirs, files in os.walk(dirname):
51 | if 'val' in root: continue
52 | dirs.sort()
53 | pairs += [ (os.path.join(root,f), os.path.join(root,f[:-len('_1.jpeg')]+'_2.jpeg')) for f in sorted(files) if f.endswith('_1.jpeg')]
54 | print('Found {:,} pairs'.format(len(pairs)))
55 | print('Writing cache to: '+cache_file)
56 | write_cache_file(cache_file, pairs, root=dirname)
57 |
58 | else:
59 | raise NotImplementedError('Unknown dataset: '+dname)
60 |
61 | def dnames_to_image_pairs(dnames, data_dir='./data/'):
62 | """
63 | dnames: list of datasets with image pairs, separated by +
64 | """
65 | all_pairs = []
66 | for dname in dnames.split('+'):
67 | if dname=='habitat_release':
68 | dirname = os.path.join(data_dir, 'habitat_release')
69 | assert os.path.isdir(dirname), "cannot find folder for habitat_release pairs: "+dirname
70 | cache_file = os.path.join(dirname, 'pairs.txt')
71 | assert os.path.isfile(cache_file), "cannot find cache file for habitat_release pairs, please first create the cache file, see instructions. "+cache_file
72 | pairs = load_pairs_from_cache_file(cache_file, root=dirname)
73 | elif dname in ['ARKitScenes', 'MegaDepth', '3DStreetView', 'IndoorVL']:
74 | dirname = os.path.join(data_dir, dname+'_crops')
75 | assert os.path.isdir(dirname), "cannot find folder for {:s} pairs: {:s}".format(dname, dirname)
76 | list_file = os.path.join(dirname, 'listing.txt')
77 | assert os.path.isfile(list_file), "cannot find list file for {:s} pairs, see instructions. {:s}".format(dname, list_file)
78 | pairs = load_pairs_from_list_file(list_file, root=dirname)
79 | print(' {:s}: {:,} pairs'.format(dname, len(pairs)))
80 | all_pairs += pairs
81 | if '+' in dnames: print(' Total: {:,} pairs'.format(len(all_pairs)))
82 | return all_pairs
83 |
84 |
85 | class PairsDataset(Dataset):
86 |
87 | def __init__(self, dnames, trfs='', totensor=True, normalize=True, data_dir='./data/'):
88 | super().__init__()
89 | self.image_pairs = dnames_to_image_pairs(dnames, data_dir=data_dir)
90 | self.transforms = get_pair_transforms(transform_str=trfs, totensor=totensor, normalize=normalize)
91 |
92 | def __len__(self):
93 | return len(self.image_pairs)
94 |
95 | def __getitem__(self, index):
96 | im1path, im2path = self.image_pairs[index]
97 | im1 = load_image(im1path)
98 | im2 = load_image(im2path)
99 | if self.transforms is not None: im1, im2 = self.transforms(im1, im2)
100 | return im1, im2
101 |
102 |
103 | if __name__=="__main__":
104 | import argparse
105 | parser = argparse.ArgumentParser(prog="Computing and caching list of pairs for a given dataset")
106 | parser.add_argument('--data_dir', default='./data/', type=str, help="path where data are stored")
107 | parser.add_argument('--dataset', default='habitat_release', type=str, help="name of the dataset")
108 | args = parser.parse_args()
109 | parse_and_cache_all_pairs(dname=args.dataset, data_dir=args.data_dir)
110 |
--------------------------------------------------------------------------------
/third_party/dust3r/croco/datasets/transforms.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3 |
4 | import torch
5 | import torchvision.transforms
6 | import torchvision.transforms.functional as F
7 |
8 | # "Pair": apply a transform on a pair
9 | # "Both": apply the exact same transform to both images
10 |
11 | class ComposePair(torchvision.transforms.Compose):
12 | def __call__(self, img1, img2):
13 | for t in self.transforms:
14 | img1, img2 = t(img1, img2)
15 | return img1, img2
16 |
17 | class NormalizeBoth(torchvision.transforms.Normalize):
18 | def forward(self, img1, img2):
19 | img1 = super().forward(img1)
20 | img2 = super().forward(img2)
21 | return img1, img2
22 |
23 | class ToTensorBoth(torchvision.transforms.ToTensor):
24 | def __call__(self, img1, img2):
25 | img1 = super().__call__(img1)
26 | img2 = super().__call__(img2)
27 | return img1, img2
28 |
29 | class RandomCropPair(torchvision.transforms.RandomCrop):
30 | # the crop will be intentionally different for the two images with this class
31 | def forward(self, img1, img2):
32 | img1 = super().forward(img1)
33 | img2 = super().forward(img2)
34 | return img1, img2
35 |
36 | class ColorJitterPair(torchvision.transforms.ColorJitter):
37 | # can be symmetric (same for both images) or assymetric (different jitter params for each image) depending on assymetric_prob
38 | def __init__(self, assymetric_prob, **kwargs):
39 | super().__init__(**kwargs)
40 | self.assymetric_prob = assymetric_prob
41 | def jitter_one(self, img, fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor):
42 | for fn_id in fn_idx:
43 | if fn_id == 0 and brightness_factor is not None:
44 | img = F.adjust_brightness(img, brightness_factor)
45 | elif fn_id == 1 and contrast_factor is not None:
46 | img = F.adjust_contrast(img, contrast_factor)
47 | elif fn_id == 2 and saturation_factor is not None:
48 | img = F.adjust_saturation(img, saturation_factor)
49 | elif fn_id == 3 and hue_factor is not None:
50 | img = F.adjust_hue(img, hue_factor)
51 | return img
52 |
53 | def forward(self, img1, img2):
54 |
55 | fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor = self.get_params(
56 | self.brightness, self.contrast, self.saturation, self.hue
57 | )
58 | img1 = self.jitter_one(img1, fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor)
59 | if torch.rand(1) < self.assymetric_prob: # assymetric:
60 | fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor = self.get_params(
61 | self.brightness, self.contrast, self.saturation, self.hue
62 | )
63 | img2 = self.jitter_one(img2, fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor)
64 | return img1, img2
65 |
66 | def get_pair_transforms(transform_str, totensor=True, normalize=True):
67 | # transform_str is eg crop224+color
68 | trfs = []
69 | for s in transform_str.split('+'):
70 | if s.startswith('crop'):
71 | size = int(s[len('crop'):])
72 | trfs.append(RandomCropPair(size))
73 | elif s=='acolor':
74 | trfs.append(ColorJitterPair(assymetric_prob=1.0, brightness=(0.6, 1.4), contrast=(0.6, 1.4), saturation=(0.6, 1.4), hue=0.0))
75 | elif s=='': # if transform_str was ""
76 | pass
77 | else:
78 | raise NotImplementedError('Unknown augmentation: '+s)
79 |
80 | if totensor:
81 | trfs.append( ToTensorBoth() )
82 | if normalize:
83 | trfs.append( NormalizeBoth(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) )
84 |
85 | if len(trfs)==0:
86 | return None
87 | elif len(trfs)==1:
88 | return trfs
89 | else:
90 | return ComposePair(trfs)
91 |
92 |
93 |
94 |
95 |
96 |
--------------------------------------------------------------------------------
/third_party/dust3r/croco/demo.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3 |
4 | import torch
5 | from models.croco import CroCoNet
6 | from PIL import Image
7 | import torchvision.transforms
8 | from torchvision.transforms import ToTensor, Normalize, Compose
9 |
10 | def main():
11 | device = torch.device('cuda:0' if torch.cuda.is_available() and torch.cuda.device_count()>0 else 'cpu')
12 |
13 | # load 224x224 images and transform them to tensor
14 | imagenet_mean = [0.485, 0.456, 0.406]
15 | imagenet_mean_tensor = torch.tensor(imagenet_mean).view(1,3,1,1).to(device, non_blocking=True)
16 | imagenet_std = [0.229, 0.224, 0.225]
17 | imagenet_std_tensor = torch.tensor(imagenet_std).view(1,3,1,1).to(device, non_blocking=True)
18 | trfs = Compose([ToTensor(), Normalize(mean=imagenet_mean, std=imagenet_std)])
19 | image1 = trfs(Image.open('assets/Chateau1.png').convert('RGB')).to(device, non_blocking=True).unsqueeze(0)
20 | image2 = trfs(Image.open('assets/Chateau2.png').convert('RGB')).to(device, non_blocking=True).unsqueeze(0)
21 |
22 | # load model
23 | ckpt = torch.load('pretrained_models/CroCo_V2_ViTLarge_BaseDecoder.pth', 'cpu')
24 | model = CroCoNet( **ckpt.get('croco_kwargs',{})).to(device)
25 | model.eval()
26 | msg = model.load_state_dict(ckpt['model'], strict=True)
27 |
28 | # forward
29 | with torch.inference_mode():
30 | out, mask, target = model(image1, image2)
31 |
32 | # the output is normalized, thus use the mean/std of the actual image to go back to RGB space
33 | patchified = model.patchify(image1)
34 | mean = patchified.mean(dim=-1, keepdim=True)
35 | var = patchified.var(dim=-1, keepdim=True)
36 | decoded_image = model.unpatchify(out * (var + 1.e-6)**.5 + mean)
37 | # undo imagenet normalization, prepare masked image
38 | decoded_image = decoded_image * imagenet_std_tensor + imagenet_mean_tensor
39 | input_image = image1 * imagenet_std_tensor + imagenet_mean_tensor
40 | ref_image = image2 * imagenet_std_tensor + imagenet_mean_tensor
41 | image_masks = model.unpatchify(model.patchify(torch.ones_like(ref_image)) * mask[:,:,None])
42 | masked_input_image = ((1 - image_masks) * input_image)
43 |
44 | # make visualization
45 | visualization = torch.cat((ref_image, masked_input_image, decoded_image, input_image), dim=3) # 4*(B, 3, H, W) -> B, 3, H, W*4
46 | B, C, H, W = visualization.shape
47 | visualization = visualization.permute(1, 0, 2, 3).reshape(C, B*H, W)
48 | visualization = torchvision.transforms.functional.to_pil_image(torch.clamp(visualization, 0, 1))
49 | fname = "demo_output.png"
50 | visualization.save(fname)
51 | print('Visualization save in '+fname)
52 |
53 |
54 | if __name__=="__main__":
55 | main()
56 |
--------------------------------------------------------------------------------
/third_party/dust3r/croco/models/criterion.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3 | #
4 | # --------------------------------------------------------
5 | # Criterion to train CroCo
6 | # --------------------------------------------------------
7 | # References:
8 | # MAE: https://github.com/facebookresearch/mae
9 | # --------------------------------------------------------
10 |
11 | import torch
12 |
13 | class MaskedMSE(torch.nn.Module):
14 |
15 | def __init__(self, norm_pix_loss=False, masked=True):
16 | """
17 | norm_pix_loss: normalize each patch by their pixel mean and variance
18 | masked: compute loss over the masked patches only
19 | """
20 | super().__init__()
21 | self.norm_pix_loss = norm_pix_loss
22 | self.masked = masked
23 |
24 | def forward(self, pred, mask, target):
25 |
26 | if self.norm_pix_loss:
27 | mean = target.mean(dim=-1, keepdim=True)
28 | var = target.var(dim=-1, keepdim=True)
29 | target = (target - mean) / (var + 1.e-6)**.5
30 |
31 | loss = (pred - target) ** 2
32 | loss = loss.mean(dim=-1) # [N, L], mean loss per patch
33 | if self.masked:
34 | loss = (loss * mask).sum() / mask.sum() # mean loss on masked patches
35 | else:
36 | loss = loss.mean() # mean loss
37 | return loss
38 |
--------------------------------------------------------------------------------
/third_party/dust3r/croco/models/croco_downstream.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3 |
4 | # --------------------------------------------------------
5 | # CroCo model for downstream tasks
6 | # --------------------------------------------------------
7 |
8 | import torch
9 |
10 | from .croco import CroCoNet
11 |
12 |
13 | def croco_args_from_ckpt(ckpt):
14 | if 'croco_kwargs' in ckpt: # CroCo v2 released models
15 | return ckpt['croco_kwargs']
16 | elif 'args' in ckpt and hasattr(ckpt['args'], 'model'): # pretrained using the official code release
17 | s = ckpt['args'].model # eg "CroCoNet(enc_embed_dim=1024, enc_num_heads=16, enc_depth=24)"
18 | assert s.startswith('CroCoNet(')
19 | return eval('dict'+s[len('CroCoNet'):]) # transform it into the string of a dictionary and evaluate it
20 | else: # CroCo v1 released models
21 | return dict()
22 |
23 | class CroCoDownstreamMonocularEncoder(CroCoNet):
24 |
25 | def __init__(self,
26 | head,
27 | **kwargs):
28 | """ Build network for monocular downstream task, only using the encoder.
29 | It takes an extra argument head, that is called with the features
30 | and a dictionary img_info containing 'width' and 'height' keys
31 | The head is setup with the croconet arguments in this init function
32 | NOTE: It works by *calling super().__init__() but with redefined setters
33 |
34 | """
35 | super(CroCoDownstreamMonocularEncoder, self).__init__(**kwargs)
36 | head.setup(self)
37 | self.head = head
38 |
39 | def _set_mask_generator(self, *args, **kwargs):
40 | """ No mask generator """
41 | return
42 |
43 | def _set_mask_token(self, *args, **kwargs):
44 | """ No mask token """
45 | self.mask_token = None
46 | return
47 |
48 | def _set_decoder(self, *args, **kwargs):
49 | """ No decoder """
50 | return
51 |
52 | def _set_prediction_head(self, *args, **kwargs):
53 | """ No 'prediction head' for downstream tasks."""
54 | return
55 |
56 | def forward(self, img):
57 | """
58 | img if of size batch_size x 3 x h x w
59 | """
60 | B, C, H, W = img.size()
61 | img_info = {'height': H, 'width': W}
62 | need_all_layers = hasattr(self.head, 'return_all_blocks') and self.head.return_all_blocks
63 | out, _, _ = self._encode_image(img, do_mask=False, return_all_blocks=need_all_layers)
64 | return self.head(out, img_info)
65 |
66 |
67 | class CroCoDownstreamBinocular(CroCoNet):
68 |
69 | def __init__(self,
70 | head,
71 | **kwargs):
72 | """ Build network for binocular downstream task
73 | It takes an extra argument head, that is called with the features
74 | and a dictionary img_info containing 'width' and 'height' keys
75 | The head is setup with the croconet arguments in this init function
76 | """
77 | super(CroCoDownstreamBinocular, self).__init__(**kwargs)
78 | head.setup(self)
79 | self.head = head
80 |
81 | def _set_mask_generator(self, *args, **kwargs):
82 | """ No mask generator """
83 | return
84 |
85 | def _set_mask_token(self, *args, **kwargs):
86 | """ No mask token """
87 | self.mask_token = None
88 | return
89 |
90 | def _set_prediction_head(self, *args, **kwargs):
91 | """ No prediction head for downstream tasks, define your own head """
92 | return
93 |
94 | def encode_image_pairs(self, img1, img2, return_all_blocks=False):
95 | """ run encoder for a pair of images
96 | it is actually ~5% faster to concatenate the images along the batch dimension
97 | than to encode them separately
98 | """
99 | ## the two commented lines below is the naive version with separate encoding
100 | #out, pos, _ = self._encode_image(img1, do_mask=False, return_all_blocks=return_all_blocks)
101 | #out2, pos2, _ = self._encode_image(img2, do_mask=False, return_all_blocks=False)
102 | ## and now the faster version
103 | out, pos, _ = self._encode_image( torch.cat( (img1,img2), dim=0), do_mask=False, return_all_blocks=return_all_blocks )
104 | if return_all_blocks:
105 | out,out2 = list(map(list, zip(*[o.chunk(2, dim=0) for o in out])))
106 | out2 = out2[-1]
107 | else:
108 | out,out2 = out.chunk(2, dim=0)
109 | pos,pos2 = pos.chunk(2, dim=0)
110 | return out, out2, pos, pos2
111 |
112 | def forward(self, img1, img2):
113 | B, C, H, W = img1.size()
114 | img_info = {'height': H, 'width': W}
115 | return_all_blocks = hasattr(self.head, 'return_all_blocks') and self.head.return_all_blocks
116 | out, out2, pos, pos2 = self.encode_image_pairs(img1, img2, return_all_blocks=return_all_blocks)
117 | if return_all_blocks:
118 | decout = self._decoder(out[-1], pos, None, out2, pos2, return_all_blocks=return_all_blocks)
119 | decout = out+decout
120 | else:
121 | decout = self._decoder(out, pos, None, out2, pos2, return_all_blocks=return_all_blocks)
122 | return self.head(decout, img_info)
--------------------------------------------------------------------------------
/third_party/dust3r/croco/models/curope/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3 |
4 | from .curope2d import cuRoPE2D
5 |
--------------------------------------------------------------------------------
/third_party/dust3r/croco/models/curope/build/temp.linux-x86_64-cpython-38/build.ninja:
--------------------------------------------------------------------------------
1 | ninja_required_version = 1.3
2 | cxx = g++
3 | nvcc = /usr/local/cuda/bin/nvcc
4 |
5 | cflags = -pthread -B /root/anaconda3/envs/3studio/compiler_compat -Wl,--sysroot=/ -Wsign-compare -DNDEBUG -g -fwrapv -O3 -Wall -Wstrict-prototypes -fPIC -I/root/anaconda3/envs/3studio/lib/python3.8/site-packages/torch/include -I/root/anaconda3/envs/3studio/lib/python3.8/site-packages/torch/include/torch/csrc/api/include -I/root/anaconda3/envs/3studio/lib/python3.8/site-packages/torch/include/TH -I/root/anaconda3/envs/3studio/lib/python3.8/site-packages/torch/include/THC -I/usr/local/cuda/include -I/root/anaconda3/envs/3studio/include/python3.8 -c
6 | post_cflags = -O3 -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1011"' -DTORCH_EXTENSION_NAME=curope -D_GLIBCXX_USE_CXX11_ABI=0 -std=c++17
7 | cuda_cflags = -I/root/anaconda3/envs/3studio/lib/python3.8/site-packages/torch/include -I/root/anaconda3/envs/3studio/lib/python3.8/site-packages/torch/include/torch/csrc/api/include -I/root/anaconda3/envs/3studio/lib/python3.8/site-packages/torch/include/TH -I/root/anaconda3/envs/3studio/lib/python3.8/site-packages/torch/include/THC -I/usr/local/cuda/include -I/root/anaconda3/envs/3studio/include/python3.8 -c
8 | cuda_post_cflags = -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options ''"'"'-fPIC'"'"'' -O3 --ptxas-options=-v --use_fast_math -gencode arch=compute_37,code=sm_37 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_70,code=sm_70 -gencode arch=compute_75,code=sm_75 -gencode arch=compute_80,code=sm_80 -gencode arch=compute_86,code=sm_86 -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1011"' -DTORCH_EXTENSION_NAME=curope -D_GLIBCXX_USE_CXX11_ABI=0 -std=c++17
9 | cuda_dlink_post_cflags =
10 | ldflags =
11 |
12 | rule compile
13 | command = $cxx -MMD -MF $out.d $cflags -c $in -o $out $post_cflags
14 | depfile = $out.d
15 | deps = gcc
16 |
17 | rule cuda_compile
18 | depfile = $out.d
19 | deps = gcc
20 | command = $nvcc $cuda_cflags -c $in -o $out $cuda_post_cflags
21 |
22 |
23 |
24 |
25 |
26 | build /apdcephfs_cq10/share_1290939/karmyu/dust3r-gaussian-splatting/dust3r/croco/models/curope/build/temp.linux-x86_64-cpython-38/curope.o: compile /apdcephfs_cq10/share_1290939/karmyu/dust3r-gaussian-splatting/dust3r/croco/models/curope/curope.cpp
27 | build /apdcephfs_cq10/share_1290939/karmyu/dust3r-gaussian-splatting/dust3r/croco/models/curope/build/temp.linux-x86_64-cpython-38/kernels.o: cuda_compile /apdcephfs_cq10/share_1290939/karmyu/dust3r-gaussian-splatting/dust3r/croco/models/curope/kernels.cu
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
--------------------------------------------------------------------------------
/third_party/dust3r/croco/models/curope/curope.cpp:
--------------------------------------------------------------------------------
1 | /*
2 | Copyright (C) 2022-present Naver Corporation. All rights reserved.
3 | Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
4 | */
5 |
6 | #include
7 |
8 | // forward declaration
9 | void rope_2d_cuda( torch::Tensor tokens, const torch::Tensor pos, const float base, const float fwd );
10 |
11 | void rope_2d_cpu( torch::Tensor tokens, const torch::Tensor positions, const float base, const float fwd )
12 | {
13 | const int B = tokens.size(0);
14 | const int N = tokens.size(1);
15 | const int H = tokens.size(2);
16 | const int D = tokens.size(3) / 4;
17 |
18 | auto tok = tokens.accessor();
19 | auto pos = positions.accessor();
20 |
21 | for (int b = 0; b < B; b++) {
22 | for (int x = 0; x < 2; x++) { // y and then x (2d)
23 | for (int n = 0; n < N; n++) {
24 |
25 | // grab the token position
26 | const int p = pos[b][n][x];
27 |
28 | for (int h = 0; h < H; h++) {
29 | for (int d = 0; d < D; d++) {
30 | // grab the two values
31 | float u = tok[b][n][h][d+0+x*2*D];
32 | float v = tok[b][n][h][d+D+x*2*D];
33 |
34 | // grab the cos,sin
35 | const float inv_freq = fwd * p / powf(base, d/float(D));
36 | float c = cosf(inv_freq);
37 | float s = sinf(inv_freq);
38 |
39 | // write the result
40 | tok[b][n][h][d+0+x*2*D] = u*c - v*s;
41 | tok[b][n][h][d+D+x*2*D] = v*c + u*s;
42 | }
43 | }
44 | }
45 | }
46 | }
47 | }
48 |
49 | void rope_2d( torch::Tensor tokens, // B,N,H,D
50 | const torch::Tensor positions, // B,N,2
51 | const float base,
52 | const float fwd )
53 | {
54 | TORCH_CHECK(tokens.dim() == 4, "tokens must have 4 dimensions");
55 | TORCH_CHECK(positions.dim() == 3, "positions must have 3 dimensions");
56 | TORCH_CHECK(tokens.size(0) == positions.size(0), "batch size differs between tokens & positions");
57 | TORCH_CHECK(tokens.size(1) == positions.size(1), "seq_length differs between tokens & positions");
58 | TORCH_CHECK(positions.size(2) == 2, "positions.shape[2] must be equal to 2");
59 | TORCH_CHECK(tokens.is_cuda() == positions.is_cuda(), "tokens and positions are not on the same device" );
60 |
61 | if (tokens.is_cuda())
62 | rope_2d_cuda( tokens, positions, base, fwd );
63 | else
64 | rope_2d_cpu( tokens, positions, base, fwd );
65 | }
66 |
67 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
68 | m.def("rope_2d", &rope_2d, "RoPE 2d forward/backward");
69 | }
70 |
--------------------------------------------------------------------------------
/third_party/dust3r/croco/models/curope/curope2d.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3 |
4 | import torch
5 |
6 | try:
7 | import curope as _kernels # run `python setup.py install`
8 | except ModuleNotFoundError:
9 | from . import curope as _kernels # run `python setup.py build_ext --inplace`
10 |
11 |
12 | class cuRoPE2D_func (torch.autograd.Function):
13 |
14 | @staticmethod
15 | def forward(ctx, tokens, positions, base, F0=1):
16 | ctx.save_for_backward(positions)
17 | ctx.saved_base = base
18 | ctx.saved_F0 = F0
19 | # tokens = tokens.clone() # uncomment this if inplace doesn't work
20 | _kernels.rope_2d( tokens, positions, base, F0 )
21 | ctx.mark_dirty(tokens)
22 | return tokens
23 |
24 | @staticmethod
25 | def backward(ctx, grad_res):
26 | positions, base, F0 = ctx.saved_tensors[0], ctx.saved_base, ctx.saved_F0
27 | _kernels.rope_2d( grad_res, positions, base, -F0 )
28 | ctx.mark_dirty(grad_res)
29 | return grad_res, None, None, None
30 |
31 |
32 | class cuRoPE2D(torch.nn.Module):
33 | def __init__(self, freq=100.0, F0=1.0):
34 | super().__init__()
35 | self.base = freq
36 | self.F0 = F0
37 |
38 | def forward(self, tokens, positions):
39 | cuRoPE2D_func.apply( tokens.transpose(1,2), positions, self.base, self.F0 )
40 | return tokens
--------------------------------------------------------------------------------
/third_party/dust3r/croco/models/curope/kernels.cu:
--------------------------------------------------------------------------------
1 | /*
2 | Copyright (C) 2022-present Naver Corporation. All rights reserved.
3 | Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
4 | */
5 |
6 | #include
7 | #include
8 | #include
9 | #include
10 |
11 | #define CHECK_CUDA(tensor) {\
12 | TORCH_CHECK((tensor).is_cuda(), #tensor " is not in cuda memory"); \
13 | TORCH_CHECK((tensor).is_contiguous(), #tensor " is not contiguous"); }
14 | void CHECK_KERNEL() {auto error = cudaGetLastError(); TORCH_CHECK( error == cudaSuccess, cudaGetErrorString(error));}
15 |
16 |
17 | template < typename scalar_t >
18 | __global__ void rope_2d_cuda_kernel(
19 | //scalar_t* __restrict__ tokens,
20 | torch::PackedTensorAccessor32 tokens,
21 | const int64_t* __restrict__ pos,
22 | const float base,
23 | const float fwd )
24 | // const int N, const int H, const int D )
25 | {
26 | // tokens shape = (B, N, H, D)
27 | const int N = tokens.size(1);
28 | const int H = tokens.size(2);
29 | const int D = tokens.size(3);
30 |
31 | // each block update a single token, for all heads
32 | // each thread takes care of a single output
33 | extern __shared__ float shared[];
34 | float* shared_inv_freq = shared + D;
35 |
36 | const int b = blockIdx.x / N;
37 | const int n = blockIdx.x % N;
38 |
39 | const int Q = D / 4;
40 | // one token = [0..Q : Q..2Q : 2Q..3Q : 3Q..D]
41 | // u_Y v_Y u_X v_X
42 |
43 | // shared memory: first, compute inv_freq
44 | if (threadIdx.x < Q)
45 | shared_inv_freq[threadIdx.x] = fwd / powf(base, threadIdx.x/float(Q));
46 | __syncthreads();
47 |
48 | // start of X or Y part
49 | const int X = threadIdx.x < D/2 ? 0 : 1;
50 | const int m = (X*D/2) + (threadIdx.x % Q); // index of u_Y or u_X
51 |
52 | // grab the cos,sin appropriate for me
53 | const float freq = pos[blockIdx.x*2+X] * shared_inv_freq[threadIdx.x % Q];
54 | const float cos = cosf(freq);
55 | const float sin = sinf(freq);
56 | /*
57 | float* shared_cos_sin = shared + D + D/4;
58 | if ((threadIdx.x % (D/2)) < Q)
59 | shared_cos_sin[m+0] = cosf(freq);
60 | else
61 | shared_cos_sin[m+Q] = sinf(freq);
62 | __syncthreads();
63 | const float cos = shared_cos_sin[m+0];
64 | const float sin = shared_cos_sin[m+Q];
65 | */
66 |
67 | for (int h = 0; h < H; h++)
68 | {
69 | // then, load all the token for this head in shared memory
70 | shared[threadIdx.x] = tokens[b][n][h][threadIdx.x];
71 | __syncthreads();
72 |
73 | const float u = shared[m];
74 | const float v = shared[m+Q];
75 |
76 | // write output
77 | if ((threadIdx.x % (D/2)) < Q)
78 | tokens[b][n][h][threadIdx.x] = u*cos - v*sin;
79 | else
80 | tokens[b][n][h][threadIdx.x] = v*cos + u*sin;
81 | }
82 | }
83 |
84 | void rope_2d_cuda( torch::Tensor tokens, const torch::Tensor pos, const float base, const float fwd )
85 | {
86 | const int B = tokens.size(0); // batch size
87 | const int N = tokens.size(1); // sequence length
88 | const int H = tokens.size(2); // number of heads
89 | const int D = tokens.size(3); // dimension per head
90 |
91 | TORCH_CHECK(tokens.stride(3) == 1 && tokens.stride(2) == D, "tokens are not contiguous");
92 | TORCH_CHECK(pos.is_contiguous(), "positions are not contiguous");
93 | TORCH_CHECK(pos.size(0) == B && pos.size(1) == N && pos.size(2) == 2, "bad pos.shape");
94 | TORCH_CHECK(D % 4 == 0, "token dim must be multiple of 4");
95 |
96 | // one block for each layer, one thread per local-max
97 | const int THREADS_PER_BLOCK = D;
98 | const int N_BLOCKS = B * N; // each block takes care of H*D values
99 | const int SHARED_MEM = sizeof(float) * (D + D/4);
100 |
101 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(tokens.type(), "rope_2d_cuda", ([&] {
102 | rope_2d_cuda_kernel <<>> (
103 | //tokens.data_ptr(),
104 | tokens.packed_accessor32(),
105 | pos.data_ptr(),
106 | base, fwd); //, N, H, D );
107 | }));
108 | }
109 |
--------------------------------------------------------------------------------
/third_party/dust3r/croco/models/curope/setup.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3 |
4 | from setuptools import setup
5 | from torch import cuda
6 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension
7 |
8 | # compile for all possible CUDA architectures
9 | all_cuda_archs = cuda.get_gencode_flags().replace('compute=','arch=').split()
10 | # alternatively, you can list cuda archs that you want, eg:
11 | # all_cuda_archs = [
12 | # '-gencode', 'arch=compute_70,code=sm_70',
13 | # '-gencode', 'arch=compute_75,code=sm_75',
14 | # '-gencode', 'arch=compute_80,code=sm_80',
15 | # '-gencode', 'arch=compute_86,code=sm_86'
16 | # ]
17 |
18 | setup(
19 | name = 'curope',
20 | ext_modules = [
21 | CUDAExtension(
22 | name='curope',
23 | sources=[
24 | "curope.cpp",
25 | "kernels.cu",
26 | ],
27 | extra_compile_args = dict(
28 | nvcc=['-O3','--ptxas-options=-v',"--use_fast_math"]+all_cuda_archs,
29 | cxx=['-O3'])
30 | )
31 | ],
32 | cmdclass = {
33 | 'build_ext': BuildExtension
34 | })
35 |
--------------------------------------------------------------------------------
/third_party/dust3r/croco/models/head_downstream.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3 |
4 | # --------------------------------------------------------
5 | # Heads for downstream tasks
6 | # --------------------------------------------------------
7 |
8 | """
9 | A head is a module where the __init__ defines only the head hyperparameters.
10 | A method setup(croconet) takes a CroCoNet and set all layers according to the head and croconet attributes.
11 | The forward takes the features as well as a dictionary img_info containing the keys 'width' and 'height'
12 | """
13 |
14 | import torch
15 | import torch.nn as nn
16 | from .dpt_block import DPTOutputAdapter
17 |
18 |
19 | class PixelwiseTaskWithDPT(nn.Module):
20 | """ DPT module for CroCo.
21 | by default, hooks_idx will be equal to:
22 | * for encoder-only: 4 equally spread layers
23 | * for encoder+decoder: last encoder + 3 equally spread layers of the decoder
24 | """
25 |
26 | def __init__(self, *, hooks_idx=None, layer_dims=[96,192,384,768],
27 | output_width_ratio=1, num_channels=1, postprocess=None, **kwargs):
28 | super(PixelwiseTaskWithDPT, self).__init__()
29 | self.return_all_blocks = True # backbone needs to return all layers
30 | self.postprocess = postprocess
31 | self.output_width_ratio = output_width_ratio
32 | self.num_channels = num_channels
33 | self.hooks_idx = hooks_idx
34 | self.layer_dims = layer_dims
35 |
36 | def setup(self, croconet):
37 | dpt_args = {'output_width_ratio': self.output_width_ratio, 'num_channels': self.num_channels}
38 | if self.hooks_idx is None:
39 | if hasattr(croconet, 'dec_blocks'): # encoder + decoder
40 | step = {8: 3, 12: 4, 24: 8}[croconet.dec_depth]
41 | hooks_idx = [croconet.dec_depth+croconet.enc_depth-1-i*step for i in range(3,-1,-1)]
42 | else: # encoder only
43 | step = croconet.enc_depth//4
44 | hooks_idx = [croconet.enc_depth-1-i*step for i in range(3,-1,-1)]
45 | self.hooks_idx = hooks_idx
46 | print(f' PixelwiseTaskWithDPT: automatically setting hook_idxs={self.hooks_idx}')
47 | dpt_args['hooks'] = self.hooks_idx
48 | dpt_args['layer_dims'] = self.layer_dims
49 | self.dpt = DPTOutputAdapter(**dpt_args)
50 | dim_tokens = [croconet.enc_embed_dim if hook0:
36 | pos_embed = np.concatenate([np.zeros([n_cls_token, embed_dim]), pos_embed], axis=0)
37 | return pos_embed
38 |
39 |
40 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
41 | assert embed_dim % 2 == 0
42 |
43 | # use half of dimensions to encode grid_h
44 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
45 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
46 |
47 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
48 | return emb
49 |
50 |
51 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
52 | """
53 | embed_dim: output dimension for each position
54 | pos: a list of positions to be encoded: size (M,)
55 | out: (M, D)
56 | """
57 | assert embed_dim % 2 == 0
58 | omega = np.arange(embed_dim // 2, dtype=float)
59 | omega /= embed_dim / 2.
60 | omega = 1. / 10000**omega # (D/2,)
61 |
62 | pos = pos.reshape(-1) # (M,)
63 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
64 |
65 | emb_sin = np.sin(out) # (M, D/2)
66 | emb_cos = np.cos(out) # (M, D/2)
67 |
68 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
69 | return emb
70 |
71 |
72 | # --------------------------------------------------------
73 | # Interpolate position embeddings for high-resolution
74 | # References:
75 | # MAE: https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
76 | # DeiT: https://github.com/facebookresearch/deit
77 | # --------------------------------------------------------
78 | def interpolate_pos_embed(model, checkpoint_model):
79 | if 'pos_embed' in checkpoint_model:
80 | pos_embed_checkpoint = checkpoint_model['pos_embed']
81 | embedding_size = pos_embed_checkpoint.shape[-1]
82 | num_patches = model.patch_embed.num_patches
83 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches
84 | # height (== width) for the checkpoint position embedding
85 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
86 | # height (== width) for the new position embedding
87 | new_size = int(num_patches ** 0.5)
88 | # class_token and dist_token are kept unchanged
89 | if orig_size != new_size:
90 | print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
91 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
92 | # only the position tokens are interpolated
93 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
94 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
95 | pos_tokens = torch.nn.functional.interpolate(
96 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
97 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
98 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
99 | checkpoint_model['pos_embed'] = new_pos_embed
100 |
101 |
102 | #----------------------------------------------------------
103 | # RoPE2D: RoPE implementation in 2D
104 | #----------------------------------------------------------
105 |
106 | try:
107 | from models.curope import cuRoPE2D
108 | RoPE2D = cuRoPE2D
109 | except ImportError:
110 | print('Warning, cannot find cuda-compiled version of RoPE2D, using a slow pytorch version instead')
111 |
112 | class RoPE2D(torch.nn.Module):
113 |
114 | def __init__(self, freq=100.0, F0=1.0):
115 | super().__init__()
116 | self.base = freq
117 | self.F0 = F0
118 | self.cache = {}
119 |
120 | def get_cos_sin(self, D, seq_len, device, dtype):
121 | if (D,seq_len,device,dtype) not in self.cache:
122 | inv_freq = 1.0 / (self.base ** (torch.arange(0, D, 2).float().to(device) / D))
123 | t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype)
124 | freqs = torch.einsum("i,j->ij", t, inv_freq).to(dtype)
125 | freqs = torch.cat((freqs, freqs), dim=-1)
126 | cos = freqs.cos() # (Seq, Dim)
127 | sin = freqs.sin()
128 | self.cache[D,seq_len,device,dtype] = (cos,sin)
129 | return self.cache[D,seq_len,device,dtype]
130 |
131 | @staticmethod
132 | def rotate_half(x):
133 | x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
134 | return torch.cat((-x2, x1), dim=-1)
135 |
136 | def apply_rope1d(self, tokens, pos1d, cos, sin):
137 | assert pos1d.ndim==2
138 | cos = torch.nn.functional.embedding(pos1d, cos)[:, None, :, :]
139 | sin = torch.nn.functional.embedding(pos1d, sin)[:, None, :, :]
140 | return (tokens * cos) + (self.rotate_half(tokens) * sin)
141 |
142 | def forward(self, tokens, positions):
143 | """
144 | input:
145 | * tokens: batch_size x nheads x ntokens x dim
146 | * positions: batch_size x ntokens x 2 (y and x position of each token)
147 | output:
148 | * tokens after appplying RoPE2D (batch_size x nheads x ntokens x dim)
149 | """
150 | assert tokens.size(3)%2==0, "number of dimensions should be a multiple of two"
151 | D = tokens.size(3) // 2
152 | assert positions.ndim==3 and positions.shape[-1] == 2 # Batch, Seq, 2
153 | cos, sin = self.get_cos_sin(D, int(positions.max())+1, tokens.device, tokens.dtype)
154 | # split features into two along the feature dimension, and apply rope1d on each half
155 | y, x = tokens.chunk(2, dim=-1)
156 | y = self.apply_rope1d(y, positions[:,:,0], cos, sin)
157 | x = self.apply_rope1d(x, positions[:,:,1], cos, sin)
158 | tokens = torch.cat((y, x), dim=-1)
159 | return tokens
--------------------------------------------------------------------------------
/third_party/dust3r/croco/stereoflow/download_model.sh:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3 |
4 | model=$1
5 | outfile="stereoflow_models/${model}"
6 | if [[ ! -f $outfile ]]
7 | then
8 | mkdir -p stereoflow_models/;
9 | wget https://download.europe.naverlabs.com/ComputerVision/CroCo/StereoFlow_models/$1 -P stereoflow_models/;
10 | else
11 | echo "Model ${model} already downloaded in ${outfile}."
12 | fi
--------------------------------------------------------------------------------
/third_party/dust3r/datasets_preprocess/path_to_root.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3 | #
4 | # --------------------------------------------------------
5 | # DUSt3R repo root import
6 | # --------------------------------------------------------
7 |
8 | import sys
9 | import os.path as path
10 | HERE_PATH = path.normpath(path.dirname(__file__))
11 | DUST3R_REPO_PATH = path.normpath(path.join(HERE_PATH, '../'))
12 | # workaround for sibling import
13 | sys.path.insert(0, DUST3R_REPO_PATH)
14 |
--------------------------------------------------------------------------------
/third_party/dust3r/dust3r/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3 |
--------------------------------------------------------------------------------
/third_party/dust3r/dust3r/cloud_opt/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3 | #
4 | # --------------------------------------------------------
5 | # global alignment optimization wrapper function
6 | # --------------------------------------------------------
7 | from enum import Enum
8 |
9 | from .optimizer import PointCloudOptimizer
10 | from .pair_viewer import PairViewer
11 |
12 |
13 | class GlobalAlignerMode(Enum):
14 | PointCloudOptimizer = "PointCloudOptimizer"
15 | PairViewer = "PairViewer"
16 |
17 |
18 | def global_aligner(dust3r_output, device, mode=GlobalAlignerMode.PointCloudOptimizer, **optim_kw):
19 | # extract all inputs
20 | view1, view2, pred1, pred2 = [dust3r_output[k] for k in 'view1 view2 pred1 pred2'.split()]
21 | # build the optimizer
22 | if mode == GlobalAlignerMode.PointCloudOptimizer:
23 | net = PointCloudOptimizer(view1, view2, pred1, pred2, **optim_kw).to(device)
24 | elif mode == GlobalAlignerMode.PairViewer:
25 | net = PairViewer(view1, view2, pred1, pred2, **optim_kw).to(device)
26 | else:
27 | raise NotImplementedError(f'Unknown mode {mode}')
28 |
29 | return net
30 |
--------------------------------------------------------------------------------
/third_party/dust3r/dust3r/cloud_opt/commons.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3 | #
4 | # --------------------------------------------------------
5 | # utility functions for global alignment
6 | # --------------------------------------------------------
7 | import torch
8 | import torch.nn as nn
9 | import numpy as np
10 |
11 |
12 | def edge_str(i, j):
13 | return f'{i}_{j}'
14 |
15 |
16 | def i_j_ij(ij):
17 | return edge_str(*ij), ij
18 |
19 |
20 | def edge_conf(conf_i, conf_j, edge):
21 | return float(conf_i[edge].mean() * conf_j[edge].mean())
22 |
23 |
24 | def compute_edge_scores(edges, conf_i, conf_j):
25 | return {(i, j): edge_conf(conf_i, conf_j, e) for e, (i, j) in edges}
26 |
27 |
28 | def NoGradParamDict(x):
29 | assert isinstance(x, dict)
30 | return nn.ParameterDict(x).requires_grad_(False)
31 |
32 |
33 | def get_imshapes(edges, pred_i, pred_j):
34 | n_imgs = max(max(e) for e in edges) + 1
35 | imshapes = [None] * n_imgs
36 | for e, (i, j) in enumerate(edges):
37 | shape_i = tuple(pred_i[e].shape[0:2])
38 | shape_j = tuple(pred_j[e].shape[0:2])
39 | if imshapes[i]:
40 | assert imshapes[i] == shape_i, f'incorrect shape for image {i}'
41 | if imshapes[j]:
42 | assert imshapes[j] == shape_j, f'incorrect shape for image {j}'
43 | imshapes[i] = shape_i
44 | imshapes[j] = shape_j
45 | return imshapes
46 |
47 |
48 | def get_conf_trf(mode):
49 | if mode == 'log':
50 | def conf_trf(x): return x.log()
51 | elif mode == 'sqrt':
52 | def conf_trf(x): return x.sqrt()
53 | elif mode == 'm1':
54 | def conf_trf(x): return x-1
55 | elif mode in ('id', 'none'):
56 | def conf_trf(x): return x
57 | else:
58 | raise ValueError(f'bad mode for {mode=}')
59 | return conf_trf
60 |
61 |
62 | def l2_dist(a, b, weight):
63 | return ((a - b).square().sum(dim=-1) * weight)
64 |
65 |
66 | def l1_dist(a, b, weight):
67 | return ((a - b).norm(dim=-1) * weight)
68 |
69 |
70 | ALL_DISTS = dict(l1=l1_dist, l2=l2_dist)
71 |
72 |
73 | def signed_log1p(x):
74 | sign = torch.sign(x)
75 | return sign * torch.log1p(torch.abs(x))
76 |
77 |
78 | def signed_expm1(x):
79 | sign = torch.sign(x)
80 | return sign * torch.expm1(torch.abs(x))
81 |
82 |
83 | def cosine_schedule(t, lr_start, lr_end):
84 | assert 0 <= t <= 1
85 | return lr_end + (lr_start - lr_end) * (1+np.cos(t * np.pi))/2
86 |
87 |
88 | def linear_schedule(t, lr_start, lr_end):
89 | assert 0 <= t <= 1
90 | return lr_start + (lr_end - lr_start) * t
91 |
--------------------------------------------------------------------------------
/third_party/dust3r/dust3r/cloud_opt/pair_viewer.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3 | #
4 | # --------------------------------------------------------
5 | # Dummy optimizer for visualizing pairs
6 | # --------------------------------------------------------
7 | import numpy as np
8 | import torch
9 | import torch.nn as nn
10 | import cv2
11 |
12 | from dust3r.cloud_opt.base_opt import BasePCOptimizer
13 | from dust3r.utils.geometry import inv, geotrf, depthmap_to_absolute_camera_coordinates
14 | from dust3r.cloud_opt.commons import edge_str
15 | from dust3r.post_process import estimate_focal_knowing_depth
16 |
17 |
18 | class PairViewer (BasePCOptimizer):
19 | """
20 | This a Dummy Optimizer.
21 | To use only when the goal is to visualize the results for a pair of images (with is_symmetrized)
22 | """
23 |
24 | def __init__(self, *args, **kwargs):
25 | super().__init__(*args, **kwargs)
26 | assert self.is_symmetrized and self.n_edges == 2
27 | self.has_im_poses = True
28 |
29 | # compute all parameters directly from raw input
30 | self.focals = []
31 | self.pp = []
32 | rel_poses = []
33 | confs = []
34 | for i in range(self.n_imgs):
35 | conf = float(self.conf_i[edge_str(i, 1-i)].mean() * self.conf_j[edge_str(i, 1-i)].mean())
36 | print(f' - {conf=:.3} for edge {i}-{1-i}')
37 | confs.append(conf)
38 |
39 | H, W = self.imshapes[i]
40 | pts3d = self.pred_i[edge_str(i, 1-i)]
41 | pp = torch.tensor((W/2, H/2))
42 | focal = float(estimate_focal_knowing_depth(pts3d[None], pp, focal_mode='weiszfeld'))
43 | self.focals.append(focal)
44 | self.pp.append(pp)
45 |
46 | # estimate the pose of pts1 in image 2
47 | pixels = np.mgrid[:W, :H].T.astype(np.float32)
48 | pts3d = self.pred_j[edge_str(1-i, i)].numpy()
49 | assert pts3d.shape[:2] == (H, W)
50 | msk = self.get_masks()[i].numpy()
51 | K = np.float32([(focal, 0, pp[0]), (0, focal, pp[1]), (0, 0, 1)])
52 |
53 | try:
54 | res = cv2.solvePnPRansac(pts3d[msk], pixels[msk], K, None,
55 | iterationsCount=100, reprojectionError=5, flags=cv2.SOLVEPNP_SQPNP)
56 | success, R, T, inliers = res
57 | assert success
58 |
59 | R = cv2.Rodrigues(R)[0] # world to cam
60 | pose = inv(np.r_[np.c_[R, T], [(0, 0, 0, 1)]]) # cam to world
61 | except:
62 | pose = np.eye(4)
63 | rel_poses.append(torch.from_numpy(pose.astype(np.float32)))
64 |
65 | # let's use the pair with the most confidence
66 | if confs[0] > confs[1]:
67 | # ptcloud is expressed in camera1
68 | self.im_poses = [torch.eye(4), rel_poses[1]] # I, cam2-to-cam1
69 | self.depth = [self.pred_i['0_1'][..., 2], geotrf(inv(rel_poses[1]), self.pred_j['0_1'])[..., 2]]
70 | else:
71 | # ptcloud is expressed in camera2
72 | self.im_poses = [rel_poses[0], torch.eye(4)] # I, cam1-to-cam2
73 | self.depth = [geotrf(inv(rel_poses[0]), self.pred_j['1_0'])[..., 2], self.pred_i['1_0'][..., 2]]
74 |
75 | self.im_poses = nn.Parameter(torch.stack(self.im_poses, dim=0), requires_grad=False)
76 | self.focals = nn.Parameter(torch.tensor(self.focals), requires_grad=False)
77 | self.pp = nn.Parameter(torch.stack(self.pp, dim=0), requires_grad=False)
78 | self.depth = nn.ParameterList(self.depth)
79 | for p in self.parameters():
80 | p.requires_grad = False
81 |
82 | def _set_depthmap(self, idx, depth, force=False):
83 | print('_set_depthmap is ignored in PairViewer')
84 | return
85 |
86 | def get_depthmaps(self, raw=False):
87 | depth = [d.to(self.device) for d in self.depth]
88 | return depth
89 |
90 | def _set_focal(self, idx, focal, force=False):
91 | self.focals[idx] = focal
92 |
93 | def get_focals(self):
94 | return self.focals
95 |
96 | def get_known_focal_mask(self):
97 | return torch.tensor([not (p.requires_grad) for p in self.focals])
98 |
99 | def get_principal_points(self):
100 | return self.pp
101 |
102 | def get_intrinsics(self):
103 | focals = self.get_focals()
104 | pps = self.get_principal_points()
105 | K = torch.zeros((len(focals), 3, 3), device=self.device)
106 | for i in range(len(focals)):
107 | K[i, 0, 0] = K[i, 1, 1] = focals[i]
108 | K[i, :2, 2] = pps[i]
109 | K[i, 2, 2] = 1
110 | return K
111 |
112 | def get_im_poses(self):
113 | return self.im_poses
114 |
115 | def depth_to_pts3d(self):
116 | pts3d = []
117 | for d, intrinsics, im_pose in zip(self.depth, self.get_intrinsics(), self.get_im_poses()):
118 | pts, _ = depthmap_to_absolute_camera_coordinates(d.cpu().numpy(),
119 | intrinsics.cpu().numpy(),
120 | im_pose.cpu().numpy())
121 | pts3d.append(torch.from_numpy(pts).to(device=self.device))
122 | return pts3d
123 |
124 | def forward(self):
125 | return float('nan')
126 |
--------------------------------------------------------------------------------
/third_party/dust3r/dust3r/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3 | from .utils.transforms import *
4 | from .base.batched_sampler import BatchedRandomSampler # noqa: F401
5 | from .co3d import Co3d # noqa: F401
6 |
7 |
8 | def get_data_loader(dataset, batch_size, num_workers=8, shuffle=True, drop_last=True, pin_mem=True):
9 | import torch
10 | from croco.utils.misc import get_world_size, get_rank
11 |
12 | # pytorch dataset
13 | if isinstance(dataset, str):
14 | dataset = eval(dataset)
15 |
16 | world_size = get_world_size()
17 | rank = get_rank()
18 |
19 | try:
20 | sampler = dataset.make_sampler(batch_size, shuffle=shuffle, world_size=world_size,
21 | rank=rank, drop_last=drop_last)
22 | except (AttributeError, NotImplementedError):
23 | # not avail for this dataset
24 | if torch.distributed.is_initialized():
25 | sampler = torch.utils.data.DistributedSampler(
26 | dataset, num_replicas=world_size, rank=rank, shuffle=shuffle, drop_last=drop_last
27 | )
28 | elif shuffle:
29 | sampler = torch.utils.data.RandomSampler(dataset)
30 | else:
31 | sampler = torch.utils.data.SequentialSampler(dataset)
32 |
33 | data_loader = torch.utils.data.DataLoader(
34 | dataset,
35 | sampler=sampler,
36 | batch_size=batch_size,
37 | num_workers=num_workers,
38 | pin_memory=pin_mem,
39 | drop_last=drop_last,
40 | )
41 |
42 | return data_loader
43 |
--------------------------------------------------------------------------------
/third_party/dust3r/dust3r/datasets/base/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3 |
--------------------------------------------------------------------------------
/third_party/dust3r/dust3r/datasets/base/batched_sampler.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3 | #
4 | # --------------------------------------------------------
5 | # Random sampling under a constraint
6 | # --------------------------------------------------------
7 | import numpy as np
8 | import torch
9 |
10 |
11 | class BatchedRandomSampler:
12 | """ Random sampling under a constraint: each sample in the batch has the same feature,
13 | which is chosen randomly from a known pool of 'features' for each batch.
14 |
15 | For instance, the 'feature' could be the image aspect-ratio.
16 |
17 | The index returned is a tuple (sample_idx, feat_idx).
18 | This sampler ensures that each series of `batch_size` indices has the same `feat_idx`.
19 | """
20 |
21 | def __init__(self, dataset, batch_size, pool_size, world_size=1, rank=0, drop_last=True):
22 | self.batch_size = batch_size
23 | self.pool_size = pool_size
24 |
25 | self.len_dataset = N = len(dataset)
26 | self.total_size = round_by(N, batch_size*world_size) if drop_last else N
27 | assert world_size == 1 or drop_last, 'must drop the last batch in distributed mode'
28 |
29 | # distributed sampler
30 | self.world_size = world_size
31 | self.rank = rank
32 | self.epoch = None
33 |
34 | def __len__(self):
35 | return self.total_size // self.world_size
36 |
37 | def set_epoch(self, epoch):
38 | self.epoch = epoch
39 |
40 | def __iter__(self):
41 | # prepare RNG
42 | if self.epoch is None:
43 | assert self.world_size == 1 and self.rank == 0, 'use set_epoch() if distributed mode is used'
44 | seed = int(torch.empty((), dtype=torch.int64).random_().item())
45 | else:
46 | seed = self.epoch + 777
47 | rng = np.random.default_rng(seed=seed)
48 |
49 | # random indices (will restart from 0 if not drop_last)
50 | sample_idxs = np.arange(self.total_size)
51 | rng.shuffle(sample_idxs)
52 |
53 | # random feat_idxs (same across each batch)
54 | n_batches = (self.total_size+self.batch_size-1) // self.batch_size
55 | feat_idxs = rng.integers(self.pool_size, size=n_batches)
56 | feat_idxs = np.broadcast_to(feat_idxs[:, None], (n_batches, self.batch_size))
57 | feat_idxs = feat_idxs.ravel()[:self.total_size]
58 |
59 | # put them together
60 | idxs = np.c_[sample_idxs, feat_idxs] # shape = (total_size, 2)
61 |
62 | # Distributed sampler: we select a subset of batches
63 | # make sure the slice for each node is aligned with batch_size
64 | size_per_proc = self.batch_size * ((self.total_size + self.world_size *
65 | self.batch_size-1) // (self.world_size * self.batch_size))
66 | idxs = idxs[self.rank*size_per_proc: (self.rank+1)*size_per_proc]
67 |
68 | yield from (tuple(idx) for idx in idxs)
69 |
70 |
71 | def round_by(total, multiple, up=False):
72 | if up:
73 | total = total + multiple-1
74 | return (total//multiple) * multiple
75 |
--------------------------------------------------------------------------------
/third_party/dust3r/dust3r/datasets/base/easy_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3 | #
4 | # --------------------------------------------------------
5 | # A dataset base class that you can easily resize and combine.
6 | # --------------------------------------------------------
7 | import numpy as np
8 | from dust3r.datasets.base.batched_sampler import BatchedRandomSampler
9 |
10 |
11 | class EasyDataset:
12 | """ a dataset that you can easily resize and combine.
13 | Examples:
14 | ---------
15 | 2 * dataset ==> duplicate each element 2x
16 |
17 | 10 @ dataset ==> set the size to 10 (random sampling, duplicates if necessary)
18 |
19 | dataset1 + dataset2 ==> concatenate datasets
20 | """
21 |
22 | def __add__(self, other):
23 | return CatDataset([self, other])
24 |
25 | def __rmul__(self, factor):
26 | return MulDataset(factor, self)
27 |
28 | def __rmatmul__(self, factor):
29 | return ResizedDataset(factor, self)
30 |
31 | def set_epoch(self, epoch):
32 | pass # nothing to do by default
33 |
34 | def make_sampler(self, batch_size, shuffle=True, world_size=1, rank=0, drop_last=True):
35 | if not (shuffle):
36 | raise NotImplementedError() # cannot deal yet
37 | num_of_aspect_ratios = len(self._resolutions)
38 | return BatchedRandomSampler(self, batch_size, num_of_aspect_ratios, world_size=world_size, rank=rank, drop_last=drop_last)
39 |
40 |
41 | class MulDataset (EasyDataset):
42 | """ Artifically augmenting the size of a dataset.
43 | """
44 | multiplicator: int
45 |
46 | def __init__(self, multiplicator, dataset):
47 | assert isinstance(multiplicator, int) and multiplicator > 0
48 | self.multiplicator = multiplicator
49 | self.dataset = dataset
50 |
51 | def __len__(self):
52 | return self.multiplicator * len(self.dataset)
53 |
54 | def __repr__(self):
55 | return f'{self.multiplicator}*{repr(self.dataset)}'
56 |
57 | def __getitem__(self, idx):
58 | if isinstance(idx, tuple):
59 | idx, other = idx
60 | return self.dataset[idx // self.multiplicator, other]
61 | else:
62 | return self.dataset[idx // self.multiplicator]
63 |
64 | @property
65 | def _resolutions(self):
66 | return self.dataset._resolutions
67 |
68 |
69 | class ResizedDataset (EasyDataset):
70 | """ Artifically changing the size of a dataset.
71 | """
72 | new_size: int
73 |
74 | def __init__(self, new_size, dataset):
75 | assert isinstance(new_size, int) and new_size > 0
76 | self.new_size = new_size
77 | self.dataset = dataset
78 |
79 | def __len__(self):
80 | return self.new_size
81 |
82 | def __repr__(self):
83 | size_str = str(self.new_size)
84 | for i in range((len(size_str)-1) // 3):
85 | sep = -4*i-3
86 | size_str = size_str[:sep] + '_' + size_str[sep:]
87 | return f'{size_str} @ {repr(self.dataset)}'
88 |
89 | def set_epoch(self, epoch):
90 | # this random shuffle only depends on the epoch
91 | rng = np.random.default_rng(seed=epoch+777)
92 |
93 | # shuffle all indices
94 | perm = rng.permutation(len(self.dataset))
95 |
96 | # rotary extension until target size is met
97 | shuffled_idxs = np.concatenate([perm] * (1 + (len(self)-1) // len(self.dataset)))
98 | self._idxs_mapping = shuffled_idxs[:self.new_size]
99 |
100 | assert len(self._idxs_mapping) == self.new_size
101 |
102 | def __getitem__(self, idx):
103 | assert hasattr(self, '_idxs_mapping'), 'You need to call dataset.set_epoch() to use ResizedDataset.__getitem__()'
104 | if isinstance(idx, tuple):
105 | idx, other = idx
106 | return self.dataset[self._idxs_mapping[idx], other]
107 | else:
108 | return self.dataset[self._idxs_mapping[idx]]
109 |
110 | @property
111 | def _resolutions(self):
112 | return self.dataset._resolutions
113 |
114 |
115 | class CatDataset (EasyDataset):
116 | """ Concatenation of several datasets
117 | """
118 |
119 | def __init__(self, datasets):
120 | for dataset in datasets:
121 | assert isinstance(dataset, EasyDataset)
122 | self.datasets = datasets
123 | self._cum_sizes = np.cumsum([len(dataset) for dataset in datasets])
124 |
125 | def __len__(self):
126 | return self._cum_sizes[-1]
127 |
128 | def __repr__(self):
129 | # remove uselessly long transform
130 | return ' + '.join(repr(dataset).replace(',transform=Compose( ToTensor() Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)))', '') for dataset in self.datasets)
131 |
132 | def set_epoch(self, epoch):
133 | for dataset in self.datasets:
134 | dataset.set_epoch(epoch)
135 |
136 | def __getitem__(self, idx):
137 | other = None
138 | if isinstance(idx, tuple):
139 | idx, other = idx
140 |
141 | if not (0 <= idx < len(self)):
142 | raise IndexError()
143 |
144 | db_idx = np.searchsorted(self._cum_sizes, idx, 'right')
145 | dataset = self.datasets[db_idx]
146 | new_idx = idx - (self._cum_sizes[db_idx - 1] if db_idx > 0 else 0)
147 |
148 | if other is not None:
149 | new_idx = (new_idx, other)
150 | return dataset[new_idx]
151 |
152 | @property
153 | def _resolutions(self):
154 | resolutions = self.datasets[0]._resolutions
155 | for dataset in self.datasets[1:]:
156 | assert tuple(dataset._resolutions) == tuple(resolutions)
157 | return resolutions
158 |
--------------------------------------------------------------------------------
/third_party/dust3r/dust3r/datasets/co3d.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3 | #
4 | # --------------------------------------------------------
5 | # Dataloader for preprocessed Co3d_v2
6 | # dataset at https://github.com/facebookresearch/co3d - Creative Commons Attribution-NonCommercial 4.0 International
7 | # See datasets_preprocess/preprocess_co3d.py
8 | # --------------------------------------------------------
9 | import os.path as osp
10 | import json
11 | import itertools
12 | from collections import deque
13 |
14 | import cv2
15 | import numpy as np
16 |
17 | from dust3r.datasets.base.base_stereo_view_dataset import BaseStereoViewDataset
18 | from dust3r.utils.image import imread_cv2
19 |
20 |
21 | class Co3d(BaseStereoViewDataset):
22 | def __init__(self, mask_bg=True, *args, ROOT, **kwargs):
23 | self.ROOT = ROOT
24 | super().__init__(*args, **kwargs)
25 | assert mask_bg in (True, False, 'rand')
26 | self.mask_bg = mask_bg
27 |
28 | # load all scenes
29 | with open(osp.join(self.ROOT, f'selected_seqs_{self.split}.json'), 'r') as f:
30 | self.scenes = json.load(f)
31 | self.scenes = {k: v for k, v in self.scenes.items() if len(v) > 0}
32 | self.scenes = {(k, k2): v2 for k, v in self.scenes.items()
33 | for k2, v2 in v.items()}
34 | self.scene_list = list(self.scenes.keys())
35 |
36 | # for each scene, we have 100 images ==> 360 degrees (so 25 frames ~= 90 degrees)
37 | # we prepare all combinations such that i-j = +/- [5, 10, .., 90] degrees
38 | self.combinations = [(i, j)
39 | for i, j in itertools.combinations(range(100), 2)
40 | if 0 < abs(i-j) <= 30 and abs(i-j) % 5 == 0]
41 |
42 | self.invalidate = {scene: {} for scene in self.scene_list}
43 |
44 | def __len__(self):
45 | return len(self.scene_list) * len(self.combinations)
46 |
47 | def _get_views(self, idx, resolution, rng):
48 | # choose a scene
49 | obj, instance = self.scene_list[idx // len(self.combinations)]
50 | image_pool = self.scenes[obj, instance]
51 | im1_idx, im2_idx = self.combinations[idx % len(self.combinations)]
52 |
53 | # add a bit of randomness
54 | last = len(image_pool)-1
55 |
56 | if resolution not in self.invalidate[obj, instance]: # flag invalid images
57 | self.invalidate[obj, instance][resolution] = [False for _ in range(len(image_pool))]
58 |
59 | # decide now if we mask the bg
60 | mask_bg = (self.mask_bg == True) or (self.mask_bg == 'rand' and rng.choice(2))
61 |
62 | views = []
63 | imgs_idxs = [max(0, min(im_idx + rng.integers(-4, 5), last)) for im_idx in [im2_idx, im1_idx]]
64 | imgs_idxs = deque(imgs_idxs)
65 | while len(imgs_idxs) > 0: # some images (few) have zero depth
66 | im_idx = imgs_idxs.pop()
67 |
68 | if self.invalidate[obj, instance][resolution][im_idx]:
69 | # search for a valid image
70 | random_direction = 2 * rng.choice(2) - 1
71 | for offset in range(1, len(image_pool)):
72 | tentative_im_idx = (im_idx + (random_direction * offset)) % len(image_pool)
73 | if not self.invalidate[obj, instance][resolution][tentative_im_idx]:
74 | im_idx = tentative_im_idx
75 | break
76 |
77 | view_idx = image_pool[im_idx]
78 |
79 | impath = osp.join(self.ROOT, obj, instance, 'images', f'frame{view_idx:06n}.jpg')
80 |
81 | # load camera params
82 | input_metadata = np.load(impath.replace('jpg', 'npz'))
83 | camera_pose = input_metadata['camera_pose'].astype(np.float32)
84 | intrinsics = input_metadata['camera_intrinsics'].astype(np.float32)
85 |
86 | # load image and depth
87 | rgb_image = imread_cv2(impath)
88 | depthmap = imread_cv2(impath.replace('images', 'depths') + '.geometric.png', cv2.IMREAD_UNCHANGED)
89 | depthmap = (depthmap.astype(np.float32) / 65535) * np.nan_to_num(input_metadata['maximum_depth'])
90 |
91 | if mask_bg:
92 | # load object mask
93 | maskpath = osp.join(self.ROOT, obj, instance, 'masks', f'frame{view_idx:06n}.png')
94 | maskmap = imread_cv2(maskpath, cv2.IMREAD_UNCHANGED).astype(np.float32)
95 | maskmap = (maskmap / 255.0) > 0.1
96 |
97 | # update the depthmap with mask
98 | depthmap *= maskmap
99 |
100 | rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary(
101 | rgb_image, depthmap, intrinsics, resolution, rng=rng, info=impath)
102 |
103 | num_valid = (depthmap > 0.0).sum()
104 | if num_valid == 0:
105 | # problem, invalidate image and retry
106 | self.invalidate[obj, instance][resolution][im_idx] = True
107 | imgs_idxs.append(im_idx)
108 | continue
109 |
110 | views.append(dict(
111 | img=rgb_image,
112 | depthmap=depthmap,
113 | camera_pose=camera_pose,
114 | camera_intrinsics=intrinsics,
115 | dataset='Co3d_v2',
116 | label=osp.join(obj, instance),
117 | instance=osp.split(impath)[1],
118 | ))
119 | return views
120 |
121 |
122 | if __name__ == "__main__":
123 | from dust3r.datasets.base.base_stereo_view_dataset import view_name
124 | from dust3r.viz import SceneViz, auto_cam_size
125 | from dust3r.utils.image import rgb
126 |
127 | dataset = Co3d(split='train', ROOT="data/co3d_subset_processed", resolution=224, aug_crop=16)
128 |
129 | for idx in np.random.permutation(len(dataset)):
130 | views = dataset[idx]
131 | assert len(views) == 2
132 | print(view_name(views[0]), view_name(views[1]))
133 | viz = SceneViz()
134 | poses = [views[view_idx]['camera_pose'] for view_idx in [0, 1]]
135 | cam_size = max(auto_cam_size(poses), 0.001)
136 | for view_idx in [0, 1]:
137 | pts3d = views[view_idx]['pts3d']
138 | valid_mask = views[view_idx]['valid_mask']
139 | colors = rgb(views[view_idx]['img'])
140 | viz.add_pointcloud(pts3d, colors, valid_mask)
141 | viz.add_camera(pose_c2w=views[view_idx]['camera_pose'],
142 | focal=views[view_idx]['camera_intrinsics'][0, 0],
143 | color=(idx*255, (1 - idx)*255, 0),
144 | image=colors,
145 | cam_size=cam_size)
146 | viz.show()
147 |
--------------------------------------------------------------------------------
/third_party/dust3r/dust3r/datasets/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3 |
--------------------------------------------------------------------------------
/third_party/dust3r/dust3r/datasets/utils/cropping.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3 | #
4 | # --------------------------------------------------------
5 | # croppping utilities
6 | # --------------------------------------------------------
7 | import PIL.Image
8 | import os
9 | os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
10 | import cv2 # noqa
11 | import numpy as np # noqa
12 | from dust3r.utils.geometry import colmap_to_opencv_intrinsics, opencv_to_colmap_intrinsics # noqa
13 | try:
14 | lanczos = PIL.Image.Resampling.LANCZOS
15 | except AttributeError:
16 | lanczos = PIL.Image.LANCZOS
17 |
18 |
19 | class ImageList:
20 | """ Convenience class to aply the same operation to a whole set of images.
21 | """
22 |
23 | def __init__(self, images):
24 | if not isinstance(images, (tuple, list, set)):
25 | images = [images]
26 | self.images = []
27 | for image in images:
28 | if not isinstance(image, PIL.Image.Image):
29 | image = PIL.Image.fromarray(image)
30 | self.images.append(image)
31 |
32 | def __len__(self):
33 | return len(self.images)
34 |
35 | def to_pil(self):
36 | return tuple(self.images) if len(self.images) > 1 else self.images[0]
37 |
38 | @property
39 | def size(self):
40 | sizes = [im.size for im in self.images]
41 | assert all(sizes[0] == s for s in sizes)
42 | return sizes[0]
43 |
44 | def resize(self, *args, **kwargs):
45 | return ImageList(self._dispatch('resize', *args, **kwargs))
46 |
47 | def crop(self, *args, **kwargs):
48 | return ImageList(self._dispatch('crop', *args, **kwargs))
49 |
50 | def _dispatch(self, func, *args, **kwargs):
51 | return [getattr(im, func)(*args, **kwargs) for im in self.images]
52 |
53 |
54 | def rescale_image_depthmap(image, depthmap, camera_intrinsics, output_resolution):
55 | """ Jointly rescale a (image, depthmap)
56 | so that (out_width, out_height) >= output_res
57 | """
58 | image = ImageList(image)
59 | input_resolution = np.array(image.size) # (W,H)
60 | output_resolution = np.array(output_resolution)
61 | if depthmap is not None:
62 | # can also use this with masks instead of depthmaps
63 | assert tuple(depthmap.shape[:2]) == image.size[::-1]
64 | assert output_resolution.shape == (2,)
65 | # define output resolution
66 | scale_final = max(output_resolution / image.size) + 1e-8
67 | output_resolution = np.floor(input_resolution * scale_final).astype(int)
68 |
69 | # first rescale the image so that it contains the crop
70 | image = image.resize(output_resolution, resample=lanczos)
71 | if depthmap is not None:
72 | depthmap = cv2.resize(depthmap, output_resolution, fx=scale_final,
73 | fy=scale_final, interpolation=cv2.INTER_NEAREST)
74 |
75 | # no offset here; simple rescaling
76 | camera_intrinsics = camera_matrix_of_crop(
77 | camera_intrinsics, input_resolution, output_resolution, scaling=scale_final)
78 |
79 | return image.to_pil(), depthmap, camera_intrinsics
80 |
81 |
82 | def camera_matrix_of_crop(input_camera_matrix, input_resolution, output_resolution, scaling=1, offset_factor=0.5, offset=None):
83 | # Margins to offset the origin
84 | margins = np.asarray(input_resolution) * scaling - output_resolution
85 | assert np.all(margins >= 0.0)
86 | if offset is None:
87 | offset = offset_factor * margins
88 |
89 | # Generate new camera parameters
90 | output_camera_matrix_colmap = opencv_to_colmap_intrinsics(input_camera_matrix)
91 | output_camera_matrix_colmap[:2, :] *= scaling
92 | output_camera_matrix_colmap[:2, 2] -= offset
93 | output_camera_matrix = colmap_to_opencv_intrinsics(output_camera_matrix_colmap)
94 |
95 | return output_camera_matrix
96 |
97 |
98 | def crop_image_depthmap(image, depthmap, camera_intrinsics, crop_bbox):
99 | """
100 | Return a crop of the input view.
101 | """
102 | image = ImageList(image)
103 | l, t, r, b = crop_bbox
104 |
105 | image = image.crop((l, t, r, b))
106 | depthmap = depthmap[t:b, l:r]
107 |
108 | camera_intrinsics = camera_intrinsics.copy()
109 | camera_intrinsics[0, 2] -= l
110 | camera_intrinsics[1, 2] -= t
111 |
112 | return image.to_pil(), depthmap, camera_intrinsics
113 |
114 |
115 | def bbox_from_intrinsics_in_out(input_camera_matrix, output_camera_matrix, output_resolution):
116 | out_width, out_height = output_resolution
117 | l, t = np.int32(np.round(input_camera_matrix[:2, 2] - output_camera_matrix[:2, 2]))
118 | crop_bbox = (l, t, l+out_width, t+out_height)
119 | return crop_bbox
120 |
--------------------------------------------------------------------------------
/third_party/dust3r/dust3r/datasets/utils/transforms.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3 | #
4 | # --------------------------------------------------------
5 | # DUST3R default transforms
6 | # --------------------------------------------------------
7 | import torchvision.transforms as tvf
8 | from dust3r.utils.image import ImgNorm
9 |
10 | # define the standard image transforms
11 | ColorJitter = tvf.Compose([tvf.ColorJitter(0.5, 0.5, 0.5, 0.1), ImgNorm])
12 |
--------------------------------------------------------------------------------
/third_party/dust3r/dust3r/heads/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3 | #
4 | # --------------------------------------------------------
5 | # head factory
6 | # --------------------------------------------------------
7 | from .linear_head import LinearPts3d
8 | from .dpt_head import create_dpt_head
9 |
10 |
11 | def head_factory(head_type, output_mode, net, has_conf=False):
12 | """" build a prediction head for the decoder
13 | """
14 | if head_type == 'linear' and output_mode == 'pts3d':
15 | return LinearPts3d(net, has_conf)
16 | elif head_type == 'dpt' and output_mode == 'pts3d':
17 | return create_dpt_head(net, has_conf=has_conf)
18 | else:
19 | raise NotImplementedError(f"unexpected {head_type=} and {output_mode=}")
20 |
--------------------------------------------------------------------------------
/third_party/dust3r/dust3r/heads/dpt_head.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3 | #
4 | # --------------------------------------------------------
5 | # dpt head implementation for DUST3R
6 | # Downstream heads assume inputs of size B x N x C (where N is the number of tokens) ;
7 | # or if it takes as input the output at every layer, the attribute return_all_layers should be set to True
8 | # the forward function also takes as input a dictionnary img_info with key "height" and "width"
9 | # for PixelwiseTask, the output will be of dimension B x num_channels x H x W
10 | # --------------------------------------------------------
11 | from einops import rearrange
12 | from typing import List
13 | import torch
14 | import torch.nn as nn
15 | from dust3r.heads.postprocess import postprocess
16 | import dust3r.utils.path_to_croco # noqa: F401
17 | from models.dpt_block import DPTOutputAdapter # noqa
18 |
19 |
20 | class DPTOutputAdapter_fix(DPTOutputAdapter):
21 | """
22 | Adapt croco's DPTOutputAdapter implementation for dust3r:
23 | remove duplicated weigths, and fix forward for dust3r
24 | """
25 |
26 | def init(self, dim_tokens_enc=768):
27 | super().init(dim_tokens_enc)
28 | # these are duplicated weights
29 | del self.act_1_postprocess
30 | del self.act_2_postprocess
31 | del self.act_3_postprocess
32 | del self.act_4_postprocess
33 |
34 | def forward(self, encoder_tokens: List[torch.Tensor], image_size=None):
35 | assert self.dim_tokens_enc is not None, 'Need to call init(dim_tokens_enc) function first'
36 | # H, W = input_info['image_size']
37 | image_size = self.image_size if image_size is None else image_size
38 | H, W = image_size
39 | # Number of patches in height and width
40 | N_H = H // (self.stride_level * self.P_H)
41 | N_W = W // (self.stride_level * self.P_W)
42 |
43 | # Hook decoder onto 4 layers from specified ViT layers
44 | layers = [encoder_tokens[hook] for hook in self.hooks]
45 |
46 | # Extract only task-relevant tokens and ignore global tokens.
47 | layers = [self.adapt_tokens(l) for l in layers]
48 |
49 | # Reshape tokens to spatial representation
50 | layers = [rearrange(l, 'b (nh nw) c -> b c nh nw', nh=N_H, nw=N_W) for l in layers]
51 |
52 | layers = [self.act_postprocess[idx](l) for idx, l in enumerate(layers)]
53 | # Project layers to chosen feature dim
54 | layers = [self.scratch.layer_rn[idx](l) for idx, l in enumerate(layers)]
55 |
56 | # Fuse layers using refinement stages
57 | path_4 = self.scratch.refinenet4(layers[3])[:, :, :layers[2].shape[2], :layers[2].shape[3]]
58 | path_3 = self.scratch.refinenet3(path_4, layers[2])
59 | path_2 = self.scratch.refinenet2(path_3, layers[1])
60 | path_1 = self.scratch.refinenet1(path_2, layers[0])
61 |
62 | # Output head
63 | out = self.head(path_1)
64 |
65 | return out
66 |
67 |
68 | class PixelwiseTaskWithDPT(nn.Module):
69 | """ DPT module for dust3r, can return 3D points + confidence for all pixels"""
70 |
71 | def __init__(self, *, n_cls_token=0, hooks_idx=None, dim_tokens=None,
72 | output_width_ratio=1, num_channels=1, postprocess=None, depth_mode=None, conf_mode=None, **kwargs):
73 | super(PixelwiseTaskWithDPT, self).__init__()
74 | self.return_all_layers = True # backbone needs to return all layers
75 | self.postprocess = postprocess
76 | self.depth_mode = depth_mode
77 | self.conf_mode = conf_mode
78 |
79 | assert n_cls_token == 0, "Not implemented"
80 | dpt_args = dict(output_width_ratio=output_width_ratio,
81 | num_channels=num_channels,
82 | **kwargs)
83 | if hooks_idx is not None:
84 | dpt_args.update(hooks=hooks_idx)
85 | self.dpt = DPTOutputAdapter_fix(**dpt_args)
86 | dpt_init_args = {} if dim_tokens is None else {'dim_tokens_enc': dim_tokens}
87 | self.dpt.init(**dpt_init_args)
88 |
89 | def forward(self, x, img_info):
90 | out = self.dpt(x, image_size=(img_info[0], img_info[1]))
91 | if self.postprocess:
92 | out = self.postprocess(out, self.depth_mode, self.conf_mode)
93 | return out
94 |
95 |
96 | def create_dpt_head(net, has_conf=False):
97 | """
98 | return PixelwiseTaskWithDPT for given net params
99 | """
100 | assert net.dec_depth > 9
101 | l2 = net.dec_depth
102 | feature_dim = 256
103 | last_dim = feature_dim//2
104 | out_nchan = 3
105 | ed = net.enc_embed_dim
106 | dd = net.dec_embed_dim
107 | return PixelwiseTaskWithDPT(num_channels=out_nchan + has_conf,
108 | feature_dim=feature_dim,
109 | last_dim=last_dim,
110 | hooks_idx=[0, l2*2//4, l2*3//4, l2],
111 | dim_tokens=[ed, dd, dd, dd],
112 | postprocess=postprocess,
113 | depth_mode=net.depth_mode,
114 | conf_mode=net.conf_mode,
115 | head_type='regression')
116 |
--------------------------------------------------------------------------------
/third_party/dust3r/dust3r/heads/linear_head.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3 | #
4 | # --------------------------------------------------------
5 | # linear head implementation for DUST3R
6 | # --------------------------------------------------------
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 | from dust3r.heads.postprocess import postprocess
10 |
11 |
12 | class LinearPts3d (nn.Module):
13 | """
14 | Linear head for dust3r
15 | Each token outputs: - 16x16 3D points (+ confidence)
16 | """
17 |
18 | def __init__(self, net, has_conf=False):
19 | super().__init__()
20 | self.patch_size = net.patch_embed.patch_size[0]
21 | self.depth_mode = net.depth_mode
22 | self.conf_mode = net.conf_mode
23 | self.has_conf = has_conf
24 |
25 | self.proj = nn.Linear(net.dec_embed_dim, (3 + has_conf)*self.patch_size**2)
26 |
27 | def setup(self, croconet):
28 | pass
29 |
30 | def forward(self, decout, img_shape):
31 | H, W = img_shape
32 | tokens = decout[-1]
33 | B, S, D = tokens.shape
34 |
35 | # extract 3D points
36 | feat = self.proj(tokens) # B,S,D
37 | feat = feat.transpose(-1, -2).view(B, -1, H//self.patch_size, W//self.patch_size)
38 | feat = F.pixel_shuffle(feat, self.patch_size) # B,3,H,W
39 |
40 | # permute + norm depth
41 | return postprocess(feat, self.depth_mode, self.conf_mode)
42 |
--------------------------------------------------------------------------------
/third_party/dust3r/dust3r/heads/postprocess.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3 | #
4 | # --------------------------------------------------------
5 | # post process function for all heads: extract 3D points/confidence from output
6 | # --------------------------------------------------------
7 | import torch
8 |
9 |
10 | def postprocess(out, depth_mode, conf_mode):
11 | """
12 | extract 3D points/confidence from prediction head output
13 | """
14 | fmap = out.permute(0, 2, 3, 1) # B,H,W,3
15 | res = dict(pts3d=reg_dense_depth(fmap[:, :, :, 0:3], mode=depth_mode))
16 |
17 | if conf_mode is not None:
18 | res['conf'] = reg_dense_conf(fmap[:, :, :, 3], mode=conf_mode)
19 | return res
20 |
21 |
22 | def reg_dense_depth(xyz, mode):
23 | """
24 | extract 3D points from prediction head output
25 | """
26 | mode, vmin, vmax = mode
27 |
28 | no_bounds = (vmin == -float('inf')) and (vmax == float('inf'))
29 | assert no_bounds
30 |
31 | if mode == 'linear':
32 | if no_bounds:
33 | return xyz # [-inf, +inf]
34 | return xyz.clip(min=vmin, max=vmax)
35 |
36 | # distance to origin
37 | d = xyz.norm(dim=-1, keepdim=True)
38 | xyz = xyz / d.clip(min=1e-8)
39 |
40 | if mode == 'square':
41 | return xyz * d.square()
42 |
43 | if mode == 'exp':
44 | return xyz * torch.expm1(d)
45 |
46 | raise ValueError(f'bad {mode=}')
47 |
48 |
49 | def reg_dense_conf(x, mode):
50 | """
51 | extract confidence from prediction head output
52 | """
53 | mode, vmin, vmax = mode
54 | if mode == 'exp':
55 | return vmin + x.exp().clip(max=vmax-vmin)
56 | if mode == 'sigmoid':
57 | return (vmax - vmin) * torch.sigmoid(x) + vmin
58 | raise ValueError(f'bad {mode=}')
59 |
--------------------------------------------------------------------------------
/third_party/dust3r/dust3r/image_pairs.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3 | #
4 | # --------------------------------------------------------
5 | # utilities needed to load image pairs
6 | # --------------------------------------------------------
7 | import numpy as np
8 | import torch
9 | import itertools
10 |
11 |
12 | def make_pairs(imgs, scene_graph='complete', prefilter=None, symmetrize=True):
13 | pairs = []
14 |
15 | if scene_graph == 'complete': # complete graph
16 | for i in range(len(imgs)):
17 | for j in range(i):
18 | pairs.append((imgs[i], imgs[j]))
19 |
20 | elif scene_graph.startswith('swin'):
21 | winsize = int(scene_graph.split('-')[1]) if '-' in scene_graph else 3
22 | for i in range(len(imgs)):
23 | for j in range(winsize):
24 | idx = (i + j) % len(imgs) # explicit loop closure
25 | pairs.append((imgs[i], imgs[idx]))
26 |
27 | elif scene_graph.startswith('oneref'):
28 | refid = int(scene_graph.split('-')[1]) if '-' in scene_graph else 0
29 | for j in range(len(imgs)):
30 | if j != refid:
31 | pairs.append((imgs[refid], imgs[j]))
32 |
33 | elif scene_graph == 'pairs':
34 | assert len(imgs) % 2 == 0
35 | for i in range(0, len(imgs), 2):
36 | pairs.append((imgs[i], imgs[i+1]))
37 |
38 | if symmetrize:
39 | pairs += [(img2, img1) for img1, img2 in pairs]
40 |
41 | # now, remove edges
42 | if isinstance(prefilter, str) and prefilter.startswith('seq'):
43 | pairs = filter_pairs_seq(pairs, int(prefilter[3:]))
44 |
45 | if isinstance(prefilter, str) and prefilter.startswith('cyc'):
46 | pairs = filter_pairs_seq(pairs, int(prefilter[3:]), cyclic=True)
47 |
48 | return pairs
49 |
50 | def make_pairs_fast(imgs, scene_graph='complete', prefilter=None, symmetrize=True):
51 | pairs = []
52 |
53 | if scene_graph == 'complete': # complete graph
54 | pairs = list(itertools.combinations(imgs, 2))
55 |
56 | elif scene_graph.startswith('swin'):
57 | winsize = int(scene_graph.split('-')[1]) if '-' in scene_graph else 3
58 | for i in range(len(imgs)):
59 | for j in range(winsize):
60 | idx = (i + j) % len(imgs) # explicit loop closure
61 | pairs.append((imgs[i], imgs[idx]))
62 |
63 | elif scene_graph.startswith('oneref'):
64 | refid = int(scene_graph.split('-')[1]) if '-' in scene_graph else 0
65 | for j in range(len(imgs)):
66 | if j != refid:
67 | pairs.append((imgs[refid], imgs[j]))
68 |
69 | elif scene_graph == 'pairs':
70 | assert len(imgs) % 2 == 0
71 | for i in range(0, len(imgs), 2):
72 | pairs.append((imgs[i], imgs[i+1]))
73 |
74 | if symmetrize:
75 | pairs += [(img2, img1) for img1, img2 in pairs]
76 |
77 | # now, remove edges
78 | if isinstance(prefilter, str) and prefilter.startswith('seq'):
79 | pairs = filter_pairs_seq(pairs, int(prefilter[3:]))
80 |
81 | if isinstance(prefilter, str) and prefilter.startswith('cyc'):
82 | pairs = filter_pairs_seq(pairs, int(prefilter[3:]), cyclic=True)
83 |
84 | return pairs
85 |
86 | def sel(x, kept):
87 | if isinstance(x, dict):
88 | return {k: sel(v, kept) for k, v in x.items()}
89 | if isinstance(x, (torch.Tensor, np.ndarray)):
90 | return x[kept]
91 | if isinstance(x, (tuple, list)):
92 | return type(x)([x[k] for k in kept])
93 |
94 |
95 | def _filter_edges_seq(edges, seq_dis_thr, cyclic=False):
96 | # number of images
97 | n = max(max(e) for e in edges)+1
98 |
99 | kept = []
100 | for e, (i, j) in enumerate(edges):
101 | dis = abs(i-j)
102 | if cyclic:
103 | dis = min(dis, abs(i+n-j), abs(i-n-j))
104 | if dis <= seq_dis_thr:
105 | kept.append(e)
106 | return kept
107 |
108 |
109 | def filter_pairs_seq(pairs, seq_dis_thr, cyclic=False):
110 | edges = [(img1['idx'], img2['idx']) for img1, img2 in pairs]
111 | kept = _filter_edges_seq(edges, seq_dis_thr, cyclic=cyclic)
112 | return [pairs[i] for i in kept]
113 |
114 |
115 | def filter_edges_seq(view1, view2, pred1, pred2, seq_dis_thr, cyclic=False):
116 | edges = [(int(i), int(j)) for i, j in zip(view1['idx'], view2['idx'])]
117 | kept = _filter_edges_seq(edges, seq_dis_thr, cyclic=cyclic)
118 | print(f'>> Filtering edges more than {seq_dis_thr} frames apart: kept {len(kept)}/{len(edges)} edges')
119 | return sel(view1, kept), sel(view2, kept), sel(pred1, kept), sel(pred2, kept)
120 |
--------------------------------------------------------------------------------
/third_party/dust3r/dust3r/inference.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3 | #
4 | # --------------------------------------------------------
5 | # utilities needed for the inference
6 | # --------------------------------------------------------
7 | import tqdm
8 | import torch
9 | from dust3r.utils.device import to_cpu, collate_with_cat
10 | from dust3r.model import AsymmetricCroCo3DStereo, inf # noqa: F401, needed when loading the model
11 | from dust3r.utils.misc import invalid_to_nans
12 | from dust3r.utils.geometry import depthmap_to_pts3d, geotrf
13 |
14 |
15 | def load_model(model_path, device):
16 | print('... loading model from', model_path)
17 | ckpt = torch.load(model_path, map_location='cpu')
18 | args = ckpt['args'].model.replace("ManyAR_PatchEmbed", "PatchEmbedDust3R")
19 | if 'landscape_only' not in args:
20 | args = args[:-1] + ', landscape_only=False)'
21 | else:
22 | args = args.replace(" ", "").replace('landscape_only=True', 'landscape_only=False')
23 | assert "landscape_only=False" in args
24 | print(f"instantiating : {args}")
25 | net = eval(args)
26 | print(net.load_state_dict(ckpt['model'], strict=False))
27 | return net.to(device)
28 |
29 |
30 | def _interleave_imgs(img1, img2):
31 | res = {}
32 | for key, value1 in img1.items():
33 | value2 = img2[key]
34 | if isinstance(value1, torch.Tensor):
35 | value = torch.stack((value1, value2), dim=1).flatten(0, 1)
36 | else:
37 | value = [x for pair in zip(value1, value2) for x in pair]
38 | res[key] = value
39 | return res
40 |
41 |
42 | def make_batch_symmetric(batch):
43 | view1, view2 = batch
44 | view1, view2 = (_interleave_imgs(view1, view2), _interleave_imgs(view2, view1))
45 | return view1, view2
46 |
47 |
48 | def loss_of_one_batch(batch, model, criterion, device, symmetrize_batch=False, use_amp=False, ret=None):
49 | view1, view2 = batch
50 | for view in batch:
51 | for name in 'img pts3d valid_mask camera_pose camera_intrinsics F_matrix corres'.split(): # pseudo_focal
52 | if name not in view:
53 | continue
54 | view[name] = view[name].to(device, non_blocking=True)
55 |
56 | if symmetrize_batch:
57 | view1, view2 = make_batch_symmetric(batch)
58 |
59 | with torch.cuda.amp.autocast(enabled=bool(use_amp)):
60 | pred1, pred2 = model(view1, view2)
61 |
62 | # loss is supposed to be symmetric
63 | with torch.cuda.amp.autocast(enabled=False):
64 | loss = criterion(view1, view2, pred1, pred2) if criterion is not None else None
65 |
66 | result = dict(view1=view1, view2=view2, pred1=pred1, pred2=pred2, loss=loss)
67 | return result[ret] if ret else result
68 |
69 |
70 | @torch.no_grad()
71 | def inference(pairs, model, device, batch_size=8):
72 | print(f'>> Inference with model on {len(pairs)} image pairs')
73 | result = []
74 |
75 | # first, check if all images have the same size
76 | multiple_shapes = not (check_if_same_size(pairs))
77 | if multiple_shapes: # force bs=1
78 | batch_size = 1
79 |
80 | for i in tqdm.trange(0, len(pairs), batch_size):
81 | res = loss_of_one_batch(collate_with_cat(pairs[i:i+batch_size]), model, None, device)
82 | result.append(to_cpu(res))
83 |
84 | result = collate_with_cat(result, lists=multiple_shapes)
85 |
86 | torch.cuda.empty_cache()
87 | return result
88 |
89 |
90 | def check_if_same_size(pairs):
91 | shapes1 = [img1['img'].shape[-2:] for img1, img2 in pairs]
92 | shapes2 = [img2['img'].shape[-2:] for img1, img2 in pairs]
93 | return all(shapes1[0] == s for s in shapes1) and all(shapes2[0] == s for s in shapes2)
94 |
95 |
96 | def get_pred_pts3d(gt, pred, use_pose=False):
97 | if 'depth' in pred and 'pseudo_focal' in pred:
98 | try:
99 | pp = gt['camera_intrinsics'][..., :2, 2]
100 | except KeyError:
101 | pp = None
102 | pts3d = depthmap_to_pts3d(**pred, pp=pp)
103 |
104 | elif 'pts3d' in pred:
105 | # pts3d from my camera
106 | pts3d = pred['pts3d']
107 |
108 | elif 'pts3d_in_other_view' in pred:
109 | # pts3d from the other camera, already transformed
110 | assert use_pose is True
111 | return pred['pts3d_in_other_view'] # return!
112 |
113 | if use_pose:
114 | camera_pose = pred.get('camera_pose')
115 | assert camera_pose is not None
116 | pts3d = geotrf(camera_pose, pts3d)
117 |
118 | return pts3d
119 |
120 |
121 | def find_opt_scaling(gt_pts1, gt_pts2, pr_pts1, pr_pts2=None, fit_mode='weiszfeld_stop_grad', valid1=None, valid2=None):
122 | assert gt_pts1.ndim == pr_pts1.ndim == 4
123 | assert gt_pts1.shape == pr_pts1.shape
124 | if gt_pts2 is not None:
125 | assert gt_pts2.ndim == pr_pts2.ndim == 4
126 | assert gt_pts2.shape == pr_pts2.shape
127 |
128 | # concat the pointcloud
129 | nan_gt_pts1 = invalid_to_nans(gt_pts1, valid1).flatten(1, 2)
130 | nan_gt_pts2 = invalid_to_nans(gt_pts2, valid2).flatten(1, 2) if gt_pts2 is not None else None
131 |
132 | pr_pts1 = invalid_to_nans(pr_pts1, valid1).flatten(1, 2)
133 | pr_pts2 = invalid_to_nans(pr_pts2, valid2).flatten(1, 2) if pr_pts2 is not None else None
134 |
135 | all_gt = torch.cat((nan_gt_pts1, nan_gt_pts2), dim=1) if gt_pts2 is not None else nan_gt_pts1
136 | all_pr = torch.cat((pr_pts1, pr_pts2), dim=1) if pr_pts2 is not None else pr_pts1
137 |
138 | dot_gt_pr = (all_pr * all_gt).sum(dim=-1)
139 | dot_gt_gt = all_gt.square().sum(dim=-1)
140 |
141 | if fit_mode.startswith('avg'):
142 | # scaling = (all_pr / all_gt).view(B, -1).mean(dim=1)
143 | scaling = dot_gt_pr.nanmean(dim=1) / dot_gt_gt.nanmean(dim=1)
144 | elif fit_mode.startswith('median'):
145 | scaling = (dot_gt_pr / dot_gt_gt).nanmedian(dim=1).values
146 | elif fit_mode.startswith('weiszfeld'):
147 | # init scaling with l2 closed form
148 | scaling = dot_gt_pr.nanmean(dim=1) / dot_gt_gt.nanmean(dim=1)
149 | # iterative re-weighted least-squares
150 | for iter in range(10):
151 | # re-weighting by inverse of distance
152 | dis = (all_pr - scaling.view(-1, 1, 1) * all_gt).norm(dim=-1)
153 | # print(dis.nanmean(-1))
154 | w = dis.clip_(min=1e-8).reciprocal()
155 | # update the scaling with the new weights
156 | scaling = (w * dot_gt_pr).nanmean(dim=1) / (w * dot_gt_gt).nanmean(dim=1)
157 | else:
158 | raise ValueError(f'bad {fit_mode=}')
159 |
160 | if fit_mode.endswith('stop_grad'):
161 | scaling = scaling.detach()
162 |
163 | scaling = scaling.clip(min=1e-3)
164 | # assert scaling.isfinite().all(), bb()
165 | return scaling
166 |
--------------------------------------------------------------------------------
/third_party/dust3r/dust3r/optim_factory.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3 | #
4 | # --------------------------------------------------------
5 | # optimization functions
6 | # --------------------------------------------------------
7 |
8 |
9 | def adjust_learning_rate_by_lr(optimizer, lr):
10 | for param_group in optimizer.param_groups:
11 | if "lr_scale" in param_group:
12 | param_group["lr"] = lr * param_group["lr_scale"]
13 | else:
14 | param_group["lr"] = lr
15 |
--------------------------------------------------------------------------------
/third_party/dust3r/dust3r/patch_embed.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3 | #
4 | # --------------------------------------------------------
5 | # PatchEmbed implementation for DUST3R,
6 | # in particular ManyAR_PatchEmbed that Handle images with non-square aspect ratio
7 | # --------------------------------------------------------
8 | import torch
9 | import dust3r.utils.path_to_croco # noqa: F401
10 | from models.blocks import PatchEmbed # noqa
11 |
12 |
13 | def get_patch_embed(patch_embed_cls, img_size, patch_size, enc_embed_dim):
14 | assert patch_embed_cls in ['PatchEmbedDust3R', 'ManyAR_PatchEmbed']
15 | patch_embed = eval(patch_embed_cls)(img_size, patch_size, 3, enc_embed_dim)
16 | return patch_embed
17 |
18 |
19 | class PatchEmbedDust3R(PatchEmbed):
20 | def forward(self, x, **kw):
21 | B, C, H, W = x.shape
22 | assert H % self.patch_size[0] == 0, f"Input image height ({H}) is not a multiple of patch size ({self.patch_size[0]})."
23 | assert W % self.patch_size[1] == 0, f"Input image width ({W}) is not a multiple of patch size ({self.patch_size[1]})."
24 | x = self.proj(x)
25 | pos = self.position_getter(B, x.size(2), x.size(3), x.device)
26 | if self.flatten:
27 | x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
28 | x = self.norm(x)
29 | return x, pos
30 |
31 |
32 | class ManyAR_PatchEmbed (PatchEmbed):
33 | """ Handle images with non-square aspect ratio.
34 | All images in the same batch have the same aspect ratio.
35 | true_shape = [(height, width) ...] indicates the actual shape of each image.
36 | """
37 |
38 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True):
39 | self.embed_dim = embed_dim
40 | super().__init__(img_size, patch_size, in_chans, embed_dim, norm_layer, flatten)
41 |
42 | def forward(self, img, true_shape):
43 | B, C, H, W = img.shape
44 | assert W >= H, f'img should be in landscape mode, but got {W=} {H=}'
45 | assert H % self.patch_size[0] == 0, f"Input image height ({H}) is not a multiple of patch size ({self.patch_size[0]})."
46 | assert W % self.patch_size[1] == 0, f"Input image width ({W}) is not a multiple of patch size ({self.patch_size[1]})."
47 | assert true_shape.shape == (B, 2), f"true_shape has the wrong shape={true_shape.shape}"
48 |
49 | # size expressed in tokens
50 | W //= self.patch_size[0]
51 | H //= self.patch_size[1]
52 | n_tokens = H * W
53 |
54 | height, width = true_shape.T
55 | is_landscape = (width >= height)
56 | is_portrait = ~is_landscape
57 |
58 | # allocate result
59 | x = img.new_zeros((B, n_tokens, self.embed_dim))
60 | pos = img.new_zeros((B, n_tokens, 2), dtype=torch.int64)
61 |
62 | # linear projection, transposed if necessary
63 | x[is_landscape] = self.proj(img[is_landscape]).permute(0, 2, 3, 1).flatten(1, 2).float()
64 | x[is_portrait] = self.proj(img[is_portrait].swapaxes(-1, -2)).permute(0, 2, 3, 1).flatten(1, 2).float()
65 |
66 | pos[is_landscape] = self.position_getter(1, H, W, pos.device)
67 | pos[is_portrait] = self.position_getter(1, W, H, pos.device)
68 |
69 | x = self.norm(x)
70 | return x, pos
71 |
--------------------------------------------------------------------------------
/third_party/dust3r/dust3r/post_process.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3 | #
4 | # --------------------------------------------------------
5 | # utilities for interpreting the DUST3R output
6 | # --------------------------------------------------------
7 | import numpy as np
8 | import torch
9 | from dust3r.utils.geometry import xy_grid
10 |
11 |
12 | def estimate_focal_knowing_depth(pts3d, pp, focal_mode='median', min_focal=0.5, max_focal=3.5):
13 | """ Reprojection method, for when the absolute depth is known:
14 | 1) estimate the camera focal using a robust estimator
15 | 2) reproject points onto true rays, minimizing a certain error
16 | """
17 | B, H, W, THREE = pts3d.shape
18 | assert THREE == 3
19 |
20 | # centered pixel grid
21 | pixels = xy_grid(W, H, device=pts3d.device).view(1, -1, 2) - pp.view(-1, 1, 2) # B,HW,2
22 | pts3d = pts3d.flatten(1, 2) # (B, HW, 3)
23 |
24 | if focal_mode == 'median':
25 | with torch.no_grad():
26 | # direct estimation of focal
27 | u, v = pixels.unbind(dim=-1)
28 | x, y, z = pts3d.unbind(dim=-1)
29 | fx_votes = (u * z) / x
30 | fy_votes = (v * z) / y
31 |
32 | # assume square pixels, hence same focal for X and Y
33 | f_votes = torch.cat((fx_votes.view(B, -1), fy_votes.view(B, -1)), dim=-1)
34 | focal = torch.nanmedian(f_votes, dim=-1).values
35 |
36 | elif focal_mode == 'weiszfeld':
37 | # init focal with l2 closed form
38 | # we try to find focal = argmin Sum | pixel - focal * (x,y)/z|
39 | xy_over_z = (pts3d[..., :2] / pts3d[..., 2:3]).nan_to_num(posinf=0, neginf=0) # homogeneous (x,y,1)
40 |
41 | dot_xy_px = (xy_over_z * pixels).sum(dim=-1)
42 | dot_xy_xy = xy_over_z.square().sum(dim=-1)
43 |
44 | focal = dot_xy_px.mean(dim=1) / dot_xy_xy.mean(dim=1)
45 |
46 | # iterative re-weighted least-squares
47 | for iter in range(10):
48 | # re-weighting by inverse of distance
49 | dis = (pixels - focal.view(-1, 1, 1) * xy_over_z).norm(dim=-1)
50 | # print(dis.nanmean(-1))
51 | w = dis.clip(min=1e-8).reciprocal()
52 | # update the scaling with the new weights
53 | focal = (w * dot_xy_px).mean(dim=1) / (w * dot_xy_xy).mean(dim=1)
54 | else:
55 | raise ValueError(f'bad {focal_mode=}')
56 |
57 | focal_base = max(H, W) / (2 * np.tan(np.deg2rad(60) / 2)) # size / 1.1547005383792515
58 | focal = focal.clip(min=min_focal*focal_base, max=max_focal*focal_base)
59 | # print(focal)
60 | return focal
61 |
--------------------------------------------------------------------------------
/third_party/dust3r/dust3r/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3 |
--------------------------------------------------------------------------------
/third_party/dust3r/dust3r/utils/device.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3 | #
4 | # --------------------------------------------------------
5 | # utilitary functions for DUSt3R
6 | # --------------------------------------------------------
7 | import numpy as np
8 | import torch
9 |
10 |
11 | def todevice(batch, device, callback=None, non_blocking=False):
12 | ''' Transfer some variables to another device (i.e. GPU, CPU:torch, CPU:numpy).
13 |
14 | batch: list, tuple, dict of tensors or other things
15 | device: pytorch device or 'numpy'
16 | callback: function that would be called on every sub-elements.
17 | '''
18 | if callback:
19 | batch = callback(batch)
20 |
21 | if isinstance(batch, dict):
22 | return {k: todevice(v, device) for k, v in batch.items()}
23 |
24 | if isinstance(batch, (tuple, list)):
25 | return type(batch)(todevice(x, device) for x in batch)
26 |
27 | x = batch
28 | if device == 'numpy':
29 | if isinstance(x, torch.Tensor):
30 | x = x.detach().cpu().numpy()
31 | elif x is not None:
32 | if isinstance(x, np.ndarray):
33 | x = torch.from_numpy(x)
34 | if torch.is_tensor(x):
35 | x = x.to(device, non_blocking=non_blocking)
36 | return x
37 |
38 |
39 | to_device = todevice # alias
40 |
41 |
42 | def to_numpy(x): return todevice(x, 'numpy')
43 | def to_cpu(x): return todevice(x, 'cpu')
44 | def to_cuda(x): return todevice(x, 'cuda')
45 |
46 |
47 | def collate_with_cat(whatever, lists=False):
48 | if isinstance(whatever, dict):
49 | return {k: collate_with_cat(vals, lists=lists) for k, vals in whatever.items()}
50 |
51 | elif isinstance(whatever, (tuple, list)):
52 | if len(whatever) == 0:
53 | return whatever
54 | elem = whatever[0]
55 | T = type(whatever)
56 |
57 | if elem is None:
58 | return None
59 | if isinstance(elem, (bool, float, int, str)):
60 | return whatever
61 | if isinstance(elem, tuple):
62 | return T(collate_with_cat(x, lists=lists) for x in zip(*whatever))
63 | if isinstance(elem, dict):
64 | return {k: collate_with_cat([e[k] for e in whatever], lists=lists) for k in elem}
65 |
66 | if isinstance(elem, torch.Tensor):
67 | return listify(whatever) if lists else torch.cat(whatever)
68 | if isinstance(elem, np.ndarray):
69 | return listify(whatever) if lists else torch.cat([torch.from_numpy(x) for x in whatever])
70 |
71 | # otherwise, we just chain lists
72 | return sum(whatever, T())
73 |
74 |
75 | def listify(elems):
76 | return [x for e in elems for x in e]
77 |
--------------------------------------------------------------------------------
/third_party/dust3r/dust3r/utils/misc.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3 | #
4 | # --------------------------------------------------------
5 | # utilitary functions for DUSt3R
6 | # --------------------------------------------------------
7 | import torch
8 |
9 |
10 | def fill_default_args(kwargs, func):
11 | import inspect # a bit hacky but it works reliably
12 | signature = inspect.signature(func)
13 |
14 | for k, v in signature.parameters.items():
15 | if v.default is inspect.Parameter.empty:
16 | continue
17 | kwargs.setdefault(k, v.default)
18 |
19 | return kwargs
20 |
21 |
22 | def freeze_all_params(modules):
23 | for module in modules:
24 | try:
25 | for n, param in module.named_parameters():
26 | param.requires_grad = False
27 | except AttributeError:
28 | # module is directly a parameter
29 | module.requires_grad = False
30 |
31 |
32 | def is_symmetrized(gt1, gt2):
33 | x = gt1['instance']
34 | y = gt2['instance']
35 | if len(x) == len(y) and len(x) == 1:
36 | return False # special case of batchsize 1
37 | ok = True
38 | for i in range(0, len(x), 2):
39 | ok = ok and (x[i] == y[i+1]) and (x[i+1] == y[i])
40 | return ok
41 |
42 |
43 | def flip(tensor):
44 | """ flip so that tensor[0::2] <=> tensor[1::2] """
45 | return torch.stack((tensor[1::2], tensor[0::2]), dim=1).flatten(0, 1)
46 |
47 |
48 | def interleave(tensor1, tensor2):
49 | res1 = torch.stack((tensor1, tensor2), dim=1).flatten(0, 1)
50 | res2 = torch.stack((tensor2, tensor1), dim=1).flatten(0, 1)
51 | return res1, res2
52 |
53 |
54 | def transpose_to_landscape(head, activate=True):
55 | """ Predict in the correct aspect-ratio,
56 | then transpose the result in landscape
57 | and stack everything back together.
58 | """
59 | def wrapper_no(decout, true_shape):
60 | B = len(true_shape)
61 | assert true_shape[0:1].allclose(true_shape), 'true_shape must be all identical'
62 | H, W = true_shape[0].cpu().tolist()
63 | res = head(decout, (H, W))
64 | return res
65 |
66 | def wrapper_yes(decout, true_shape):
67 | B = len(true_shape)
68 | # by definition, the batch is in landscape mode so W >= H
69 | H, W = int(true_shape.min()), int(true_shape.max())
70 |
71 | height, width = true_shape.T
72 | is_landscape = (width >= height)
73 | is_portrait = ~is_landscape
74 |
75 | # true_shape = true_shape.cpu()
76 | if is_landscape.all():
77 | return head(decout, (H, W))
78 | if is_portrait.all():
79 | return transposed(head(decout, (W, H)))
80 |
81 | # batch is a mix of both portraint & landscape
82 | def selout(ar): return [d[ar] for d in decout]
83 | l_result = head(selout(is_landscape), (H, W))
84 | p_result = transposed(head(selout(is_portrait), (W, H)))
85 |
86 | # allocate full result
87 | result = {}
88 | for k in l_result | p_result:
89 | x = l_result[k].new(B, *l_result[k].shape[1:])
90 | x[is_landscape] = l_result[k]
91 | x[is_portrait] = p_result[k]
92 | result[k] = x
93 |
94 | return result
95 |
96 | return wrapper_yes if activate else wrapper_no
97 |
98 |
99 | def transposed(dic):
100 | return {k: v.swapaxes(1, 2) for k, v in dic.items()}
101 |
102 |
103 | def invalid_to_nans(arr, valid_mask, ndim=999):
104 | if valid_mask is not None:
105 | arr = arr.clone()
106 | arr[~valid_mask] = float('nan')
107 | if arr.ndim > ndim:
108 | arr = arr.flatten(-2 - (arr.ndim - ndim), -2)
109 | return arr
110 |
111 |
112 | def invalid_to_zeros(arr, valid_mask, ndim=999):
113 | if valid_mask is not None:
114 | arr = arr.clone()
115 | arr[~valid_mask] = 0
116 | nnz = valid_mask.view(len(valid_mask), -1).sum(1)
117 | else:
118 | nnz = arr.numel() // len(arr) if len(arr) else 0 # number of point per image
119 | if arr.ndim > ndim:
120 | arr = arr.flatten(-2 - (arr.ndim - ndim), -2)
121 | return arr, nnz
122 |
--------------------------------------------------------------------------------
/third_party/dust3r/dust3r/utils/path_to_croco.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3 | #
4 | # --------------------------------------------------------
5 | # CroCo submodule import
6 | # --------------------------------------------------------
7 |
8 | import sys
9 | import os.path as path
10 | HERE_PATH = path.normpath(path.dirname(__file__))
11 | CROCO_REPO_PATH = path.normpath(path.join(HERE_PATH, '../../croco'))
12 | CROCO_MODELS_PATH = path.join(CROCO_REPO_PATH, 'models')
13 | # check the presence of models directory in repo to be sure its cloned
14 | if path.isdir(CROCO_MODELS_PATH):
15 | # workaround for sibling import
16 | sys.path.insert(0, CROCO_REPO_PATH)
17 | else:
18 | raise ImportError(f"croco is not initialized, could not find: {CROCO_MODELS_PATH}.\n "
19 | "Did you forget to run 'git submodule update --init --recursive' ?")
20 |
--------------------------------------------------------------------------------
/utils/adain.py:
--------------------------------------------------------------------------------
1 | def masked_adain(content_feat, style_feat, content_mask, style_mask):
2 | assert (content_feat.size()[:2] == style_feat.size()[:2])
3 | size = content_feat.size()
4 | style_mean, style_std = calc_mean_std(style_feat, mask=style_mask)
5 | content_mean, content_std = calc_mean_std(content_feat, mask=content_mask)
6 | normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
7 | style_normalized_feat = normalized_feat * style_std.expand(size) + style_mean.expand(size)
8 | return content_feat * (1 - content_mask) + style_normalized_feat * content_mask
9 |
10 |
11 | def adain(content_feat, style_feat):
12 | assert (content_feat.size()[:2] == style_feat.size()[:2])
13 | size = content_feat.size()
14 | style_mean, style_std = calc_mean_std(style_feat)
15 | content_mean, content_std = calc_mean_std(content_feat)
16 | normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
17 | return normalized_feat * style_std.expand(size) + style_mean.expand(size)
18 |
19 |
20 | def calc_mean_std(feat, eps=1e-5, mask=None):
21 | # eps is a small value added to the variance to avoid divide-by-zero.
22 | size = feat.size()
23 | if len(size) == 2:
24 | return calc_mean_std_2d(feat, eps, mask)
25 |
26 | assert (len(size) == 3)
27 | C = size[0]
28 | if mask is not None:
29 | feat_var = feat.view(C, -1)[:, mask.view(-1) == 1].var(dim=1) + eps
30 | feat_std = feat_var.sqrt().view(C, 1, 1)
31 | feat_mean = feat.view(C, -1)[:, mask.view(-1) == 1].mean(dim=1).view(C, 1, 1)
32 | else:
33 | feat_var = feat.view(C, -1).var(dim=1) + eps
34 | feat_std = feat_var.sqrt().view(C, 1, 1)
35 | feat_mean = feat.view(C, -1).mean(dim=1).view(C, 1, 1)
36 |
37 | return feat_mean, feat_std
38 |
39 |
40 | def calc_mean_std_2d(feat, eps=1e-5, mask=None):
41 | # eps is a small value added to the variance to avoid divide-by-zero.
42 | size = feat.size()
43 | assert (len(size) == 2)
44 | C = size[0]
45 | if mask is not None:
46 | feat_var = feat.view(C, -1)[:, mask.view(-1) == 1].var(dim=1) + eps
47 | feat_std = feat_var.sqrt().view(C, 1)
48 | feat_mean = feat.view(C, -1)[:, mask.view(-1) == 1].mean(dim=1).view(C, 1)
49 | else:
50 | feat_var = feat.view(C, -1).var(dim=1) + eps
51 | feat_std = feat_var.sqrt().view(C, 1)
52 | feat_mean = feat.view(C, -1).mean(dim=1).view(C, 1)
53 |
54 | return feat_mean, feat_std
55 |
--------------------------------------------------------------------------------
/utils/logging.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | # configure once, at import time
4 | logger = logging.getLogger("ReStyle3D")
5 | logger.setLevel(logging.INFO)
6 |
7 | handler = logging.StreamHandler()
8 | handler.setFormatter(logging.Formatter(
9 | "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
10 | ))
11 | logger.addHandler(handler)
--------------------------------------------------------------------------------
/viewformer/__init__.py:
--------------------------------------------------------------------------------
1 | from .viewtransfer_pipeline import ViewTransferSDXLPipeline
2 | from .UNet2DConditionalModel import UNet2DConditionModel
3 |
4 |
--------------------------------------------------------------------------------
/viewformer/image_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from PIL import Image
3 | import cv2
4 | import torch
5 | from torchvision import transforms
6 |
7 | def load_and_resize_image(path, size=(1024, 1024)):
8 | img = Image.open(path).convert("RGB").resize(size)
9 | return transforms.ToTensor()(img).unsqueeze(0)
10 |
11 |
12 | def match_histograms_masked_full(source_img, reference_img, mask):
13 | """
14 | Match histograms based on masked region but apply to whole image
15 |
16 | Parameters:
17 | source_img: numpy array (H x W x 3) - Source image to be modified
18 | reference_img: numpy array (H x W x 3) - Reference image to match
19 | mask: numpy array (H x W x 3) - RGB mask
20 | """
21 | # Convert to float32
22 | source_float = source_img.astype(np.float32) / 255.0
23 | reference_float = reference_img.astype(np.float32) / 255.0
24 |
25 | # Initialize output image
26 | matched = source_float.copy()
27 |
28 | # Use first channel of RGB mask and ensure it's binary
29 | mask_channel = mask[:,:,0] if len(mask.shape) == 3 else mask
30 | if mask_channel.dtype != np.uint8:
31 | mask_channel = mask_channel.astype(np.uint8) * 255
32 | _, mask_binary = cv2.threshold(mask_channel, 127, 255, cv2.THRESH_BINARY)
33 | mask_binary = cv2.bitwise_not(mask_binary) # Invert the mask
34 |
35 | # Create boolean mask
36 | bool_mask = mask_binary > 0
37 |
38 | for i in range(3):
39 | # Get masked pixels for computing transformation
40 | source_channel = source_float[:,:,i]
41 | reference_channel = reference_float[:,:,i]
42 |
43 | # Apply boolean mask correctly
44 | source_masked = source_channel[bool_mask]
45 | reference_masked = reference_channel[bool_mask]
46 |
47 | if len(source_masked) > 0 and len(reference_masked) > 0:
48 | # Use more bins for better precision
49 | nbins = 256
50 | source_hist, bin_edges = np.histogram(source_masked, nbins, [0, 1])
51 | reference_hist, _ = np.histogram(reference_masked, nbins, [0, 1])
52 |
53 | # Add small epsilon to avoid division by zero
54 | source_hist = source_hist + 1e-8
55 | reference_hist = reference_hist + 1e-8
56 |
57 | # Calculate normalized cumulative histograms
58 | source_cdf = source_hist.cumsum() / source_hist.sum()
59 | reference_cdf = reference_hist.cumsum() / reference_hist.sum()
60 |
61 | # Create interpolation function
62 | bins = np.linspace(0, 1, nbins)
63 | lookup_table = np.interp(source_cdf, reference_cdf, bins)
64 |
65 | # Apply transformation to entire channel
66 | channel_values = source_float[:,:,i] * (nbins-1)
67 | channel_indices = channel_values.astype(int)
68 | matched[:,:,i] = lookup_table[channel_indices]
69 |
70 | # Ensure output is in valid range
71 | matched = np.clip(matched, 0, 1)
72 | matched = (matched * 255).astype(np.uint8)
73 |
74 | return Image.fromarray(matched)
--------------------------------------------------------------------------------