├── .gitignore ├── CONTRIBUTING.md ├── LICENSE.txt ├── README.md ├── assets └── teaser.jpg ├── requirements.txt ├── restyle_image.py ├── restyle_scene.py ├── scene_transfer ├── __init__.py ├── attention_utils.py ├── config.py ├── ddpm_inversion.py ├── depth_estimator.py ├── image_utils.py ├── latent_utils.py ├── model_utils.py ├── sd15_transfer.py ├── sdxl_refiner.py └── semantic_matching.py ├── scene_transfer_model.py ├── scripts ├── download_data.py ├── download_data.sh └── download_weights.sh ├── third_party ├── depth_anything_v2 │ ├── dinov2.py │ ├── dinov2_layers │ │ ├── __init__.py │ │ ├── attention.py │ │ ├── block.py │ │ ├── drop_path.py │ │ ├── layer_scale.py │ │ ├── mlp.py │ │ ├── patch_embed.py │ │ └── swiglu_ffn.py │ ├── dpt.py │ └── util │ │ ├── blocks.py │ │ └── transform.py └── dust3r │ ├── LICENSE │ ├── croco │ ├── LICENSE │ ├── NOTICE │ ├── README.MD │ ├── assets │ │ ├── Chateau1.png │ │ ├── Chateau2.png │ │ └── arch.jpg │ ├── croco-stereo-flow-demo.ipynb │ ├── datasets │ │ ├── __init__.py │ │ ├── crops │ │ │ ├── README.MD │ │ │ └── extract_crops_from_images.py │ │ ├── habitat_sim │ │ │ ├── README.MD │ │ │ ├── __init__.py │ │ │ ├── generate_from_metadata.py │ │ │ ├── generate_from_metadata_files.py │ │ │ ├── generate_multiview_images.py │ │ │ ├── multiview_habitat_sim_generator.py │ │ │ ├── pack_metadata_files.py │ │ │ └── paths.py │ │ ├── pairs_dataset.py │ │ └── transforms.py │ ├── demo.py │ ├── interactive_demo.ipynb │ ├── models │ │ ├── blocks.py │ │ ├── criterion.py │ │ ├── croco.py │ │ ├── croco_downstream.py │ │ ├── curope │ │ │ ├── __init__.py │ │ │ ├── build │ │ │ │ └── temp.linux-x86_64-cpython-38 │ │ │ │ │ └── build.ninja │ │ │ ├── curope.cpp │ │ │ ├── curope2d.py │ │ │ ├── kernels.cu │ │ │ └── setup.py │ │ ├── dpt_block.py │ │ ├── head_downstream.py │ │ ├── masking.py │ │ └── pos_embed.py │ ├── pretrain.py │ ├── stereoflow │ │ ├── README.MD │ │ ├── augmentor.py │ │ ├── criterion.py │ │ ├── datasets_flow.py │ │ ├── datasets_stereo.py │ │ ├── download_model.sh │ │ ├── engine.py │ │ ├── test.py │ │ └── train.py │ └── utils │ │ └── misc.py │ ├── datasets_preprocess │ ├── path_to_root.py │ └── preprocess_co3d.py │ └── dust3r │ ├── __init__.py │ ├── cloud_opt │ ├── __init__.py │ ├── base_opt.py │ ├── commons.py │ ├── init_im_poses.py │ ├── optimizer.py │ └── pair_viewer.py │ ├── datasets │ ├── __init__.py │ ├── base │ │ ├── __init__.py │ │ ├── base_stereo_view_dataset.py │ │ ├── batched_sampler.py │ │ └── easy_dataset.py │ ├── co3d.py │ └── utils │ │ ├── __init__.py │ │ ├── cropping.py │ │ └── transforms.py │ ├── heads │ ├── __init__.py │ ├── dpt_head.py │ ├── linear_head.py │ └── postprocess.py │ ├── image_pairs.py │ ├── inference.py │ ├── losses.py │ ├── model.py │ ├── optim_factory.py │ ├── patch_embed.py │ ├── post_process.py │ ├── utils │ ├── __init__.py │ ├── device.py │ ├── geometry.py │ ├── image.py │ ├── misc.py │ └── path_to_croco.py │ └── viz.py ├── utils ├── adain.py ├── logging.py └── proj_utils.py └── viewformer ├── UNet2DConditionalModel.py ├── __init__.py ├── image_utils.py ├── sdxl.py ├── stylelifter.py └── viewtransfer_pipeline.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | output 3 | checkpoints 4 | data 5 | demo -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to contribute 2 | 3 | We'd love to accept your patches and contributions to this project. 4 | 5 | ## Before you begin 6 | 7 | ### Sign our Contributor License Agreement 8 | 9 | Contributions to this project must be accompanied by a 10 | [Contributor License Agreement](https://cla.developers.google.com/about) (CLA). 11 | You (or your employer) retain the copyright to your contribution; this simply 12 | gives us permission to use and redistribute your contributions as part of the 13 | project. 14 | 15 | If you or your current employer have already signed the Google CLA (even if it 16 | was for a different project), you probably don't need to do it again. 17 | 18 | Visit to see your current agreements or to 19 | sign a new one. 20 | 21 | ### Review our community guidelines 22 | 23 | This project follows 24 | [Google's Open Source Community Guidelines](https://opensource.google/conduct/). 25 | 26 | ## Contribution process 27 | 28 | ### Code reviews 29 | 30 | All submissions, including submissions by project members, require review. We 31 | use GitHub pull requests for this purpose. Consult 32 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 33 | information on using pull requests. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 🎨 ReStyle3D: Scene-Level Appearance Transfer with Semantic Correspondences 2 | 3 | ### ACM SIGGRAPH 2025 4 | 5 | [![ProjectPage](https://img.shields.io/badge/Project_Page-ReStyle3D-blue)](https://restyle3d.github.io/) [![arXiv](https://img.shields.io/badge/arXiv-2502.10377-blue?logo=arxiv&color=%23B31B1B)](https://arxiv.org/abs/2502.10377) [![Hugging Face (LCM) Space](https://img.shields.io/badge/🤗%20Hugging%20Face%20-Space-yellow)](https://huggingface.co/gradient-spaces/ReStyle3D) [![License](https://img.shields.io/badge/License-Apache--2.0-929292)](https://www.apache.org/licenses/LICENSE-2.0) 6 | 7 | Official implementation of the paper titled "Scene-level Appearance Transfer with Semantic Correspondences". 8 | 9 | [Liyuan Zhu](https://www.zhuliyuan.net/)1, 10 | [Shengqu Cai](https://primecai.github.io/)1,\*, 11 | [Shengyu Huang](https://shengyuh.github.io/)2,\*, 12 | [Gordon Wetzstein](https://stanford.edu/~gordonwz/)1, 13 | [Naji Khosravan](https://www.najikhosravan.com/)3, 14 | [Iro Armeni](https://ir0.github.io/)1 15 | 16 | 17 | 18 | 1Stanford University, 2NVIDIA Research, 3Zillow Group | \* denotes equal contribution 19 | 20 | 21 | ```bibtex 22 | @inproceedings{zhu2025_restyle3d, 23 | author = {Liyuan Zhu and Shengqu Cai and Shengyu Huang and Gordon Wetzstein and Naji Khosravan and Iro Armeni}, 24 | title = {Scene-level Appearance Transfer with Semantic Correspondences}, 25 | booktitle = {ACM SIGGRAPH 2025 Conference Papers}, 26 | year = {2025}, 27 | } 28 | ``` 29 | 30 | We introduce ReStyle3D, a novel framework for scene-level appearance 31 | transfer from a single style image to a real-world scene represented by 32 | multiple views. This method combines explicit semantic correspondences 33 | with multi-view consistency to achieve precise and coherent stylization. 34 |

35 | 36 | 37 | 38 |

39 | 40 | 41 | ## 🛠️ Setup 42 | ### ✅ Tested Environments 43 | - Ubuntu 22.04 LTS, Python 3.10.15, CUDA 12.2, GeForce RTX 4090/3090 44 | 45 | - CentOS Linux 7, Python 3.12.1, CUDA 12.4, NVIDIA A100 46 | 47 | ### 📦 Repository 48 | ``` 49 | git clone git@github.com:GradientSpaces/ReStyle3D.git 50 | cd ReStyle3D 51 | ``` 52 | 53 | ### 💻 Installation 54 | ``` 55 | conda create -n restyle3d python=3.10 56 | conda activate restyle3d 57 | pip install -r requirements.txt 58 | ``` 59 | 60 | ### 📦 Pretrained Checkpoints 61 | Download the pretrained models by running: 62 | ``` 63 | bash scripts/download_weights.sh 64 | ``` 65 | 66 | 67 | ## 🚀 Usage 68 | 69 | We download our dataset: 70 | ``` 71 | bash scripts/download_data.sh 72 | ``` 73 | 74 | ### 🎮 Demo (Single-view) 75 | We include 3 demo images to run semantic appearance transfer: 76 | ``` 77 | python restyle_image.py 78 | ``` 79 | 80 | 81 | 82 | ### 🎨 Stylizing Multi-view Scenes 83 | To run on a single scene and style: 84 | ``` 85 | python restyle_scene.py \ 86 | --scene_path demo/scene_transfer/bedroom/ \ 87 | --scene_type bedroom \ 88 | --style_path demo/design_styles/bedroom/pexels-itsterrymag-2631746 89 | ``` 90 | 91 | ### 📂 Dataset: SceneTransfer 92 | We organize the data into two components: 93 | 94 | 1. Interior Scenes: 95 | Multi-view real-world scans with aligned images, depth, and semantic segmentations. 96 | ``` 97 | 📁 data/ 98 | └── interiors/ 99 | ├── bedroom/ 100 | │ ├── 0/ 101 | │ │ ├── images/ # multi-view RGB images 102 | │ │ ├── depth/ # depth maps 103 | │ │ └── seg_dict/ # semantic segmentation dictionaries 104 | │ └── 1/ 105 | │ └── ... 106 | ├── living_room/ 107 | └── kitchen/ 108 | ``` 109 | 2. Design Styles: 110 | Style examplars with precomputed semantic segmentation. 111 | ``` 112 | 📁 data/ 113 | └── design_styles/ 114 | ├── bedroom/ 115 | │ └── pexels-itsterrymag-2631746/ 116 | │ ├── image.jpg # style reference image 117 | │ ├── seg_dict.pth # semantic segmentation dictionary 118 | │ └── seg.png # segmentation visualization 119 | ├── living_room/ 120 | └── kitchen/ 121 | ``` 122 | 123 | 124 | 125 | 126 | 127 | ## 🚧 TODO 128 | - [ ] Release full dataset 129 | - [ ] Release evaluation code 130 | - [ ] Customize dataset 131 | 132 | 133 | ## 🙏 Acknowledgement 134 | Our codebase is built on top of the following works: 135 | - [Cross-image-attention](https://github.com/garibida/cross-image-attention) 136 | - [ODISE](https://github.com/NVlabs/ODISE) 137 | - [ViewCrafter](https://github.com/Drexubery/ViewCrafter) 138 | - [GenWarp](https://github.com/sony/genwarp) 139 | - [DUSt3R](https://github.com/naver/dust3r) 140 | 141 | We appreciate the open-source efforts from the authors. 142 | 143 | ## 📫 Contact 144 | If you encounter any issues or have questions, feel free to reach out: [Liyuan Zhu](liyzhu@stanford.edu). 145 | 146 | 147 | 148 | -------------------------------------------------------------------------------- /assets/teaser.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GradientSpaces/ReStyle3D/5a6440eeee53783530dbbda7c26898b8b4c351da/assets/teaser.jpg -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # Use PyTorch CUDA wheels in addition to PyPI 2 | --extra-index-url https://download.pytorch.org/whl/cu121 3 | 4 | torch==2.5.0 5 | torchvision==0.20.0 6 | torchaudio==2.5.0 7 | 8 | # Everything below comes from the default PyPI index 9 | diffusers==0.31.0 10 | xformers==0.0.28.post2 11 | transformers==4.43.2 12 | accelerate==1.0.1 13 | einops 14 | roma 15 | open3d 16 | scikit-learn 17 | pyrallis 18 | jaxtyping 19 | opencv-python 20 | matplotlib 21 | huggingface_hub[cli] 22 | git+https://github.com/pesser/splatting 23 | -------------------------------------------------------------------------------- /restyle_scene.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | import argparse 4 | 5 | from restyle_image import generate_single_view_stylized 6 | from viewformer.stylelifter import StyleLifter 7 | from utils.logging import logger 8 | 9 | def restyle_scene(scene_path: str, style_path: str, scene_type: str, output_root: str, downsample: int): 10 | scene_path = Path(scene_path) 11 | style_path = Path(style_path) 12 | scene_id = scene_path.parts[-1] 13 | style_id = style_path.parts[-1] 14 | 15 | # Get structure image 16 | image_files = sorted((scene_path / "images").glob("*.jpg")) 17 | if not image_files: 18 | logger.error(f"No images found in {scene_path / 'images'}") 19 | return 20 | struct_img = image_files[0] 21 | frame_name = struct_img.stem 22 | 23 | struct_seg = scene_path / "seg_dict" / f"{frame_name}.pth" 24 | 25 | # Look for style image.* (jpg or png) 26 | image_candidates = list(style_path.glob("image.*")) 27 | if not image_candidates: 28 | logger.error(f"No image.* found in {style_path}") 29 | return 30 | style_img = image_candidates[0] 31 | 32 | style_seg = style_path / "seg_dict.pth" 33 | 34 | # Check required files 35 | for p in [struct_img, struct_seg, style_img, style_seg]: 36 | if not p.exists(): 37 | logger.error(f"Missing required input: {p}") 38 | return 39 | 40 | # Step 1: generate single-view stylized image 41 | logger.info(f"Generating single-view stylization: {style_id} → {scene_path}") 42 | stylized_2d_output = Path("output/2d_results") / f"{scene_id}_style_{style_id}" 43 | generate_single_view_stylized( 44 | struct_img_path=struct_img, 45 | style_img_path=style_img, 46 | struct_seg_dict=struct_seg, 47 | style_seg_dict=style_seg, 48 | output_path=stylized_2d_output / "intermediate", 49 | scene_type=scene_type, 50 | ) 51 | 52 | # Step 2: multi-view lifting 53 | logger.info(f"Starting multi-view style lifting...") 54 | stylelifter = StyleLifter(ckpt_path="checkpoints") 55 | output_3d_path = Path(output_root) / f"{scene_type}_{scene_id}" / style_id 56 | 57 | stylelifter( 58 | src_scene=str(scene_path), 59 | stylized_path=stylized_2d_output / "stylized.png", 60 | output_path=output_3d_path, 61 | downsample=downsample 62 | ) 63 | 64 | logger.info(f"✅ Scene stylization complete for {scene_type}/{scene_id} using {style_id}.") 65 | logger.info(f"🖼️ Results saved to: {output_3d_path.resolve()}") 66 | 67 | 68 | 69 | if __name__ == "__main__": 70 | parser = argparse.ArgumentParser(description="ReStyle3D: Scene Stylization Pipeline") 71 | parser.add_argument("--scene_path", type=str, required=True, help="Path to scene directory (e.g., data/interiors/bedroom/0)") 72 | parser.add_argument("--style_path", type=str, required=True, help="Path to style folder (e.g., data/design_styles_v2/bedroom/pexels-xxx)") 73 | parser.add_argument("--scene_type", type=str, required=True, help="Scene type (e.g., bedroom, kitchen, living_room)") 74 | parser.add_argument("--output_root", type=str, default="output/demo_restyle3d", help="Root path to save results") 75 | parser.add_argument("--downsample", type=int, default=4, help="Downsampling stride for multi-view processing (default: 4)") 76 | 77 | args = parser.parse_args() 78 | restyle_scene(args.scene_path, args.style_path, args.scene_type, args.output_root, args.downsample) 79 | 80 | -------------------------------------------------------------------------------- /scene_transfer/__init__.py: -------------------------------------------------------------------------------- 1 | OUT_INDEX = 0 2 | STYLE_INDEX = 1 3 | STRUCT_INDEX = 2 4 | -------------------------------------------------------------------------------- /scene_transfer/attention_utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from scene_transfer import OUT_INDEX 4 | 5 | def should_mix_keys_and_values(model, hidden_states: torch.Tensor) -> bool: 6 | """ Verify whether we should perform the mixing in the current timestep. """ 7 | is_in_32_timestep_range = ( 8 | model.config.cross_attn_32_range.start <= model.step < model.config.cross_attn_32_range.end 9 | ) 10 | is_in_64_timestep_range = ( 11 | model.config.cross_attn_64_range.start <= model.step < model.config.cross_attn_64_range.end 12 | ) 13 | is_hidden_states_32_square = (hidden_states.shape[1] == 32 ** 2) 14 | is_hidden_states_64_square = (hidden_states.shape[1] == 64 ** 2) 15 | should_mix = (is_in_32_timestep_range and is_hidden_states_32_square) or \ 16 | (is_in_64_timestep_range and is_hidden_states_64_square) 17 | return should_mix 18 | 19 | 20 | def compute_scaled_dot_product_attention(Q, K, V, edit_map=False, is_cross=False, contrast_strength=1.0, masks=None): 21 | """ Compute the scale dot product attention, potentially with our contrasting operation. """ 22 | cost_volume = Q @ K.transpose(-2, -1) / math.sqrt(Q.size(-1)) 23 | # attn_weight = torch.softmax((Q @ K.transpose(-2, -1) / math.sqrt(Q.size(-1))), dim=-1) 24 | if masks is not None: 25 | mask_64, mask_32 = masks 26 | if (Q.shape[-2] == 32 ** 2): 27 | cost_volume[OUT_INDEX] = cost_volume[OUT_INDEX] * mask_32 28 | 29 | if (Q.shape[-2] == 64 ** 2): 30 | cost_volume[OUT_INDEX] = cost_volume[OUT_INDEX] * mask_64 31 | 32 | attn_weight = torch.softmax(cost_volume, dim=-1) 33 | if edit_map and not is_cross: 34 | attn_weight[OUT_INDEX] = torch.stack([ 35 | torch.clip(enhance_tensor(attn_weight[OUT_INDEX][head_idx], contrast_factor=contrast_strength), 36 | min=0.0, max=1.0) 37 | for head_idx in range(attn_weight.shape[1]) 38 | ]) 39 | return attn_weight @ V, attn_weight 40 | 41 | def enhance_tensor(tensor: torch.Tensor, contrast_factor: float = 1.67) -> torch.Tensor: 42 | """ Compute the attention map contrasting. """ 43 | adjusted_tensor = (tensor - tensor.mean(dim=-1)) * contrast_factor + tensor.mean(dim=-1) 44 | return adjusted_tensor -------------------------------------------------------------------------------- /scene_transfer/config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from pathlib import Path 3 | from typing import NamedTuple, Optional 4 | 5 | 6 | class Range(NamedTuple): 7 | start: int 8 | end: int 9 | 10 | 11 | @dataclass 12 | class RunConfig: 13 | # Appearance image path 14 | app_image_path: Path 15 | # Struct image path 16 | struct_image_path: Path 17 | # Domain name (e.g., buildings, animals) 18 | domain_name: Optional[str] = None 19 | # Output path 20 | output_path: Path = Path('./output/test') 21 | # Random seed 22 | seed: int = 42 23 | # Input prompt for inversion (will use domain name as default) 24 | prompt: Optional[str] = None 25 | # Number of timesteps 26 | num_timesteps: int = 120 27 | # Whether to use a binary mask for performing AdaIN 28 | use_masked_adain: bool = False 29 | # Timesteps to apply cross-attention on 64x64 layers 30 | cross_attn_64_range: Range = Range(start=10, end=70) 31 | # Timesteps to apply cross-attention on 32x32 layers 32 | cross_attn_32_range: Range = Range(start=10, end=70) 33 | # Timesteps to apply AdaIn 34 | adain_range: Range = Range(start=20, end=100) 35 | # Swap guidance scale 36 | swap_guidance_scale: float = 2.0 37 | # Attention contrasting strength 38 | contrast_strength: float = 1.67 39 | # Object nouns to use for self-segmentation (will use the domain name as default) 40 | object_noun: Optional[str] = None 41 | # Whether to load previously saved inverted latent codes 42 | load_latents: bool = True 43 | # Number of steps to skip in the denoising process (used value from original edit-friendly DDPM paper) 44 | skip_steps: int = 32 45 | # ControlNet guidance scale 46 | controlnet_guidance: float = 1.0 47 | # Predict depth 48 | pred_depth: bool = True 49 | # Appearance image depth path 50 | app_depth_path: Path = None 51 | # Struct image depth path 52 | struct_depth_path: Path = None 53 | 54 | def config_exp(self): 55 | self.output_path = self.output_path 56 | self.output_path.mkdir(parents=True, exist_ok=True) 57 | 58 | # Handle the domain name, prompt, and object nouns used for masking, etc. 59 | if self.use_masked_adain and self.domain_name is None: 60 | raise ValueError("Must provide --domain_name and --prompt when using masked AdaIN") 61 | if not self.use_masked_adain and self.domain_name is None: 62 | self.domain_name = "object" 63 | if self.prompt is None: 64 | self.prompt = f"A photo of a {self.domain_name}" 65 | if self.object_noun is None: 66 | self.object_noun = self.domain_name 67 | 68 | # Define the paths to store the inverted latents to 69 | self.latents_path = Path(self.output_path) / "latents" 70 | self.latents_path.mkdir(parents=True, exist_ok=True) 71 | self.app_latent_save_path = self.latents_path / f"{self.app_image_path.stem}.pt" 72 | self.struct_latent_save_path = self.latents_path / f"{self.struct_image_path.stem}.pt" 73 | 74 | if self.pred_depth: 75 | self.app_depth_path = self.output_path / "app_depth.png" 76 | self.struct_depth_path = self.output_path / "struct_depth.png" -------------------------------------------------------------------------------- /scene_transfer/depth_estimator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import cv2 3 | import diffusers 4 | 5 | 6 | from scene_transfer.config import RunConfig 7 | from third_party.depth_anything_v2.dpt import DepthAnythingV2 8 | 9 | 10 | DEVICE = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu' 11 | 12 | def get_DepthAnyThing_model(encoder='vitl'): 13 | model_configs = { 14 | 'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]}, 15 | 'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]}, 16 | 'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]}, 17 | 'vitg': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]} 18 | } 19 | model = DepthAnythingV2(**model_configs[encoder]) 20 | model.load_state_dict(torch.load(f'checkpoints/depth_anything_v2_{encoder}.pth', map_location='cpu')) 21 | model = model.to(DEVICE).eval() 22 | return model 23 | 24 | def normalize_depthmap(depthmap): 25 | min_, max_ = depthmap.min(), depthmap.max() 26 | depthmap = (depthmap - min_) / (max_ - min_) * 255 27 | return depthmap 28 | 29 | def get_depthmaps(cfg: RunConfig, model='Depth-Anything'): 30 | if model == "Depth-Anything": 31 | model = get_DepthAnyThing_model() 32 | app_image = cv2.imread(cfg.app_image_path) 33 | struct_image = cv2.imread(cfg.struct_image_path) 34 | 35 | app_depth = model.infer_image(app_image) 36 | struct_depth = model.infer_image(struct_image) 37 | 38 | elif model == "Marigold": 39 | pipe = diffusers.MarigoldDepthPipeline.from_pretrained( 40 | "prs-eth/marigold-depth-lcm-v1-0", variant="fp16", torch_dtype=torch.float16 41 | ).to(DEVICE) 42 | app_image = diffusers.utils.load_image(str(cfg.app_image_path)) 43 | struct_image = diffusers.utils.load_image(str(cfg.app_image_path)) 44 | app_depth = pipe(app_image)[0].squeeze() 45 | struct_depth = pipe(struct_image)[0].squeeze() 46 | else: 47 | raise NotImplementedError("unknown depth estimator!") 48 | 49 | app_depth = normalize_depthmap(app_depth) 50 | struct_depth = normalize_depthmap(struct_depth) 51 | 52 | return app_depth, struct_depth 53 | 54 | def get_depthmap(image, model='Depth-Anything', normalize=True): 55 | if model == "Depth-Anything": 56 | model = get_DepthAnyThing_model() 57 | 58 | depth = model.infer_image(image) 59 | 60 | else: 61 | raise NotImplementedError("unknown depth estimator!") 62 | 63 | if normalize: 64 | depth = normalize_depthmap(depth) 65 | 66 | return depth -------------------------------------------------------------------------------- /scene_transfer/image_utils.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | from typing import Optional, Tuple 3 | 4 | import numpy as np 5 | from PIL import Image 6 | 7 | from scene_transfer.config import RunConfig 8 | 9 | 10 | def load_images(cfg: RunConfig, save_path: Optional[pathlib.Path] = None) -> Tuple[Image.Image, Image.Image]: 11 | image_style = load_size(cfg.app_image_path) 12 | image_struct = load_size(cfg.struct_image_path) 13 | 14 | if save_path is not None: 15 | Image.fromarray(image_style).save(save_path / f"in_style.png") 16 | Image.fromarray(image_struct).save(save_path / f"in_struct.png") 17 | return image_style, image_struct 18 | 19 | 20 | def load_size(image_path: pathlib.Path, 21 | left: int = 0, 22 | right: int = 0, 23 | top: int = 0, 24 | bottom: int = 0, 25 | size: int = 512, 26 | resize: bool = True) -> Image.Image: 27 | if isinstance(image_path, (str, pathlib.Path)): 28 | image = np.array(Image.open(str(image_path)).convert('RGB')) 29 | else: 30 | image = image_path 31 | 32 | if resize: 33 | # Resize the image 34 | resized_image = Image.fromarray(image).resize((size, size)) 35 | 36 | # Convert back to numpy array 37 | resized_array = np.array(resized_image) 38 | return resized_array 39 | 40 | h, w, _ = image.shape 41 | 42 | left = min(left, w - 1) 43 | right = min(right, w - left - 1) 44 | top = min(top, h - left - 1) 45 | bottom = min(bottom, h - top - 1) 46 | image = image[top:h - bottom, left:w - right] 47 | 48 | h, w, c = image.shape 49 | 50 | if h < w: 51 | offset = (w - h) // 2 52 | image = image[:, offset:offset + h] 53 | elif w < h: 54 | offset = (h - w) // 2 55 | image = image[offset:offset + w] 56 | 57 | image = np.array(Image.fromarray(image).resize((size, size))) 58 | return image 59 | 60 | 61 | def save_generated_masks(model, cfg: RunConfig): 62 | tensor2im(model.image_app_mask_32).save(cfg.output_path / f"mask_style_32.png") 63 | tensor2im(model.image_struct_mask_32).save(cfg.output_path / f"mask_struct_32.png") 64 | tensor2im(model.image_app_mask_64).save(cfg.output_path / f"mask_style_64.png") 65 | tensor2im(model.image_struct_mask_64).save(cfg.output_path / f"mask_struct_64.png") 66 | 67 | 68 | def tensor2im(x) -> Image.Image: 69 | return Image.fromarray(x.cpu().numpy().astype(np.uint8) * 255) -------------------------------------------------------------------------------- /scene_transfer/latent_utils.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Tuple, Union 3 | 4 | import numpy as np 5 | import torch 6 | from PIL import Image 7 | 8 | from scene_transfer_model import SceneTransfer 9 | from scene_transfer.config import RunConfig 10 | from scene_transfer import image_utils 11 | from scene_transfer.ddpm_inversion import invert 12 | from scene_transfer.depth_estimator import get_depthmaps 13 | from utils.logging import logger 14 | 15 | def load_latents_or_invert_images(model: Union[SceneTransfer], cfg: RunConfig): 16 | if cfg.load_latents and cfg.app_latent_save_path.exists() and cfg.struct_latent_save_path.exists(): 17 | logger.info("Loading existing latents...") 18 | latents_app, latents_struct = load_latents(cfg.app_latent_save_path, cfg.struct_latent_save_path) 19 | noise_app, noise_struct = load_noise(cfg.app_latent_save_path, cfg.struct_latent_save_path) 20 | else: 21 | logger.info("Inverting images...") 22 | app_image, struct_image = image_utils.load_images(cfg=cfg, save_path=cfg.output_path) 23 | # Load depth images 24 | if cfg.pred_depth: 25 | depth_app, depth_struct = get_depthmaps(cfg) 26 | depth_app, depth_struct = Image.fromarray(depth_app).convert("RGB"), Image.fromarray(depth_struct).convert("RGB") 27 | depth_app.save(cfg.app_depth_path) 28 | depth_struct.save(cfg.struct_depth_path) 29 | else: 30 | depth_app = Image.open(cfg.app_depth_path).convert("RGB") 31 | depth_struct = Image.open(cfg.struct_depth_path).convert("RGB") 32 | 33 | # Ensure depth images are the same size as the input images 34 | depth_app = depth_app.resize(app_image.shape[:2]) 35 | depth_struct = depth_struct.resize(struct_image.shape[:2]) 36 | 37 | # Normalize depth images to [0, 1] 38 | depth_app = np.array(depth_app).astype(np.float32) / 255.0 39 | depth_struct = np.array(depth_struct).astype(np.float32) / 255.0 40 | 41 | # Convert back to PIL Image 42 | depth_app = Image.fromarray((depth_app * 255).astype(np.uint8)) 43 | depth_struct = Image.fromarray((depth_struct * 255).astype(np.uint8)) 44 | 45 | model.enable_edit = False # Deactivate the cross-image attention layers 46 | latents_app, latents_struct, noise_app, noise_struct = invert_images(app_image=app_image, 47 | struct_image=struct_image, 48 | sd_model=model.pipe, 49 | depth_app=depth_app, 50 | depth_struct=depth_struct, 51 | cfg=cfg) 52 | model.enable_edit = True 53 | return latents_app, latents_struct, noise_app, noise_struct 54 | 55 | 56 | def load_latents(app_latent_save_path: Path, struct_latent_save_path: Path) -> Tuple[torch.Tensor, torch.Tensor]: 57 | latents_app = torch.load(app_latent_save_path, weights_only=True) 58 | latents_struct = torch.load(struct_latent_save_path, weights_only=True) 59 | if type(latents_struct) == list: 60 | latents_app = [l.to("cuda") for l in latents_app] 61 | latents_struct = [l.to("cuda") for l in latents_struct] 62 | else: 63 | latents_app = latents_app.to("cuda") 64 | latents_struct = latents_struct.to("cuda") 65 | return latents_app, latents_struct 66 | 67 | 68 | def load_noise(app_latent_save_path: Path, struct_latent_save_path: Path) -> Tuple[torch.Tensor, torch.Tensor]: 69 | latents_app = torch.load(app_latent_save_path.parent / (app_latent_save_path.stem + "_ddpm_noise.pt")) 70 | latents_struct = torch.load(struct_latent_save_path.parent / (struct_latent_save_path.stem + "_ddpm_noise.pt")) 71 | latents_app = latents_app.to("cuda") 72 | latents_struct = latents_struct.to("cuda") 73 | return latents_app, latents_struct 74 | 75 | 76 | def invert_images(sd_model: Union[SceneTransfer], app_image: Image.Image, struct_image: Image.Image, depth_app: Image.Image, depth_struct: Image.Image, cfg: RunConfig): 77 | input_app = torch.from_numpy(np.array(app_image)).float() / 127.5 - 1.0 78 | input_struct = torch.from_numpy(np.array(struct_image)).float() / 127.5 - 1.0 79 | 80 | zs_app, latents_app = invert(x0=input_app.permute(2, 0, 1).unsqueeze(0).to('cuda'), 81 | pipe=sd_model, 82 | prompt_src=cfg.prompt, 83 | num_diffusion_steps=cfg.num_timesteps, 84 | cfg_scale_src=3.5, 85 | depth=depth_app) 86 | 87 | zs_struct, latents_struct = invert(x0=input_struct.permute(2, 0, 1).unsqueeze(0).to('cuda'), 88 | pipe=sd_model, 89 | prompt_src=cfg.prompt, 90 | num_diffusion_steps=cfg.num_timesteps, 91 | cfg_scale_src=3.5, 92 | depth=depth_struct) 93 | 94 | # Save the inverted latents and noises 95 | torch.save(latents_app, cfg.latents_path / f"{cfg.app_image_path.stem}.pt") 96 | torch.save(latents_struct, cfg.latents_path / f"{cfg.struct_image_path.stem}.pt") 97 | torch.save(zs_app, cfg.latents_path / f"{cfg.app_image_path.stem}_ddpm_noise.pt") 98 | torch.save(zs_struct, cfg.latents_path / f"{cfg.struct_image_path.stem}_ddpm_noise.pt") 99 | return latents_app, latents_struct, zs_app, zs_struct 100 | 101 | 102 | def get_init_latents_and_noises(model: Union[SceneTransfer], cfg: RunConfig) -> Tuple[torch.Tensor, torch.Tensor]: 103 | # If we stored all the latents along the diffusion process, select the desired one based on the skip_steps 104 | if model.latents_struct.dim() == 4 and model.latents_app.dim() == 4 and model.latents_app.shape[0] > 1: 105 | model.latents_struct = model.latents_struct[cfg.skip_steps] 106 | model.latents_app = model.latents_app[cfg.skip_steps] 107 | init_latents = torch.stack([model.latents_struct, model.latents_app, model.latents_struct]) 108 | init_zs = [model.zs_struct[cfg.skip_steps:], model.zs_app[cfg.skip_steps:], model.zs_struct[cfg.skip_steps:]] 109 | return init_latents, init_zs 110 | -------------------------------------------------------------------------------- /scene_transfer/model_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from diffusers import DDIMScheduler, ControlNetModel 3 | from diffusers import EulerAncestralDiscreteScheduler, AutoencoderKL 4 | from typing import Optional 5 | from scene_transfer.sd15_transfer import SemanticAttentionSD15 6 | from scene_transfer.sdxl_refiner import StableDiffusionXLControlNetPipeline 7 | from utils.logging import logger 8 | 9 | def get_scene_transfer_sd15() -> SemanticAttentionSD15: 10 | logger.info("Loading SD1.5...") 11 | device = torch.device(f'cuda') if torch.cuda.is_available() else torch.device('cpu') 12 | pipe = SemanticAttentionSD15.from_pretrained("runwayml/stable-diffusion-v1-5", 13 | controlnet=ControlNetModel.from_pretrained('lllyasviel/sd-controlnet-depth'), 14 | safety_checker=None).to(device) 15 | # pipe.unet = FreeUUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet").to(device) 16 | pipe.unet.enable_freeu(s1=0.9, s2=0.2, b1=1.5, b2=1.6) 17 | pipe.scheduler = DDIMScheduler.from_config("runwayml/stable-diffusion-v1-5", subfolder="scheduler") 18 | return pipe 19 | 20 | def get_refining_pipe(precision : torch.dtype = torch.float16) -> StableDiffusionXLControlNetPipeline: 21 | logger.info("Loading SDXL...") 22 | # Setup device 23 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 24 | 25 | controlnet = ControlNetModel.from_pretrained( 26 | "diffusers/controlnet-depth-sdxl-1.0", 27 | variant="fp16", 28 | use_safetensors=True, 29 | torch_dtype=precision, 30 | ) 31 | controlnet.enable_xformers_memory_efficient_attention() 32 | vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=precision) 33 | pipe = StableDiffusionXLControlNetPipeline.from_pretrained( 34 | "stabilityai/stable-diffusion-xl-base-1.0", 35 | controlnet=controlnet, 36 | vae=vae, 37 | variant="fp16", 38 | use_safetensors=True, 39 | torch_dtype=precision, 40 | ) 41 | pipe.to(device) 42 | pipe.enable_model_cpu_offload() 43 | pipe.enable_xformers_memory_efficient_attention() 44 | 45 | return pipe -------------------------------------------------------------------------------- /scene_transfer/semantic_matching.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from utils.logging import logger 3 | 4 | def get_mask_for_label(label_name, objects_list, mask_tensor): 5 | """ 6 | Get the mask for a given label from a 2D mask tensor, including all instances of the label. 7 | 8 | Args: 9 | label_name (str): The label to find the mask for 10 | objects_list (list): List of dictionaries containing object information 11 | mask_tensor (torch.Tensor): 2D PyTorch tensor where pixel values correspond to object IDs 12 | 13 | Returns: 14 | torch.Tensor: Binary mask for the given label, including all instances 15 | """ 16 | # Find all objects with the given label 17 | target_objects = [obj for obj in objects_list if obj['label'] == label_name] 18 | 19 | if not target_objects: 20 | raise ValueError(f"No object found with label '{label_name}'") 21 | 22 | # Get the IDs of all target objects 23 | target_ids = [obj['id'] for obj in target_objects] 24 | 25 | # Create a binary mask where any of the target objects are True and everything else is False 26 | binary_mask = torch.zeros_like(mask_tensor, dtype=torch.bool) 27 | for target_id in target_ids: 28 | binary_mask |= (mask_tensor == target_id) 29 | 30 | return binary_mask 31 | 32 | 33 | def match_semantic_labels(src_dict, tgt_dict): 34 | """ 35 | Match semantic labels between source and target images in seg_dict. 36 | 37 | Args: 38 | seg_dict (dict): Dictionary containing segmentation predictions for source and target images. 39 | 40 | Returns: 41 | list: List of tuples containing matched labels and their similarity scores. 42 | """ 43 | # Extract labels from source and target predictions 44 | source_labels = [obj['label'] for obj in src_dict['pred_tgt'][1]] 45 | target_labels = [obj['label'] for obj in tgt_dict['pred_tgt'][1]] 46 | # Find the intersection between source and target labels 47 | common_labels = list(set(source_labels) & set(target_labels)) 48 | 49 | # Initialize a list to store matched labels and their masks 50 | matched_labels = [] 51 | 52 | # Iterate through common labels 53 | for label in common_labels: 54 | # Get masks for the current label in both source and target 55 | source_mask = get_mask_for_label(label, src_dict['pred_tgt'][1], src_dict['pred_tgt'][0]) 56 | target_mask = get_mask_for_label(label, tgt_dict['pred_tgt'][1], tgt_dict['pred_tgt'][0]) 57 | 58 | # get the area of the source and target masks 59 | source_mask_area = source_mask.sum() 60 | target_mask_area = target_mask.sum() 61 | # get the area of the original images for both source and target 62 | source_img_area = src_dict['pred_tgt'][0].shape[-1] * src_dict['pred_tgt'][0].shape[-2] 63 | target_img_area = tgt_dict['pred_tgt'][0].shape[-1] * tgt_dict['pred_tgt'][0].shape[-2] 64 | # get the ratio of the source and target masks to the original images 65 | source_mask_ratio = source_mask_area / source_img_area 66 | target_mask_ratio = target_mask_area / target_img_area 67 | 68 | # Skip labels with mask ratios below 1% 69 | if source_mask_ratio < 0.01 or target_mask_ratio < 0.01: 70 | logger.info(f"[Semantic Matching] Skipping {label} mask due to small mask ratio") 71 | continue 72 | 73 | matched_labels.append((label, source_mask, target_mask)) 74 | 75 | return matched_labels 76 | 77 | def merge_similar_labels(seg_dict, labels=['wall', 'floor']): 78 | """ 79 | Merge similar labels in the objects list. 80 | 81 | Args: 82 | objects_list (list): List of dictionaries containing object information 83 | 84 | Returns: 85 | list: List of dictionaries containing merged object information 86 | """ 87 | for obj in seg_dict['pred_src'][1]: 88 | for label in labels: 89 | if label in obj['label']: 90 | obj['label'] = label 91 | seg_dict['pred_src'][1][seg_dict['pred_src'][1].index(obj)]['label'] = label 92 | 93 | for obj in seg_dict['pred_tgt'][1]: 94 | for label in labels: 95 | if label in obj['label']: 96 | obj['label'] = label 97 | seg_dict['pred_tgt'][1][seg_dict['pred_tgt'][1].index(obj)]['label'] = label 98 | 99 | return seg_dict 100 | 101 | -------------------------------------------------------------------------------- /scripts/download_data.py: -------------------------------------------------------------------------------- 1 | from huggingface_hub import hf_hub_download 2 | 3 | hf_hub_download( 4 | repo_id="gradient-spaces/SceneTransfer", 5 | filename="demo.zip", 6 | repo_type="dataset", 7 | local_dir=".", # Downloads to current directory 8 | ) 9 | 10 | 11 | # Download SceneTransfer.zip 12 | hf_hub_download( 13 | repo_id="gradient-spaces/SceneTransfer", 14 | filename="SceneTransfer.zip", 15 | repo_type="dataset", 16 | local_dir=".", # Downloads to current directory 17 | ) -------------------------------------------------------------------------------- /scripts/download_data.sh: -------------------------------------------------------------------------------- 1 | python scripts/download_data.py 2 | unzip demo.zip 3 | rm -rf demo.zip 4 | 5 | unzip SceneTransfer.zip 6 | rm -rf SceneTransfer.zip -------------------------------------------------------------------------------- /scripts/download_weights.sh: -------------------------------------------------------------------------------- 1 | mkdir checkpoints && cd checkpoints 2 | wget https://download.europe.naverlabs.com/ComputerVision/DUSt3R/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth 3 | wget https://huggingface.co/depth-anything/Depth-Anything-V2-Large/resolve/main/depth_anything_v2_vitl.pth 4 | cd .. 5 | 6 | 7 | -------------------------------------------------------------------------------- /third_party/depth_anything_v2/dinov2_layers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .mlp import Mlp 8 | from .patch_embed import PatchEmbed 9 | from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused 10 | from .block import NestedTensorBlock 11 | from .attention import MemEffAttention 12 | -------------------------------------------------------------------------------- /third_party/depth_anything_v2/dinov2_layers/attention.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # References: 8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py 10 | 11 | import logging 12 | 13 | from torch import Tensor 14 | from torch import nn 15 | 16 | 17 | logger = logging.getLogger("dinov2") 18 | 19 | 20 | try: 21 | from xformers.ops import memory_efficient_attention, unbind, fmha 22 | 23 | XFORMERS_AVAILABLE = True 24 | except ImportError: 25 | logger.warning("xFormers not available") 26 | XFORMERS_AVAILABLE = False 27 | 28 | 29 | class Attention(nn.Module): 30 | def __init__( 31 | self, 32 | dim: int, 33 | num_heads: int = 8, 34 | qkv_bias: bool = False, 35 | proj_bias: bool = True, 36 | attn_drop: float = 0.0, 37 | proj_drop: float = 0.0, 38 | ) -> None: 39 | super().__init__() 40 | self.num_heads = num_heads 41 | head_dim = dim // num_heads 42 | self.scale = head_dim**-0.5 43 | 44 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 45 | self.attn_drop = nn.Dropout(attn_drop) 46 | self.proj = nn.Linear(dim, dim, bias=proj_bias) 47 | self.proj_drop = nn.Dropout(proj_drop) 48 | 49 | def forward(self, x: Tensor) -> Tensor: 50 | B, N, C = x.shape 51 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 52 | 53 | q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] 54 | attn = q @ k.transpose(-2, -1) 55 | 56 | attn = attn.softmax(dim=-1) 57 | attn = self.attn_drop(attn) 58 | 59 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 60 | x = self.proj(x) 61 | x = self.proj_drop(x) 62 | return x 63 | 64 | 65 | class MemEffAttention(Attention): 66 | def forward(self, x: Tensor, attn_bias=None) -> Tensor: 67 | if not XFORMERS_AVAILABLE: 68 | assert attn_bias is None, "xFormers is required for nested tensors usage" 69 | return super().forward(x) 70 | 71 | B, N, C = x.shape 72 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) 73 | 74 | q, k, v = unbind(qkv, 2) 75 | 76 | x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) 77 | x = x.reshape([B, N, C]) 78 | 79 | x = self.proj(x) 80 | x = self.proj_drop(x) 81 | return x 82 | 83 | -------------------------------------------------------------------------------- /third_party/depth_anything_v2/dinov2_layers/drop_path.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # References: 8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py 10 | 11 | 12 | from torch import nn 13 | 14 | 15 | def drop_path(x, drop_prob: float = 0.0, training: bool = False): 16 | if drop_prob == 0.0 or not training: 17 | return x 18 | keep_prob = 1 - drop_prob 19 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 20 | random_tensor = x.new_empty(shape).bernoulli_(keep_prob) 21 | if keep_prob > 0.0: 22 | random_tensor.div_(keep_prob) 23 | output = x * random_tensor 24 | return output 25 | 26 | 27 | class DropPath(nn.Module): 28 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" 29 | 30 | def __init__(self, drop_prob=None): 31 | super(DropPath, self).__init__() 32 | self.drop_prob = drop_prob 33 | 34 | def forward(self, x): 35 | return drop_path(x, self.drop_prob, self.training) 36 | -------------------------------------------------------------------------------- /third_party/depth_anything_v2/dinov2_layers/layer_scale.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110 8 | 9 | from typing import Union 10 | 11 | import torch 12 | from torch import Tensor 13 | from torch import nn 14 | 15 | 16 | class LayerScale(nn.Module): 17 | def __init__( 18 | self, 19 | dim: int, 20 | init_values: Union[float, Tensor] = 1e-5, 21 | inplace: bool = False, 22 | ) -> None: 23 | super().__init__() 24 | self.inplace = inplace 25 | self.gamma = nn.Parameter(init_values * torch.ones(dim)) 26 | 27 | def forward(self, x: Tensor) -> Tensor: 28 | return x.mul_(self.gamma) if self.inplace else x * self.gamma 29 | -------------------------------------------------------------------------------- /third_party/depth_anything_v2/dinov2_layers/mlp.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # References: 8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py 10 | 11 | 12 | from typing import Callable, Optional 13 | 14 | from torch import Tensor, nn 15 | 16 | 17 | class Mlp(nn.Module): 18 | def __init__( 19 | self, 20 | in_features: int, 21 | hidden_features: Optional[int] = None, 22 | out_features: Optional[int] = None, 23 | act_layer: Callable[..., nn.Module] = nn.GELU, 24 | drop: float = 0.0, 25 | bias: bool = True, 26 | ) -> None: 27 | super().__init__() 28 | out_features = out_features or in_features 29 | hidden_features = hidden_features or in_features 30 | self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) 31 | self.act = act_layer() 32 | self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) 33 | self.drop = nn.Dropout(drop) 34 | 35 | def forward(self, x: Tensor) -> Tensor: 36 | x = self.fc1(x) 37 | x = self.act(x) 38 | x = self.drop(x) 39 | x = self.fc2(x) 40 | x = self.drop(x) 41 | return x 42 | -------------------------------------------------------------------------------- /third_party/depth_anything_v2/dinov2_layers/patch_embed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # References: 8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py 10 | 11 | from typing import Callable, Optional, Tuple, Union 12 | 13 | from torch import Tensor 14 | import torch.nn as nn 15 | 16 | 17 | def make_2tuple(x): 18 | if isinstance(x, tuple): 19 | assert len(x) == 2 20 | return x 21 | 22 | assert isinstance(x, int) 23 | return (x, x) 24 | 25 | 26 | class PatchEmbed(nn.Module): 27 | """ 28 | 2D image to patch embedding: (B,C,H,W) -> (B,N,D) 29 | 30 | Args: 31 | img_size: Image size. 32 | patch_size: Patch token size. 33 | in_chans: Number of input image channels. 34 | embed_dim: Number of linear projection output channels. 35 | norm_layer: Normalization layer. 36 | """ 37 | 38 | def __init__( 39 | self, 40 | img_size: Union[int, Tuple[int, int]] = 224, 41 | patch_size: Union[int, Tuple[int, int]] = 16, 42 | in_chans: int = 3, 43 | embed_dim: int = 768, 44 | norm_layer: Optional[Callable] = None, 45 | flatten_embedding: bool = True, 46 | ) -> None: 47 | super().__init__() 48 | 49 | image_HW = make_2tuple(img_size) 50 | patch_HW = make_2tuple(patch_size) 51 | patch_grid_size = ( 52 | image_HW[0] // patch_HW[0], 53 | image_HW[1] // patch_HW[1], 54 | ) 55 | 56 | self.img_size = image_HW 57 | self.patch_size = patch_HW 58 | self.patches_resolution = patch_grid_size 59 | self.num_patches = patch_grid_size[0] * patch_grid_size[1] 60 | 61 | self.in_chans = in_chans 62 | self.embed_dim = embed_dim 63 | 64 | self.flatten_embedding = flatten_embedding 65 | 66 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW) 67 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() 68 | 69 | def forward(self, x: Tensor) -> Tensor: 70 | _, _, H, W = x.shape 71 | patch_H, patch_W = self.patch_size 72 | 73 | assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}" 74 | assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}" 75 | 76 | x = self.proj(x) # B C H W 77 | H, W = x.size(2), x.size(3) 78 | x = x.flatten(2).transpose(1, 2) # B HW C 79 | x = self.norm(x) 80 | if not self.flatten_embedding: 81 | x = x.reshape(-1, H, W, self.embed_dim) # B H W C 82 | return x 83 | 84 | def flops(self) -> float: 85 | Ho, Wo = self.patches_resolution 86 | flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) 87 | if self.norm is not None: 88 | flops += Ho * Wo * self.embed_dim 89 | return flops 90 | -------------------------------------------------------------------------------- /third_party/depth_anything_v2/dinov2_layers/swiglu_ffn.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Callable, Optional 8 | 9 | from torch import Tensor, nn 10 | import torch.nn.functional as F 11 | 12 | 13 | class SwiGLUFFN(nn.Module): 14 | def __init__( 15 | self, 16 | in_features: int, 17 | hidden_features: Optional[int] = None, 18 | out_features: Optional[int] = None, 19 | act_layer: Callable[..., nn.Module] = None, 20 | drop: float = 0.0, 21 | bias: bool = True, 22 | ) -> None: 23 | super().__init__() 24 | out_features = out_features or in_features 25 | hidden_features = hidden_features or in_features 26 | self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) 27 | self.w3 = nn.Linear(hidden_features, out_features, bias=bias) 28 | 29 | def forward(self, x: Tensor) -> Tensor: 30 | x12 = self.w12(x) 31 | x1, x2 = x12.chunk(2, dim=-1) 32 | hidden = F.silu(x1) * x2 33 | return self.w3(hidden) 34 | 35 | 36 | try: 37 | from xformers.ops import SwiGLU 38 | 39 | XFORMERS_AVAILABLE = True 40 | except ImportError: 41 | SwiGLU = SwiGLUFFN 42 | XFORMERS_AVAILABLE = False 43 | 44 | 45 | class SwiGLUFFNFused(SwiGLU): 46 | def __init__( 47 | self, 48 | in_features: int, 49 | hidden_features: Optional[int] = None, 50 | out_features: Optional[int] = None, 51 | act_layer: Callable[..., nn.Module] = None, 52 | drop: float = 0.0, 53 | bias: bool = True, 54 | ) -> None: 55 | out_features = out_features or in_features 56 | hidden_features = hidden_features or in_features 57 | hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 58 | super().__init__( 59 | in_features=in_features, 60 | hidden_features=hidden_features, 61 | out_features=out_features, 62 | bias=bias, 63 | ) 64 | -------------------------------------------------------------------------------- /third_party/depth_anything_v2/util/blocks.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | def _make_scratch(in_shape, out_shape, groups=1, expand=False): 5 | scratch = nn.Module() 6 | 7 | out_shape1 = out_shape 8 | out_shape2 = out_shape 9 | out_shape3 = out_shape 10 | if len(in_shape) >= 4: 11 | out_shape4 = out_shape 12 | 13 | if expand: 14 | out_shape1 = out_shape 15 | out_shape2 = out_shape * 2 16 | out_shape3 = out_shape * 4 17 | if len(in_shape) >= 4: 18 | out_shape4 = out_shape * 8 19 | 20 | scratch.layer1_rn = nn.Conv2d(in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups) 21 | scratch.layer2_rn = nn.Conv2d(in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups) 22 | scratch.layer3_rn = nn.Conv2d(in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups) 23 | if len(in_shape) >= 4: 24 | scratch.layer4_rn = nn.Conv2d(in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups) 25 | 26 | return scratch 27 | 28 | 29 | class ResidualConvUnit(nn.Module): 30 | """Residual convolution module. 31 | """ 32 | 33 | def __init__(self, features, activation, bn): 34 | """Init. 35 | 36 | Args: 37 | features (int): number of features 38 | """ 39 | super().__init__() 40 | 41 | self.bn = bn 42 | 43 | self.groups=1 44 | 45 | self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups) 46 | 47 | self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups) 48 | 49 | if self.bn == True: 50 | self.bn1 = nn.BatchNorm2d(features) 51 | self.bn2 = nn.BatchNorm2d(features) 52 | 53 | self.activation = activation 54 | 55 | self.skip_add = nn.quantized.FloatFunctional() 56 | 57 | def forward(self, x): 58 | """Forward pass. 59 | 60 | Args: 61 | x (tensor): input 62 | 63 | Returns: 64 | tensor: output 65 | """ 66 | 67 | out = self.activation(x) 68 | out = self.conv1(out) 69 | if self.bn == True: 70 | out = self.bn1(out) 71 | 72 | out = self.activation(out) 73 | out = self.conv2(out) 74 | if self.bn == True: 75 | out = self.bn2(out) 76 | 77 | if self.groups > 1: 78 | out = self.conv_merge(out) 79 | 80 | return self.skip_add.add(out, x) 81 | 82 | 83 | class FeatureFusionBlock(nn.Module): 84 | """Feature fusion block. 85 | """ 86 | 87 | def __init__( 88 | self, 89 | features, 90 | activation, 91 | deconv=False, 92 | bn=False, 93 | expand=False, 94 | align_corners=True, 95 | size=None 96 | ): 97 | """Init. 98 | 99 | Args: 100 | features (int): number of features 101 | """ 102 | super(FeatureFusionBlock, self).__init__() 103 | 104 | self.deconv = deconv 105 | self.align_corners = align_corners 106 | 107 | self.groups=1 108 | 109 | self.expand = expand 110 | out_features = features 111 | if self.expand == True: 112 | out_features = features // 2 113 | 114 | self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1) 115 | 116 | self.resConfUnit1 = ResidualConvUnit(features, activation, bn) 117 | self.resConfUnit2 = ResidualConvUnit(features, activation, bn) 118 | 119 | self.skip_add = nn.quantized.FloatFunctional() 120 | 121 | self.size=size 122 | 123 | def forward(self, *xs, size=None): 124 | """Forward pass. 125 | 126 | Returns: 127 | tensor: output 128 | """ 129 | output = xs[0] 130 | 131 | if len(xs) == 2: 132 | res = self.resConfUnit1(xs[1]) 133 | output = self.skip_add.add(output, res) 134 | 135 | output = self.resConfUnit2(output) 136 | 137 | if (size is None) and (self.size is None): 138 | modifier = {"scale_factor": 2} 139 | elif size is None: 140 | modifier = {"size": self.size} 141 | else: 142 | modifier = {"size": size} 143 | 144 | output = nn.functional.interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners) 145 | 146 | output = self.out_conv(output) 147 | 148 | return output 149 | -------------------------------------------------------------------------------- /third_party/depth_anything_v2/util/transform.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | 4 | 5 | class Resize(object): 6 | """Resize sample to given size (width, height). 7 | """ 8 | 9 | def __init__( 10 | self, 11 | width, 12 | height, 13 | resize_target=True, 14 | keep_aspect_ratio=False, 15 | ensure_multiple_of=1, 16 | resize_method="lower_bound", 17 | image_interpolation_method=cv2.INTER_AREA, 18 | ): 19 | """Init. 20 | 21 | Args: 22 | width (int): desired output width 23 | height (int): desired output height 24 | resize_target (bool, optional): 25 | True: Resize the full sample (image, mask, target). 26 | False: Resize image only. 27 | Defaults to True. 28 | keep_aspect_ratio (bool, optional): 29 | True: Keep the aspect ratio of the input sample. 30 | Output sample might not have the given width and height, and 31 | resize behaviour depends on the parameter 'resize_method'. 32 | Defaults to False. 33 | ensure_multiple_of (int, optional): 34 | Output width and height is constrained to be multiple of this parameter. 35 | Defaults to 1. 36 | resize_method (str, optional): 37 | "lower_bound": Output will be at least as large as the given size. 38 | "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.) 39 | "minimal": Scale as least as possible. (Output size might be smaller than given size.) 40 | Defaults to "lower_bound". 41 | """ 42 | self.__width = width 43 | self.__height = height 44 | 45 | self.__resize_target = resize_target 46 | self.__keep_aspect_ratio = keep_aspect_ratio 47 | self.__multiple_of = ensure_multiple_of 48 | self.__resize_method = resize_method 49 | self.__image_interpolation_method = image_interpolation_method 50 | 51 | def constrain_to_multiple_of(self, x, min_val=0, max_val=None): 52 | y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int) 53 | 54 | if max_val is not None and y > max_val: 55 | y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int) 56 | 57 | if y < min_val: 58 | y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int) 59 | 60 | return y 61 | 62 | def get_size(self, width, height): 63 | # determine new height and width 64 | scale_height = self.__height / height 65 | scale_width = self.__width / width 66 | 67 | if self.__keep_aspect_ratio: 68 | if self.__resize_method == "lower_bound": 69 | # scale such that output size is lower bound 70 | if scale_width > scale_height: 71 | # fit width 72 | scale_height = scale_width 73 | else: 74 | # fit height 75 | scale_width = scale_height 76 | elif self.__resize_method == "upper_bound": 77 | # scale such that output size is upper bound 78 | if scale_width < scale_height: 79 | # fit width 80 | scale_height = scale_width 81 | else: 82 | # fit height 83 | scale_width = scale_height 84 | elif self.__resize_method == "minimal": 85 | # scale as least as possbile 86 | if abs(1 - scale_width) < abs(1 - scale_height): 87 | # fit width 88 | scale_height = scale_width 89 | else: 90 | # fit height 91 | scale_width = scale_height 92 | else: 93 | raise ValueError(f"resize_method {self.__resize_method} not implemented") 94 | 95 | if self.__resize_method == "lower_bound": 96 | new_height = self.constrain_to_multiple_of(scale_height * height, min_val=self.__height) 97 | new_width = self.constrain_to_multiple_of(scale_width * width, min_val=self.__width) 98 | elif self.__resize_method == "upper_bound": 99 | new_height = self.constrain_to_multiple_of(scale_height * height, max_val=self.__height) 100 | new_width = self.constrain_to_multiple_of(scale_width * width, max_val=self.__width) 101 | elif self.__resize_method == "minimal": 102 | new_height = self.constrain_to_multiple_of(scale_height * height) 103 | new_width = self.constrain_to_multiple_of(scale_width * width) 104 | else: 105 | raise ValueError(f"resize_method {self.__resize_method} not implemented") 106 | 107 | return (new_width, new_height) 108 | 109 | def __call__(self, sample): 110 | width, height = self.get_size(sample["image"].shape[1], sample["image"].shape[0]) 111 | 112 | # resize sample 113 | sample["image"] = cv2.resize(sample["image"], (width, height), interpolation=self.__image_interpolation_method) 114 | 115 | if self.__resize_target: 116 | if "depth" in sample: 117 | sample["depth"] = cv2.resize(sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST) 118 | 119 | if "mask" in sample: 120 | sample["mask"] = cv2.resize(sample["mask"].astype(np.float32), (width, height), interpolation=cv2.INTER_NEAREST) 121 | 122 | return sample 123 | 124 | 125 | class NormalizeImage(object): 126 | """Normlize image by given mean and std. 127 | """ 128 | 129 | def __init__(self, mean, std): 130 | self.__mean = mean 131 | self.__std = std 132 | 133 | def __call__(self, sample): 134 | sample["image"] = (sample["image"] - self.__mean) / self.__std 135 | 136 | return sample 137 | 138 | 139 | class PrepareForNet(object): 140 | """Prepare sample for usage as network input. 141 | """ 142 | 143 | def __init__(self): 144 | pass 145 | 146 | def __call__(self, sample): 147 | image = np.transpose(sample["image"], (2, 0, 1)) 148 | sample["image"] = np.ascontiguousarray(image).astype(np.float32) 149 | 150 | if "depth" in sample: 151 | depth = sample["depth"].astype(np.float32) 152 | sample["depth"] = np.ascontiguousarray(depth) 153 | 154 | if "mask" in sample: 155 | sample["mask"] = sample["mask"].astype(np.float32) 156 | sample["mask"] = np.ascontiguousarray(sample["mask"]) 157 | 158 | return sample -------------------------------------------------------------------------------- /third_party/dust3r/LICENSE: -------------------------------------------------------------------------------- 1 | DUSt3R, Copyright (c) 2024-present Naver Corporation, is licensed under the Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license. 2 | 3 | A summary of the CC BY-NC-SA 4.0 license is located here: 4 | https://creativecommons.org/licenses/by-nc-sa/4.0/ 5 | 6 | The CC BY-NC-SA 4.0 license is located here: 7 | https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode 8 | -------------------------------------------------------------------------------- /third_party/dust3r/croco/LICENSE: -------------------------------------------------------------------------------- 1 | CroCo, Copyright (c) 2022-present Naver Corporation, is licensed under the Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license. 2 | 3 | A summary of the CC BY-NC-SA 4.0 license is located here: 4 | https://creativecommons.org/licenses/by-nc-sa/4.0/ 5 | 6 | The CC BY-NC-SA 4.0 license is located here: 7 | https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode 8 | 9 | 10 | SEE NOTICE BELOW WITH RESPECT TO THE FILE: models/pos_embed.py, models/blocks.py 11 | 12 | *************************** 13 | 14 | NOTICE WITH RESPECT TO THE FILE: models/pos_embed.py 15 | 16 | This software is being redistributed in a modifiled form. The original form is available here: 17 | 18 | https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py 19 | 20 | This software in this file incorporates parts of the following software available here: 21 | 22 | Transformer: https://github.com/tensorflow/models/blob/master/official/legacy/transformer/model_utils.py 23 | available under the following license: https://github.com/tensorflow/models/blob/master/LICENSE 24 | 25 | MoCo v3: https://github.com/facebookresearch/moco-v3 26 | available under the following license: https://github.com/facebookresearch/moco-v3/blob/main/LICENSE 27 | 28 | DeiT: https://github.com/facebookresearch/deit 29 | available under the following license: https://github.com/facebookresearch/deit/blob/main/LICENSE 30 | 31 | 32 | ORIGINAL COPYRIGHT NOTICE AND PERMISSION NOTICE AVAILABLE HERE IS REPRODUCE BELOW: 33 | 34 | https://github.com/facebookresearch/mae/blob/main/LICENSE 35 | 36 | Attribution-NonCommercial 4.0 International 37 | 38 | *************************** 39 | 40 | NOTICE WITH RESPECT TO THE FILE: models/blocks.py 41 | 42 | This software is being redistributed in a modifiled form. The original form is available here: 43 | 44 | https://github.com/rwightman/pytorch-image-models 45 | 46 | ORIGINAL COPYRIGHT NOTICE AND PERMISSION NOTICE AVAILABLE HERE IS REPRODUCE BELOW: 47 | 48 | https://github.com/rwightman/pytorch-image-models/blob/master/LICENSE 49 | 50 | Apache License 51 | Version 2.0, January 2004 52 | http://www.apache.org/licenses/ -------------------------------------------------------------------------------- /third_party/dust3r/croco/NOTICE: -------------------------------------------------------------------------------- 1 | CroCo 2 | Copyright 2022-present NAVER Corp. 3 | 4 | This project contains subcomponents with separate copyright notices and license terms. 5 | Your use of the source code for these subcomponents is subject to the terms and conditions of the following licenses. 6 | 7 | ==== 8 | 9 | facebookresearch/mae 10 | https://github.com/facebookresearch/mae 11 | 12 | Attribution-NonCommercial 4.0 International 13 | 14 | ==== 15 | 16 | rwightman/pytorch-image-models 17 | https://github.com/rwightman/pytorch-image-models 18 | 19 | Apache License 20 | Version 2.0, January 2004 21 | http://www.apache.org/licenses/ -------------------------------------------------------------------------------- /third_party/dust3r/croco/assets/Chateau1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GradientSpaces/ReStyle3D/5a6440eeee53783530dbbda7c26898b8b4c351da/third_party/dust3r/croco/assets/Chateau1.png -------------------------------------------------------------------------------- /third_party/dust3r/croco/assets/Chateau2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GradientSpaces/ReStyle3D/5a6440eeee53783530dbbda7c26898b8b4c351da/third_party/dust3r/croco/assets/Chateau2.png -------------------------------------------------------------------------------- /third_party/dust3r/croco/assets/arch.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GradientSpaces/ReStyle3D/5a6440eeee53783530dbbda7c26898b8b4c351da/third_party/dust3r/croco/assets/arch.jpg -------------------------------------------------------------------------------- /third_party/dust3r/croco/croco-stereo-flow-demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "9bca0f41", 6 | "metadata": {}, 7 | "source": [ 8 | "# Simple inference example with CroCo-Stereo or CroCo-Flow" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": null, 14 | "id": "80653ef7", 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "# Copyright (C) 2022-present Naver Corporation. All rights reserved.\n", 19 | "# Licensed under CC BY-NC-SA 4.0 (non-commercial use only)." 20 | ] 21 | }, 22 | { 23 | "cell_type": "markdown", 24 | "id": "4f033862", 25 | "metadata": {}, 26 | "source": [ 27 | "First download the model(s) of your choice by running\n", 28 | "```\n", 29 | "bash stereoflow/download_model.sh crocostereo.pth\n", 30 | "bash stereoflow/download_model.sh crocoflow.pth\n", 31 | "```" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": null, 37 | "id": "1fb2e392", 38 | "metadata": {}, 39 | "outputs": [], 40 | "source": [ 41 | "import torch\n", 42 | "use_gpu = torch.cuda.is_available() and torch.cuda.device_count()>0\n", 43 | "device = torch.device('cuda:0' if use_gpu else 'cpu')\n", 44 | "import matplotlib.pylab as plt" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": null, 50 | "id": "e0e25d77", 51 | "metadata": {}, 52 | "outputs": [], 53 | "source": [ 54 | "from stereoflow.test import _load_model_and_criterion\n", 55 | "from stereoflow.engine import tiled_pred\n", 56 | "from stereoflow.datasets_stereo import img_to_tensor, vis_disparity\n", 57 | "from stereoflow.datasets_flow import flowToColor\n", 58 | "tile_overlap=0.7 # recommended value, higher value can be slightly better but slower" 59 | ] 60 | }, 61 | { 62 | "cell_type": "markdown", 63 | "id": "86a921f5", 64 | "metadata": {}, 65 | "source": [ 66 | "### CroCo-Stereo example" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": null, 72 | "id": "64e483cb", 73 | "metadata": {}, 74 | "outputs": [], 75 | "source": [ 76 | "image1 = np.asarray(Image.open(''))\n", 77 | "image2 = np.asarray(Image.open(''))" 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": null, 83 | "id": "f0d04303", 84 | "metadata": {}, 85 | "outputs": [], 86 | "source": [ 87 | "model, _, cropsize, with_conf, task, tile_conf_mode = _load_model_and_criterion('stereoflow_models/crocostereo.pth', None, device)\n" 88 | ] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": null, 93 | "id": "47dc14b5", 94 | "metadata": {}, 95 | "outputs": [], 96 | "source": [ 97 | "im1 = img_to_tensor(image1).to(device).unsqueeze(0)\n", 98 | "im2 = img_to_tensor(image2).to(device).unsqueeze(0)\n", 99 | "with torch.inference_mode():\n", 100 | " pred, _, _ = tiled_pred(model, None, im1, im2, None, conf_mode=tile_conf_mode, overlap=tile_overlap, crop=cropsize, with_conf=with_conf, return_time=False)\n", 101 | "pred = pred.squeeze(0).squeeze(0).cpu().numpy()" 102 | ] 103 | }, 104 | { 105 | "cell_type": "code", 106 | "execution_count": null, 107 | "id": "583b9f16", 108 | "metadata": {}, 109 | "outputs": [], 110 | "source": [ 111 | "plt.imshow(vis_disparity(pred))\n", 112 | "plt.axis('off')" 113 | ] 114 | }, 115 | { 116 | "cell_type": "markdown", 117 | "id": "d2df5d70", 118 | "metadata": {}, 119 | "source": [ 120 | "### CroCo-Flow example" 121 | ] 122 | }, 123 | { 124 | "cell_type": "code", 125 | "execution_count": null, 126 | "id": "9ee257a7", 127 | "metadata": {}, 128 | "outputs": [], 129 | "source": [ 130 | "image1 = np.asarray(Image.open(''))\n", 131 | "image2 = np.asarray(Image.open(''))" 132 | ] 133 | }, 134 | { 135 | "cell_type": "code", 136 | "execution_count": null, 137 | "id": "d5edccf0", 138 | "metadata": {}, 139 | "outputs": [], 140 | "source": [ 141 | "model, _, cropsize, with_conf, task, tile_conf_mode = _load_model_and_criterion('stereoflow_models/crocoflow.pth', None, device)\n" 142 | ] 143 | }, 144 | { 145 | "cell_type": "code", 146 | "execution_count": null, 147 | "id": "b19692c3", 148 | "metadata": {}, 149 | "outputs": [], 150 | "source": [ 151 | "im1 = img_to_tensor(image1).to(device).unsqueeze(0)\n", 152 | "im2 = img_to_tensor(image2).to(device).unsqueeze(0)\n", 153 | "with torch.inference_mode():\n", 154 | " pred, _, _ = tiled_pred(model, None, im1, im2, None, conf_mode=tile_conf_mode, overlap=tile_overlap, crop=cropsize, with_conf=with_conf, return_time=False)\n", 155 | "pred = pred.squeeze(0).permute(1,2,0).cpu().numpy()" 156 | ] 157 | }, 158 | { 159 | "cell_type": "code", 160 | "execution_count": null, 161 | "id": "26f79db3", 162 | "metadata": {}, 163 | "outputs": [], 164 | "source": [ 165 | "plt.imshow(flowToColor(pred))\n", 166 | "plt.axis('off')" 167 | ] 168 | } 169 | ], 170 | "metadata": { 171 | "kernelspec": { 172 | "display_name": "Python 3 (ipykernel)", 173 | "language": "python", 174 | "name": "python3" 175 | }, 176 | "language_info": { 177 | "codemirror_mode": { 178 | "name": "ipython", 179 | "version": 3 180 | }, 181 | "file_extension": ".py", 182 | "mimetype": "text/x-python", 183 | "name": "python", 184 | "nbconvert_exporter": "python", 185 | "pygments_lexer": "ipython3", 186 | "version": "3.9.7" 187 | } 188 | }, 189 | "nbformat": 4, 190 | "nbformat_minor": 5 191 | } 192 | -------------------------------------------------------------------------------- /third_party/dust3r/croco/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GradientSpaces/ReStyle3D/5a6440eeee53783530dbbda7c26898b8b4c351da/third_party/dust3r/croco/datasets/__init__.py -------------------------------------------------------------------------------- /third_party/dust3r/croco/datasets/crops/README.MD: -------------------------------------------------------------------------------- 1 | ## Generation of crops from the real datasets 2 | 3 | The instructions below allow to generate the crops used for pre-training CroCo v2 from the following real-world datasets: ARKitScenes, MegaDepth, 3DStreetView and IndoorVL. 4 | 5 | ### Download the metadata of the crops to generate 6 | 7 | First, download the metadata and put them in `./data/`: 8 | ``` 9 | mkdir -p data 10 | cd data/ 11 | wget https://download.europe.naverlabs.com/ComputerVision/CroCo/data/crop_metadata.zip 12 | unzip crop_metadata.zip 13 | rm crop_metadata.zip 14 | cd .. 15 | ``` 16 | 17 | ### Prepare the original datasets 18 | 19 | Second, download the original datasets in `./data/original_datasets/`. 20 | ``` 21 | mkdir -p data/original_datasets 22 | ``` 23 | 24 | ##### ARKitScenes 25 | 26 | Download the `raw` dataset from https://github.com/apple/ARKitScenes/blob/main/DATA.md and put it in `./data/original_datasets/ARKitScenes/`. 27 | The resulting file structure should be like: 28 | ``` 29 | ./data/original_datasets/ARKitScenes/ 30 | └───Training 31 | └───40753679 32 | │ │ ultrawide 33 | │ │ ... 34 | └───40753686 35 | │ 36 | ... 37 | ``` 38 | 39 | ##### MegaDepth 40 | 41 | Download `MegaDepth v1 Dataset` from https://www.cs.cornell.edu/projects/megadepth/ and put it in `./data/original_datasets/MegaDepth/`. 42 | The resulting file structure should be like: 43 | 44 | ``` 45 | ./data/original_datasets/MegaDepth/ 46 | └───0000 47 | │ └───images 48 | │ │ │ 1000557903_87fa96b8a4_o.jpg 49 | │ │ └ ... 50 | │ └─── ... 51 | └───0001 52 | │ │ 53 | │ └ ... 54 | └─── ... 55 | ``` 56 | 57 | ##### 3DStreetView 58 | 59 | Download `3D_Street_View` dataset from https://github.com/amir32002/3D_Street_View and put it in `./data/original_datasets/3DStreetView/`. 60 | The resulting file structure should be like: 61 | 62 | ``` 63 | ./data/original_datasets/3DStreetView/ 64 | └───dataset_aligned 65 | │ └───0002 66 | │ │ │ 0000002_0000001_0000002_0000001.jpg 67 | │ │ └ ... 68 | │ └─── ... 69 | └───dataset_unaligned 70 | │ └───0003 71 | │ │ │ 0000003_0000001_0000002_0000001.jpg 72 | │ │ └ ... 73 | │ └─── ... 74 | ``` 75 | 76 | ##### IndoorVL 77 | 78 | Download the `IndoorVL` datasets using [Kapture](https://github.com/naver/kapture). 79 | 80 | ``` 81 | pip install kapture 82 | mkdir -p ./data/original_datasets/IndoorVL 83 | cd ./data/original_datasets/IndoorVL 84 | kapture_download_dataset.py update 85 | kapture_download_dataset.py install "HyundaiDepartmentStore_*" 86 | kapture_download_dataset.py install "GangnamStation_*" 87 | cd - 88 | ``` 89 | 90 | ### Extract the crops 91 | 92 | Now, extract the crops for each of the dataset: 93 | ``` 94 | for dataset in ARKitScenes MegaDepth 3DStreetView IndoorVL; 95 | do 96 | python3 datasets/crops/extract_crops_from_images.py --crops ./data/crop_metadata/${dataset}/crops_release.txt --root-dir ./data/original_datasets/${dataset}/ --output-dir ./data/${dataset}_crops/ --imsize 256 --nthread 8 --max-subdir-levels 5 --ideal-number-pairs-in-dir 500; 97 | done 98 | ``` 99 | 100 | ##### Note for IndoorVL 101 | 102 | Due to some legal issues, we can only release 144,228 pairs out of the 1,593,689 pairs used in the paper. 103 | To account for it in terms of number of pre-training iterations, the pre-training command in this repository uses 125 training epochs including 12 warm-up epochs and learning rate cosine schedule of 250, instead of 100, 10 and 200 respectively. 104 | The impact on the performance is negligible. 105 | -------------------------------------------------------------------------------- /third_party/dust3r/croco/datasets/crops/extract_crops_from_images.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2022-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | # 4 | # -------------------------------------------------------- 5 | # Extracting crops for pre-training 6 | # -------------------------------------------------------- 7 | 8 | import os 9 | import argparse 10 | from tqdm import tqdm 11 | from PIL import Image 12 | import functools 13 | from multiprocessing import Pool 14 | import math 15 | 16 | 17 | def arg_parser(): 18 | parser = argparse.ArgumentParser('Generate cropped image pairs from image crop list') 19 | 20 | parser.add_argument('--crops', type=str, required=True, help='crop file') 21 | parser.add_argument('--root-dir', type=str, required=True, help='root directory') 22 | parser.add_argument('--output-dir', type=str, required=True, help='output directory') 23 | parser.add_argument('--imsize', type=int, default=256, help='size of the crops') 24 | parser.add_argument('--nthread', type=int, required=True, help='number of simultaneous threads') 25 | parser.add_argument('--max-subdir-levels', type=int, default=5, help='maximum number of subdirectories') 26 | parser.add_argument('--ideal-number-pairs-in-dir', type=int, default=500, help='number of pairs stored in a dir') 27 | return parser 28 | 29 | 30 | def main(args): 31 | listing_path = os.path.join(args.output_dir, 'listing.txt') 32 | 33 | print(f'Loading list of crops ... ({args.nthread} threads)') 34 | crops, num_crops_to_generate = load_crop_file(args.crops) 35 | 36 | print(f'Preparing jobs ({len(crops)} candidate image pairs)...') 37 | num_levels = min(math.ceil(math.log(num_crops_to_generate, args.ideal_number_pairs_in_dir)), args.max_subdir_levels) 38 | num_pairs_in_dir = math.ceil(num_crops_to_generate ** (1/num_levels)) 39 | 40 | jobs = prepare_jobs(crops, num_levels, num_pairs_in_dir) 41 | del crops 42 | 43 | os.makedirs(args.output_dir, exist_ok=True) 44 | mmap = Pool(args.nthread).imap_unordered if args.nthread > 1 else map 45 | call = functools.partial(save_image_crops, args) 46 | 47 | print(f"Generating cropped images to {args.output_dir} ...") 48 | with open(listing_path, 'w') as listing: 49 | listing.write('# pair_path\n') 50 | for results in tqdm(mmap(call, jobs), total=len(jobs)): 51 | for path in results: 52 | listing.write(f'{path}\n') 53 | print('Finished writing listing to', listing_path) 54 | 55 | 56 | def load_crop_file(path): 57 | data = open(path).read().splitlines() 58 | pairs = [] 59 | num_crops_to_generate = 0 60 | for line in tqdm(data): 61 | if line.startswith('#'): 62 | continue 63 | line = line.split(', ') 64 | if len(line) < 8: 65 | img1, img2, rotation = line 66 | pairs.append((img1, img2, int(rotation), [])) 67 | else: 68 | l1, r1, t1, b1, l2, r2, t2, b2 = map(int, line) 69 | rect1, rect2 = (l1, t1, r1, b1), (l2, t2, r2, b2) 70 | pairs[-1][-1].append((rect1, rect2)) 71 | num_crops_to_generate += 1 72 | return pairs, num_crops_to_generate 73 | 74 | 75 | def prepare_jobs(pairs, num_levels, num_pairs_in_dir): 76 | jobs = [] 77 | powers = [num_pairs_in_dir**level for level in reversed(range(num_levels))] 78 | 79 | def get_path(idx): 80 | idx_array = [] 81 | d = idx 82 | for level in range(num_levels - 1): 83 | idx_array.append(idx // powers[level]) 84 | idx = idx % powers[level] 85 | idx_array.append(d) 86 | return '/'.join(map(lambda x: hex(x)[2:], idx_array)) 87 | 88 | idx = 0 89 | for pair_data in tqdm(pairs): 90 | img1, img2, rotation, crops = pair_data 91 | if -60 <= rotation and rotation <= 60: 92 | rotation = 0 # most likely not a true rotation 93 | paths = [get_path(idx + k) for k in range(len(crops))] 94 | idx += len(crops) 95 | jobs.append(((img1, img2), rotation, crops, paths)) 96 | return jobs 97 | 98 | 99 | def load_image(path): 100 | try: 101 | return Image.open(path).convert('RGB') 102 | except Exception as e: 103 | print('skipping', path, e) 104 | raise OSError() 105 | 106 | 107 | def save_image_crops(args, data): 108 | # load images 109 | img_pair, rot, crops, paths = data 110 | try: 111 | img1, img2 = [load_image(os.path.join(args.root_dir, impath)) for impath in img_pair] 112 | except OSError as e: 113 | return [] 114 | 115 | def area(sz): 116 | return sz[0] * sz[1] 117 | 118 | tgt_size = (args.imsize, args.imsize) 119 | 120 | def prepare_crop(img, rect, rot=0): 121 | # actual crop 122 | img = img.crop(rect) 123 | 124 | # resize to desired size 125 | interp = Image.Resampling.LANCZOS if area(img.size) > 4*area(tgt_size) else Image.Resampling.BICUBIC 126 | img = img.resize(tgt_size, resample=interp) 127 | 128 | # rotate the image 129 | rot90 = (round(rot/90) % 4) * 90 130 | if rot90 == 90: 131 | img = img.transpose(Image.Transpose.ROTATE_90) 132 | elif rot90 == 180: 133 | img = img.transpose(Image.Transpose.ROTATE_180) 134 | elif rot90 == 270: 135 | img = img.transpose(Image.Transpose.ROTATE_270) 136 | return img 137 | 138 | results = [] 139 | for (rect1, rect2), path in zip(crops, paths): 140 | crop1 = prepare_crop(img1, rect1) 141 | crop2 = prepare_crop(img2, rect2, rot) 142 | 143 | fullpath1 = os.path.join(args.output_dir, path+'_1.jpg') 144 | fullpath2 = os.path.join(args.output_dir, path+'_2.jpg') 145 | os.makedirs(os.path.dirname(fullpath1), exist_ok=True) 146 | 147 | assert not os.path.isfile(fullpath1), fullpath1 148 | assert not os.path.isfile(fullpath2), fullpath2 149 | crop1.save(fullpath1) 150 | crop2.save(fullpath2) 151 | results.append(path) 152 | 153 | return results 154 | 155 | 156 | if __name__ == '__main__': 157 | args = arg_parser().parse_args() 158 | main(args) 159 | 160 | -------------------------------------------------------------------------------- /third_party/dust3r/croco/datasets/habitat_sim/README.MD: -------------------------------------------------------------------------------- 1 | ## Generation of synthetic image pairs using Habitat-Sim 2 | 3 | These instructions allow to generate pre-training pairs from the Habitat simulator. 4 | As we did not save metadata of the pairs used in the original paper, they are not strictly the same, but these data use the same setting and are equivalent. 5 | 6 | ### Download Habitat-Sim scenes 7 | Download Habitat-Sim scenes: 8 | - Download links can be found here: https://github.com/facebookresearch/habitat-sim/blob/main/DATASETS.md 9 | - We used scenes from the HM3D, habitat-test-scenes, Replica, ReplicaCad and ScanNet datasets. 10 | - Please put the scenes under `./data/habitat-sim-data/scene_datasets/` following the structure below, or update manually paths in `paths.py`. 11 | ``` 12 | ./data/ 13 | └──habitat-sim-data/ 14 | └──scene_datasets/ 15 | ├──hm3d/ 16 | ├──gibson/ 17 | ├──habitat-test-scenes/ 18 | ├──replica_cad_baked_lighting/ 19 | ├──replica_cad/ 20 | ├──ReplicaDataset/ 21 | └──scannet/ 22 | ``` 23 | 24 | ### Image pairs generation 25 | We provide metadata to generate reproducible images pairs for pretraining and validation. 26 | Experiments described in the paper used similar data, but whose generation was not reproducible at the time. 27 | 28 | Specifications: 29 | - 256x256 resolution images, with 60 degrees field of view . 30 | - Up to 1000 image pairs per scene. 31 | - Number of scenes considered/number of images pairs per dataset: 32 | - Scannet: 1097 scenes / 985 209 pairs 33 | - HM3D: 34 | - hm3d/train: 800 / 800k pairs 35 | - hm3d/val: 100 scenes / 100k pairs 36 | - hm3d/minival: 10 scenes / 10k pairs 37 | - habitat-test-scenes: 3 scenes / 3k pairs 38 | - replica_cad_baked_lighting: 13 scenes / 13k pairs 39 | 40 | - Scenes from hm3d/val and hm3d/minival pairs were not used for the pre-training but kept for validation purposes. 41 | 42 | Download metadata and extract it: 43 | ```bash 44 | mkdir -p data/habitat_release_metadata/ 45 | cd data/habitat_release_metadata/ 46 | wget https://download.europe.naverlabs.com/ComputerVision/CroCo/data/habitat_release_metadata/multiview_habitat_metadata.tar.gz 47 | tar -xvf multiview_habitat_metadata.tar.gz 48 | cd ../.. 49 | # Location of the metadata 50 | METADATA_DIR="./data/habitat_release_metadata/multiview_habitat_metadata" 51 | ``` 52 | 53 | Generate image pairs from metadata: 54 | - The following command will print a list of commandlines to generate image pairs for each scene: 55 | ```bash 56 | # Target output directory 57 | PAIRS_DATASET_DIR="./data/habitat_release/" 58 | python datasets/habitat_sim/generate_from_metadata_files.py --input_dir=$METADATA_DIR --output_dir=$PAIRS_DATASET_DIR 59 | ``` 60 | - One can launch multiple of such commands in parallel e.g. using GNU Parallel: 61 | ```bash 62 | python datasets/habitat_sim/generate_from_metadata_files.py --input_dir=$METADATA_DIR --output_dir=$PAIRS_DATASET_DIR | parallel -j 16 63 | ``` 64 | 65 | ## Metadata generation 66 | 67 | Image pairs were randomly sampled using the following commands, whose outputs contain randomness and are thus not exactly reproducible: 68 | ```bash 69 | # Print commandlines to generate image pairs from the different scenes available. 70 | PAIRS_DATASET_DIR=MY_CUSTOM_PATH 71 | python datasets/habitat_sim/generate_multiview_images.py --list_commands --output_dir=$PAIRS_DATASET_DIR 72 | 73 | # Once a dataset is generated, pack metadata files for reproducibility. 74 | METADATA_DIR=MY_CUSTON_PATH 75 | python datasets/habitat_sim/pack_metadata_files.py $PAIRS_DATASET_DIR $METADATA_DIR 76 | ``` 77 | -------------------------------------------------------------------------------- /third_party/dust3r/croco/datasets/habitat_sim/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GradientSpaces/ReStyle3D/5a6440eeee53783530dbbda7c26898b8b4c351da/third_party/dust3r/croco/datasets/habitat_sim/__init__.py -------------------------------------------------------------------------------- /third_party/dust3r/croco/datasets/habitat_sim/generate_from_metadata.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2022-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | 4 | """ 5 | Script to generate image pairs for a given scene reproducing poses provided in a metadata file. 6 | """ 7 | import os 8 | from datasets.habitat_sim.multiview_habitat_sim_generator import MultiviewHabitatSimGenerator 9 | from datasets.habitat_sim.paths import SCENES_DATASET 10 | import argparse 11 | import quaternion 12 | import PIL.Image 13 | import cv2 14 | import json 15 | from tqdm import tqdm 16 | 17 | def generate_multiview_images_from_metadata(metadata_filename, 18 | output_dir, 19 | overload_params = dict(), 20 | scene_datasets_paths=None, 21 | exist_ok=False): 22 | """ 23 | Generate images from a metadata file for reproducibility purposes. 24 | """ 25 | # Reorder paths by decreasing label length, to avoid collisions when testing if a string by such label 26 | if scene_datasets_paths is not None: 27 | scene_datasets_paths = dict(sorted(scene_datasets_paths.items(), key= lambda x: len(x[0]), reverse=True)) 28 | 29 | with open(metadata_filename, 'r') as f: 30 | input_metadata = json.load(f) 31 | metadata = dict() 32 | for key, value in input_metadata.items(): 33 | # Optionally replace some paths 34 | if key in ("scene_dataset_config_file", "scene", "navmesh") and value != "": 35 | if scene_datasets_paths is not None: 36 | for dataset_label, dataset_path in scene_datasets_paths.items(): 37 | if value.startswith(dataset_label): 38 | value = os.path.normpath(os.path.join(dataset_path, os.path.relpath(value, dataset_label))) 39 | break 40 | metadata[key] = value 41 | 42 | # Overload some parameters 43 | for key, value in overload_params.items(): 44 | metadata[key] = value 45 | 46 | generation_entries = dict([(key, value) for key, value in metadata.items() if not (key in ('multiviews', 'output_dir', 'generate_depth'))]) 47 | generate_depth = metadata["generate_depth"] 48 | 49 | os.makedirs(output_dir, exist_ok=exist_ok) 50 | 51 | generator = MultiviewHabitatSimGenerator(**generation_entries) 52 | 53 | # Generate views 54 | for idx_label, data in tqdm(metadata['multiviews'].items()): 55 | positions = data["positions"] 56 | orientations = data["orientations"] 57 | n = len(positions) 58 | for oidx in range(n): 59 | observation = generator.render_viewpoint(positions[oidx], quaternion.from_float_array(orientations[oidx])) 60 | observation_label = f"{oidx + 1}" # Leonid is indexing starting from 1 61 | # Color image saved using PIL 62 | img = PIL.Image.fromarray(observation['color'][:,:,:3]) 63 | filename = os.path.join(output_dir, f"{idx_label}_{observation_label}.jpeg") 64 | img.save(filename) 65 | if generate_depth: 66 | # Depth image as EXR file 67 | filename = os.path.join(output_dir, f"{idx_label}_{observation_label}_depth.exr") 68 | cv2.imwrite(filename, observation['depth'], [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_HALF]) 69 | # Camera parameters 70 | camera_params = dict([(key, observation[key].tolist()) for key in ("camera_intrinsics", "R_cam2world", "t_cam2world")]) 71 | filename = os.path.join(output_dir, f"{idx_label}_{observation_label}_camera_params.json") 72 | with open(filename, "w") as f: 73 | json.dump(camera_params, f) 74 | # Save metadata 75 | with open(os.path.join(output_dir, "metadata.json"), "w") as f: 76 | json.dump(metadata, f) 77 | 78 | generator.close() 79 | 80 | if __name__ == "__main__": 81 | parser = argparse.ArgumentParser() 82 | parser.add_argument("--metadata_filename", required=True) 83 | parser.add_argument("--output_dir", required=True) 84 | args = parser.parse_args() 85 | 86 | generate_multiview_images_from_metadata(metadata_filename=args.metadata_filename, 87 | output_dir=args.output_dir, 88 | scene_datasets_paths=SCENES_DATASET, 89 | overload_params=dict(), 90 | exist_ok=True) 91 | 92 | -------------------------------------------------------------------------------- /third_party/dust3r/croco/datasets/habitat_sim/generate_from_metadata_files.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2022-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | 4 | """ 5 | Script generating commandlines to generate image pairs from metadata files. 6 | """ 7 | import os 8 | import glob 9 | from tqdm import tqdm 10 | import argparse 11 | 12 | if __name__ == "__main__": 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument("--input_dir", required=True) 15 | parser.add_argument("--output_dir", required=True) 16 | parser.add_argument("--prefix", default="", help="Commanline prefix, useful e.g. to setup environment.") 17 | args = parser.parse_args() 18 | 19 | input_metadata_filenames = glob.iglob(f"{args.input_dir}/**/metadata.json", recursive=True) 20 | 21 | for metadata_filename in tqdm(input_metadata_filenames): 22 | output_dir = os.path.join(args.output_dir, os.path.relpath(os.path.dirname(metadata_filename), args.input_dir)) 23 | # Do not process the scene if the metadata file already exists 24 | if os.path.exists(os.path.join(output_dir, "metadata.json")): 25 | continue 26 | commandline = f"{args.prefix}python datasets/habitat_sim/generate_from_metadata.py --metadata_filename={metadata_filename} --output_dir={output_dir}" 27 | print(commandline) 28 | -------------------------------------------------------------------------------- /third_party/dust3r/croco/datasets/habitat_sim/pack_metadata_files.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2022-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | """ 4 | Utility script to pack metadata files of the dataset in order to be able to re-generate it elsewhere. 5 | """ 6 | import os 7 | import glob 8 | from tqdm import tqdm 9 | import shutil 10 | import json 11 | from datasets.habitat_sim.paths import * 12 | import argparse 13 | import collections 14 | 15 | if __name__ == "__main__": 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument("input_dir") 18 | parser.add_argument("output_dir") 19 | args = parser.parse_args() 20 | 21 | input_dirname = args.input_dir 22 | output_dirname = args.output_dir 23 | 24 | input_metadata_filenames = glob.iglob(f"{input_dirname}/**/metadata.json", recursive=True) 25 | 26 | images_count = collections.defaultdict(lambda : 0) 27 | 28 | os.makedirs(output_dirname) 29 | for input_filename in tqdm(input_metadata_filenames): 30 | # Ignore empty files 31 | with open(input_filename, "r") as f: 32 | original_metadata = json.load(f) 33 | if "multiviews" not in original_metadata or len(original_metadata["multiviews"]) == 0: 34 | print("No views in", input_filename) 35 | continue 36 | 37 | relpath = os.path.relpath(input_filename, input_dirname) 38 | print(relpath) 39 | 40 | # Copy metadata, while replacing scene paths by generic keys depending on the dataset, for portability. 41 | # Data paths are sorted by decreasing length to avoid potential bugs due to paths starting by the same string pattern. 42 | scenes_dataset_paths = dict(sorted(SCENES_DATASET.items(), key=lambda x: len(x[1]), reverse=True)) 43 | metadata = dict() 44 | for key, value in original_metadata.items(): 45 | if key in ("scene_dataset_config_file", "scene", "navmesh") and value != "": 46 | known_path = False 47 | for dataset, dataset_path in scenes_dataset_paths.items(): 48 | if value.startswith(dataset_path): 49 | value = os.path.join(dataset, os.path.relpath(value, dataset_path)) 50 | known_path = True 51 | break 52 | if not known_path: 53 | raise KeyError("Unknown path:" + value) 54 | metadata[key] = value 55 | 56 | # Compile some general statistics while packing data 57 | scene_split = metadata["scene"].split("/") 58 | upper_level = "/".join(scene_split[:2]) if scene_split[0] == "hm3d" else scene_split[0] 59 | images_count[upper_level] += len(metadata["multiviews"]) 60 | 61 | output_filename = os.path.join(output_dirname, relpath) 62 | os.makedirs(os.path.dirname(output_filename), exist_ok=True) 63 | with open(output_filename, "w") as f: 64 | json.dump(metadata, f) 65 | 66 | # Print statistics 67 | print("Images count:") 68 | for upper_level, count in images_count.items(): 69 | print(f"- {upper_level}: {count}") -------------------------------------------------------------------------------- /third_party/dust3r/croco/datasets/habitat_sim/paths.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2022-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | 4 | """ 5 | Paths to Habitat-Sim scenes 6 | """ 7 | 8 | import os 9 | import json 10 | import collections 11 | from tqdm import tqdm 12 | 13 | 14 | # Hardcoded path to the different scene datasets 15 | SCENES_DATASET = { 16 | "hm3d": "./data/habitat-sim-data/scene_datasets/hm3d/", 17 | "gibson": "./data/habitat-sim-data/scene_datasets/gibson/", 18 | "habitat-test-scenes": "./data/habitat-sim/scene_datasets/habitat-test-scenes/", 19 | "replica_cad_baked_lighting": "./data/habitat-sim/scene_datasets/replica_cad_baked_lighting/", 20 | "replica_cad": "./data/habitat-sim/scene_datasets/replica_cad/", 21 | "replica": "./data/habitat-sim/scene_datasets/ReplicaDataset/", 22 | "scannet": "./data/habitat-sim/scene_datasets/scannet/" 23 | } 24 | 25 | SceneData = collections.namedtuple("SceneData", ["scene_dataset_config_file", "scene", "navmesh", "output_dir"]) 26 | 27 | def list_replicacad_scenes(base_output_dir, base_path=SCENES_DATASET["replica_cad"]): 28 | scene_dataset_config_file = os.path.join(base_path, "replicaCAD.scene_dataset_config.json") 29 | scenes = [f"apt_{i}" for i in range(6)] + ["empty_stage"] 30 | navmeshes = [f"navmeshes/apt_{i}_static_furniture.navmesh" for i in range(6)] + ["empty_stage.navmesh"] 31 | scenes_data = [] 32 | for idx in range(len(scenes)): 33 | output_dir = os.path.join(base_output_dir, "ReplicaCAD", scenes[idx]) 34 | # Add scene 35 | data = SceneData(scene_dataset_config_file=scene_dataset_config_file, 36 | scene = scenes[idx] + ".scene_instance.json", 37 | navmesh = os.path.join(base_path, navmeshes[idx]), 38 | output_dir = output_dir) 39 | scenes_data.append(data) 40 | return scenes_data 41 | 42 | def list_replica_cad_baked_lighting_scenes(base_output_dir, base_path=SCENES_DATASET["replica_cad_baked_lighting"]): 43 | scene_dataset_config_file = os.path.join(base_path, "replicaCAD_baked.scene_dataset_config.json") 44 | scenes = sum([[f"Baked_sc{i}_staging_{j:02}" for i in range(5)] for j in range(21)], []) 45 | navmeshes = ""#[f"navmeshes/apt_{i}_static_furniture.navmesh" for i in range(6)] + ["empty_stage.navmesh"] 46 | scenes_data = [] 47 | for idx in range(len(scenes)): 48 | output_dir = os.path.join(base_output_dir, "replica_cad_baked_lighting", scenes[idx]) 49 | data = SceneData(scene_dataset_config_file=scene_dataset_config_file, 50 | scene = scenes[idx], 51 | navmesh = "", 52 | output_dir = output_dir) 53 | scenes_data.append(data) 54 | return scenes_data 55 | 56 | def list_replica_scenes(base_output_dir, base_path): 57 | scenes_data = [] 58 | for scene_id in os.listdir(base_path): 59 | scene = os.path.join(base_path, scene_id, "mesh.ply") 60 | navmesh = os.path.join(base_path, scene_id, "habitat/mesh_preseg_semantic.navmesh") # Not sure if I should use it 61 | scene_dataset_config_file = "" 62 | output_dir = os.path.join(base_output_dir, scene_id) 63 | # Add scene only if it does not exist already, or if exist_ok 64 | data = SceneData(scene_dataset_config_file = scene_dataset_config_file, 65 | scene = scene, 66 | navmesh = navmesh, 67 | output_dir = output_dir) 68 | scenes_data.append(data) 69 | return scenes_data 70 | 71 | 72 | def list_scenes(base_output_dir, base_path): 73 | """ 74 | Generic method iterating through a base_path folder to find scenes. 75 | """ 76 | scenes_data = [] 77 | for root, dirs, files in os.walk(base_path, followlinks=True): 78 | folder_scenes_data = [] 79 | for file in files: 80 | name, ext = os.path.splitext(file) 81 | if ext == ".glb": 82 | scene = os.path.join(root, name + ".glb") 83 | navmesh = os.path.join(root, name + ".navmesh") 84 | if not os.path.exists(navmesh): 85 | navmesh = "" 86 | relpath = os.path.relpath(root, base_path) 87 | output_dir = os.path.abspath(os.path.join(base_output_dir, relpath, name)) 88 | data = SceneData(scene_dataset_config_file="", 89 | scene = scene, 90 | navmesh = navmesh, 91 | output_dir = output_dir) 92 | folder_scenes_data.append(data) 93 | 94 | # Specific check for HM3D: 95 | # When two meshesxxxx.basis.glb and xxxx.glb are present, use the 'basis' version. 96 | basis_scenes = [data.scene[:-len(".basis.glb")] for data in folder_scenes_data if data.scene.endswith(".basis.glb")] 97 | if len(basis_scenes) != 0: 98 | folder_scenes_data = [data for data in folder_scenes_data if not (data.scene[:-len(".glb")] in basis_scenes)] 99 | 100 | scenes_data.extend(folder_scenes_data) 101 | return scenes_data 102 | 103 | def list_scenes_available(base_output_dir, scenes_dataset_paths=SCENES_DATASET): 104 | scenes_data = [] 105 | 106 | # HM3D 107 | for split in ("minival", "train", "val", "examples"): 108 | scenes_data += list_scenes(base_output_dir=os.path.join(base_output_dir, f"hm3d/{split}/"), 109 | base_path=f"{scenes_dataset_paths['hm3d']}/{split}") 110 | 111 | # Gibson 112 | scenes_data += list_scenes(base_output_dir=os.path.join(base_output_dir, "gibson"), 113 | base_path=scenes_dataset_paths["gibson"]) 114 | 115 | # Habitat test scenes (just a few) 116 | scenes_data += list_scenes(base_output_dir=os.path.join(base_output_dir, "habitat-test-scenes"), 117 | base_path=scenes_dataset_paths["habitat-test-scenes"]) 118 | 119 | # ReplicaCAD (baked lightning) 120 | scenes_data += list_replica_cad_baked_lighting_scenes(base_output_dir=base_output_dir) 121 | 122 | # ScanNet 123 | scenes_data += list_scenes(base_output_dir=os.path.join(base_output_dir, "scannet"), 124 | base_path=scenes_dataset_paths["scannet"]) 125 | 126 | # Replica 127 | list_replica_scenes(base_output_dir=os.path.join(base_output_dir, "replica"), 128 | base_path=scenes_dataset_paths["replica"]) 129 | return scenes_data 130 | -------------------------------------------------------------------------------- /third_party/dust3r/croco/datasets/pairs_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2022-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | 4 | import os 5 | from torch.utils.data import Dataset 6 | from PIL import Image 7 | 8 | from datasets.transforms import get_pair_transforms 9 | 10 | def load_image(impath): 11 | return Image.open(impath) 12 | 13 | def load_pairs_from_cache_file(fname, root=''): 14 | assert os.path.isfile(fname), "cannot parse pairs from {:s}, file does not exist".format(fname) 15 | with open(fname, 'r') as fid: 16 | lines = fid.read().strip().splitlines() 17 | pairs = [ (os.path.join(root,l.split()[0]), os.path.join(root,l.split()[1])) for l in lines] 18 | return pairs 19 | 20 | def load_pairs_from_list_file(fname, root=''): 21 | assert os.path.isfile(fname), "cannot parse pairs from {:s}, file does not exist".format(fname) 22 | with open(fname, 'r') as fid: 23 | lines = fid.read().strip().splitlines() 24 | pairs = [ (os.path.join(root,l+'_1.jpg'), os.path.join(root,l+'_2.jpg')) for l in lines if not l.startswith('#')] 25 | return pairs 26 | 27 | 28 | def write_cache_file(fname, pairs, root=''): 29 | if len(root)>0: 30 | if not root.endswith('/'): root+='/' 31 | assert os.path.isdir(root) 32 | s = '' 33 | for im1, im2 in pairs: 34 | if len(root)>0: 35 | assert im1.startswith(root), im1 36 | assert im2.startswith(root), im2 37 | s += '{:s} {:s}\n'.format(im1[len(root):], im2[len(root):]) 38 | with open(fname, 'w') as fid: 39 | fid.write(s[:-1]) 40 | 41 | def parse_and_cache_all_pairs(dname, data_dir='./data/'): 42 | if dname=='habitat_release': 43 | dirname = os.path.join(data_dir, 'habitat_release') 44 | assert os.path.isdir(dirname), "cannot find folder for habitat_release pairs: "+dirname 45 | cache_file = os.path.join(dirname, 'pairs.txt') 46 | assert not os.path.isfile(cache_file), "cache file already exists: "+cache_file 47 | 48 | print('Parsing pairs for dataset: '+dname) 49 | pairs = [] 50 | for root, dirs, files in os.walk(dirname): 51 | if 'val' in root: continue 52 | dirs.sort() 53 | pairs += [ (os.path.join(root,f), os.path.join(root,f[:-len('_1.jpeg')]+'_2.jpeg')) for f in sorted(files) if f.endswith('_1.jpeg')] 54 | print('Found {:,} pairs'.format(len(pairs))) 55 | print('Writing cache to: '+cache_file) 56 | write_cache_file(cache_file, pairs, root=dirname) 57 | 58 | else: 59 | raise NotImplementedError('Unknown dataset: '+dname) 60 | 61 | def dnames_to_image_pairs(dnames, data_dir='./data/'): 62 | """ 63 | dnames: list of datasets with image pairs, separated by + 64 | """ 65 | all_pairs = [] 66 | for dname in dnames.split('+'): 67 | if dname=='habitat_release': 68 | dirname = os.path.join(data_dir, 'habitat_release') 69 | assert os.path.isdir(dirname), "cannot find folder for habitat_release pairs: "+dirname 70 | cache_file = os.path.join(dirname, 'pairs.txt') 71 | assert os.path.isfile(cache_file), "cannot find cache file for habitat_release pairs, please first create the cache file, see instructions. "+cache_file 72 | pairs = load_pairs_from_cache_file(cache_file, root=dirname) 73 | elif dname in ['ARKitScenes', 'MegaDepth', '3DStreetView', 'IndoorVL']: 74 | dirname = os.path.join(data_dir, dname+'_crops') 75 | assert os.path.isdir(dirname), "cannot find folder for {:s} pairs: {:s}".format(dname, dirname) 76 | list_file = os.path.join(dirname, 'listing.txt') 77 | assert os.path.isfile(list_file), "cannot find list file for {:s} pairs, see instructions. {:s}".format(dname, list_file) 78 | pairs = load_pairs_from_list_file(list_file, root=dirname) 79 | print(' {:s}: {:,} pairs'.format(dname, len(pairs))) 80 | all_pairs += pairs 81 | if '+' in dnames: print(' Total: {:,} pairs'.format(len(all_pairs))) 82 | return all_pairs 83 | 84 | 85 | class PairsDataset(Dataset): 86 | 87 | def __init__(self, dnames, trfs='', totensor=True, normalize=True, data_dir='./data/'): 88 | super().__init__() 89 | self.image_pairs = dnames_to_image_pairs(dnames, data_dir=data_dir) 90 | self.transforms = get_pair_transforms(transform_str=trfs, totensor=totensor, normalize=normalize) 91 | 92 | def __len__(self): 93 | return len(self.image_pairs) 94 | 95 | def __getitem__(self, index): 96 | im1path, im2path = self.image_pairs[index] 97 | im1 = load_image(im1path) 98 | im2 = load_image(im2path) 99 | if self.transforms is not None: im1, im2 = self.transforms(im1, im2) 100 | return im1, im2 101 | 102 | 103 | if __name__=="__main__": 104 | import argparse 105 | parser = argparse.ArgumentParser(prog="Computing and caching list of pairs for a given dataset") 106 | parser.add_argument('--data_dir', default='./data/', type=str, help="path where data are stored") 107 | parser.add_argument('--dataset', default='habitat_release', type=str, help="name of the dataset") 108 | args = parser.parse_args() 109 | parse_and_cache_all_pairs(dname=args.dataset, data_dir=args.data_dir) 110 | -------------------------------------------------------------------------------- /third_party/dust3r/croco/datasets/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2022-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | 4 | import torch 5 | import torchvision.transforms 6 | import torchvision.transforms.functional as F 7 | 8 | # "Pair": apply a transform on a pair 9 | # "Both": apply the exact same transform to both images 10 | 11 | class ComposePair(torchvision.transforms.Compose): 12 | def __call__(self, img1, img2): 13 | for t in self.transforms: 14 | img1, img2 = t(img1, img2) 15 | return img1, img2 16 | 17 | class NormalizeBoth(torchvision.transforms.Normalize): 18 | def forward(self, img1, img2): 19 | img1 = super().forward(img1) 20 | img2 = super().forward(img2) 21 | return img1, img2 22 | 23 | class ToTensorBoth(torchvision.transforms.ToTensor): 24 | def __call__(self, img1, img2): 25 | img1 = super().__call__(img1) 26 | img2 = super().__call__(img2) 27 | return img1, img2 28 | 29 | class RandomCropPair(torchvision.transforms.RandomCrop): 30 | # the crop will be intentionally different for the two images with this class 31 | def forward(self, img1, img2): 32 | img1 = super().forward(img1) 33 | img2 = super().forward(img2) 34 | return img1, img2 35 | 36 | class ColorJitterPair(torchvision.transforms.ColorJitter): 37 | # can be symmetric (same for both images) or assymetric (different jitter params for each image) depending on assymetric_prob 38 | def __init__(self, assymetric_prob, **kwargs): 39 | super().__init__(**kwargs) 40 | self.assymetric_prob = assymetric_prob 41 | def jitter_one(self, img, fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor): 42 | for fn_id in fn_idx: 43 | if fn_id == 0 and brightness_factor is not None: 44 | img = F.adjust_brightness(img, brightness_factor) 45 | elif fn_id == 1 and contrast_factor is not None: 46 | img = F.adjust_contrast(img, contrast_factor) 47 | elif fn_id == 2 and saturation_factor is not None: 48 | img = F.adjust_saturation(img, saturation_factor) 49 | elif fn_id == 3 and hue_factor is not None: 50 | img = F.adjust_hue(img, hue_factor) 51 | return img 52 | 53 | def forward(self, img1, img2): 54 | 55 | fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor = self.get_params( 56 | self.brightness, self.contrast, self.saturation, self.hue 57 | ) 58 | img1 = self.jitter_one(img1, fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor) 59 | if torch.rand(1) < self.assymetric_prob: # assymetric: 60 | fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor = self.get_params( 61 | self.brightness, self.contrast, self.saturation, self.hue 62 | ) 63 | img2 = self.jitter_one(img2, fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor) 64 | return img1, img2 65 | 66 | def get_pair_transforms(transform_str, totensor=True, normalize=True): 67 | # transform_str is eg crop224+color 68 | trfs = [] 69 | for s in transform_str.split('+'): 70 | if s.startswith('crop'): 71 | size = int(s[len('crop'):]) 72 | trfs.append(RandomCropPair(size)) 73 | elif s=='acolor': 74 | trfs.append(ColorJitterPair(assymetric_prob=1.0, brightness=(0.6, 1.4), contrast=(0.6, 1.4), saturation=(0.6, 1.4), hue=0.0)) 75 | elif s=='': # if transform_str was "" 76 | pass 77 | else: 78 | raise NotImplementedError('Unknown augmentation: '+s) 79 | 80 | if totensor: 81 | trfs.append( ToTensorBoth() ) 82 | if normalize: 83 | trfs.append( NormalizeBoth(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ) 84 | 85 | if len(trfs)==0: 86 | return None 87 | elif len(trfs)==1: 88 | return trfs 89 | else: 90 | return ComposePair(trfs) 91 | 92 | 93 | 94 | 95 | 96 | -------------------------------------------------------------------------------- /third_party/dust3r/croco/demo.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2022-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | 4 | import torch 5 | from models.croco import CroCoNet 6 | from PIL import Image 7 | import torchvision.transforms 8 | from torchvision.transforms import ToTensor, Normalize, Compose 9 | 10 | def main(): 11 | device = torch.device('cuda:0' if torch.cuda.is_available() and torch.cuda.device_count()>0 else 'cpu') 12 | 13 | # load 224x224 images and transform them to tensor 14 | imagenet_mean = [0.485, 0.456, 0.406] 15 | imagenet_mean_tensor = torch.tensor(imagenet_mean).view(1,3,1,1).to(device, non_blocking=True) 16 | imagenet_std = [0.229, 0.224, 0.225] 17 | imagenet_std_tensor = torch.tensor(imagenet_std).view(1,3,1,1).to(device, non_blocking=True) 18 | trfs = Compose([ToTensor(), Normalize(mean=imagenet_mean, std=imagenet_std)]) 19 | image1 = trfs(Image.open('assets/Chateau1.png').convert('RGB')).to(device, non_blocking=True).unsqueeze(0) 20 | image2 = trfs(Image.open('assets/Chateau2.png').convert('RGB')).to(device, non_blocking=True).unsqueeze(0) 21 | 22 | # load model 23 | ckpt = torch.load('pretrained_models/CroCo_V2_ViTLarge_BaseDecoder.pth', 'cpu') 24 | model = CroCoNet( **ckpt.get('croco_kwargs',{})).to(device) 25 | model.eval() 26 | msg = model.load_state_dict(ckpt['model'], strict=True) 27 | 28 | # forward 29 | with torch.inference_mode(): 30 | out, mask, target = model(image1, image2) 31 | 32 | # the output is normalized, thus use the mean/std of the actual image to go back to RGB space 33 | patchified = model.patchify(image1) 34 | mean = patchified.mean(dim=-1, keepdim=True) 35 | var = patchified.var(dim=-1, keepdim=True) 36 | decoded_image = model.unpatchify(out * (var + 1.e-6)**.5 + mean) 37 | # undo imagenet normalization, prepare masked image 38 | decoded_image = decoded_image * imagenet_std_tensor + imagenet_mean_tensor 39 | input_image = image1 * imagenet_std_tensor + imagenet_mean_tensor 40 | ref_image = image2 * imagenet_std_tensor + imagenet_mean_tensor 41 | image_masks = model.unpatchify(model.patchify(torch.ones_like(ref_image)) * mask[:,:,None]) 42 | masked_input_image = ((1 - image_masks) * input_image) 43 | 44 | # make visualization 45 | visualization = torch.cat((ref_image, masked_input_image, decoded_image, input_image), dim=3) # 4*(B, 3, H, W) -> B, 3, H, W*4 46 | B, C, H, W = visualization.shape 47 | visualization = visualization.permute(1, 0, 2, 3).reshape(C, B*H, W) 48 | visualization = torchvision.transforms.functional.to_pil_image(torch.clamp(visualization, 0, 1)) 49 | fname = "demo_output.png" 50 | visualization.save(fname) 51 | print('Visualization save in '+fname) 52 | 53 | 54 | if __name__=="__main__": 55 | main() 56 | -------------------------------------------------------------------------------- /third_party/dust3r/croco/models/criterion.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2022-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | # 4 | # -------------------------------------------------------- 5 | # Criterion to train CroCo 6 | # -------------------------------------------------------- 7 | # References: 8 | # MAE: https://github.com/facebookresearch/mae 9 | # -------------------------------------------------------- 10 | 11 | import torch 12 | 13 | class MaskedMSE(torch.nn.Module): 14 | 15 | def __init__(self, norm_pix_loss=False, masked=True): 16 | """ 17 | norm_pix_loss: normalize each patch by their pixel mean and variance 18 | masked: compute loss over the masked patches only 19 | """ 20 | super().__init__() 21 | self.norm_pix_loss = norm_pix_loss 22 | self.masked = masked 23 | 24 | def forward(self, pred, mask, target): 25 | 26 | if self.norm_pix_loss: 27 | mean = target.mean(dim=-1, keepdim=True) 28 | var = target.var(dim=-1, keepdim=True) 29 | target = (target - mean) / (var + 1.e-6)**.5 30 | 31 | loss = (pred - target) ** 2 32 | loss = loss.mean(dim=-1) # [N, L], mean loss per patch 33 | if self.masked: 34 | loss = (loss * mask).sum() / mask.sum() # mean loss on masked patches 35 | else: 36 | loss = loss.mean() # mean loss 37 | return loss 38 | -------------------------------------------------------------------------------- /third_party/dust3r/croco/models/croco_downstream.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2022-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | 4 | # -------------------------------------------------------- 5 | # CroCo model for downstream tasks 6 | # -------------------------------------------------------- 7 | 8 | import torch 9 | 10 | from .croco import CroCoNet 11 | 12 | 13 | def croco_args_from_ckpt(ckpt): 14 | if 'croco_kwargs' in ckpt: # CroCo v2 released models 15 | return ckpt['croco_kwargs'] 16 | elif 'args' in ckpt and hasattr(ckpt['args'], 'model'): # pretrained using the official code release 17 | s = ckpt['args'].model # eg "CroCoNet(enc_embed_dim=1024, enc_num_heads=16, enc_depth=24)" 18 | assert s.startswith('CroCoNet(') 19 | return eval('dict'+s[len('CroCoNet'):]) # transform it into the string of a dictionary and evaluate it 20 | else: # CroCo v1 released models 21 | return dict() 22 | 23 | class CroCoDownstreamMonocularEncoder(CroCoNet): 24 | 25 | def __init__(self, 26 | head, 27 | **kwargs): 28 | """ Build network for monocular downstream task, only using the encoder. 29 | It takes an extra argument head, that is called with the features 30 | and a dictionary img_info containing 'width' and 'height' keys 31 | The head is setup with the croconet arguments in this init function 32 | NOTE: It works by *calling super().__init__() but with redefined setters 33 | 34 | """ 35 | super(CroCoDownstreamMonocularEncoder, self).__init__(**kwargs) 36 | head.setup(self) 37 | self.head = head 38 | 39 | def _set_mask_generator(self, *args, **kwargs): 40 | """ No mask generator """ 41 | return 42 | 43 | def _set_mask_token(self, *args, **kwargs): 44 | """ No mask token """ 45 | self.mask_token = None 46 | return 47 | 48 | def _set_decoder(self, *args, **kwargs): 49 | """ No decoder """ 50 | return 51 | 52 | def _set_prediction_head(self, *args, **kwargs): 53 | """ No 'prediction head' for downstream tasks.""" 54 | return 55 | 56 | def forward(self, img): 57 | """ 58 | img if of size batch_size x 3 x h x w 59 | """ 60 | B, C, H, W = img.size() 61 | img_info = {'height': H, 'width': W} 62 | need_all_layers = hasattr(self.head, 'return_all_blocks') and self.head.return_all_blocks 63 | out, _, _ = self._encode_image(img, do_mask=False, return_all_blocks=need_all_layers) 64 | return self.head(out, img_info) 65 | 66 | 67 | class CroCoDownstreamBinocular(CroCoNet): 68 | 69 | def __init__(self, 70 | head, 71 | **kwargs): 72 | """ Build network for binocular downstream task 73 | It takes an extra argument head, that is called with the features 74 | and a dictionary img_info containing 'width' and 'height' keys 75 | The head is setup with the croconet arguments in this init function 76 | """ 77 | super(CroCoDownstreamBinocular, self).__init__(**kwargs) 78 | head.setup(self) 79 | self.head = head 80 | 81 | def _set_mask_generator(self, *args, **kwargs): 82 | """ No mask generator """ 83 | return 84 | 85 | def _set_mask_token(self, *args, **kwargs): 86 | """ No mask token """ 87 | self.mask_token = None 88 | return 89 | 90 | def _set_prediction_head(self, *args, **kwargs): 91 | """ No prediction head for downstream tasks, define your own head """ 92 | return 93 | 94 | def encode_image_pairs(self, img1, img2, return_all_blocks=False): 95 | """ run encoder for a pair of images 96 | it is actually ~5% faster to concatenate the images along the batch dimension 97 | than to encode them separately 98 | """ 99 | ## the two commented lines below is the naive version with separate encoding 100 | #out, pos, _ = self._encode_image(img1, do_mask=False, return_all_blocks=return_all_blocks) 101 | #out2, pos2, _ = self._encode_image(img2, do_mask=False, return_all_blocks=False) 102 | ## and now the faster version 103 | out, pos, _ = self._encode_image( torch.cat( (img1,img2), dim=0), do_mask=False, return_all_blocks=return_all_blocks ) 104 | if return_all_blocks: 105 | out,out2 = list(map(list, zip(*[o.chunk(2, dim=0) for o in out]))) 106 | out2 = out2[-1] 107 | else: 108 | out,out2 = out.chunk(2, dim=0) 109 | pos,pos2 = pos.chunk(2, dim=0) 110 | return out, out2, pos, pos2 111 | 112 | def forward(self, img1, img2): 113 | B, C, H, W = img1.size() 114 | img_info = {'height': H, 'width': W} 115 | return_all_blocks = hasattr(self.head, 'return_all_blocks') and self.head.return_all_blocks 116 | out, out2, pos, pos2 = self.encode_image_pairs(img1, img2, return_all_blocks=return_all_blocks) 117 | if return_all_blocks: 118 | decout = self._decoder(out[-1], pos, None, out2, pos2, return_all_blocks=return_all_blocks) 119 | decout = out+decout 120 | else: 121 | decout = self._decoder(out, pos, None, out2, pos2, return_all_blocks=return_all_blocks) 122 | return self.head(decout, img_info) -------------------------------------------------------------------------------- /third_party/dust3r/croco/models/curope/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2022-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | 4 | from .curope2d import cuRoPE2D 5 | -------------------------------------------------------------------------------- /third_party/dust3r/croco/models/curope/build/temp.linux-x86_64-cpython-38/build.ninja: -------------------------------------------------------------------------------- 1 | ninja_required_version = 1.3 2 | cxx = g++ 3 | nvcc = /usr/local/cuda/bin/nvcc 4 | 5 | cflags = -pthread -B /root/anaconda3/envs/3studio/compiler_compat -Wl,--sysroot=/ -Wsign-compare -DNDEBUG -g -fwrapv -O3 -Wall -Wstrict-prototypes -fPIC -I/root/anaconda3/envs/3studio/lib/python3.8/site-packages/torch/include -I/root/anaconda3/envs/3studio/lib/python3.8/site-packages/torch/include/torch/csrc/api/include -I/root/anaconda3/envs/3studio/lib/python3.8/site-packages/torch/include/TH -I/root/anaconda3/envs/3studio/lib/python3.8/site-packages/torch/include/THC -I/usr/local/cuda/include -I/root/anaconda3/envs/3studio/include/python3.8 -c 6 | post_cflags = -O3 -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1011"' -DTORCH_EXTENSION_NAME=curope -D_GLIBCXX_USE_CXX11_ABI=0 -std=c++17 7 | cuda_cflags = -I/root/anaconda3/envs/3studio/lib/python3.8/site-packages/torch/include -I/root/anaconda3/envs/3studio/lib/python3.8/site-packages/torch/include/torch/csrc/api/include -I/root/anaconda3/envs/3studio/lib/python3.8/site-packages/torch/include/TH -I/root/anaconda3/envs/3studio/lib/python3.8/site-packages/torch/include/THC -I/usr/local/cuda/include -I/root/anaconda3/envs/3studio/include/python3.8 -c 8 | cuda_post_cflags = -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options ''"'"'-fPIC'"'"'' -O3 --ptxas-options=-v --use_fast_math -gencode arch=compute_37,code=sm_37 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_70,code=sm_70 -gencode arch=compute_75,code=sm_75 -gencode arch=compute_80,code=sm_80 -gencode arch=compute_86,code=sm_86 -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1011"' -DTORCH_EXTENSION_NAME=curope -D_GLIBCXX_USE_CXX11_ABI=0 -std=c++17 9 | cuda_dlink_post_cflags = 10 | ldflags = 11 | 12 | rule compile 13 | command = $cxx -MMD -MF $out.d $cflags -c $in -o $out $post_cflags 14 | depfile = $out.d 15 | deps = gcc 16 | 17 | rule cuda_compile 18 | depfile = $out.d 19 | deps = gcc 20 | command = $nvcc $cuda_cflags -c $in -o $out $cuda_post_cflags 21 | 22 | 23 | 24 | 25 | 26 | build /apdcephfs_cq10/share_1290939/karmyu/dust3r-gaussian-splatting/dust3r/croco/models/curope/build/temp.linux-x86_64-cpython-38/curope.o: compile /apdcephfs_cq10/share_1290939/karmyu/dust3r-gaussian-splatting/dust3r/croco/models/curope/curope.cpp 27 | build /apdcephfs_cq10/share_1290939/karmyu/dust3r-gaussian-splatting/dust3r/croco/models/curope/build/temp.linux-x86_64-cpython-38/kernels.o: cuda_compile /apdcephfs_cq10/share_1290939/karmyu/dust3r-gaussian-splatting/dust3r/croco/models/curope/kernels.cu 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | -------------------------------------------------------------------------------- /third_party/dust3r/croco/models/curope/curope.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright (C) 2022-present Naver Corporation. All rights reserved. 3 | Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 4 | */ 5 | 6 | #include 7 | 8 | // forward declaration 9 | void rope_2d_cuda( torch::Tensor tokens, const torch::Tensor pos, const float base, const float fwd ); 10 | 11 | void rope_2d_cpu( torch::Tensor tokens, const torch::Tensor positions, const float base, const float fwd ) 12 | { 13 | const int B = tokens.size(0); 14 | const int N = tokens.size(1); 15 | const int H = tokens.size(2); 16 | const int D = tokens.size(3) / 4; 17 | 18 | auto tok = tokens.accessor(); 19 | auto pos = positions.accessor(); 20 | 21 | for (int b = 0; b < B; b++) { 22 | for (int x = 0; x < 2; x++) { // y and then x (2d) 23 | for (int n = 0; n < N; n++) { 24 | 25 | // grab the token position 26 | const int p = pos[b][n][x]; 27 | 28 | for (int h = 0; h < H; h++) { 29 | for (int d = 0; d < D; d++) { 30 | // grab the two values 31 | float u = tok[b][n][h][d+0+x*2*D]; 32 | float v = tok[b][n][h][d+D+x*2*D]; 33 | 34 | // grab the cos,sin 35 | const float inv_freq = fwd * p / powf(base, d/float(D)); 36 | float c = cosf(inv_freq); 37 | float s = sinf(inv_freq); 38 | 39 | // write the result 40 | tok[b][n][h][d+0+x*2*D] = u*c - v*s; 41 | tok[b][n][h][d+D+x*2*D] = v*c + u*s; 42 | } 43 | } 44 | } 45 | } 46 | } 47 | } 48 | 49 | void rope_2d( torch::Tensor tokens, // B,N,H,D 50 | const torch::Tensor positions, // B,N,2 51 | const float base, 52 | const float fwd ) 53 | { 54 | TORCH_CHECK(tokens.dim() == 4, "tokens must have 4 dimensions"); 55 | TORCH_CHECK(positions.dim() == 3, "positions must have 3 dimensions"); 56 | TORCH_CHECK(tokens.size(0) == positions.size(0), "batch size differs between tokens & positions"); 57 | TORCH_CHECK(tokens.size(1) == positions.size(1), "seq_length differs between tokens & positions"); 58 | TORCH_CHECK(positions.size(2) == 2, "positions.shape[2] must be equal to 2"); 59 | TORCH_CHECK(tokens.is_cuda() == positions.is_cuda(), "tokens and positions are not on the same device" ); 60 | 61 | if (tokens.is_cuda()) 62 | rope_2d_cuda( tokens, positions, base, fwd ); 63 | else 64 | rope_2d_cpu( tokens, positions, base, fwd ); 65 | } 66 | 67 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 68 | m.def("rope_2d", &rope_2d, "RoPE 2d forward/backward"); 69 | } 70 | -------------------------------------------------------------------------------- /third_party/dust3r/croco/models/curope/curope2d.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2022-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | 4 | import torch 5 | 6 | try: 7 | import curope as _kernels # run `python setup.py install` 8 | except ModuleNotFoundError: 9 | from . import curope as _kernels # run `python setup.py build_ext --inplace` 10 | 11 | 12 | class cuRoPE2D_func (torch.autograd.Function): 13 | 14 | @staticmethod 15 | def forward(ctx, tokens, positions, base, F0=1): 16 | ctx.save_for_backward(positions) 17 | ctx.saved_base = base 18 | ctx.saved_F0 = F0 19 | # tokens = tokens.clone() # uncomment this if inplace doesn't work 20 | _kernels.rope_2d( tokens, positions, base, F0 ) 21 | ctx.mark_dirty(tokens) 22 | return tokens 23 | 24 | @staticmethod 25 | def backward(ctx, grad_res): 26 | positions, base, F0 = ctx.saved_tensors[0], ctx.saved_base, ctx.saved_F0 27 | _kernels.rope_2d( grad_res, positions, base, -F0 ) 28 | ctx.mark_dirty(grad_res) 29 | return grad_res, None, None, None 30 | 31 | 32 | class cuRoPE2D(torch.nn.Module): 33 | def __init__(self, freq=100.0, F0=1.0): 34 | super().__init__() 35 | self.base = freq 36 | self.F0 = F0 37 | 38 | def forward(self, tokens, positions): 39 | cuRoPE2D_func.apply( tokens.transpose(1,2), positions, self.base, self.F0 ) 40 | return tokens -------------------------------------------------------------------------------- /third_party/dust3r/croco/models/curope/kernels.cu: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright (C) 2022-present Naver Corporation. All rights reserved. 3 | Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 4 | */ 5 | 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | #define CHECK_CUDA(tensor) {\ 12 | TORCH_CHECK((tensor).is_cuda(), #tensor " is not in cuda memory"); \ 13 | TORCH_CHECK((tensor).is_contiguous(), #tensor " is not contiguous"); } 14 | void CHECK_KERNEL() {auto error = cudaGetLastError(); TORCH_CHECK( error == cudaSuccess, cudaGetErrorString(error));} 15 | 16 | 17 | template < typename scalar_t > 18 | __global__ void rope_2d_cuda_kernel( 19 | //scalar_t* __restrict__ tokens, 20 | torch::PackedTensorAccessor32 tokens, 21 | const int64_t* __restrict__ pos, 22 | const float base, 23 | const float fwd ) 24 | // const int N, const int H, const int D ) 25 | { 26 | // tokens shape = (B, N, H, D) 27 | const int N = tokens.size(1); 28 | const int H = tokens.size(2); 29 | const int D = tokens.size(3); 30 | 31 | // each block update a single token, for all heads 32 | // each thread takes care of a single output 33 | extern __shared__ float shared[]; 34 | float* shared_inv_freq = shared + D; 35 | 36 | const int b = blockIdx.x / N; 37 | const int n = blockIdx.x % N; 38 | 39 | const int Q = D / 4; 40 | // one token = [0..Q : Q..2Q : 2Q..3Q : 3Q..D] 41 | // u_Y v_Y u_X v_X 42 | 43 | // shared memory: first, compute inv_freq 44 | if (threadIdx.x < Q) 45 | shared_inv_freq[threadIdx.x] = fwd / powf(base, threadIdx.x/float(Q)); 46 | __syncthreads(); 47 | 48 | // start of X or Y part 49 | const int X = threadIdx.x < D/2 ? 0 : 1; 50 | const int m = (X*D/2) + (threadIdx.x % Q); // index of u_Y or u_X 51 | 52 | // grab the cos,sin appropriate for me 53 | const float freq = pos[blockIdx.x*2+X] * shared_inv_freq[threadIdx.x % Q]; 54 | const float cos = cosf(freq); 55 | const float sin = sinf(freq); 56 | /* 57 | float* shared_cos_sin = shared + D + D/4; 58 | if ((threadIdx.x % (D/2)) < Q) 59 | shared_cos_sin[m+0] = cosf(freq); 60 | else 61 | shared_cos_sin[m+Q] = sinf(freq); 62 | __syncthreads(); 63 | const float cos = shared_cos_sin[m+0]; 64 | const float sin = shared_cos_sin[m+Q]; 65 | */ 66 | 67 | for (int h = 0; h < H; h++) 68 | { 69 | // then, load all the token for this head in shared memory 70 | shared[threadIdx.x] = tokens[b][n][h][threadIdx.x]; 71 | __syncthreads(); 72 | 73 | const float u = shared[m]; 74 | const float v = shared[m+Q]; 75 | 76 | // write output 77 | if ((threadIdx.x % (D/2)) < Q) 78 | tokens[b][n][h][threadIdx.x] = u*cos - v*sin; 79 | else 80 | tokens[b][n][h][threadIdx.x] = v*cos + u*sin; 81 | } 82 | } 83 | 84 | void rope_2d_cuda( torch::Tensor tokens, const torch::Tensor pos, const float base, const float fwd ) 85 | { 86 | const int B = tokens.size(0); // batch size 87 | const int N = tokens.size(1); // sequence length 88 | const int H = tokens.size(2); // number of heads 89 | const int D = tokens.size(3); // dimension per head 90 | 91 | TORCH_CHECK(tokens.stride(3) == 1 && tokens.stride(2) == D, "tokens are not contiguous"); 92 | TORCH_CHECK(pos.is_contiguous(), "positions are not contiguous"); 93 | TORCH_CHECK(pos.size(0) == B && pos.size(1) == N && pos.size(2) == 2, "bad pos.shape"); 94 | TORCH_CHECK(D % 4 == 0, "token dim must be multiple of 4"); 95 | 96 | // one block for each layer, one thread per local-max 97 | const int THREADS_PER_BLOCK = D; 98 | const int N_BLOCKS = B * N; // each block takes care of H*D values 99 | const int SHARED_MEM = sizeof(float) * (D + D/4); 100 | 101 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(tokens.type(), "rope_2d_cuda", ([&] { 102 | rope_2d_cuda_kernel <<>> ( 103 | //tokens.data_ptr(), 104 | tokens.packed_accessor32(), 105 | pos.data_ptr(), 106 | base, fwd); //, N, H, D ); 107 | })); 108 | } 109 | -------------------------------------------------------------------------------- /third_party/dust3r/croco/models/curope/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2022-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | 4 | from setuptools import setup 5 | from torch import cuda 6 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 7 | 8 | # compile for all possible CUDA architectures 9 | all_cuda_archs = cuda.get_gencode_flags().replace('compute=','arch=').split() 10 | # alternatively, you can list cuda archs that you want, eg: 11 | # all_cuda_archs = [ 12 | # '-gencode', 'arch=compute_70,code=sm_70', 13 | # '-gencode', 'arch=compute_75,code=sm_75', 14 | # '-gencode', 'arch=compute_80,code=sm_80', 15 | # '-gencode', 'arch=compute_86,code=sm_86' 16 | # ] 17 | 18 | setup( 19 | name = 'curope', 20 | ext_modules = [ 21 | CUDAExtension( 22 | name='curope', 23 | sources=[ 24 | "curope.cpp", 25 | "kernels.cu", 26 | ], 27 | extra_compile_args = dict( 28 | nvcc=['-O3','--ptxas-options=-v',"--use_fast_math"]+all_cuda_archs, 29 | cxx=['-O3']) 30 | ) 31 | ], 32 | cmdclass = { 33 | 'build_ext': BuildExtension 34 | }) 35 | -------------------------------------------------------------------------------- /third_party/dust3r/croco/models/head_downstream.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2022-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | 4 | # -------------------------------------------------------- 5 | # Heads for downstream tasks 6 | # -------------------------------------------------------- 7 | 8 | """ 9 | A head is a module where the __init__ defines only the head hyperparameters. 10 | A method setup(croconet) takes a CroCoNet and set all layers according to the head and croconet attributes. 11 | The forward takes the features as well as a dictionary img_info containing the keys 'width' and 'height' 12 | """ 13 | 14 | import torch 15 | import torch.nn as nn 16 | from .dpt_block import DPTOutputAdapter 17 | 18 | 19 | class PixelwiseTaskWithDPT(nn.Module): 20 | """ DPT module for CroCo. 21 | by default, hooks_idx will be equal to: 22 | * for encoder-only: 4 equally spread layers 23 | * for encoder+decoder: last encoder + 3 equally spread layers of the decoder 24 | """ 25 | 26 | def __init__(self, *, hooks_idx=None, layer_dims=[96,192,384,768], 27 | output_width_ratio=1, num_channels=1, postprocess=None, **kwargs): 28 | super(PixelwiseTaskWithDPT, self).__init__() 29 | self.return_all_blocks = True # backbone needs to return all layers 30 | self.postprocess = postprocess 31 | self.output_width_ratio = output_width_ratio 32 | self.num_channels = num_channels 33 | self.hooks_idx = hooks_idx 34 | self.layer_dims = layer_dims 35 | 36 | def setup(self, croconet): 37 | dpt_args = {'output_width_ratio': self.output_width_ratio, 'num_channels': self.num_channels} 38 | if self.hooks_idx is None: 39 | if hasattr(croconet, 'dec_blocks'): # encoder + decoder 40 | step = {8: 3, 12: 4, 24: 8}[croconet.dec_depth] 41 | hooks_idx = [croconet.dec_depth+croconet.enc_depth-1-i*step for i in range(3,-1,-1)] 42 | else: # encoder only 43 | step = croconet.enc_depth//4 44 | hooks_idx = [croconet.enc_depth-1-i*step for i in range(3,-1,-1)] 45 | self.hooks_idx = hooks_idx 46 | print(f' PixelwiseTaskWithDPT: automatically setting hook_idxs={self.hooks_idx}') 47 | dpt_args['hooks'] = self.hooks_idx 48 | dpt_args['layer_dims'] = self.layer_dims 49 | self.dpt = DPTOutputAdapter(**dpt_args) 50 | dim_tokens = [croconet.enc_embed_dim if hook0: 36 | pos_embed = np.concatenate([np.zeros([n_cls_token, embed_dim]), pos_embed], axis=0) 37 | return pos_embed 38 | 39 | 40 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 41 | assert embed_dim % 2 == 0 42 | 43 | # use half of dimensions to encode grid_h 44 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 45 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 46 | 47 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 48 | return emb 49 | 50 | 51 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 52 | """ 53 | embed_dim: output dimension for each position 54 | pos: a list of positions to be encoded: size (M,) 55 | out: (M, D) 56 | """ 57 | assert embed_dim % 2 == 0 58 | omega = np.arange(embed_dim // 2, dtype=float) 59 | omega /= embed_dim / 2. 60 | omega = 1. / 10000**omega # (D/2,) 61 | 62 | pos = pos.reshape(-1) # (M,) 63 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product 64 | 65 | emb_sin = np.sin(out) # (M, D/2) 66 | emb_cos = np.cos(out) # (M, D/2) 67 | 68 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 69 | return emb 70 | 71 | 72 | # -------------------------------------------------------- 73 | # Interpolate position embeddings for high-resolution 74 | # References: 75 | # MAE: https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py 76 | # DeiT: https://github.com/facebookresearch/deit 77 | # -------------------------------------------------------- 78 | def interpolate_pos_embed(model, checkpoint_model): 79 | if 'pos_embed' in checkpoint_model: 80 | pos_embed_checkpoint = checkpoint_model['pos_embed'] 81 | embedding_size = pos_embed_checkpoint.shape[-1] 82 | num_patches = model.patch_embed.num_patches 83 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches 84 | # height (== width) for the checkpoint position embedding 85 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 86 | # height (== width) for the new position embedding 87 | new_size = int(num_patches ** 0.5) 88 | # class_token and dist_token are kept unchanged 89 | if orig_size != new_size: 90 | print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) 91 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 92 | # only the position tokens are interpolated 93 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 94 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 95 | pos_tokens = torch.nn.functional.interpolate( 96 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 97 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 98 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 99 | checkpoint_model['pos_embed'] = new_pos_embed 100 | 101 | 102 | #---------------------------------------------------------- 103 | # RoPE2D: RoPE implementation in 2D 104 | #---------------------------------------------------------- 105 | 106 | try: 107 | from models.curope import cuRoPE2D 108 | RoPE2D = cuRoPE2D 109 | except ImportError: 110 | print('Warning, cannot find cuda-compiled version of RoPE2D, using a slow pytorch version instead') 111 | 112 | class RoPE2D(torch.nn.Module): 113 | 114 | def __init__(self, freq=100.0, F0=1.0): 115 | super().__init__() 116 | self.base = freq 117 | self.F0 = F0 118 | self.cache = {} 119 | 120 | def get_cos_sin(self, D, seq_len, device, dtype): 121 | if (D,seq_len,device,dtype) not in self.cache: 122 | inv_freq = 1.0 / (self.base ** (torch.arange(0, D, 2).float().to(device) / D)) 123 | t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) 124 | freqs = torch.einsum("i,j->ij", t, inv_freq).to(dtype) 125 | freqs = torch.cat((freqs, freqs), dim=-1) 126 | cos = freqs.cos() # (Seq, Dim) 127 | sin = freqs.sin() 128 | self.cache[D,seq_len,device,dtype] = (cos,sin) 129 | return self.cache[D,seq_len,device,dtype] 130 | 131 | @staticmethod 132 | def rotate_half(x): 133 | x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] 134 | return torch.cat((-x2, x1), dim=-1) 135 | 136 | def apply_rope1d(self, tokens, pos1d, cos, sin): 137 | assert pos1d.ndim==2 138 | cos = torch.nn.functional.embedding(pos1d, cos)[:, None, :, :] 139 | sin = torch.nn.functional.embedding(pos1d, sin)[:, None, :, :] 140 | return (tokens * cos) + (self.rotate_half(tokens) * sin) 141 | 142 | def forward(self, tokens, positions): 143 | """ 144 | input: 145 | * tokens: batch_size x nheads x ntokens x dim 146 | * positions: batch_size x ntokens x 2 (y and x position of each token) 147 | output: 148 | * tokens after appplying RoPE2D (batch_size x nheads x ntokens x dim) 149 | """ 150 | assert tokens.size(3)%2==0, "number of dimensions should be a multiple of two" 151 | D = tokens.size(3) // 2 152 | assert positions.ndim==3 and positions.shape[-1] == 2 # Batch, Seq, 2 153 | cos, sin = self.get_cos_sin(D, int(positions.max())+1, tokens.device, tokens.dtype) 154 | # split features into two along the feature dimension, and apply rope1d on each half 155 | y, x = tokens.chunk(2, dim=-1) 156 | y = self.apply_rope1d(y, positions[:,:,0], cos, sin) 157 | x = self.apply_rope1d(x, positions[:,:,1], cos, sin) 158 | tokens = torch.cat((y, x), dim=-1) 159 | return tokens -------------------------------------------------------------------------------- /third_party/dust3r/croco/stereoflow/download_model.sh: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2022-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | 4 | model=$1 5 | outfile="stereoflow_models/${model}" 6 | if [[ ! -f $outfile ]] 7 | then 8 | mkdir -p stereoflow_models/; 9 | wget https://download.europe.naverlabs.com/ComputerVision/CroCo/StereoFlow_models/$1 -P stereoflow_models/; 10 | else 11 | echo "Model ${model} already downloaded in ${outfile}." 12 | fi -------------------------------------------------------------------------------- /third_party/dust3r/datasets_preprocess/path_to_root.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | # 4 | # -------------------------------------------------------- 5 | # DUSt3R repo root import 6 | # -------------------------------------------------------- 7 | 8 | import sys 9 | import os.path as path 10 | HERE_PATH = path.normpath(path.dirname(__file__)) 11 | DUST3R_REPO_PATH = path.normpath(path.join(HERE_PATH, '../')) 12 | # workaround for sibling import 13 | sys.path.insert(0, DUST3R_REPO_PATH) 14 | -------------------------------------------------------------------------------- /third_party/dust3r/dust3r/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | -------------------------------------------------------------------------------- /third_party/dust3r/dust3r/cloud_opt/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | # 4 | # -------------------------------------------------------- 5 | # global alignment optimization wrapper function 6 | # -------------------------------------------------------- 7 | from enum import Enum 8 | 9 | from .optimizer import PointCloudOptimizer 10 | from .pair_viewer import PairViewer 11 | 12 | 13 | class GlobalAlignerMode(Enum): 14 | PointCloudOptimizer = "PointCloudOptimizer" 15 | PairViewer = "PairViewer" 16 | 17 | 18 | def global_aligner(dust3r_output, device, mode=GlobalAlignerMode.PointCloudOptimizer, **optim_kw): 19 | # extract all inputs 20 | view1, view2, pred1, pred2 = [dust3r_output[k] for k in 'view1 view2 pred1 pred2'.split()] 21 | # build the optimizer 22 | if mode == GlobalAlignerMode.PointCloudOptimizer: 23 | net = PointCloudOptimizer(view1, view2, pred1, pred2, **optim_kw).to(device) 24 | elif mode == GlobalAlignerMode.PairViewer: 25 | net = PairViewer(view1, view2, pred1, pred2, **optim_kw).to(device) 26 | else: 27 | raise NotImplementedError(f'Unknown mode {mode}') 28 | 29 | return net 30 | -------------------------------------------------------------------------------- /third_party/dust3r/dust3r/cloud_opt/commons.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | # 4 | # -------------------------------------------------------- 5 | # utility functions for global alignment 6 | # -------------------------------------------------------- 7 | import torch 8 | import torch.nn as nn 9 | import numpy as np 10 | 11 | 12 | def edge_str(i, j): 13 | return f'{i}_{j}' 14 | 15 | 16 | def i_j_ij(ij): 17 | return edge_str(*ij), ij 18 | 19 | 20 | def edge_conf(conf_i, conf_j, edge): 21 | return float(conf_i[edge].mean() * conf_j[edge].mean()) 22 | 23 | 24 | def compute_edge_scores(edges, conf_i, conf_j): 25 | return {(i, j): edge_conf(conf_i, conf_j, e) for e, (i, j) in edges} 26 | 27 | 28 | def NoGradParamDict(x): 29 | assert isinstance(x, dict) 30 | return nn.ParameterDict(x).requires_grad_(False) 31 | 32 | 33 | def get_imshapes(edges, pred_i, pred_j): 34 | n_imgs = max(max(e) for e in edges) + 1 35 | imshapes = [None] * n_imgs 36 | for e, (i, j) in enumerate(edges): 37 | shape_i = tuple(pred_i[e].shape[0:2]) 38 | shape_j = tuple(pred_j[e].shape[0:2]) 39 | if imshapes[i]: 40 | assert imshapes[i] == shape_i, f'incorrect shape for image {i}' 41 | if imshapes[j]: 42 | assert imshapes[j] == shape_j, f'incorrect shape for image {j}' 43 | imshapes[i] = shape_i 44 | imshapes[j] = shape_j 45 | return imshapes 46 | 47 | 48 | def get_conf_trf(mode): 49 | if mode == 'log': 50 | def conf_trf(x): return x.log() 51 | elif mode == 'sqrt': 52 | def conf_trf(x): return x.sqrt() 53 | elif mode == 'm1': 54 | def conf_trf(x): return x-1 55 | elif mode in ('id', 'none'): 56 | def conf_trf(x): return x 57 | else: 58 | raise ValueError(f'bad mode for {mode=}') 59 | return conf_trf 60 | 61 | 62 | def l2_dist(a, b, weight): 63 | return ((a - b).square().sum(dim=-1) * weight) 64 | 65 | 66 | def l1_dist(a, b, weight): 67 | return ((a - b).norm(dim=-1) * weight) 68 | 69 | 70 | ALL_DISTS = dict(l1=l1_dist, l2=l2_dist) 71 | 72 | 73 | def signed_log1p(x): 74 | sign = torch.sign(x) 75 | return sign * torch.log1p(torch.abs(x)) 76 | 77 | 78 | def signed_expm1(x): 79 | sign = torch.sign(x) 80 | return sign * torch.expm1(torch.abs(x)) 81 | 82 | 83 | def cosine_schedule(t, lr_start, lr_end): 84 | assert 0 <= t <= 1 85 | return lr_end + (lr_start - lr_end) * (1+np.cos(t * np.pi))/2 86 | 87 | 88 | def linear_schedule(t, lr_start, lr_end): 89 | assert 0 <= t <= 1 90 | return lr_start + (lr_end - lr_start) * t 91 | -------------------------------------------------------------------------------- /third_party/dust3r/dust3r/cloud_opt/pair_viewer.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | # 4 | # -------------------------------------------------------- 5 | # Dummy optimizer for visualizing pairs 6 | # -------------------------------------------------------- 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | import cv2 11 | 12 | from dust3r.cloud_opt.base_opt import BasePCOptimizer 13 | from dust3r.utils.geometry import inv, geotrf, depthmap_to_absolute_camera_coordinates 14 | from dust3r.cloud_opt.commons import edge_str 15 | from dust3r.post_process import estimate_focal_knowing_depth 16 | 17 | 18 | class PairViewer (BasePCOptimizer): 19 | """ 20 | This a Dummy Optimizer. 21 | To use only when the goal is to visualize the results for a pair of images (with is_symmetrized) 22 | """ 23 | 24 | def __init__(self, *args, **kwargs): 25 | super().__init__(*args, **kwargs) 26 | assert self.is_symmetrized and self.n_edges == 2 27 | self.has_im_poses = True 28 | 29 | # compute all parameters directly from raw input 30 | self.focals = [] 31 | self.pp = [] 32 | rel_poses = [] 33 | confs = [] 34 | for i in range(self.n_imgs): 35 | conf = float(self.conf_i[edge_str(i, 1-i)].mean() * self.conf_j[edge_str(i, 1-i)].mean()) 36 | print(f' - {conf=:.3} for edge {i}-{1-i}') 37 | confs.append(conf) 38 | 39 | H, W = self.imshapes[i] 40 | pts3d = self.pred_i[edge_str(i, 1-i)] 41 | pp = torch.tensor((W/2, H/2)) 42 | focal = float(estimate_focal_knowing_depth(pts3d[None], pp, focal_mode='weiszfeld')) 43 | self.focals.append(focal) 44 | self.pp.append(pp) 45 | 46 | # estimate the pose of pts1 in image 2 47 | pixels = np.mgrid[:W, :H].T.astype(np.float32) 48 | pts3d = self.pred_j[edge_str(1-i, i)].numpy() 49 | assert pts3d.shape[:2] == (H, W) 50 | msk = self.get_masks()[i].numpy() 51 | K = np.float32([(focal, 0, pp[0]), (0, focal, pp[1]), (0, 0, 1)]) 52 | 53 | try: 54 | res = cv2.solvePnPRansac(pts3d[msk], pixels[msk], K, None, 55 | iterationsCount=100, reprojectionError=5, flags=cv2.SOLVEPNP_SQPNP) 56 | success, R, T, inliers = res 57 | assert success 58 | 59 | R = cv2.Rodrigues(R)[0] # world to cam 60 | pose = inv(np.r_[np.c_[R, T], [(0, 0, 0, 1)]]) # cam to world 61 | except: 62 | pose = np.eye(4) 63 | rel_poses.append(torch.from_numpy(pose.astype(np.float32))) 64 | 65 | # let's use the pair with the most confidence 66 | if confs[0] > confs[1]: 67 | # ptcloud is expressed in camera1 68 | self.im_poses = [torch.eye(4), rel_poses[1]] # I, cam2-to-cam1 69 | self.depth = [self.pred_i['0_1'][..., 2], geotrf(inv(rel_poses[1]), self.pred_j['0_1'])[..., 2]] 70 | else: 71 | # ptcloud is expressed in camera2 72 | self.im_poses = [rel_poses[0], torch.eye(4)] # I, cam1-to-cam2 73 | self.depth = [geotrf(inv(rel_poses[0]), self.pred_j['1_0'])[..., 2], self.pred_i['1_0'][..., 2]] 74 | 75 | self.im_poses = nn.Parameter(torch.stack(self.im_poses, dim=0), requires_grad=False) 76 | self.focals = nn.Parameter(torch.tensor(self.focals), requires_grad=False) 77 | self.pp = nn.Parameter(torch.stack(self.pp, dim=0), requires_grad=False) 78 | self.depth = nn.ParameterList(self.depth) 79 | for p in self.parameters(): 80 | p.requires_grad = False 81 | 82 | def _set_depthmap(self, idx, depth, force=False): 83 | print('_set_depthmap is ignored in PairViewer') 84 | return 85 | 86 | def get_depthmaps(self, raw=False): 87 | depth = [d.to(self.device) for d in self.depth] 88 | return depth 89 | 90 | def _set_focal(self, idx, focal, force=False): 91 | self.focals[idx] = focal 92 | 93 | def get_focals(self): 94 | return self.focals 95 | 96 | def get_known_focal_mask(self): 97 | return torch.tensor([not (p.requires_grad) for p in self.focals]) 98 | 99 | def get_principal_points(self): 100 | return self.pp 101 | 102 | def get_intrinsics(self): 103 | focals = self.get_focals() 104 | pps = self.get_principal_points() 105 | K = torch.zeros((len(focals), 3, 3), device=self.device) 106 | for i in range(len(focals)): 107 | K[i, 0, 0] = K[i, 1, 1] = focals[i] 108 | K[i, :2, 2] = pps[i] 109 | K[i, 2, 2] = 1 110 | return K 111 | 112 | def get_im_poses(self): 113 | return self.im_poses 114 | 115 | def depth_to_pts3d(self): 116 | pts3d = [] 117 | for d, intrinsics, im_pose in zip(self.depth, self.get_intrinsics(), self.get_im_poses()): 118 | pts, _ = depthmap_to_absolute_camera_coordinates(d.cpu().numpy(), 119 | intrinsics.cpu().numpy(), 120 | im_pose.cpu().numpy()) 121 | pts3d.append(torch.from_numpy(pts).to(device=self.device)) 122 | return pts3d 123 | 124 | def forward(self): 125 | return float('nan') 126 | -------------------------------------------------------------------------------- /third_party/dust3r/dust3r/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | from .utils.transforms import * 4 | from .base.batched_sampler import BatchedRandomSampler # noqa: F401 5 | from .co3d import Co3d # noqa: F401 6 | 7 | 8 | def get_data_loader(dataset, batch_size, num_workers=8, shuffle=True, drop_last=True, pin_mem=True): 9 | import torch 10 | from croco.utils.misc import get_world_size, get_rank 11 | 12 | # pytorch dataset 13 | if isinstance(dataset, str): 14 | dataset = eval(dataset) 15 | 16 | world_size = get_world_size() 17 | rank = get_rank() 18 | 19 | try: 20 | sampler = dataset.make_sampler(batch_size, shuffle=shuffle, world_size=world_size, 21 | rank=rank, drop_last=drop_last) 22 | except (AttributeError, NotImplementedError): 23 | # not avail for this dataset 24 | if torch.distributed.is_initialized(): 25 | sampler = torch.utils.data.DistributedSampler( 26 | dataset, num_replicas=world_size, rank=rank, shuffle=shuffle, drop_last=drop_last 27 | ) 28 | elif shuffle: 29 | sampler = torch.utils.data.RandomSampler(dataset) 30 | else: 31 | sampler = torch.utils.data.SequentialSampler(dataset) 32 | 33 | data_loader = torch.utils.data.DataLoader( 34 | dataset, 35 | sampler=sampler, 36 | batch_size=batch_size, 37 | num_workers=num_workers, 38 | pin_memory=pin_mem, 39 | drop_last=drop_last, 40 | ) 41 | 42 | return data_loader 43 | -------------------------------------------------------------------------------- /third_party/dust3r/dust3r/datasets/base/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | -------------------------------------------------------------------------------- /third_party/dust3r/dust3r/datasets/base/batched_sampler.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | # 4 | # -------------------------------------------------------- 5 | # Random sampling under a constraint 6 | # -------------------------------------------------------- 7 | import numpy as np 8 | import torch 9 | 10 | 11 | class BatchedRandomSampler: 12 | """ Random sampling under a constraint: each sample in the batch has the same feature, 13 | which is chosen randomly from a known pool of 'features' for each batch. 14 | 15 | For instance, the 'feature' could be the image aspect-ratio. 16 | 17 | The index returned is a tuple (sample_idx, feat_idx). 18 | This sampler ensures that each series of `batch_size` indices has the same `feat_idx`. 19 | """ 20 | 21 | def __init__(self, dataset, batch_size, pool_size, world_size=1, rank=0, drop_last=True): 22 | self.batch_size = batch_size 23 | self.pool_size = pool_size 24 | 25 | self.len_dataset = N = len(dataset) 26 | self.total_size = round_by(N, batch_size*world_size) if drop_last else N 27 | assert world_size == 1 or drop_last, 'must drop the last batch in distributed mode' 28 | 29 | # distributed sampler 30 | self.world_size = world_size 31 | self.rank = rank 32 | self.epoch = None 33 | 34 | def __len__(self): 35 | return self.total_size // self.world_size 36 | 37 | def set_epoch(self, epoch): 38 | self.epoch = epoch 39 | 40 | def __iter__(self): 41 | # prepare RNG 42 | if self.epoch is None: 43 | assert self.world_size == 1 and self.rank == 0, 'use set_epoch() if distributed mode is used' 44 | seed = int(torch.empty((), dtype=torch.int64).random_().item()) 45 | else: 46 | seed = self.epoch + 777 47 | rng = np.random.default_rng(seed=seed) 48 | 49 | # random indices (will restart from 0 if not drop_last) 50 | sample_idxs = np.arange(self.total_size) 51 | rng.shuffle(sample_idxs) 52 | 53 | # random feat_idxs (same across each batch) 54 | n_batches = (self.total_size+self.batch_size-1) // self.batch_size 55 | feat_idxs = rng.integers(self.pool_size, size=n_batches) 56 | feat_idxs = np.broadcast_to(feat_idxs[:, None], (n_batches, self.batch_size)) 57 | feat_idxs = feat_idxs.ravel()[:self.total_size] 58 | 59 | # put them together 60 | idxs = np.c_[sample_idxs, feat_idxs] # shape = (total_size, 2) 61 | 62 | # Distributed sampler: we select a subset of batches 63 | # make sure the slice for each node is aligned with batch_size 64 | size_per_proc = self.batch_size * ((self.total_size + self.world_size * 65 | self.batch_size-1) // (self.world_size * self.batch_size)) 66 | idxs = idxs[self.rank*size_per_proc: (self.rank+1)*size_per_proc] 67 | 68 | yield from (tuple(idx) for idx in idxs) 69 | 70 | 71 | def round_by(total, multiple, up=False): 72 | if up: 73 | total = total + multiple-1 74 | return (total//multiple) * multiple 75 | -------------------------------------------------------------------------------- /third_party/dust3r/dust3r/datasets/base/easy_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | # 4 | # -------------------------------------------------------- 5 | # A dataset base class that you can easily resize and combine. 6 | # -------------------------------------------------------- 7 | import numpy as np 8 | from dust3r.datasets.base.batched_sampler import BatchedRandomSampler 9 | 10 | 11 | class EasyDataset: 12 | """ a dataset that you can easily resize and combine. 13 | Examples: 14 | --------- 15 | 2 * dataset ==> duplicate each element 2x 16 | 17 | 10 @ dataset ==> set the size to 10 (random sampling, duplicates if necessary) 18 | 19 | dataset1 + dataset2 ==> concatenate datasets 20 | """ 21 | 22 | def __add__(self, other): 23 | return CatDataset([self, other]) 24 | 25 | def __rmul__(self, factor): 26 | return MulDataset(factor, self) 27 | 28 | def __rmatmul__(self, factor): 29 | return ResizedDataset(factor, self) 30 | 31 | def set_epoch(self, epoch): 32 | pass # nothing to do by default 33 | 34 | def make_sampler(self, batch_size, shuffle=True, world_size=1, rank=0, drop_last=True): 35 | if not (shuffle): 36 | raise NotImplementedError() # cannot deal yet 37 | num_of_aspect_ratios = len(self._resolutions) 38 | return BatchedRandomSampler(self, batch_size, num_of_aspect_ratios, world_size=world_size, rank=rank, drop_last=drop_last) 39 | 40 | 41 | class MulDataset (EasyDataset): 42 | """ Artifically augmenting the size of a dataset. 43 | """ 44 | multiplicator: int 45 | 46 | def __init__(self, multiplicator, dataset): 47 | assert isinstance(multiplicator, int) and multiplicator > 0 48 | self.multiplicator = multiplicator 49 | self.dataset = dataset 50 | 51 | def __len__(self): 52 | return self.multiplicator * len(self.dataset) 53 | 54 | def __repr__(self): 55 | return f'{self.multiplicator}*{repr(self.dataset)}' 56 | 57 | def __getitem__(self, idx): 58 | if isinstance(idx, tuple): 59 | idx, other = idx 60 | return self.dataset[idx // self.multiplicator, other] 61 | else: 62 | return self.dataset[idx // self.multiplicator] 63 | 64 | @property 65 | def _resolutions(self): 66 | return self.dataset._resolutions 67 | 68 | 69 | class ResizedDataset (EasyDataset): 70 | """ Artifically changing the size of a dataset. 71 | """ 72 | new_size: int 73 | 74 | def __init__(self, new_size, dataset): 75 | assert isinstance(new_size, int) and new_size > 0 76 | self.new_size = new_size 77 | self.dataset = dataset 78 | 79 | def __len__(self): 80 | return self.new_size 81 | 82 | def __repr__(self): 83 | size_str = str(self.new_size) 84 | for i in range((len(size_str)-1) // 3): 85 | sep = -4*i-3 86 | size_str = size_str[:sep] + '_' + size_str[sep:] 87 | return f'{size_str} @ {repr(self.dataset)}' 88 | 89 | def set_epoch(self, epoch): 90 | # this random shuffle only depends on the epoch 91 | rng = np.random.default_rng(seed=epoch+777) 92 | 93 | # shuffle all indices 94 | perm = rng.permutation(len(self.dataset)) 95 | 96 | # rotary extension until target size is met 97 | shuffled_idxs = np.concatenate([perm] * (1 + (len(self)-1) // len(self.dataset))) 98 | self._idxs_mapping = shuffled_idxs[:self.new_size] 99 | 100 | assert len(self._idxs_mapping) == self.new_size 101 | 102 | def __getitem__(self, idx): 103 | assert hasattr(self, '_idxs_mapping'), 'You need to call dataset.set_epoch() to use ResizedDataset.__getitem__()' 104 | if isinstance(idx, tuple): 105 | idx, other = idx 106 | return self.dataset[self._idxs_mapping[idx], other] 107 | else: 108 | return self.dataset[self._idxs_mapping[idx]] 109 | 110 | @property 111 | def _resolutions(self): 112 | return self.dataset._resolutions 113 | 114 | 115 | class CatDataset (EasyDataset): 116 | """ Concatenation of several datasets 117 | """ 118 | 119 | def __init__(self, datasets): 120 | for dataset in datasets: 121 | assert isinstance(dataset, EasyDataset) 122 | self.datasets = datasets 123 | self._cum_sizes = np.cumsum([len(dataset) for dataset in datasets]) 124 | 125 | def __len__(self): 126 | return self._cum_sizes[-1] 127 | 128 | def __repr__(self): 129 | # remove uselessly long transform 130 | return ' + '.join(repr(dataset).replace(',transform=Compose( ToTensor() Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)))', '') for dataset in self.datasets) 131 | 132 | def set_epoch(self, epoch): 133 | for dataset in self.datasets: 134 | dataset.set_epoch(epoch) 135 | 136 | def __getitem__(self, idx): 137 | other = None 138 | if isinstance(idx, tuple): 139 | idx, other = idx 140 | 141 | if not (0 <= idx < len(self)): 142 | raise IndexError() 143 | 144 | db_idx = np.searchsorted(self._cum_sizes, idx, 'right') 145 | dataset = self.datasets[db_idx] 146 | new_idx = idx - (self._cum_sizes[db_idx - 1] if db_idx > 0 else 0) 147 | 148 | if other is not None: 149 | new_idx = (new_idx, other) 150 | return dataset[new_idx] 151 | 152 | @property 153 | def _resolutions(self): 154 | resolutions = self.datasets[0]._resolutions 155 | for dataset in self.datasets[1:]: 156 | assert tuple(dataset._resolutions) == tuple(resolutions) 157 | return resolutions 158 | -------------------------------------------------------------------------------- /third_party/dust3r/dust3r/datasets/co3d.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | # 4 | # -------------------------------------------------------- 5 | # Dataloader for preprocessed Co3d_v2 6 | # dataset at https://github.com/facebookresearch/co3d - Creative Commons Attribution-NonCommercial 4.0 International 7 | # See datasets_preprocess/preprocess_co3d.py 8 | # -------------------------------------------------------- 9 | import os.path as osp 10 | import json 11 | import itertools 12 | from collections import deque 13 | 14 | import cv2 15 | import numpy as np 16 | 17 | from dust3r.datasets.base.base_stereo_view_dataset import BaseStereoViewDataset 18 | from dust3r.utils.image import imread_cv2 19 | 20 | 21 | class Co3d(BaseStereoViewDataset): 22 | def __init__(self, mask_bg=True, *args, ROOT, **kwargs): 23 | self.ROOT = ROOT 24 | super().__init__(*args, **kwargs) 25 | assert mask_bg in (True, False, 'rand') 26 | self.mask_bg = mask_bg 27 | 28 | # load all scenes 29 | with open(osp.join(self.ROOT, f'selected_seqs_{self.split}.json'), 'r') as f: 30 | self.scenes = json.load(f) 31 | self.scenes = {k: v for k, v in self.scenes.items() if len(v) > 0} 32 | self.scenes = {(k, k2): v2 for k, v in self.scenes.items() 33 | for k2, v2 in v.items()} 34 | self.scene_list = list(self.scenes.keys()) 35 | 36 | # for each scene, we have 100 images ==> 360 degrees (so 25 frames ~= 90 degrees) 37 | # we prepare all combinations such that i-j = +/- [5, 10, .., 90] degrees 38 | self.combinations = [(i, j) 39 | for i, j in itertools.combinations(range(100), 2) 40 | if 0 < abs(i-j) <= 30 and abs(i-j) % 5 == 0] 41 | 42 | self.invalidate = {scene: {} for scene in self.scene_list} 43 | 44 | def __len__(self): 45 | return len(self.scene_list) * len(self.combinations) 46 | 47 | def _get_views(self, idx, resolution, rng): 48 | # choose a scene 49 | obj, instance = self.scene_list[idx // len(self.combinations)] 50 | image_pool = self.scenes[obj, instance] 51 | im1_idx, im2_idx = self.combinations[idx % len(self.combinations)] 52 | 53 | # add a bit of randomness 54 | last = len(image_pool)-1 55 | 56 | if resolution not in self.invalidate[obj, instance]: # flag invalid images 57 | self.invalidate[obj, instance][resolution] = [False for _ in range(len(image_pool))] 58 | 59 | # decide now if we mask the bg 60 | mask_bg = (self.mask_bg == True) or (self.mask_bg == 'rand' and rng.choice(2)) 61 | 62 | views = [] 63 | imgs_idxs = [max(0, min(im_idx + rng.integers(-4, 5), last)) for im_idx in [im2_idx, im1_idx]] 64 | imgs_idxs = deque(imgs_idxs) 65 | while len(imgs_idxs) > 0: # some images (few) have zero depth 66 | im_idx = imgs_idxs.pop() 67 | 68 | if self.invalidate[obj, instance][resolution][im_idx]: 69 | # search for a valid image 70 | random_direction = 2 * rng.choice(2) - 1 71 | for offset in range(1, len(image_pool)): 72 | tentative_im_idx = (im_idx + (random_direction * offset)) % len(image_pool) 73 | if not self.invalidate[obj, instance][resolution][tentative_im_idx]: 74 | im_idx = tentative_im_idx 75 | break 76 | 77 | view_idx = image_pool[im_idx] 78 | 79 | impath = osp.join(self.ROOT, obj, instance, 'images', f'frame{view_idx:06n}.jpg') 80 | 81 | # load camera params 82 | input_metadata = np.load(impath.replace('jpg', 'npz')) 83 | camera_pose = input_metadata['camera_pose'].astype(np.float32) 84 | intrinsics = input_metadata['camera_intrinsics'].astype(np.float32) 85 | 86 | # load image and depth 87 | rgb_image = imread_cv2(impath) 88 | depthmap = imread_cv2(impath.replace('images', 'depths') + '.geometric.png', cv2.IMREAD_UNCHANGED) 89 | depthmap = (depthmap.astype(np.float32) / 65535) * np.nan_to_num(input_metadata['maximum_depth']) 90 | 91 | if mask_bg: 92 | # load object mask 93 | maskpath = osp.join(self.ROOT, obj, instance, 'masks', f'frame{view_idx:06n}.png') 94 | maskmap = imread_cv2(maskpath, cv2.IMREAD_UNCHANGED).astype(np.float32) 95 | maskmap = (maskmap / 255.0) > 0.1 96 | 97 | # update the depthmap with mask 98 | depthmap *= maskmap 99 | 100 | rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary( 101 | rgb_image, depthmap, intrinsics, resolution, rng=rng, info=impath) 102 | 103 | num_valid = (depthmap > 0.0).sum() 104 | if num_valid == 0: 105 | # problem, invalidate image and retry 106 | self.invalidate[obj, instance][resolution][im_idx] = True 107 | imgs_idxs.append(im_idx) 108 | continue 109 | 110 | views.append(dict( 111 | img=rgb_image, 112 | depthmap=depthmap, 113 | camera_pose=camera_pose, 114 | camera_intrinsics=intrinsics, 115 | dataset='Co3d_v2', 116 | label=osp.join(obj, instance), 117 | instance=osp.split(impath)[1], 118 | )) 119 | return views 120 | 121 | 122 | if __name__ == "__main__": 123 | from dust3r.datasets.base.base_stereo_view_dataset import view_name 124 | from dust3r.viz import SceneViz, auto_cam_size 125 | from dust3r.utils.image import rgb 126 | 127 | dataset = Co3d(split='train', ROOT="data/co3d_subset_processed", resolution=224, aug_crop=16) 128 | 129 | for idx in np.random.permutation(len(dataset)): 130 | views = dataset[idx] 131 | assert len(views) == 2 132 | print(view_name(views[0]), view_name(views[1])) 133 | viz = SceneViz() 134 | poses = [views[view_idx]['camera_pose'] for view_idx in [0, 1]] 135 | cam_size = max(auto_cam_size(poses), 0.001) 136 | for view_idx in [0, 1]: 137 | pts3d = views[view_idx]['pts3d'] 138 | valid_mask = views[view_idx]['valid_mask'] 139 | colors = rgb(views[view_idx]['img']) 140 | viz.add_pointcloud(pts3d, colors, valid_mask) 141 | viz.add_camera(pose_c2w=views[view_idx]['camera_pose'], 142 | focal=views[view_idx]['camera_intrinsics'][0, 0], 143 | color=(idx*255, (1 - idx)*255, 0), 144 | image=colors, 145 | cam_size=cam_size) 146 | viz.show() 147 | -------------------------------------------------------------------------------- /third_party/dust3r/dust3r/datasets/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | -------------------------------------------------------------------------------- /third_party/dust3r/dust3r/datasets/utils/cropping.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | # 4 | # -------------------------------------------------------- 5 | # croppping utilities 6 | # -------------------------------------------------------- 7 | import PIL.Image 8 | import os 9 | os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" 10 | import cv2 # noqa 11 | import numpy as np # noqa 12 | from dust3r.utils.geometry import colmap_to_opencv_intrinsics, opencv_to_colmap_intrinsics # noqa 13 | try: 14 | lanczos = PIL.Image.Resampling.LANCZOS 15 | except AttributeError: 16 | lanczos = PIL.Image.LANCZOS 17 | 18 | 19 | class ImageList: 20 | """ Convenience class to aply the same operation to a whole set of images. 21 | """ 22 | 23 | def __init__(self, images): 24 | if not isinstance(images, (tuple, list, set)): 25 | images = [images] 26 | self.images = [] 27 | for image in images: 28 | if not isinstance(image, PIL.Image.Image): 29 | image = PIL.Image.fromarray(image) 30 | self.images.append(image) 31 | 32 | def __len__(self): 33 | return len(self.images) 34 | 35 | def to_pil(self): 36 | return tuple(self.images) if len(self.images) > 1 else self.images[0] 37 | 38 | @property 39 | def size(self): 40 | sizes = [im.size for im in self.images] 41 | assert all(sizes[0] == s for s in sizes) 42 | return sizes[0] 43 | 44 | def resize(self, *args, **kwargs): 45 | return ImageList(self._dispatch('resize', *args, **kwargs)) 46 | 47 | def crop(self, *args, **kwargs): 48 | return ImageList(self._dispatch('crop', *args, **kwargs)) 49 | 50 | def _dispatch(self, func, *args, **kwargs): 51 | return [getattr(im, func)(*args, **kwargs) for im in self.images] 52 | 53 | 54 | def rescale_image_depthmap(image, depthmap, camera_intrinsics, output_resolution): 55 | """ Jointly rescale a (image, depthmap) 56 | so that (out_width, out_height) >= output_res 57 | """ 58 | image = ImageList(image) 59 | input_resolution = np.array(image.size) # (W,H) 60 | output_resolution = np.array(output_resolution) 61 | if depthmap is not None: 62 | # can also use this with masks instead of depthmaps 63 | assert tuple(depthmap.shape[:2]) == image.size[::-1] 64 | assert output_resolution.shape == (2,) 65 | # define output resolution 66 | scale_final = max(output_resolution / image.size) + 1e-8 67 | output_resolution = np.floor(input_resolution * scale_final).astype(int) 68 | 69 | # first rescale the image so that it contains the crop 70 | image = image.resize(output_resolution, resample=lanczos) 71 | if depthmap is not None: 72 | depthmap = cv2.resize(depthmap, output_resolution, fx=scale_final, 73 | fy=scale_final, interpolation=cv2.INTER_NEAREST) 74 | 75 | # no offset here; simple rescaling 76 | camera_intrinsics = camera_matrix_of_crop( 77 | camera_intrinsics, input_resolution, output_resolution, scaling=scale_final) 78 | 79 | return image.to_pil(), depthmap, camera_intrinsics 80 | 81 | 82 | def camera_matrix_of_crop(input_camera_matrix, input_resolution, output_resolution, scaling=1, offset_factor=0.5, offset=None): 83 | # Margins to offset the origin 84 | margins = np.asarray(input_resolution) * scaling - output_resolution 85 | assert np.all(margins >= 0.0) 86 | if offset is None: 87 | offset = offset_factor * margins 88 | 89 | # Generate new camera parameters 90 | output_camera_matrix_colmap = opencv_to_colmap_intrinsics(input_camera_matrix) 91 | output_camera_matrix_colmap[:2, :] *= scaling 92 | output_camera_matrix_colmap[:2, 2] -= offset 93 | output_camera_matrix = colmap_to_opencv_intrinsics(output_camera_matrix_colmap) 94 | 95 | return output_camera_matrix 96 | 97 | 98 | def crop_image_depthmap(image, depthmap, camera_intrinsics, crop_bbox): 99 | """ 100 | Return a crop of the input view. 101 | """ 102 | image = ImageList(image) 103 | l, t, r, b = crop_bbox 104 | 105 | image = image.crop((l, t, r, b)) 106 | depthmap = depthmap[t:b, l:r] 107 | 108 | camera_intrinsics = camera_intrinsics.copy() 109 | camera_intrinsics[0, 2] -= l 110 | camera_intrinsics[1, 2] -= t 111 | 112 | return image.to_pil(), depthmap, camera_intrinsics 113 | 114 | 115 | def bbox_from_intrinsics_in_out(input_camera_matrix, output_camera_matrix, output_resolution): 116 | out_width, out_height = output_resolution 117 | l, t = np.int32(np.round(input_camera_matrix[:2, 2] - output_camera_matrix[:2, 2])) 118 | crop_bbox = (l, t, l+out_width, t+out_height) 119 | return crop_bbox 120 | -------------------------------------------------------------------------------- /third_party/dust3r/dust3r/datasets/utils/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | # 4 | # -------------------------------------------------------- 5 | # DUST3R default transforms 6 | # -------------------------------------------------------- 7 | import torchvision.transforms as tvf 8 | from dust3r.utils.image import ImgNorm 9 | 10 | # define the standard image transforms 11 | ColorJitter = tvf.Compose([tvf.ColorJitter(0.5, 0.5, 0.5, 0.1), ImgNorm]) 12 | -------------------------------------------------------------------------------- /third_party/dust3r/dust3r/heads/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | # 4 | # -------------------------------------------------------- 5 | # head factory 6 | # -------------------------------------------------------- 7 | from .linear_head import LinearPts3d 8 | from .dpt_head import create_dpt_head 9 | 10 | 11 | def head_factory(head_type, output_mode, net, has_conf=False): 12 | """" build a prediction head for the decoder 13 | """ 14 | if head_type == 'linear' and output_mode == 'pts3d': 15 | return LinearPts3d(net, has_conf) 16 | elif head_type == 'dpt' and output_mode == 'pts3d': 17 | return create_dpt_head(net, has_conf=has_conf) 18 | else: 19 | raise NotImplementedError(f"unexpected {head_type=} and {output_mode=}") 20 | -------------------------------------------------------------------------------- /third_party/dust3r/dust3r/heads/dpt_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | # 4 | # -------------------------------------------------------- 5 | # dpt head implementation for DUST3R 6 | # Downstream heads assume inputs of size B x N x C (where N is the number of tokens) ; 7 | # or if it takes as input the output at every layer, the attribute return_all_layers should be set to True 8 | # the forward function also takes as input a dictionnary img_info with key "height" and "width" 9 | # for PixelwiseTask, the output will be of dimension B x num_channels x H x W 10 | # -------------------------------------------------------- 11 | from einops import rearrange 12 | from typing import List 13 | import torch 14 | import torch.nn as nn 15 | from dust3r.heads.postprocess import postprocess 16 | import dust3r.utils.path_to_croco # noqa: F401 17 | from models.dpt_block import DPTOutputAdapter # noqa 18 | 19 | 20 | class DPTOutputAdapter_fix(DPTOutputAdapter): 21 | """ 22 | Adapt croco's DPTOutputAdapter implementation for dust3r: 23 | remove duplicated weigths, and fix forward for dust3r 24 | """ 25 | 26 | def init(self, dim_tokens_enc=768): 27 | super().init(dim_tokens_enc) 28 | # these are duplicated weights 29 | del self.act_1_postprocess 30 | del self.act_2_postprocess 31 | del self.act_3_postprocess 32 | del self.act_4_postprocess 33 | 34 | def forward(self, encoder_tokens: List[torch.Tensor], image_size=None): 35 | assert self.dim_tokens_enc is not None, 'Need to call init(dim_tokens_enc) function first' 36 | # H, W = input_info['image_size'] 37 | image_size = self.image_size if image_size is None else image_size 38 | H, W = image_size 39 | # Number of patches in height and width 40 | N_H = H // (self.stride_level * self.P_H) 41 | N_W = W // (self.stride_level * self.P_W) 42 | 43 | # Hook decoder onto 4 layers from specified ViT layers 44 | layers = [encoder_tokens[hook] for hook in self.hooks] 45 | 46 | # Extract only task-relevant tokens and ignore global tokens. 47 | layers = [self.adapt_tokens(l) for l in layers] 48 | 49 | # Reshape tokens to spatial representation 50 | layers = [rearrange(l, 'b (nh nw) c -> b c nh nw', nh=N_H, nw=N_W) for l in layers] 51 | 52 | layers = [self.act_postprocess[idx](l) for idx, l in enumerate(layers)] 53 | # Project layers to chosen feature dim 54 | layers = [self.scratch.layer_rn[idx](l) for idx, l in enumerate(layers)] 55 | 56 | # Fuse layers using refinement stages 57 | path_4 = self.scratch.refinenet4(layers[3])[:, :, :layers[2].shape[2], :layers[2].shape[3]] 58 | path_3 = self.scratch.refinenet3(path_4, layers[2]) 59 | path_2 = self.scratch.refinenet2(path_3, layers[1]) 60 | path_1 = self.scratch.refinenet1(path_2, layers[0]) 61 | 62 | # Output head 63 | out = self.head(path_1) 64 | 65 | return out 66 | 67 | 68 | class PixelwiseTaskWithDPT(nn.Module): 69 | """ DPT module for dust3r, can return 3D points + confidence for all pixels""" 70 | 71 | def __init__(self, *, n_cls_token=0, hooks_idx=None, dim_tokens=None, 72 | output_width_ratio=1, num_channels=1, postprocess=None, depth_mode=None, conf_mode=None, **kwargs): 73 | super(PixelwiseTaskWithDPT, self).__init__() 74 | self.return_all_layers = True # backbone needs to return all layers 75 | self.postprocess = postprocess 76 | self.depth_mode = depth_mode 77 | self.conf_mode = conf_mode 78 | 79 | assert n_cls_token == 0, "Not implemented" 80 | dpt_args = dict(output_width_ratio=output_width_ratio, 81 | num_channels=num_channels, 82 | **kwargs) 83 | if hooks_idx is not None: 84 | dpt_args.update(hooks=hooks_idx) 85 | self.dpt = DPTOutputAdapter_fix(**dpt_args) 86 | dpt_init_args = {} if dim_tokens is None else {'dim_tokens_enc': dim_tokens} 87 | self.dpt.init(**dpt_init_args) 88 | 89 | def forward(self, x, img_info): 90 | out = self.dpt(x, image_size=(img_info[0], img_info[1])) 91 | if self.postprocess: 92 | out = self.postprocess(out, self.depth_mode, self.conf_mode) 93 | return out 94 | 95 | 96 | def create_dpt_head(net, has_conf=False): 97 | """ 98 | return PixelwiseTaskWithDPT for given net params 99 | """ 100 | assert net.dec_depth > 9 101 | l2 = net.dec_depth 102 | feature_dim = 256 103 | last_dim = feature_dim//2 104 | out_nchan = 3 105 | ed = net.enc_embed_dim 106 | dd = net.dec_embed_dim 107 | return PixelwiseTaskWithDPT(num_channels=out_nchan + has_conf, 108 | feature_dim=feature_dim, 109 | last_dim=last_dim, 110 | hooks_idx=[0, l2*2//4, l2*3//4, l2], 111 | dim_tokens=[ed, dd, dd, dd], 112 | postprocess=postprocess, 113 | depth_mode=net.depth_mode, 114 | conf_mode=net.conf_mode, 115 | head_type='regression') 116 | -------------------------------------------------------------------------------- /third_party/dust3r/dust3r/heads/linear_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | # 4 | # -------------------------------------------------------- 5 | # linear head implementation for DUST3R 6 | # -------------------------------------------------------- 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from dust3r.heads.postprocess import postprocess 10 | 11 | 12 | class LinearPts3d (nn.Module): 13 | """ 14 | Linear head for dust3r 15 | Each token outputs: - 16x16 3D points (+ confidence) 16 | """ 17 | 18 | def __init__(self, net, has_conf=False): 19 | super().__init__() 20 | self.patch_size = net.patch_embed.patch_size[0] 21 | self.depth_mode = net.depth_mode 22 | self.conf_mode = net.conf_mode 23 | self.has_conf = has_conf 24 | 25 | self.proj = nn.Linear(net.dec_embed_dim, (3 + has_conf)*self.patch_size**2) 26 | 27 | def setup(self, croconet): 28 | pass 29 | 30 | def forward(self, decout, img_shape): 31 | H, W = img_shape 32 | tokens = decout[-1] 33 | B, S, D = tokens.shape 34 | 35 | # extract 3D points 36 | feat = self.proj(tokens) # B,S,D 37 | feat = feat.transpose(-1, -2).view(B, -1, H//self.patch_size, W//self.patch_size) 38 | feat = F.pixel_shuffle(feat, self.patch_size) # B,3,H,W 39 | 40 | # permute + norm depth 41 | return postprocess(feat, self.depth_mode, self.conf_mode) 42 | -------------------------------------------------------------------------------- /third_party/dust3r/dust3r/heads/postprocess.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | # 4 | # -------------------------------------------------------- 5 | # post process function for all heads: extract 3D points/confidence from output 6 | # -------------------------------------------------------- 7 | import torch 8 | 9 | 10 | def postprocess(out, depth_mode, conf_mode): 11 | """ 12 | extract 3D points/confidence from prediction head output 13 | """ 14 | fmap = out.permute(0, 2, 3, 1) # B,H,W,3 15 | res = dict(pts3d=reg_dense_depth(fmap[:, :, :, 0:3], mode=depth_mode)) 16 | 17 | if conf_mode is not None: 18 | res['conf'] = reg_dense_conf(fmap[:, :, :, 3], mode=conf_mode) 19 | return res 20 | 21 | 22 | def reg_dense_depth(xyz, mode): 23 | """ 24 | extract 3D points from prediction head output 25 | """ 26 | mode, vmin, vmax = mode 27 | 28 | no_bounds = (vmin == -float('inf')) and (vmax == float('inf')) 29 | assert no_bounds 30 | 31 | if mode == 'linear': 32 | if no_bounds: 33 | return xyz # [-inf, +inf] 34 | return xyz.clip(min=vmin, max=vmax) 35 | 36 | # distance to origin 37 | d = xyz.norm(dim=-1, keepdim=True) 38 | xyz = xyz / d.clip(min=1e-8) 39 | 40 | if mode == 'square': 41 | return xyz * d.square() 42 | 43 | if mode == 'exp': 44 | return xyz * torch.expm1(d) 45 | 46 | raise ValueError(f'bad {mode=}') 47 | 48 | 49 | def reg_dense_conf(x, mode): 50 | """ 51 | extract confidence from prediction head output 52 | """ 53 | mode, vmin, vmax = mode 54 | if mode == 'exp': 55 | return vmin + x.exp().clip(max=vmax-vmin) 56 | if mode == 'sigmoid': 57 | return (vmax - vmin) * torch.sigmoid(x) + vmin 58 | raise ValueError(f'bad {mode=}') 59 | -------------------------------------------------------------------------------- /third_party/dust3r/dust3r/image_pairs.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | # 4 | # -------------------------------------------------------- 5 | # utilities needed to load image pairs 6 | # -------------------------------------------------------- 7 | import numpy as np 8 | import torch 9 | import itertools 10 | 11 | 12 | def make_pairs(imgs, scene_graph='complete', prefilter=None, symmetrize=True): 13 | pairs = [] 14 | 15 | if scene_graph == 'complete': # complete graph 16 | for i in range(len(imgs)): 17 | for j in range(i): 18 | pairs.append((imgs[i], imgs[j])) 19 | 20 | elif scene_graph.startswith('swin'): 21 | winsize = int(scene_graph.split('-')[1]) if '-' in scene_graph else 3 22 | for i in range(len(imgs)): 23 | for j in range(winsize): 24 | idx = (i + j) % len(imgs) # explicit loop closure 25 | pairs.append((imgs[i], imgs[idx])) 26 | 27 | elif scene_graph.startswith('oneref'): 28 | refid = int(scene_graph.split('-')[1]) if '-' in scene_graph else 0 29 | for j in range(len(imgs)): 30 | if j != refid: 31 | pairs.append((imgs[refid], imgs[j])) 32 | 33 | elif scene_graph == 'pairs': 34 | assert len(imgs) % 2 == 0 35 | for i in range(0, len(imgs), 2): 36 | pairs.append((imgs[i], imgs[i+1])) 37 | 38 | if symmetrize: 39 | pairs += [(img2, img1) for img1, img2 in pairs] 40 | 41 | # now, remove edges 42 | if isinstance(prefilter, str) and prefilter.startswith('seq'): 43 | pairs = filter_pairs_seq(pairs, int(prefilter[3:])) 44 | 45 | if isinstance(prefilter, str) and prefilter.startswith('cyc'): 46 | pairs = filter_pairs_seq(pairs, int(prefilter[3:]), cyclic=True) 47 | 48 | return pairs 49 | 50 | def make_pairs_fast(imgs, scene_graph='complete', prefilter=None, symmetrize=True): 51 | pairs = [] 52 | 53 | if scene_graph == 'complete': # complete graph 54 | pairs = list(itertools.combinations(imgs, 2)) 55 | 56 | elif scene_graph.startswith('swin'): 57 | winsize = int(scene_graph.split('-')[1]) if '-' in scene_graph else 3 58 | for i in range(len(imgs)): 59 | for j in range(winsize): 60 | idx = (i + j) % len(imgs) # explicit loop closure 61 | pairs.append((imgs[i], imgs[idx])) 62 | 63 | elif scene_graph.startswith('oneref'): 64 | refid = int(scene_graph.split('-')[1]) if '-' in scene_graph else 0 65 | for j in range(len(imgs)): 66 | if j != refid: 67 | pairs.append((imgs[refid], imgs[j])) 68 | 69 | elif scene_graph == 'pairs': 70 | assert len(imgs) % 2 == 0 71 | for i in range(0, len(imgs), 2): 72 | pairs.append((imgs[i], imgs[i+1])) 73 | 74 | if symmetrize: 75 | pairs += [(img2, img1) for img1, img2 in pairs] 76 | 77 | # now, remove edges 78 | if isinstance(prefilter, str) and prefilter.startswith('seq'): 79 | pairs = filter_pairs_seq(pairs, int(prefilter[3:])) 80 | 81 | if isinstance(prefilter, str) and prefilter.startswith('cyc'): 82 | pairs = filter_pairs_seq(pairs, int(prefilter[3:]), cyclic=True) 83 | 84 | return pairs 85 | 86 | def sel(x, kept): 87 | if isinstance(x, dict): 88 | return {k: sel(v, kept) for k, v in x.items()} 89 | if isinstance(x, (torch.Tensor, np.ndarray)): 90 | return x[kept] 91 | if isinstance(x, (tuple, list)): 92 | return type(x)([x[k] for k in kept]) 93 | 94 | 95 | def _filter_edges_seq(edges, seq_dis_thr, cyclic=False): 96 | # number of images 97 | n = max(max(e) for e in edges)+1 98 | 99 | kept = [] 100 | for e, (i, j) in enumerate(edges): 101 | dis = abs(i-j) 102 | if cyclic: 103 | dis = min(dis, abs(i+n-j), abs(i-n-j)) 104 | if dis <= seq_dis_thr: 105 | kept.append(e) 106 | return kept 107 | 108 | 109 | def filter_pairs_seq(pairs, seq_dis_thr, cyclic=False): 110 | edges = [(img1['idx'], img2['idx']) for img1, img2 in pairs] 111 | kept = _filter_edges_seq(edges, seq_dis_thr, cyclic=cyclic) 112 | return [pairs[i] for i in kept] 113 | 114 | 115 | def filter_edges_seq(view1, view2, pred1, pred2, seq_dis_thr, cyclic=False): 116 | edges = [(int(i), int(j)) for i, j in zip(view1['idx'], view2['idx'])] 117 | kept = _filter_edges_seq(edges, seq_dis_thr, cyclic=cyclic) 118 | print(f'>> Filtering edges more than {seq_dis_thr} frames apart: kept {len(kept)}/{len(edges)} edges') 119 | return sel(view1, kept), sel(view2, kept), sel(pred1, kept), sel(pred2, kept) 120 | -------------------------------------------------------------------------------- /third_party/dust3r/dust3r/inference.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | # 4 | # -------------------------------------------------------- 5 | # utilities needed for the inference 6 | # -------------------------------------------------------- 7 | import tqdm 8 | import torch 9 | from dust3r.utils.device import to_cpu, collate_with_cat 10 | from dust3r.model import AsymmetricCroCo3DStereo, inf # noqa: F401, needed when loading the model 11 | from dust3r.utils.misc import invalid_to_nans 12 | from dust3r.utils.geometry import depthmap_to_pts3d, geotrf 13 | 14 | 15 | def load_model(model_path, device): 16 | print('... loading model from', model_path) 17 | ckpt = torch.load(model_path, map_location='cpu') 18 | args = ckpt['args'].model.replace("ManyAR_PatchEmbed", "PatchEmbedDust3R") 19 | if 'landscape_only' not in args: 20 | args = args[:-1] + ', landscape_only=False)' 21 | else: 22 | args = args.replace(" ", "").replace('landscape_only=True', 'landscape_only=False') 23 | assert "landscape_only=False" in args 24 | print(f"instantiating : {args}") 25 | net = eval(args) 26 | print(net.load_state_dict(ckpt['model'], strict=False)) 27 | return net.to(device) 28 | 29 | 30 | def _interleave_imgs(img1, img2): 31 | res = {} 32 | for key, value1 in img1.items(): 33 | value2 = img2[key] 34 | if isinstance(value1, torch.Tensor): 35 | value = torch.stack((value1, value2), dim=1).flatten(0, 1) 36 | else: 37 | value = [x for pair in zip(value1, value2) for x in pair] 38 | res[key] = value 39 | return res 40 | 41 | 42 | def make_batch_symmetric(batch): 43 | view1, view2 = batch 44 | view1, view2 = (_interleave_imgs(view1, view2), _interleave_imgs(view2, view1)) 45 | return view1, view2 46 | 47 | 48 | def loss_of_one_batch(batch, model, criterion, device, symmetrize_batch=False, use_amp=False, ret=None): 49 | view1, view2 = batch 50 | for view in batch: 51 | for name in 'img pts3d valid_mask camera_pose camera_intrinsics F_matrix corres'.split(): # pseudo_focal 52 | if name not in view: 53 | continue 54 | view[name] = view[name].to(device, non_blocking=True) 55 | 56 | if symmetrize_batch: 57 | view1, view2 = make_batch_symmetric(batch) 58 | 59 | with torch.cuda.amp.autocast(enabled=bool(use_amp)): 60 | pred1, pred2 = model(view1, view2) 61 | 62 | # loss is supposed to be symmetric 63 | with torch.cuda.amp.autocast(enabled=False): 64 | loss = criterion(view1, view2, pred1, pred2) if criterion is not None else None 65 | 66 | result = dict(view1=view1, view2=view2, pred1=pred1, pred2=pred2, loss=loss) 67 | return result[ret] if ret else result 68 | 69 | 70 | @torch.no_grad() 71 | def inference(pairs, model, device, batch_size=8): 72 | print(f'>> Inference with model on {len(pairs)} image pairs') 73 | result = [] 74 | 75 | # first, check if all images have the same size 76 | multiple_shapes = not (check_if_same_size(pairs)) 77 | if multiple_shapes: # force bs=1 78 | batch_size = 1 79 | 80 | for i in tqdm.trange(0, len(pairs), batch_size): 81 | res = loss_of_one_batch(collate_with_cat(pairs[i:i+batch_size]), model, None, device) 82 | result.append(to_cpu(res)) 83 | 84 | result = collate_with_cat(result, lists=multiple_shapes) 85 | 86 | torch.cuda.empty_cache() 87 | return result 88 | 89 | 90 | def check_if_same_size(pairs): 91 | shapes1 = [img1['img'].shape[-2:] for img1, img2 in pairs] 92 | shapes2 = [img2['img'].shape[-2:] for img1, img2 in pairs] 93 | return all(shapes1[0] == s for s in shapes1) and all(shapes2[0] == s for s in shapes2) 94 | 95 | 96 | def get_pred_pts3d(gt, pred, use_pose=False): 97 | if 'depth' in pred and 'pseudo_focal' in pred: 98 | try: 99 | pp = gt['camera_intrinsics'][..., :2, 2] 100 | except KeyError: 101 | pp = None 102 | pts3d = depthmap_to_pts3d(**pred, pp=pp) 103 | 104 | elif 'pts3d' in pred: 105 | # pts3d from my camera 106 | pts3d = pred['pts3d'] 107 | 108 | elif 'pts3d_in_other_view' in pred: 109 | # pts3d from the other camera, already transformed 110 | assert use_pose is True 111 | return pred['pts3d_in_other_view'] # return! 112 | 113 | if use_pose: 114 | camera_pose = pred.get('camera_pose') 115 | assert camera_pose is not None 116 | pts3d = geotrf(camera_pose, pts3d) 117 | 118 | return pts3d 119 | 120 | 121 | def find_opt_scaling(gt_pts1, gt_pts2, pr_pts1, pr_pts2=None, fit_mode='weiszfeld_stop_grad', valid1=None, valid2=None): 122 | assert gt_pts1.ndim == pr_pts1.ndim == 4 123 | assert gt_pts1.shape == pr_pts1.shape 124 | if gt_pts2 is not None: 125 | assert gt_pts2.ndim == pr_pts2.ndim == 4 126 | assert gt_pts2.shape == pr_pts2.shape 127 | 128 | # concat the pointcloud 129 | nan_gt_pts1 = invalid_to_nans(gt_pts1, valid1).flatten(1, 2) 130 | nan_gt_pts2 = invalid_to_nans(gt_pts2, valid2).flatten(1, 2) if gt_pts2 is not None else None 131 | 132 | pr_pts1 = invalid_to_nans(pr_pts1, valid1).flatten(1, 2) 133 | pr_pts2 = invalid_to_nans(pr_pts2, valid2).flatten(1, 2) if pr_pts2 is not None else None 134 | 135 | all_gt = torch.cat((nan_gt_pts1, nan_gt_pts2), dim=1) if gt_pts2 is not None else nan_gt_pts1 136 | all_pr = torch.cat((pr_pts1, pr_pts2), dim=1) if pr_pts2 is not None else pr_pts1 137 | 138 | dot_gt_pr = (all_pr * all_gt).sum(dim=-1) 139 | dot_gt_gt = all_gt.square().sum(dim=-1) 140 | 141 | if fit_mode.startswith('avg'): 142 | # scaling = (all_pr / all_gt).view(B, -1).mean(dim=1) 143 | scaling = dot_gt_pr.nanmean(dim=1) / dot_gt_gt.nanmean(dim=1) 144 | elif fit_mode.startswith('median'): 145 | scaling = (dot_gt_pr / dot_gt_gt).nanmedian(dim=1).values 146 | elif fit_mode.startswith('weiszfeld'): 147 | # init scaling with l2 closed form 148 | scaling = dot_gt_pr.nanmean(dim=1) / dot_gt_gt.nanmean(dim=1) 149 | # iterative re-weighted least-squares 150 | for iter in range(10): 151 | # re-weighting by inverse of distance 152 | dis = (all_pr - scaling.view(-1, 1, 1) * all_gt).norm(dim=-1) 153 | # print(dis.nanmean(-1)) 154 | w = dis.clip_(min=1e-8).reciprocal() 155 | # update the scaling with the new weights 156 | scaling = (w * dot_gt_pr).nanmean(dim=1) / (w * dot_gt_gt).nanmean(dim=1) 157 | else: 158 | raise ValueError(f'bad {fit_mode=}') 159 | 160 | if fit_mode.endswith('stop_grad'): 161 | scaling = scaling.detach() 162 | 163 | scaling = scaling.clip(min=1e-3) 164 | # assert scaling.isfinite().all(), bb() 165 | return scaling 166 | -------------------------------------------------------------------------------- /third_party/dust3r/dust3r/optim_factory.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | # 4 | # -------------------------------------------------------- 5 | # optimization functions 6 | # -------------------------------------------------------- 7 | 8 | 9 | def adjust_learning_rate_by_lr(optimizer, lr): 10 | for param_group in optimizer.param_groups: 11 | if "lr_scale" in param_group: 12 | param_group["lr"] = lr * param_group["lr_scale"] 13 | else: 14 | param_group["lr"] = lr 15 | -------------------------------------------------------------------------------- /third_party/dust3r/dust3r/patch_embed.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | # 4 | # -------------------------------------------------------- 5 | # PatchEmbed implementation for DUST3R, 6 | # in particular ManyAR_PatchEmbed that Handle images with non-square aspect ratio 7 | # -------------------------------------------------------- 8 | import torch 9 | import dust3r.utils.path_to_croco # noqa: F401 10 | from models.blocks import PatchEmbed # noqa 11 | 12 | 13 | def get_patch_embed(patch_embed_cls, img_size, patch_size, enc_embed_dim): 14 | assert patch_embed_cls in ['PatchEmbedDust3R', 'ManyAR_PatchEmbed'] 15 | patch_embed = eval(patch_embed_cls)(img_size, patch_size, 3, enc_embed_dim) 16 | return patch_embed 17 | 18 | 19 | class PatchEmbedDust3R(PatchEmbed): 20 | def forward(self, x, **kw): 21 | B, C, H, W = x.shape 22 | assert H % self.patch_size[0] == 0, f"Input image height ({H}) is not a multiple of patch size ({self.patch_size[0]})." 23 | assert W % self.patch_size[1] == 0, f"Input image width ({W}) is not a multiple of patch size ({self.patch_size[1]})." 24 | x = self.proj(x) 25 | pos = self.position_getter(B, x.size(2), x.size(3), x.device) 26 | if self.flatten: 27 | x = x.flatten(2).transpose(1, 2) # BCHW -> BNC 28 | x = self.norm(x) 29 | return x, pos 30 | 31 | 32 | class ManyAR_PatchEmbed (PatchEmbed): 33 | """ Handle images with non-square aspect ratio. 34 | All images in the same batch have the same aspect ratio. 35 | true_shape = [(height, width) ...] indicates the actual shape of each image. 36 | """ 37 | 38 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True): 39 | self.embed_dim = embed_dim 40 | super().__init__(img_size, patch_size, in_chans, embed_dim, norm_layer, flatten) 41 | 42 | def forward(self, img, true_shape): 43 | B, C, H, W = img.shape 44 | assert W >= H, f'img should be in landscape mode, but got {W=} {H=}' 45 | assert H % self.patch_size[0] == 0, f"Input image height ({H}) is not a multiple of patch size ({self.patch_size[0]})." 46 | assert W % self.patch_size[1] == 0, f"Input image width ({W}) is not a multiple of patch size ({self.patch_size[1]})." 47 | assert true_shape.shape == (B, 2), f"true_shape has the wrong shape={true_shape.shape}" 48 | 49 | # size expressed in tokens 50 | W //= self.patch_size[0] 51 | H //= self.patch_size[1] 52 | n_tokens = H * W 53 | 54 | height, width = true_shape.T 55 | is_landscape = (width >= height) 56 | is_portrait = ~is_landscape 57 | 58 | # allocate result 59 | x = img.new_zeros((B, n_tokens, self.embed_dim)) 60 | pos = img.new_zeros((B, n_tokens, 2), dtype=torch.int64) 61 | 62 | # linear projection, transposed if necessary 63 | x[is_landscape] = self.proj(img[is_landscape]).permute(0, 2, 3, 1).flatten(1, 2).float() 64 | x[is_portrait] = self.proj(img[is_portrait].swapaxes(-1, -2)).permute(0, 2, 3, 1).flatten(1, 2).float() 65 | 66 | pos[is_landscape] = self.position_getter(1, H, W, pos.device) 67 | pos[is_portrait] = self.position_getter(1, W, H, pos.device) 68 | 69 | x = self.norm(x) 70 | return x, pos 71 | -------------------------------------------------------------------------------- /third_party/dust3r/dust3r/post_process.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | # 4 | # -------------------------------------------------------- 5 | # utilities for interpreting the DUST3R output 6 | # -------------------------------------------------------- 7 | import numpy as np 8 | import torch 9 | from dust3r.utils.geometry import xy_grid 10 | 11 | 12 | def estimate_focal_knowing_depth(pts3d, pp, focal_mode='median', min_focal=0.5, max_focal=3.5): 13 | """ Reprojection method, for when the absolute depth is known: 14 | 1) estimate the camera focal using a robust estimator 15 | 2) reproject points onto true rays, minimizing a certain error 16 | """ 17 | B, H, W, THREE = pts3d.shape 18 | assert THREE == 3 19 | 20 | # centered pixel grid 21 | pixels = xy_grid(W, H, device=pts3d.device).view(1, -1, 2) - pp.view(-1, 1, 2) # B,HW,2 22 | pts3d = pts3d.flatten(1, 2) # (B, HW, 3) 23 | 24 | if focal_mode == 'median': 25 | with torch.no_grad(): 26 | # direct estimation of focal 27 | u, v = pixels.unbind(dim=-1) 28 | x, y, z = pts3d.unbind(dim=-1) 29 | fx_votes = (u * z) / x 30 | fy_votes = (v * z) / y 31 | 32 | # assume square pixels, hence same focal for X and Y 33 | f_votes = torch.cat((fx_votes.view(B, -1), fy_votes.view(B, -1)), dim=-1) 34 | focal = torch.nanmedian(f_votes, dim=-1).values 35 | 36 | elif focal_mode == 'weiszfeld': 37 | # init focal with l2 closed form 38 | # we try to find focal = argmin Sum | pixel - focal * (x,y)/z| 39 | xy_over_z = (pts3d[..., :2] / pts3d[..., 2:3]).nan_to_num(posinf=0, neginf=0) # homogeneous (x,y,1) 40 | 41 | dot_xy_px = (xy_over_z * pixels).sum(dim=-1) 42 | dot_xy_xy = xy_over_z.square().sum(dim=-1) 43 | 44 | focal = dot_xy_px.mean(dim=1) / dot_xy_xy.mean(dim=1) 45 | 46 | # iterative re-weighted least-squares 47 | for iter in range(10): 48 | # re-weighting by inverse of distance 49 | dis = (pixels - focal.view(-1, 1, 1) * xy_over_z).norm(dim=-1) 50 | # print(dis.nanmean(-1)) 51 | w = dis.clip(min=1e-8).reciprocal() 52 | # update the scaling with the new weights 53 | focal = (w * dot_xy_px).mean(dim=1) / (w * dot_xy_xy).mean(dim=1) 54 | else: 55 | raise ValueError(f'bad {focal_mode=}') 56 | 57 | focal_base = max(H, W) / (2 * np.tan(np.deg2rad(60) / 2)) # size / 1.1547005383792515 58 | focal = focal.clip(min=min_focal*focal_base, max=max_focal*focal_base) 59 | # print(focal) 60 | return focal 61 | -------------------------------------------------------------------------------- /third_party/dust3r/dust3r/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | -------------------------------------------------------------------------------- /third_party/dust3r/dust3r/utils/device.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | # 4 | # -------------------------------------------------------- 5 | # utilitary functions for DUSt3R 6 | # -------------------------------------------------------- 7 | import numpy as np 8 | import torch 9 | 10 | 11 | def todevice(batch, device, callback=None, non_blocking=False): 12 | ''' Transfer some variables to another device (i.e. GPU, CPU:torch, CPU:numpy). 13 | 14 | batch: list, tuple, dict of tensors or other things 15 | device: pytorch device or 'numpy' 16 | callback: function that would be called on every sub-elements. 17 | ''' 18 | if callback: 19 | batch = callback(batch) 20 | 21 | if isinstance(batch, dict): 22 | return {k: todevice(v, device) for k, v in batch.items()} 23 | 24 | if isinstance(batch, (tuple, list)): 25 | return type(batch)(todevice(x, device) for x in batch) 26 | 27 | x = batch 28 | if device == 'numpy': 29 | if isinstance(x, torch.Tensor): 30 | x = x.detach().cpu().numpy() 31 | elif x is not None: 32 | if isinstance(x, np.ndarray): 33 | x = torch.from_numpy(x) 34 | if torch.is_tensor(x): 35 | x = x.to(device, non_blocking=non_blocking) 36 | return x 37 | 38 | 39 | to_device = todevice # alias 40 | 41 | 42 | def to_numpy(x): return todevice(x, 'numpy') 43 | def to_cpu(x): return todevice(x, 'cpu') 44 | def to_cuda(x): return todevice(x, 'cuda') 45 | 46 | 47 | def collate_with_cat(whatever, lists=False): 48 | if isinstance(whatever, dict): 49 | return {k: collate_with_cat(vals, lists=lists) for k, vals in whatever.items()} 50 | 51 | elif isinstance(whatever, (tuple, list)): 52 | if len(whatever) == 0: 53 | return whatever 54 | elem = whatever[0] 55 | T = type(whatever) 56 | 57 | if elem is None: 58 | return None 59 | if isinstance(elem, (bool, float, int, str)): 60 | return whatever 61 | if isinstance(elem, tuple): 62 | return T(collate_with_cat(x, lists=lists) for x in zip(*whatever)) 63 | if isinstance(elem, dict): 64 | return {k: collate_with_cat([e[k] for e in whatever], lists=lists) for k in elem} 65 | 66 | if isinstance(elem, torch.Tensor): 67 | return listify(whatever) if lists else torch.cat(whatever) 68 | if isinstance(elem, np.ndarray): 69 | return listify(whatever) if lists else torch.cat([torch.from_numpy(x) for x in whatever]) 70 | 71 | # otherwise, we just chain lists 72 | return sum(whatever, T()) 73 | 74 | 75 | def listify(elems): 76 | return [x for e in elems for x in e] 77 | -------------------------------------------------------------------------------- /third_party/dust3r/dust3r/utils/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | # 4 | # -------------------------------------------------------- 5 | # utilitary functions for DUSt3R 6 | # -------------------------------------------------------- 7 | import torch 8 | 9 | 10 | def fill_default_args(kwargs, func): 11 | import inspect # a bit hacky but it works reliably 12 | signature = inspect.signature(func) 13 | 14 | for k, v in signature.parameters.items(): 15 | if v.default is inspect.Parameter.empty: 16 | continue 17 | kwargs.setdefault(k, v.default) 18 | 19 | return kwargs 20 | 21 | 22 | def freeze_all_params(modules): 23 | for module in modules: 24 | try: 25 | for n, param in module.named_parameters(): 26 | param.requires_grad = False 27 | except AttributeError: 28 | # module is directly a parameter 29 | module.requires_grad = False 30 | 31 | 32 | def is_symmetrized(gt1, gt2): 33 | x = gt1['instance'] 34 | y = gt2['instance'] 35 | if len(x) == len(y) and len(x) == 1: 36 | return False # special case of batchsize 1 37 | ok = True 38 | for i in range(0, len(x), 2): 39 | ok = ok and (x[i] == y[i+1]) and (x[i+1] == y[i]) 40 | return ok 41 | 42 | 43 | def flip(tensor): 44 | """ flip so that tensor[0::2] <=> tensor[1::2] """ 45 | return torch.stack((tensor[1::2], tensor[0::2]), dim=1).flatten(0, 1) 46 | 47 | 48 | def interleave(tensor1, tensor2): 49 | res1 = torch.stack((tensor1, tensor2), dim=1).flatten(0, 1) 50 | res2 = torch.stack((tensor2, tensor1), dim=1).flatten(0, 1) 51 | return res1, res2 52 | 53 | 54 | def transpose_to_landscape(head, activate=True): 55 | """ Predict in the correct aspect-ratio, 56 | then transpose the result in landscape 57 | and stack everything back together. 58 | """ 59 | def wrapper_no(decout, true_shape): 60 | B = len(true_shape) 61 | assert true_shape[0:1].allclose(true_shape), 'true_shape must be all identical' 62 | H, W = true_shape[0].cpu().tolist() 63 | res = head(decout, (H, W)) 64 | return res 65 | 66 | def wrapper_yes(decout, true_shape): 67 | B = len(true_shape) 68 | # by definition, the batch is in landscape mode so W >= H 69 | H, W = int(true_shape.min()), int(true_shape.max()) 70 | 71 | height, width = true_shape.T 72 | is_landscape = (width >= height) 73 | is_portrait = ~is_landscape 74 | 75 | # true_shape = true_shape.cpu() 76 | if is_landscape.all(): 77 | return head(decout, (H, W)) 78 | if is_portrait.all(): 79 | return transposed(head(decout, (W, H))) 80 | 81 | # batch is a mix of both portraint & landscape 82 | def selout(ar): return [d[ar] for d in decout] 83 | l_result = head(selout(is_landscape), (H, W)) 84 | p_result = transposed(head(selout(is_portrait), (W, H))) 85 | 86 | # allocate full result 87 | result = {} 88 | for k in l_result | p_result: 89 | x = l_result[k].new(B, *l_result[k].shape[1:]) 90 | x[is_landscape] = l_result[k] 91 | x[is_portrait] = p_result[k] 92 | result[k] = x 93 | 94 | return result 95 | 96 | return wrapper_yes if activate else wrapper_no 97 | 98 | 99 | def transposed(dic): 100 | return {k: v.swapaxes(1, 2) for k, v in dic.items()} 101 | 102 | 103 | def invalid_to_nans(arr, valid_mask, ndim=999): 104 | if valid_mask is not None: 105 | arr = arr.clone() 106 | arr[~valid_mask] = float('nan') 107 | if arr.ndim > ndim: 108 | arr = arr.flatten(-2 - (arr.ndim - ndim), -2) 109 | return arr 110 | 111 | 112 | def invalid_to_zeros(arr, valid_mask, ndim=999): 113 | if valid_mask is not None: 114 | arr = arr.clone() 115 | arr[~valid_mask] = 0 116 | nnz = valid_mask.view(len(valid_mask), -1).sum(1) 117 | else: 118 | nnz = arr.numel() // len(arr) if len(arr) else 0 # number of point per image 119 | if arr.ndim > ndim: 120 | arr = arr.flatten(-2 - (arr.ndim - ndim), -2) 121 | return arr, nnz 122 | -------------------------------------------------------------------------------- /third_party/dust3r/dust3r/utils/path_to_croco.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024-present Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | # 4 | # -------------------------------------------------------- 5 | # CroCo submodule import 6 | # -------------------------------------------------------- 7 | 8 | import sys 9 | import os.path as path 10 | HERE_PATH = path.normpath(path.dirname(__file__)) 11 | CROCO_REPO_PATH = path.normpath(path.join(HERE_PATH, '../../croco')) 12 | CROCO_MODELS_PATH = path.join(CROCO_REPO_PATH, 'models') 13 | # check the presence of models directory in repo to be sure its cloned 14 | if path.isdir(CROCO_MODELS_PATH): 15 | # workaround for sibling import 16 | sys.path.insert(0, CROCO_REPO_PATH) 17 | else: 18 | raise ImportError(f"croco is not initialized, could not find: {CROCO_MODELS_PATH}.\n " 19 | "Did you forget to run 'git submodule update --init --recursive' ?") 20 | -------------------------------------------------------------------------------- /utils/adain.py: -------------------------------------------------------------------------------- 1 | def masked_adain(content_feat, style_feat, content_mask, style_mask): 2 | assert (content_feat.size()[:2] == style_feat.size()[:2]) 3 | size = content_feat.size() 4 | style_mean, style_std = calc_mean_std(style_feat, mask=style_mask) 5 | content_mean, content_std = calc_mean_std(content_feat, mask=content_mask) 6 | normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size) 7 | style_normalized_feat = normalized_feat * style_std.expand(size) + style_mean.expand(size) 8 | return content_feat * (1 - content_mask) + style_normalized_feat * content_mask 9 | 10 | 11 | def adain(content_feat, style_feat): 12 | assert (content_feat.size()[:2] == style_feat.size()[:2]) 13 | size = content_feat.size() 14 | style_mean, style_std = calc_mean_std(style_feat) 15 | content_mean, content_std = calc_mean_std(content_feat) 16 | normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size) 17 | return normalized_feat * style_std.expand(size) + style_mean.expand(size) 18 | 19 | 20 | def calc_mean_std(feat, eps=1e-5, mask=None): 21 | # eps is a small value added to the variance to avoid divide-by-zero. 22 | size = feat.size() 23 | if len(size) == 2: 24 | return calc_mean_std_2d(feat, eps, mask) 25 | 26 | assert (len(size) == 3) 27 | C = size[0] 28 | if mask is not None: 29 | feat_var = feat.view(C, -1)[:, mask.view(-1) == 1].var(dim=1) + eps 30 | feat_std = feat_var.sqrt().view(C, 1, 1) 31 | feat_mean = feat.view(C, -1)[:, mask.view(-1) == 1].mean(dim=1).view(C, 1, 1) 32 | else: 33 | feat_var = feat.view(C, -1).var(dim=1) + eps 34 | feat_std = feat_var.sqrt().view(C, 1, 1) 35 | feat_mean = feat.view(C, -1).mean(dim=1).view(C, 1, 1) 36 | 37 | return feat_mean, feat_std 38 | 39 | 40 | def calc_mean_std_2d(feat, eps=1e-5, mask=None): 41 | # eps is a small value added to the variance to avoid divide-by-zero. 42 | size = feat.size() 43 | assert (len(size) == 2) 44 | C = size[0] 45 | if mask is not None: 46 | feat_var = feat.view(C, -1)[:, mask.view(-1) == 1].var(dim=1) + eps 47 | feat_std = feat_var.sqrt().view(C, 1) 48 | feat_mean = feat.view(C, -1)[:, mask.view(-1) == 1].mean(dim=1).view(C, 1) 49 | else: 50 | feat_var = feat.view(C, -1).var(dim=1) + eps 51 | feat_std = feat_var.sqrt().view(C, 1) 52 | feat_mean = feat.view(C, -1).mean(dim=1).view(C, 1) 53 | 54 | return feat_mean, feat_std 55 | -------------------------------------------------------------------------------- /utils/logging.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | # configure once, at import time 4 | logger = logging.getLogger("ReStyle3D") 5 | logger.setLevel(logging.INFO) 6 | 7 | handler = logging.StreamHandler() 8 | handler.setFormatter(logging.Formatter( 9 | "%(asctime)s - %(name)s - %(levelname)s - %(message)s" 10 | )) 11 | logger.addHandler(handler) -------------------------------------------------------------------------------- /viewformer/__init__.py: -------------------------------------------------------------------------------- 1 | from .viewtransfer_pipeline import ViewTransferSDXLPipeline 2 | from .UNet2DConditionalModel import UNet2DConditionModel 3 | 4 | -------------------------------------------------------------------------------- /viewformer/image_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | import cv2 4 | import torch 5 | from torchvision import transforms 6 | 7 | def load_and_resize_image(path, size=(1024, 1024)): 8 | img = Image.open(path).convert("RGB").resize(size) 9 | return transforms.ToTensor()(img).unsqueeze(0) 10 | 11 | 12 | def match_histograms_masked_full(source_img, reference_img, mask): 13 | """ 14 | Match histograms based on masked region but apply to whole image 15 | 16 | Parameters: 17 | source_img: numpy array (H x W x 3) - Source image to be modified 18 | reference_img: numpy array (H x W x 3) - Reference image to match 19 | mask: numpy array (H x W x 3) - RGB mask 20 | """ 21 | # Convert to float32 22 | source_float = source_img.astype(np.float32) / 255.0 23 | reference_float = reference_img.astype(np.float32) / 255.0 24 | 25 | # Initialize output image 26 | matched = source_float.copy() 27 | 28 | # Use first channel of RGB mask and ensure it's binary 29 | mask_channel = mask[:,:,0] if len(mask.shape) == 3 else mask 30 | if mask_channel.dtype != np.uint8: 31 | mask_channel = mask_channel.astype(np.uint8) * 255 32 | _, mask_binary = cv2.threshold(mask_channel, 127, 255, cv2.THRESH_BINARY) 33 | mask_binary = cv2.bitwise_not(mask_binary) # Invert the mask 34 | 35 | # Create boolean mask 36 | bool_mask = mask_binary > 0 37 | 38 | for i in range(3): 39 | # Get masked pixels for computing transformation 40 | source_channel = source_float[:,:,i] 41 | reference_channel = reference_float[:,:,i] 42 | 43 | # Apply boolean mask correctly 44 | source_masked = source_channel[bool_mask] 45 | reference_masked = reference_channel[bool_mask] 46 | 47 | if len(source_masked) > 0 and len(reference_masked) > 0: 48 | # Use more bins for better precision 49 | nbins = 256 50 | source_hist, bin_edges = np.histogram(source_masked, nbins, [0, 1]) 51 | reference_hist, _ = np.histogram(reference_masked, nbins, [0, 1]) 52 | 53 | # Add small epsilon to avoid division by zero 54 | source_hist = source_hist + 1e-8 55 | reference_hist = reference_hist + 1e-8 56 | 57 | # Calculate normalized cumulative histograms 58 | source_cdf = source_hist.cumsum() / source_hist.sum() 59 | reference_cdf = reference_hist.cumsum() / reference_hist.sum() 60 | 61 | # Create interpolation function 62 | bins = np.linspace(0, 1, nbins) 63 | lookup_table = np.interp(source_cdf, reference_cdf, bins) 64 | 65 | # Apply transformation to entire channel 66 | channel_values = source_float[:,:,i] * (nbins-1) 67 | channel_indices = channel_values.astype(int) 68 | matched[:,:,i] = lookup_table[channel_indices] 69 | 70 | # Ensure output is in valid range 71 | matched = np.clip(matched, 0, 1) 72 | matched = (matched * 255).astype(np.uint8) 73 | 74 | return Image.fromarray(matched) --------------------------------------------------------------------------------