├── .gitignore ├── LICENSE ├── README.md ├── assets └── demo.png ├── conf ├── config.yaml └── diffusion.yml ├── data └── demo │ ├── jellycat │ ├── 001.jpg │ ├── 002.jpg │ ├── 003.jpg │ └── 004.jpg │ ├── jordan │ ├── 001.png │ ├── 002.png │ ├── 003.png │ ├── 004.png │ ├── 005.png │ ├── 006.png │ ├── 007.png │ └── 008.png │ ├── kew_gardens_ruined_arch │ ├── 001.jpeg │ ├── 002.jpeg │ └── 003.jpeg │ └── kotor_cathedral │ ├── 001.jpeg │ ├── 002.jpeg │ ├── 003.jpeg │ ├── 004.jpeg │ ├── 005.jpeg │ └── 006.jpeg ├── diffusionsfm ├── __init__.py ├── dataset │ ├── __init__.py │ ├── co3d_v2.py │ └── custom.py ├── eval │ ├── __init__.py │ ├── eval_category.py │ └── eval_jobs.py ├── inference │ ├── __init__.py │ ├── ddim.py │ ├── load_model.py │ └── predict.py ├── model │ ├── base_model.py │ ├── blocks.py │ ├── diffuser.py │ ├── diffuser_dpt.py │ ├── dit.py │ ├── feature_extractors.py │ ├── memory_efficient_attention.py │ └── scheduler.py └── utils │ ├── __init__.py │ ├── configs.py │ ├── distortion.py │ ├── distributed.py │ ├── geometry.py │ ├── normalize.py │ ├── rays.py │ ├── slurm.py │ └── visualization.py ├── docs ├── eval.md └── train.md ├── gradio_app.py ├── requirements.txt └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | .DS_Store 3 | wandb/ 4 | slurm_logs/ 5 | output/ 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Qitao Zhao 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DiffusionSfM 2 | 3 | This repository contains the official implementation for **DiffusionSfM: Predicting Structure and Motion** 4 | **via Ray Origin and Endpoint Diffusion**. The paper has been accepted to [CVPR 2025](https://cvpr.thecvf.com/Conferences/2025). 5 | 6 | [Project Page](https://qitaozhao.github.io/DiffusionSfM) | [arXiv](https://arxiv.org/abs/2505.05473) | 7 | 8 | ### News 9 | 10 | - 2025.05.04: Initial code release. 11 | 12 | ## Introduction 13 | 14 | **tl;dr** Given a set of multi-view images, **DiffusionSfM** represents scene geometry and cameras as pixel-wise ray origins and endpoints in a global frame. It learns a denoising diffusion model to infer these elements directly from multi-view inputs. 15 | 16 | ![teaser](https://raw.githubusercontent.com/QitaoZhao/QitaoZhao.github.io/main/research/DiffusionSfM/figures/teaser.png) 17 | 18 | ## Install 19 | 20 | 1. Clone DiffusionSfM: 21 | 22 | ```bash 23 | git clone https://github.com/QitaoZhao/DiffusionSfM.git 24 | cd DiffusionSfM 25 | ``` 26 | 27 | 2. Create the environment and install packages: 28 | 29 | ```bash 30 | conda create -n diffusionsfm python=3.9 31 | conda activate diffusionsfm 32 | 33 | # enable nvcc 34 | conda install -c conda-forge cudatoolkit-dev 35 | 36 | ### torch 37 | # CUDA 11.7 38 | conda install pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.7 -c pytorch -c nvidia 39 | 40 | pip install -r requirements.txt 41 | 42 | ### pytorch3D 43 | # CUDA 11.7 44 | conda install https://anaconda.org/pytorch3d/pytorch3d/0.7.7/download/linux-64/pytorch3d-0.7.7-py39_cu117_pyt201.tar.bz2 45 | 46 | # xformers 47 | conda install xformers -c xformers 48 | ``` 49 | 50 | Tested on: 51 | 52 | - Springdale Linux 8.6 with torch 2.0.1 & CUDA 11.7 on A6000 GPUs. 53 | 54 | > **Note:** If you encounter the error 55 | 56 | > ImportError: .../libtorch_cpu.so: undefined symbol: iJIT_NotifyEvent 57 | 58 | > when importing PyTorch, refer to this [related issue](https://github.com/coleygroup/shepherd-score/issues/1) or try installing Intel MKL explicitly with: 59 | 60 | ``` 61 | conda install mkl==2024.0 62 | ``` 63 | 64 | ## Run Demo 65 | 66 | #### (1) Try the Online Demo 67 | 68 | Check out our interactive demo on Hugging Face: 69 | 70 | 👉 [DiffusionSfM Demo](https://huggingface.co/spaces/qitaoz/DiffusionSfM) 71 | 72 | #### (2) Run the Gradio Demo Locally 73 | 74 | Download the model weights manually from [Hugging Face](https://huggingface.co/qitaoz/DiffusionSfM): 75 | 76 | ```python 77 | from huggingface_hub import hf_hub_download 78 | 79 | filepath = hf_hub_download(repo_id="qitaoz/DiffusionSfM", filename="qitaoz/DiffusionSfM") 80 | ``` 81 | 82 | or [Google Drive](https://drive.google.com/file/d/1NBdq7A1QMFGhIbpK1HT3ATv2S1jXWr2h/view?usp=drive_link): 83 | 84 | ```bash 85 | gdown https://drive.google.com/uc\?id\=1NBdq7A1QMFGhIbpK1HT3ATv2S1jXWr2h 86 | unzip models.zip 87 | ``` 88 | Next run the demo like so: 89 | 90 | ```bash 91 | # first-time running may take a longer time 92 | python gradio_app.py 93 | ``` 94 | 95 | ![teaser](assets/demo.png) 96 | 97 | You can run our model in two ways: 98 | 99 | 1. **Upload Images** — Upload your own multi-view images above. 100 | 2. **Use a Preprocessed Example** — Select one of the pre-collected examples below. 101 | 102 | ## Training 103 | 104 | Set up wandb: 105 | 106 | ```bash 107 | wandb login 108 | ``` 109 | 110 | See [docs/train.md](https://github.com/QitaoZhao/DiffusionSfM/blob/main/docs/train.md) for more detailed instructions on training. 111 | 112 | ## Evaluation 113 | 114 | See [docs/eval.md](https://github.com/QitaoZhao/DiffusionSfM/blob/main/docs/eval.md) for instructions on how to run evaluation code. 115 | 116 | ## Acknowledgments 117 | 118 | This project builds upon [RayDiffusion](https://github.com/jasonyzhang/RayDiffusion). [Amy Lin](https://amyxlase.github.io/) and [Jason Y. Zhang](https://jasonyzhang.com/) developed the initial codebase during the early stages of this project. 119 | 120 | ## Cite DiffusionSfM 121 | 122 | If you find this code helpful, please cite: 123 | 124 | ``` 125 | @inproceedings{zhao2025diffusionsfm, 126 | title={DiffusionSfM: Predicting Structure and Motion via Ray Origin and Endpoint Diffusion}, 127 | author={Qitao Zhao and Amy Lin and Jeff Tan and Jason Y. Zhang and Deva Ramanan and Shubham Tulsiani}, 128 | booktitle={CVPR}, 129 | year={2025} 130 | } 131 | ``` -------------------------------------------------------------------------------- /assets/demo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QitaoZhao/DiffusionSfM/4bb08800721bdcf46b0c823586a2fab4702ff282/assets/demo.png -------------------------------------------------------------------------------- /conf/config.yaml: -------------------------------------------------------------------------------- 1 | training: 2 | resume: False # If True, must set hydra.run.dir accordingly 3 | pretrain_path: "" 4 | interval_visualize: 1000 5 | interval_save_checkpoint: 5000 6 | interval_delete_checkpoint: 10000 7 | interval_evaluate: 5000 8 | delete_all_checkpoints_after_training: False 9 | lr: 1e-4 10 | mixed_precision: True 11 | matmul_precision: high 12 | max_iterations: 100000 13 | batch_size: 64 14 | num_workers: 8 15 | gpu_id: 0 16 | freeze_encoder: True 17 | seed: 0 18 | job_key: "" # Use this for submitit sweeps where timestamps might collide 19 | translation_scale: 1.0 20 | regression: False 21 | prob_unconditional: 0 22 | load_extra_cameras: False 23 | calculate_intrinsics: False 24 | distort: False 25 | normalize_first_camera: True 26 | diffuse_origins_and_endpoints: True 27 | diffuse_depths: False 28 | depth_resolution: 1 29 | dpt_head: False 30 | full_num_patches_x: 16 31 | full_num_patches_y: 16 32 | dpt_encoder_features: True 33 | nearest_neighbor: True 34 | no_bg_targets: True 35 | unit_normalize_scene: False 36 | sd_scale: 2 37 | bfloat: True 38 | first_cam_mediod: True 39 | gradient_clipping: False 40 | l1_loss: False 41 | grad_accumulation: False 42 | reinit: False 43 | 44 | model: 45 | pred_x0: True 46 | model_type: dit 47 | num_patches_x: 16 48 | num_patches_y: 16 49 | depth: 16 50 | num_images: 1 51 | random_num_images: True 52 | feature_extractor: dino 53 | append_ndc: True 54 | within_image: False 55 | use_homogeneous: True 56 | freeze_transformer: False 57 | cond_depth_mask: True 58 | 59 | noise_scheduler: 60 | type: linear 61 | max_timesteps: 100 62 | beta_start: 0.0120 63 | beta_end: 0.00085 64 | marigold_ddim: False 65 | 66 | dataset: 67 | name: co3d 68 | shape: all_train 69 | apply_augmentation: True 70 | use_global_intrinsics: True 71 | mask_holes: True 72 | image_size: 224 73 | 74 | debug: 75 | wandb: True 76 | project_name: diffusionsfm 77 | run_name: 78 | anomaly_detection: False 79 | 80 | hydra: 81 | run: 82 | dir: ./output/${now:%m%d_%H%M%S_%f}${training.job_key} 83 | output_subdir: hydra 84 | -------------------------------------------------------------------------------- /conf/diffusion.yml: -------------------------------------------------------------------------------- 1 | name: diffusion 2 | channels: 3 | - conda-forge 4 | - iopath 5 | - nvidia 6 | - pkgs/main 7 | - pytorch 8 | - xformers 9 | dependencies: 10 | - _libgcc_mutex=0.1=conda_forge 11 | - _openmp_mutex=4.5=2_gnu 12 | - blas=1.0=mkl 13 | - brotli-python=1.0.9=py39h5a03fae_9 14 | - bzip2=1.0.8=h7f98852_4 15 | - ca-certificates=2023.7.22=hbcca054_0 16 | - certifi=2023.7.22=pyhd8ed1ab_0 17 | - charset-normalizer=3.2.0=pyhd8ed1ab_0 18 | - colorama=0.4.6=pyhd8ed1ab_0 19 | - cuda-cudart=11.7.99=0 20 | - cuda-cupti=11.7.101=0 21 | - cuda-libraries=11.7.1=0 22 | - cuda-nvrtc=11.7.99=0 23 | - cuda-nvtx=11.7.91=0 24 | - cuda-runtime=11.7.1=0 25 | - ffmpeg=4.3=hf484d3e_0 26 | - filelock=3.12.2=pyhd8ed1ab_0 27 | - freetype=2.12.1=hca18f0e_1 28 | - fvcore=0.1.5.post20221221=pyhd8ed1ab_0 29 | - gmp=6.2.1=h58526e2_0 30 | - gmpy2=2.1.2=py39h376b7d2_1 31 | - gnutls=3.6.13=h85f3911_1 32 | - idna=3.4=pyhd8ed1ab_0 33 | - intel-openmp=2022.1.0=h9e868ea_3769 34 | - iopath=0.1.9=py39 35 | - jinja2=3.1.2=pyhd8ed1ab_1 36 | - jpeg=9e=h0b41bf4_3 37 | - lame=3.100=h166bdaf_1003 38 | - lcms2=2.15=hfd0df8a_0 39 | - ld_impl_linux-64=2.40=h41732ed_0 40 | - lerc=4.0.0=h27087fc_0 41 | - libblas=3.9.0=16_linux64_mkl 42 | - libcblas=3.9.0=16_linux64_mkl 43 | - libcublas=11.10.3.66=0 44 | - libcufft=10.7.2.124=h4fbf590_0 45 | - libcufile=1.7.1.12=0 46 | - libcurand=10.3.3.129=0 47 | - libcusolver=11.4.0.1=0 48 | - libcusparse=11.7.4.91=0 49 | - libdeflate=1.17=h0b41bf4_0 50 | - libffi=3.3=h58526e2_2 51 | - libgcc-ng=13.1.0=he5830b7_0 52 | - libgomp=13.1.0=he5830b7_0 53 | - libiconv=1.17=h166bdaf_0 54 | - liblapack=3.9.0=16_linux64_mkl 55 | - libnpp=11.7.4.75=0 56 | - libnvjpeg=11.8.0.2=0 57 | - libpng=1.6.39=h753d276_0 58 | - libsqlite=3.42.0=h2797004_0 59 | - libstdcxx-ng=13.1.0=hfd8a6a1_0 60 | - libtiff=4.5.0=h6adf6a1_2 61 | - libwebp-base=1.3.1=hd590300_0 62 | - libxcb=1.13=h7f98852_1004 63 | - libzlib=1.2.13=hd590300_5 64 | - markupsafe=2.1.3=py39hd1e30aa_0 65 | - mkl=2022.1.0=hc2b9512_224 66 | - mpc=1.3.1=hfe3b2da_0 67 | - mpfr=4.2.0=hb012696_0 68 | - mpmath=1.3.0=pyhd8ed1ab_0 69 | - ncurses=6.4=hcb278e6_0 70 | - nettle=3.6=he412f7d_0 71 | - networkx=3.1=pyhd8ed1ab_0 72 | - numpy=1.25.2=py39h6183b62_0 73 | - openh264=2.1.1=h780b84a_0 74 | - openjpeg=2.5.0=hfec8fc6_2 75 | - openssl=1.1.1v=hd590300_0 76 | - pillow=9.4.0=py39h2320bf1_1 77 | - pip=23.2.1=pyhd8ed1ab_0 78 | - portalocker=2.7.0=py39hf3d152e_0 79 | - pthread-stubs=0.4=h36c2ea0_1001 80 | - pysocks=1.7.1=pyha2e5f31_6 81 | - python=3.9.0=hffdb5ce_5_cpython 82 | - python_abi=3.9=3_cp39 83 | - pytorch=2.0.1=py3.9_cuda11.7_cudnn8.5.0_0 84 | - pytorch-cuda=11.7=h778d358_5 85 | - pytorch-mutex=1.0=cuda 86 | - pyyaml=6.0=py39hb9d737c_5 87 | - readline=8.2=h8228510_1 88 | - requests=2.31.0=pyhd8ed1ab_0 89 | - setuptools=68.0.0=pyhd8ed1ab_0 90 | - sqlite=3.42.0=h2c6b66d_0 91 | - sympy=1.12=pypyh9d50eac_103 92 | - tabulate=0.9.0=pyhd8ed1ab_1 93 | - termcolor=2.3.0=pyhd8ed1ab_0 94 | - tk=8.6.12=h27826a3_0 95 | - torchaudio=2.0.2=py39_cu117 96 | - torchtriton=2.0.0=py39 97 | - torchvision=0.15.2=py39_cu117 98 | - tqdm=4.66.1=pyhd8ed1ab_0 99 | - typing_extensions=4.7.1=pyha770c72_0 100 | - tzdata=2023c=h71feb2d_0 101 | - urllib3=2.0.4=pyhd8ed1ab_0 102 | - wheel=0.41.1=pyhd8ed1ab_0 103 | - xformers=0.0.21=py39_cu11.8.0_pyt2.0.1 104 | - xorg-libxau=1.0.11=hd590300_0 105 | - xorg-libxdmcp=1.1.3=h7f98852_0 106 | - xz=5.2.6=h166bdaf_0 107 | - yacs=0.1.8=pyhd8ed1ab_0 108 | - yaml=0.2.5=h7f98852_2 109 | - zlib=1.2.13=hd590300_5 110 | - zstd=1.5.2=hfc55251_7 111 | -------------------------------------------------------------------------------- /data/demo/jellycat/001.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QitaoZhao/DiffusionSfM/4bb08800721bdcf46b0c823586a2fab4702ff282/data/demo/jellycat/001.jpg -------------------------------------------------------------------------------- /data/demo/jellycat/002.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QitaoZhao/DiffusionSfM/4bb08800721bdcf46b0c823586a2fab4702ff282/data/demo/jellycat/002.jpg -------------------------------------------------------------------------------- /data/demo/jellycat/003.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QitaoZhao/DiffusionSfM/4bb08800721bdcf46b0c823586a2fab4702ff282/data/demo/jellycat/003.jpg -------------------------------------------------------------------------------- /data/demo/jellycat/004.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QitaoZhao/DiffusionSfM/4bb08800721bdcf46b0c823586a2fab4702ff282/data/demo/jellycat/004.jpg -------------------------------------------------------------------------------- /data/demo/jordan/001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QitaoZhao/DiffusionSfM/4bb08800721bdcf46b0c823586a2fab4702ff282/data/demo/jordan/001.png -------------------------------------------------------------------------------- /data/demo/jordan/002.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QitaoZhao/DiffusionSfM/4bb08800721bdcf46b0c823586a2fab4702ff282/data/demo/jordan/002.png -------------------------------------------------------------------------------- /data/demo/jordan/003.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QitaoZhao/DiffusionSfM/4bb08800721bdcf46b0c823586a2fab4702ff282/data/demo/jordan/003.png -------------------------------------------------------------------------------- /data/demo/jordan/004.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QitaoZhao/DiffusionSfM/4bb08800721bdcf46b0c823586a2fab4702ff282/data/demo/jordan/004.png -------------------------------------------------------------------------------- /data/demo/jordan/005.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QitaoZhao/DiffusionSfM/4bb08800721bdcf46b0c823586a2fab4702ff282/data/demo/jordan/005.png -------------------------------------------------------------------------------- /data/demo/jordan/006.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QitaoZhao/DiffusionSfM/4bb08800721bdcf46b0c823586a2fab4702ff282/data/demo/jordan/006.png -------------------------------------------------------------------------------- /data/demo/jordan/007.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QitaoZhao/DiffusionSfM/4bb08800721bdcf46b0c823586a2fab4702ff282/data/demo/jordan/007.png -------------------------------------------------------------------------------- /data/demo/jordan/008.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QitaoZhao/DiffusionSfM/4bb08800721bdcf46b0c823586a2fab4702ff282/data/demo/jordan/008.png -------------------------------------------------------------------------------- /data/demo/kew_gardens_ruined_arch/001.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QitaoZhao/DiffusionSfM/4bb08800721bdcf46b0c823586a2fab4702ff282/data/demo/kew_gardens_ruined_arch/001.jpeg -------------------------------------------------------------------------------- /data/demo/kew_gardens_ruined_arch/002.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QitaoZhao/DiffusionSfM/4bb08800721bdcf46b0c823586a2fab4702ff282/data/demo/kew_gardens_ruined_arch/002.jpeg -------------------------------------------------------------------------------- /data/demo/kew_gardens_ruined_arch/003.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QitaoZhao/DiffusionSfM/4bb08800721bdcf46b0c823586a2fab4702ff282/data/demo/kew_gardens_ruined_arch/003.jpeg -------------------------------------------------------------------------------- /data/demo/kotor_cathedral/001.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QitaoZhao/DiffusionSfM/4bb08800721bdcf46b0c823586a2fab4702ff282/data/demo/kotor_cathedral/001.jpeg -------------------------------------------------------------------------------- /data/demo/kotor_cathedral/002.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QitaoZhao/DiffusionSfM/4bb08800721bdcf46b0c823586a2fab4702ff282/data/demo/kotor_cathedral/002.jpeg -------------------------------------------------------------------------------- /data/demo/kotor_cathedral/003.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QitaoZhao/DiffusionSfM/4bb08800721bdcf46b0c823586a2fab4702ff282/data/demo/kotor_cathedral/003.jpeg -------------------------------------------------------------------------------- /data/demo/kotor_cathedral/004.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QitaoZhao/DiffusionSfM/4bb08800721bdcf46b0c823586a2fab4702ff282/data/demo/kotor_cathedral/004.jpeg -------------------------------------------------------------------------------- /data/demo/kotor_cathedral/005.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QitaoZhao/DiffusionSfM/4bb08800721bdcf46b0c823586a2fab4702ff282/data/demo/kotor_cathedral/005.jpeg -------------------------------------------------------------------------------- /data/demo/kotor_cathedral/006.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QitaoZhao/DiffusionSfM/4bb08800721bdcf46b0c823586a2fab4702ff282/data/demo/kotor_cathedral/006.jpeg -------------------------------------------------------------------------------- /diffusionsfm/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils.rays import cameras_to_rays, rays_to_cameras, Rays 2 | -------------------------------------------------------------------------------- /diffusionsfm/dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QitaoZhao/DiffusionSfM/4bb08800721bdcf46b0c823586a2fab4702ff282/diffusionsfm/dataset/__init__.py -------------------------------------------------------------------------------- /diffusionsfm/dataset/co3d_v2.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import json 3 | import os.path as osp 4 | import random 5 | import socket 6 | import time 7 | import torch 8 | import warnings 9 | 10 | import numpy as np 11 | from PIL import Image, ImageFile 12 | from tqdm import tqdm 13 | from pytorch3d.renderer import PerspectiveCameras 14 | from torch.utils.data import Dataset 15 | from torchvision import transforms 16 | import matplotlib.pyplot as plt 17 | from scipy import ndimage as nd 18 | 19 | from diffusionsfm.utils.distortion import distort_image 20 | 21 | 22 | HOSTNAME = socket.gethostname() 23 | 24 | CO3D_DIR = "../co3d_data" # update this 25 | CO3D_ANNOTATION_DIR = osp.join(CO3D_DIR, "co3d_annotations") 26 | CO3D_DIR = CO3D_DEPTH_DIR = osp.join(CO3D_DIR, "co3d") 27 | order_path = osp.join( 28 | CO3D_DIR, "co3d_v2_random_order_{sample_num}/{category}.json" 29 | ) 30 | 31 | 32 | TRAINING_CATEGORIES = [ 33 | "apple", 34 | "backpack", 35 | "banana", 36 | "baseballbat", 37 | "baseballglove", 38 | "bench", 39 | "bicycle", 40 | "bottle", 41 | "bowl", 42 | "broccoli", 43 | "cake", 44 | "car", 45 | "carrot", 46 | "cellphone", 47 | "chair", 48 | "cup", 49 | "donut", 50 | "hairdryer", 51 | "handbag", 52 | "hydrant", 53 | "keyboard", 54 | "laptop", 55 | "microwave", 56 | "motorcycle", 57 | "mouse", 58 | "orange", 59 | "parkingmeter", 60 | "pizza", 61 | "plant", 62 | "stopsign", 63 | "teddybear", 64 | "toaster", 65 | "toilet", 66 | "toybus", 67 | "toyplane", 68 | "toytrain", 69 | "toytruck", 70 | "tv", 71 | "umbrella", 72 | "vase", 73 | "wineglass", 74 | ] 75 | 76 | TEST_CATEGORIES = [ 77 | "ball", 78 | "book", 79 | "couch", 80 | "frisbee", 81 | "hotdog", 82 | "kite", 83 | "remote", 84 | "sandwich", 85 | "skateboard", 86 | "suitcase", 87 | ] 88 | 89 | assert len(TRAINING_CATEGORIES) + len(TEST_CATEGORIES) == 51 90 | 91 | Image.MAX_IMAGE_PIXELS = None 92 | ImageFile.LOAD_TRUNCATED_IMAGES = True 93 | 94 | 95 | def fill_depths(data, invalid=None): 96 | data_list = [] 97 | for i in range(data.shape[0]): 98 | data_item = data[i].numpy() 99 | # Invalid must be 1 where stuff is invalid, 0 where valid 100 | ind = nd.distance_transform_edt( 101 | invalid[i], return_distances=False, return_indices=True 102 | ) 103 | data_list.append(torch.tensor(data_item[tuple(ind)])) 104 | return torch.stack(data_list, dim=0) 105 | 106 | 107 | def full_scene_scale(batch): 108 | cameras = PerspectiveCameras(R=batch["R"], T=batch["T"], device="cuda") 109 | cc = cameras.get_camera_center() 110 | centroid = torch.mean(cc, dim=0) 111 | 112 | diffs = cc - centroid 113 | norms = torch.linalg.norm(diffs, dim=1) 114 | 115 | furthest_index = torch.argmax(norms).item() 116 | scale = norms[furthest_index].item() 117 | return scale 118 | 119 | 120 | def square_bbox(bbox, padding=0.0, astype=None, tight=False): 121 | """ 122 | Computes a square bounding box, with optional padding parameters. 123 | Args: 124 | bbox: Bounding box in xyxy format (4,). 125 | Returns: 126 | square_bbox in xyxy format (4,). 127 | """ 128 | if astype is None: 129 | astype = type(bbox[0]) 130 | bbox = np.array(bbox) 131 | center = (bbox[:2] + bbox[2:]) / 2 132 | extents = (bbox[2:] - bbox[:2]) / 2 133 | 134 | # No black bars if tight 135 | if tight: 136 | s = min(extents) * (1 + padding) 137 | else: 138 | s = max(extents) * (1 + padding) 139 | 140 | square_bbox = np.array( 141 | [center[0] - s, center[1] - s, center[0] + s, center[1] + s], 142 | dtype=astype, 143 | ) 144 | return square_bbox 145 | 146 | 147 | def unnormalize_image(image, return_numpy=True, return_int=True): 148 | if isinstance(image, torch.Tensor): 149 | image = image.detach().cpu().numpy() 150 | 151 | if image.ndim == 3: 152 | if image.shape[0] == 3: 153 | image = image[None, ...] 154 | elif image.shape[2] == 3: 155 | image = image.transpose(2, 0, 1)[None, ...] 156 | else: 157 | raise ValueError(f"Unexpected image shape: {image.shape}") 158 | elif image.ndim == 4: 159 | if image.shape[1] == 3: 160 | pass 161 | elif image.shape[3] == 3: 162 | image = image.transpose(0, 3, 1, 2) 163 | else: 164 | raise ValueError(f"Unexpected batch image shape: {image.shape}") 165 | else: 166 | raise ValueError(f"Unsupported input shape: {image.shape}") 167 | 168 | mean = np.array([0.485, 0.456, 0.406])[None, :, None, None] 169 | std = np.array([0.229, 0.224, 0.225])[None, :, None, None] 170 | image = image * std + mean 171 | 172 | if return_int: 173 | image = np.clip(image * 255.0, 0, 255).astype(np.uint8) 174 | else: 175 | image = np.clip(image, 0.0, 1.0) 176 | 177 | if image.shape[0] == 1: 178 | image = image[0] 179 | 180 | if return_numpy: 181 | return image 182 | else: 183 | return torch.from_numpy(image) 184 | 185 | 186 | def unnormalize_image_for_vis(image): 187 | assert len(image.shape) == 5 and image.shape[2] == 3 188 | mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 1, 3, 1, 1).to(image.device) 189 | std = torch.tensor([0.229, 0.224, 0.225]).view(1, 1, 3, 1, 1).to(image.device) 190 | image = image * std + mean 191 | image = (image - 0.5) / 0.5 192 | return image 193 | 194 | 195 | def _transform_intrinsic(image, bbox, principal_point, focal_length): 196 | # Rescale intrinsics to match bbox 197 | half_box = np.array([image.width, image.height]).astype(np.float32) / 2 198 | org_scale = min(half_box).astype(np.float32) 199 | 200 | # Pixel coordinates 201 | principal_point_px = half_box - (np.array(principal_point) * org_scale) 202 | focal_length_px = np.array(focal_length) * org_scale 203 | principal_point_px -= bbox[:2] 204 | new_bbox = (bbox[2:] - bbox[:2]) / 2 205 | new_scale = min(new_bbox) 206 | 207 | # NDC coordinates 208 | new_principal_ndc = (new_bbox - principal_point_px) / new_scale 209 | new_focal_ndc = focal_length_px / new_scale 210 | 211 | principal_point = torch.tensor(new_principal_ndc.astype(np.float32)) 212 | focal_length = torch.tensor(new_focal_ndc.astype(np.float32)) 213 | 214 | return principal_point, focal_length 215 | 216 | 217 | def construct_camera_from_batch(batch, device): 218 | if isinstance(device, int): 219 | device = f"cuda:{device}" 220 | 221 | return PerspectiveCameras( 222 | R=batch["R"].reshape(-1, 3, 3), 223 | T=batch["T"].reshape(-1, 3), 224 | focal_length=batch["focal_lengths"].reshape(-1, 2), 225 | principal_point=batch["principal_points"].reshape(-1, 2), 226 | image_size=batch["image_sizes"].reshape(-1, 2), 227 | device=device, 228 | ) 229 | 230 | 231 | def save_batch_images(images, fname): 232 | cmap = plt.get_cmap("hsv") 233 | num_frames = len(images) 234 | num_rows = len(images) 235 | num_cols = 4 236 | figsize = (num_cols * 2, num_rows * 2) 237 | fig, axs = plt.subplots(num_rows, num_cols, figsize=figsize) 238 | axs = axs.flatten() 239 | for i in range(num_rows): 240 | for j in range(4): 241 | if i < num_frames: 242 | axs[i * 4 + j].imshow(unnormalize_image(images[i][j])) 243 | for s in ["bottom", "top", "left", "right"]: 244 | axs[i * 4 + j].spines[s].set_color(cmap(i / (num_frames))) 245 | axs[i * 4 + j].spines[s].set_linewidth(5) 246 | axs[i * 4 + j].set_xticks([]) 247 | axs[i * 4 + j].set_yticks([]) 248 | else: 249 | axs[i * 4 + j].axis("off") 250 | plt.tight_layout() 251 | plt.savefig(fname) 252 | 253 | 254 | def jitter_bbox( 255 | square_bbox, 256 | jitter_scale=(1.1, 1.2), 257 | jitter_trans=(-0.07, 0.07), 258 | direction_from_size=None, 259 | ): 260 | 261 | square_bbox = np.array(square_bbox.astype(float)) 262 | s = np.random.uniform(jitter_scale[0], jitter_scale[1]) 263 | 264 | # Jitter only one dimension if center cropping 265 | tx, ty = np.random.uniform(jitter_trans[0], jitter_trans[1], size=2) 266 | if direction_from_size is not None: 267 | if direction_from_size[0] > direction_from_size[1]: 268 | tx = 0 269 | else: 270 | ty = 0 271 | 272 | side_length = square_bbox[2] - square_bbox[0] 273 | center = (square_bbox[:2] + square_bbox[2:]) / 2 + np.array([tx, ty]) * side_length 274 | extent = side_length / 2 * s 275 | ul = center - extent 276 | lr = ul + 2 * extent 277 | return np.concatenate((ul, lr)) 278 | 279 | 280 | class Co3dDataset(Dataset): 281 | def __init__( 282 | self, 283 | category=("all_train",), 284 | split="train", 285 | transform=None, 286 | num_images=2, 287 | img_size=224, 288 | mask_images=False, 289 | crop_images=True, 290 | co3d_dir=None, 291 | co3d_annotation_dir=None, 292 | precropped_images=False, 293 | apply_augmentation=True, 294 | normalize_cameras=True, 295 | no_images=False, 296 | sample_num=None, 297 | seed=0, 298 | load_extra_cameras=False, 299 | distort_image=False, 300 | load_depths=False, 301 | center_crop=False, 302 | depth_size=256, 303 | mask_holes=False, 304 | object_mask=True, 305 | ): 306 | """ 307 | Args: 308 | num_images: Number of images in each batch. 309 | perspective_correction (str): 310 | "none": No perspective correction. 311 | "warp": Warp the image and label. 312 | "label_only": Correct the label only. 313 | """ 314 | start_time = time.time() 315 | 316 | self.category = category 317 | self.split = split 318 | self.transform = transform 319 | self.num_images = num_images 320 | self.img_size = img_size 321 | self.mask_images = mask_images 322 | self.crop_images = crop_images 323 | self.precropped_images = precropped_images 324 | self.apply_augmentation = apply_augmentation 325 | self.normalize_cameras = normalize_cameras 326 | self.no_images = no_images 327 | self.sample_num = sample_num 328 | self.load_extra_cameras = load_extra_cameras 329 | self.distort = distort_image 330 | self.load_depths = load_depths 331 | self.center_crop = center_crop 332 | self.depth_size = depth_size 333 | self.mask_holes = mask_holes 334 | self.object_mask = object_mask 335 | 336 | if self.apply_augmentation: 337 | if self.center_crop: 338 | self.jitter_scale = (0.8, 1.1) 339 | self.jitter_trans = (0.0, 0.0) 340 | else: 341 | self.jitter_scale = (1.1, 1.2) 342 | self.jitter_trans = (-0.07, 0.07) 343 | else: 344 | # Note if trained with apply_augmentation, we should still use 345 | # apply_augmentation at test time. 346 | self.jitter_scale = (1, 1) 347 | self.jitter_trans = (0.0, 0.0) 348 | 349 | if self.distort: 350 | self.k1_max = 1.0 351 | self.k2_max = 1.0 352 | 353 | if co3d_dir is not None: 354 | self.co3d_dir = co3d_dir 355 | self.co3d_annotation_dir = co3d_annotation_dir 356 | else: 357 | self.co3d_dir = CO3D_DIR 358 | self.co3d_annotation_dir = CO3D_ANNOTATION_DIR 359 | self.co3d_depth_dir = CO3D_DEPTH_DIR 360 | 361 | if isinstance(self.category, str): 362 | self.category = [self.category] 363 | 364 | if "all_train" in self.category: 365 | self.category = TRAINING_CATEGORIES 366 | if "all_test" in self.category: 367 | self.category = TEST_CATEGORIES 368 | if "full" in self.category: 369 | self.category = TRAINING_CATEGORIES + TEST_CATEGORIES 370 | self.category = sorted(self.category) 371 | self.is_single_category = len(self.category) == 1 372 | 373 | # Fixing seed 374 | torch.manual_seed(seed) 375 | random.seed(seed) 376 | np.random.seed(seed) 377 | 378 | print(f"Co3d ({split}):") 379 | 380 | self.low_quality_translations = [ 381 | "411_55952_107659", 382 | "427_59915_115716", 383 | "435_61970_121848", 384 | "112_13265_22828", 385 | "110_13069_25642", 386 | "165_18080_34378", 387 | "368_39891_78502", 388 | "391_47029_93665", 389 | "20_695_1450", 390 | "135_15556_31096", 391 | "417_57572_110680", 392 | ] # Initialized with sequences with poor depth masks 393 | self.rotations = {} 394 | self.category_map = {} 395 | for c in tqdm(self.category): 396 | annotation_file = osp.join( 397 | self.co3d_annotation_dir, f"{c}_{self.split}.jgz" 398 | ) 399 | with gzip.open(annotation_file, "r") as fin: 400 | annotation = json.loads(fin.read()) 401 | 402 | counter = 0 403 | for seq_name, seq_data in annotation.items(): 404 | counter += 1 405 | if len(seq_data) < self.num_images: 406 | continue 407 | 408 | filtered_data = [] 409 | self.category_map[seq_name] = c 410 | bad_seq = False 411 | for data in seq_data: 412 | # Make sure translations are not ridiculous and rotations are valid 413 | det = np.linalg.det(data["R"]) 414 | if (np.abs(data["T"]) > 1e5).any() or det < 0.99 or det > 1.01: 415 | bad_seq = True 416 | self.low_quality_translations.append(seq_name) 417 | break 418 | 419 | # Ignore all unnecessary information. 420 | filtered_data.append( 421 | { 422 | "filepath": data["filepath"], 423 | "bbox": data["bbox"], 424 | "R": data["R"], 425 | "T": data["T"], 426 | "focal_length": data["focal_length"], 427 | "principal_point": data["principal_point"], 428 | }, 429 | ) 430 | 431 | if not bad_seq: 432 | self.rotations[seq_name] = filtered_data 433 | 434 | self.sequence_list = list(self.rotations.keys()) 435 | 436 | IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) 437 | IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) 438 | 439 | if self.transform is None: 440 | self.transform = transforms.Compose( 441 | [ 442 | transforms.ToTensor(), 443 | transforms.Resize(self.img_size, antialias=True), 444 | transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), 445 | ] 446 | ) 447 | 448 | self.transform_depth = transforms.Compose( 449 | [ 450 | transforms.Resize( 451 | self.depth_size, 452 | antialias=False, 453 | interpolation=transforms.InterpolationMode.NEAREST_EXACT, 454 | ), 455 | ] 456 | ) 457 | 458 | print( 459 | f"Low quality translation sequences, not used: {self.low_quality_translations}" 460 | ) 461 | print(f"Data size: {len(self)}") 462 | print(f"Data loading took {(time.time()-start_time)} seconds.") 463 | 464 | def __len__(self): 465 | return len(self.sequence_list) 466 | 467 | def __getitem__(self, index): 468 | num_to_load = self.num_images if not self.load_extra_cameras else 8 469 | 470 | sequence_name = self.sequence_list[index % len(self.sequence_list)] 471 | metadata = self.rotations[sequence_name] 472 | 473 | if self.sample_num is not None: 474 | with open( 475 | order_path.format(sample_num=self.sample_num, category=self.category[0]) 476 | ) as f: 477 | order = json.load(f) 478 | ids = order[sequence_name][:num_to_load] 479 | else: 480 | replace = len(metadata) < 8 481 | ids = np.random.choice(len(metadata), num_to_load, replace=replace) 482 | 483 | return self.get_data(index=index, ids=ids, num_valid_frames=num_to_load) 484 | 485 | def _get_scene_scale(self, sequence_name): 486 | n = len(self.rotations[sequence_name]) 487 | 488 | R = torch.zeros(n, 3, 3) 489 | T = torch.zeros(n, 3) 490 | 491 | for i, ann in enumerate(self.rotations[sequence_name]): 492 | R[i, ...] = torch.tensor(self.rotations[sequence_name][i]["R"]) 493 | T[i, ...] = torch.tensor(self.rotations[sequence_name][i]["T"]) 494 | 495 | cameras = PerspectiveCameras(R=R, T=T) 496 | cc = cameras.get_camera_center() 497 | centeroid = torch.mean(cc, dim=0) 498 | diff = cc - centeroid 499 | 500 | norm = torch.norm(diff, dim=1) 501 | scale = torch.max(norm).item() 502 | 503 | return scale 504 | 505 | def _crop_image(self, image, bbox): 506 | image_crop = transforms.functional.crop( 507 | image, 508 | top=bbox[1], 509 | left=bbox[0], 510 | height=bbox[3] - bbox[1], 511 | width=bbox[2] - bbox[0], 512 | ) 513 | return image_crop 514 | 515 | def _transform_intrinsic(self, image, bbox, principal_point, focal_length): 516 | half_box = np.array([image.width, image.height]).astype(np.float32) / 2 517 | org_scale = min(half_box).astype(np.float32) 518 | 519 | # Pixel coordinates 520 | principal_point_px = half_box - (np.array(principal_point) * org_scale) 521 | focal_length_px = np.array(focal_length) * org_scale 522 | principal_point_px -= bbox[:2] 523 | new_bbox = (bbox[2:] - bbox[:2]) / 2 524 | new_scale = min(new_bbox) 525 | 526 | # NDC coordinates 527 | new_principal_ndc = (new_bbox - principal_point_px) / new_scale 528 | new_focal_ndc = focal_length_px / new_scale 529 | 530 | return new_principal_ndc.astype(np.float32), new_focal_ndc.astype(np.float32) 531 | 532 | def get_data( 533 | self, 534 | index=None, 535 | sequence_name=None, 536 | ids=(0, 1), 537 | no_images=False, 538 | num_valid_frames=None, 539 | load_using_order=None, 540 | ): 541 | if load_using_order is not None: 542 | with open( 543 | order_path.format(sample_num=self.sample_num, category=self.category[0]) 544 | ) as f: 545 | order = json.load(f) 546 | ids = order[sequence_name][:load_using_order] 547 | 548 | if sequence_name is None: 549 | index = index % len(self.sequence_list) 550 | sequence_name = self.sequence_list[index] 551 | metadata = self.rotations[sequence_name] 552 | category = self.category_map[sequence_name] 553 | 554 | # Read image & camera information from annotations 555 | annos = [metadata[i] for i in ids] 556 | images = [] 557 | image_sizes = [] 558 | PP = [] 559 | FL = [] 560 | crop_parameters = [] 561 | filenames = [] 562 | distortion_parameters = [] 563 | depths = [] 564 | depth_masks = [] 565 | object_masks = [] 566 | dino_images = [] 567 | for anno in annos: 568 | filepath = anno["filepath"] 569 | 570 | if not no_images: 571 | image = Image.open(osp.join(self.co3d_dir, filepath)).convert("RGB") 572 | image_size = image.size 573 | 574 | # Optionally mask images with black background 575 | if self.mask_images: 576 | black_image = Image.new("RGB", image_size, (0, 0, 0)) 577 | mask_name = osp.basename(filepath.replace(".jpg", ".png")) 578 | 579 | mask_path = osp.join( 580 | self.co3d_dir, category, sequence_name, "masks", mask_name 581 | ) 582 | mask = Image.open(mask_path).convert("L") 583 | 584 | if mask.size != image_size: 585 | mask = mask.resize(image_size) 586 | mask = Image.fromarray(np.array(mask) > 125) 587 | image = Image.composite(image, black_image, mask) 588 | 589 | if self.object_mask: 590 | mask_name = osp.basename(filepath.replace(".jpg", ".png")) 591 | mask_path = osp.join( 592 | self.co3d_dir, category, sequence_name, "masks", mask_name 593 | ) 594 | mask = Image.open(mask_path).convert("L") 595 | 596 | if mask.size != image_size: 597 | mask = mask.resize(image_size) 598 | mask = torch.from_numpy(np.array(mask) > 125) 599 | 600 | # Determine crop, Resnet wants square images 601 | bbox = np.array(anno["bbox"]) 602 | good_bbox = ((bbox[2:] - bbox[:2]) > 30).all() 603 | bbox = ( 604 | anno["bbox"] 605 | if not self.center_crop and good_bbox 606 | else [0, 0, image.width, image.height] 607 | ) 608 | 609 | # Distort image and bbox if desired 610 | if self.distort: 611 | k1 = random.uniform(0, self.k1_max) 612 | k2 = random.uniform(0, self.k2_max) 613 | 614 | try: 615 | image, bbox = distort_image( 616 | image, np.array(bbox), k1, k2, modify_bbox=True 617 | ) 618 | 619 | except: 620 | print("INFO:") 621 | print(sequence_name) 622 | print(index) 623 | print(ids) 624 | print(k1) 625 | print(k2) 626 | 627 | distortion_parameters.append(torch.FloatTensor([k1, k2])) 628 | 629 | bbox = square_bbox(np.array(bbox), tight=self.center_crop) 630 | if self.apply_augmentation: 631 | bbox = jitter_bbox( 632 | bbox, 633 | jitter_scale=self.jitter_scale, 634 | jitter_trans=self.jitter_trans, 635 | direction_from_size=image.size if self.center_crop else None, 636 | ) 637 | bbox = np.around(bbox).astype(int) 638 | 639 | # Crop parameters 640 | crop_center = (bbox[:2] + bbox[2:]) / 2 641 | principal_point = torch.tensor(anno["principal_point"]) 642 | focal_length = torch.tensor(anno["focal_length"]) 643 | 644 | # convert crop center to correspond to a "square" image 645 | width, height = image.size 646 | length = max(width, height) 647 | s = length / min(width, height) 648 | crop_center = crop_center + (length - np.array([width, height])) / 2 649 | 650 | # convert to NDC 651 | cc = s - 2 * s * crop_center / length 652 | crop_width = 2 * s * (bbox[2] - bbox[0]) / length 653 | crop_params = torch.tensor([-cc[0], -cc[1], crop_width, s]) 654 | 655 | # Crop and normalize image 656 | if not self.precropped_images: 657 | image = self._crop_image(image, bbox) 658 | 659 | try: 660 | image = self.transform(image) 661 | except: 662 | print("INFO:") 663 | print(sequence_name) 664 | print(index) 665 | print(ids) 666 | print(k1) 667 | print(k2) 668 | 669 | images.append(image[:, : self.img_size, : self.img_size]) 670 | crop_parameters.append(crop_params) 671 | 672 | if self.load_depths: 673 | # Open depth map 674 | depth_name = osp.basename( 675 | filepath.replace(".jpg", ".jpg.geometric.png") 676 | ) 677 | depth_path = osp.join( 678 | self.co3d_depth_dir, 679 | category, 680 | sequence_name, 681 | "depths", 682 | depth_name, 683 | ) 684 | depth_pil = Image.open(depth_path) 685 | 686 | # 16 bit float type casting 687 | depth = torch.tensor( 688 | np.frombuffer( 689 | np.array(depth_pil, dtype=np.uint16), dtype=np.float16 690 | ) 691 | .astype(np.float32) 692 | .reshape((depth_pil.size[1], depth_pil.size[0])) 693 | ) 694 | 695 | # Crop and resize as with images 696 | if depth_pil.size != image_size: 697 | # bbox may have the wrong scale 698 | bbox = depth_pil.size[0] * bbox / image_size[0] 699 | 700 | if self.object_mask: 701 | assert mask.shape == depth.shape 702 | 703 | bbox = np.around(bbox).astype(int) 704 | depth = self._crop_image(depth, bbox) 705 | 706 | # Resize 707 | depth = self.transform_depth(depth.unsqueeze(0))[ 708 | 0, : self.depth_size, : self.depth_size 709 | ] 710 | depths.append(depth) 711 | 712 | if self.object_mask: 713 | mask = self._crop_image(mask, bbox) 714 | mask = self.transform_depth(mask.unsqueeze(0))[ 715 | 0, : self.depth_size, : self.depth_size 716 | ] 717 | object_masks.append(mask) 718 | 719 | PP.append(principal_point) 720 | FL.append(focal_length) 721 | image_sizes.append(torch.tensor([self.img_size, self.img_size])) 722 | filenames.append(filepath) 723 | 724 | if not no_images: 725 | if self.load_depths: 726 | depths = torch.stack(depths) 727 | 728 | depth_masks = torch.logical_or(depths <= 0, depths.isinf()) 729 | depth_masks = (~depth_masks).long() 730 | 731 | if self.object_mask: 732 | object_masks = torch.stack(object_masks, dim=0) 733 | 734 | if self.mask_holes: 735 | depths = fill_depths(depths, depth_masks == 0) 736 | 737 | # Sometimes mask_holes misses stuff 738 | new_masks = torch.logical_or(depths <= 0, depths.isinf()) 739 | new_masks = (~new_masks).long() 740 | depths[new_masks == 0] = -1 741 | 742 | assert torch.logical_or(depths > 0, depths == -1).all() 743 | assert not (depths.isinf()).any() 744 | assert not (depths.isnan()).any() 745 | 746 | if self.load_extra_cameras: 747 | # Remove the extra loaded image, for saving space 748 | images = images[: self.num_images] 749 | 750 | if self.distort: 751 | distortion_parameters = torch.stack(distortion_parameters) 752 | 753 | images = torch.stack(images) 754 | crop_parameters = torch.stack(crop_parameters) 755 | focal_lengths = torch.stack(FL) 756 | principal_points = torch.stack(PP) 757 | image_sizes = torch.stack(image_sizes) 758 | else: 759 | images = None 760 | crop_parameters = None 761 | distortion_parameters = None 762 | focal_lengths = [] 763 | principal_points = [] 764 | image_sizes = [] 765 | 766 | # Assemble batch info to send back 767 | R = torch.stack([torch.tensor(anno["R"]) for anno in annos]) 768 | T = torch.stack([torch.tensor(anno["T"]) for anno in annos]) 769 | 770 | batch = { 771 | "model_id": sequence_name, 772 | "category": category, 773 | "n": len(metadata), 774 | "num_valid_frames": num_valid_frames, 775 | "ind": torch.tensor(ids), 776 | "image": images, 777 | "depth": depths, 778 | "depth_masks": depth_masks, 779 | "object_masks": object_masks, 780 | "R": R, 781 | "T": T, 782 | "focal_length": focal_lengths, 783 | "principal_point": principal_points, 784 | "image_size": image_sizes, 785 | "crop_parameters": crop_parameters, 786 | "distortion_parameters": torch.zeros(4), 787 | "filename": filenames, 788 | "category": category, 789 | "dataset": "co3d", 790 | } 791 | 792 | return batch 793 | -------------------------------------------------------------------------------- /diffusionsfm/dataset/custom.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | 6 | from PIL import Image, ImageOps 7 | from torch.utils.data import Dataset 8 | from torchvision import transforms 9 | 10 | from diffusionsfm.dataset.co3d_v2 import square_bbox 11 | 12 | 13 | class CustomDataset(Dataset): 14 | def __init__( 15 | self, 16 | image_list, 17 | ): 18 | self.images = [] 19 | 20 | for image_path in sorted(image_list): 21 | img = Image.open(image_path) 22 | img = ImageOps.exif_transpose(img).convert("RGB") # Apply EXIF rotation 23 | self.images.append(img) 24 | 25 | self.n = len(self.images) 26 | self.jitter_scale = [1, 1] 27 | self.jitter_trans = [0, 0] 28 | self.transform = transforms.Compose( 29 | [ 30 | transforms.ToTensor(), 31 | transforms.Resize(224), 32 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 33 | ] 34 | ) 35 | self.transform_for_vis = transforms.Compose( 36 | [ 37 | transforms.Resize(224), 38 | ] 39 | ) 40 | 41 | def __len__(self): 42 | return 1 43 | 44 | def _crop_image(self, image, bbox, white_bg=False): 45 | if white_bg: 46 | # Only support PIL Images 47 | image_crop = Image.new( 48 | "RGB", (bbox[2] - bbox[0], bbox[3] - bbox[1]), (255, 255, 255) 49 | ) 50 | image_crop.paste(image, (-bbox[0], -bbox[1])) 51 | else: 52 | image_crop = transforms.functional.crop( 53 | image, 54 | top=bbox[1], 55 | left=bbox[0], 56 | height=bbox[3] - bbox[1], 57 | width=bbox[2] - bbox[0], 58 | ) 59 | return image_crop 60 | 61 | def __getitem__(self): 62 | return self.get_data() 63 | 64 | def get_data(self): 65 | cmap = plt.get_cmap("hsv") 66 | ids = [i for i in range(len(self.images))] 67 | images = [self.images[i] for i in ids] 68 | images_transformed = [] 69 | images_for_vis = [] 70 | crop_parameters = [] 71 | 72 | for i, image in enumerate(images): 73 | bbox = np.array([0, 0, image.width, image.height]) 74 | bbox = square_bbox(bbox, tight=True) 75 | bbox = np.around(bbox).astype(int) 76 | image = self._crop_image(image, bbox) 77 | images_transformed.append(self.transform(image)) 78 | image_for_vis = self.transform_for_vis(image) 79 | color_float = cmap(i / len(images)) 80 | color_rgb = tuple(int(255 * c) for c in color_float[:3]) 81 | image_for_vis = ImageOps.expand(image_for_vis, border=3, fill=color_rgb) 82 | images_for_vis.append(image_for_vis) 83 | 84 | width, height = image.size 85 | length = max(width, height) 86 | s = length / min(width, height) 87 | crop_center = (bbox[:2] + bbox[2:]) / 2 88 | crop_center = crop_center + (length - np.array([width, height])) / 2 89 | # convert to NDC 90 | cc = s - 2 * s * crop_center / length 91 | crop_width = 2 * s * (bbox[2] - bbox[0]) / length 92 | crop_params = torch.tensor([-cc[0], -cc[1], crop_width, s]) 93 | 94 | crop_parameters.append(crop_params) 95 | images = images_transformed 96 | 97 | batch = {} 98 | batch["image"] = torch.stack(images) 99 | batch["image_for_vis"] = images_for_vis 100 | batch["n"] = len(images) 101 | batch["ind"] = torch.tensor(ids), 102 | batch["crop_parameters"] = torch.stack(crop_parameters) 103 | batch["distortion_parameters"] = torch.zeros(4) 104 | 105 | return batch 106 | -------------------------------------------------------------------------------- /diffusionsfm/eval/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QitaoZhao/DiffusionSfM/4bb08800721bdcf46b0c823586a2fab4702ff282/diffusionsfm/eval/__init__.py -------------------------------------------------------------------------------- /diffusionsfm/eval/eval_category.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import torch 4 | import torchvision 5 | import numpy as np 6 | from tqdm.auto import tqdm 7 | 8 | from diffusionsfm.dataset.co3d_v2 import ( 9 | Co3dDataset, 10 | full_scene_scale, 11 | ) 12 | from pytorch3d.renderer import PerspectiveCameras 13 | from diffusionsfm.utils.visualization import filter_and_align_point_clouds 14 | from diffusionsfm.inference.load_model import load_model 15 | from diffusionsfm.inference.predict import predict_cameras 16 | from diffusionsfm.utils.geometry import ( 17 | compute_angular_error_batch, 18 | get_error, 19 | n_to_np_rotations, 20 | ) 21 | from diffusionsfm.utils.slurm import init_slurm_signals_if_slurm 22 | from diffusionsfm.utils.rays import cameras_to_rays 23 | from diffusionsfm.utils.rays import normalize_cameras_batch 24 | 25 | 26 | @torch.no_grad() 27 | def evaluate( 28 | cfg, 29 | model, 30 | dataset, 31 | num_images, 32 | device, 33 | use_pbar=True, 34 | calculate_intrinsics=True, 35 | additional_timesteps=(), 36 | num_evaluate=None, 37 | max_num_images=None, 38 | mode=None, 39 | metrics=True, 40 | load_depth=True, 41 | ): 42 | if cfg.training.get("dpt_head", False): 43 | H_in = W_in = 224 44 | H_out = W_out = cfg.training.full_num_patches_y 45 | else: 46 | H_in = H_out = cfg.model.num_patches_x 47 | W_in = W_out = cfg.model.num_patches_y 48 | 49 | results = {} 50 | instances = np.arange(0, len(dataset)) if num_evaluate is None else np.linspace(0, len(dataset) - 1, num_evaluate, endpoint=True, dtype=int) 51 | instances = tqdm(instances) if use_pbar else instances 52 | 53 | for counter, idx in enumerate(instances): 54 | batch = dataset[idx] 55 | instance = batch["model_id"] 56 | images = batch["image"].to(device) 57 | focal_length = batch["focal_length"].to(device)[:num_images] 58 | R = batch["R"].to(device)[:num_images] 59 | T = batch["T"].to(device)[:num_images] 60 | crop_parameters = batch["crop_parameters"].to(device)[:num_images] 61 | 62 | if load_depth: 63 | depths = batch["depth"].to(device)[:num_images] 64 | depth_masks = batch["depth_masks"].to(device)[:num_images] 65 | try: 66 | object_masks = batch["object_masks"].to(device)[:num_images] 67 | except KeyError: 68 | object_masks = depth_masks.clone() 69 | 70 | # Normalize cameras and scale depths for output resolution 71 | cameras_gt = PerspectiveCameras( 72 | R=R, T=T, focal_length=focal_length, device=device 73 | ) 74 | cameras_gt, _, _ = normalize_cameras_batch( 75 | [cameras_gt], 76 | first_cam_mediod=cfg.training.first_cam_mediod, 77 | normalize_first_camera=cfg.training.normalize_first_camera, 78 | depths=depths.unsqueeze(0), 79 | crop_parameters=crop_parameters.unsqueeze(0), 80 | num_patches_x=H_in, 81 | num_patches_y=W_in, 82 | return_scales=True, 83 | ) 84 | cameras_gt = cameras_gt[0] 85 | 86 | gt_rays = cameras_to_rays( 87 | cameras=cameras_gt, 88 | num_patches_x=H_in, 89 | num_patches_y=W_in, 90 | crop_parameters=crop_parameters, 91 | depths=depths, 92 | mode=mode, 93 | ) 94 | gt_points = gt_rays.get_segments().view(num_images, -1, 3) 95 | 96 | resize = torchvision.transforms.Resize( 97 | 224, 98 | antialias=False, 99 | interpolation=torchvision.transforms.InterpolationMode.NEAREST_EXACT, 100 | ) 101 | else: 102 | cameras_gt = PerspectiveCameras( 103 | R=R, T=T, focal_length=focal_length, device=device 104 | ) 105 | 106 | pred_cameras, additional_cams = predict_cameras( 107 | model, 108 | images, 109 | device, 110 | crop_parameters=crop_parameters, 111 | num_patches_x=H_out, 112 | num_patches_y=W_out, 113 | max_num_images=max_num_images, 114 | additional_timesteps=additional_timesteps, 115 | calculate_intrinsics=calculate_intrinsics, 116 | mode=mode, 117 | return_rays=True, 118 | use_homogeneous=cfg.model.get("use_homogeneous", False), 119 | ) 120 | cameras_to_evaluate = additional_cams + [pred_cameras] 121 | 122 | all_cams_batch = dataset.get_data( 123 | sequence_name=instance, ids=np.arange(0, batch["n"]), no_images=True 124 | ) 125 | gt_scene_scale = full_scene_scale(all_cams_batch) 126 | R_gt = R 127 | T_gt = T 128 | 129 | errors = [] 130 | for _, (camera, pred_rays) in enumerate(cameras_to_evaluate): 131 | R_pred = camera.R 132 | T_pred = camera.T 133 | f_pred = camera.focal_length 134 | 135 | R_pred_rel = n_to_np_rotations(num_images, R_pred).cpu().numpy() 136 | R_gt_rel = n_to_np_rotations(num_images, batch["R"]).cpu().numpy() 137 | R_error = compute_angular_error_batch(R_pred_rel, R_gt_rel) 138 | 139 | CC_error, _ = get_error(True, R_pred, T_pred, R_gt, T_gt, gt_scene_scale) 140 | 141 | if load_depth and metrics: 142 | # Evaluate outputs at the same resolution as DUSt3R 143 | pred_points = pred_rays.get_segments().view(num_images, H_out, H_out, 3) 144 | pred_points = pred_points.permute(0, 3, 1, 2) 145 | pred_points = resize(pred_points).permute(0, 2, 3, 1).view(num_images, H_in*W_in, 3) 146 | 147 | ( 148 | _, 149 | _, 150 | _, 151 | _, 152 | metric_values, 153 | ) = filter_and_align_point_clouds( 154 | num_images, 155 | gt_points, 156 | pred_points, 157 | depth_masks, 158 | depth_masks, 159 | images, 160 | metrics=metrics, 161 | num_patches_x=H_in, 162 | ) 163 | 164 | ( 165 | _, 166 | _, 167 | _, 168 | _, 169 | object_metric_values, 170 | ) = filter_and_align_point_clouds( 171 | num_images, 172 | gt_points, 173 | pred_points, 174 | depth_masks * object_masks, 175 | depth_masks * object_masks, 176 | images, 177 | metrics=metrics, 178 | num_patches_x=H_in, 179 | ) 180 | 181 | result = { 182 | "R_pred": R_pred.detach().cpu().numpy().tolist(), 183 | "T_pred": T_pred.detach().cpu().numpy().tolist(), 184 | "f_pred": f_pred.detach().cpu().numpy().tolist(), 185 | "R_gt": R_gt.detach().cpu().numpy().tolist(), 186 | "T_gt": T_gt.detach().cpu().numpy().tolist(), 187 | "f_gt": focal_length.detach().cpu().numpy().tolist(), 188 | "scene_scale": gt_scene_scale, 189 | "R_error": R_error.tolist(), 190 | "CC_error": CC_error, 191 | } 192 | 193 | if load_depth and metrics: 194 | result["CD"] = metric_values[1] 195 | result["CD_Object"] = object_metric_values[1] 196 | else: 197 | result["CD"] = 0 198 | result["CD_Object"] = 0 199 | 200 | errors.append(result) 201 | 202 | results[instance] = errors 203 | 204 | if counter == len(dataset) - 1: 205 | break 206 | return results 207 | 208 | 209 | def save_results( 210 | output_dir, 211 | checkpoint=800_000, 212 | category="hydrant", 213 | num_images=None, 214 | calculate_additional_timesteps=True, 215 | calculate_intrinsics=True, 216 | split="test", 217 | force=False, 218 | sample_num=1, 219 | max_num_images=None, 220 | dataset="co3d", 221 | ): 222 | init_slurm_signals_if_slurm() 223 | os.umask(000) # Default to 777 permissions 224 | eval_path = os.path.join( 225 | output_dir, 226 | f"eval_{dataset}", 227 | f"{category}_{num_images}_{sample_num}_ckpt{checkpoint}.json", 228 | ) 229 | 230 | if os.path.exists(eval_path) and not force: 231 | print(f"File {eval_path} already exists. Skipping.") 232 | return 233 | 234 | if num_images is not None and num_images > 8: 235 | custom_keys = {"model.num_images": num_images} 236 | ignore_keys = ["pos_table"] 237 | else: 238 | custom_keys = None 239 | ignore_keys = [] 240 | 241 | device = torch.device("cuda") 242 | model, cfg = load_model( 243 | output_dir, 244 | checkpoint=checkpoint, 245 | device=device, 246 | custom_keys=custom_keys, 247 | ignore_keys=ignore_keys, 248 | ) 249 | if num_images is None: 250 | num_images = cfg.dataset.num_images 251 | 252 | if cfg.training.dpt_head: 253 | # Evaluate outputs at the same resolution as DUSt3R 254 | depth_size = 224 255 | else: 256 | depth_size = cfg.model.num_patches_x 257 | 258 | dataset = Co3dDataset( 259 | category=category, 260 | split=split, 261 | num_images=num_images, 262 | apply_augmentation=False, 263 | sample_num=None if split == "train" else sample_num, 264 | use_global_intrinsics=cfg.dataset.use_global_intrinsics, 265 | load_depths=True, 266 | center_crop=True, 267 | depth_size=depth_size, 268 | mask_holes=not cfg.training.regression, 269 | img_size=256 if cfg.model.unet_diffuser else 224, 270 | ) 271 | print(f"Category {category} {len(dataset)}") 272 | 273 | if calculate_additional_timesteps: 274 | additional_timesteps = [0, 10, 20, 30, 40, 50, 60, 70, 80, 90] 275 | else: 276 | additional_timesteps = [] 277 | 278 | results = evaluate( 279 | cfg=cfg, 280 | model=model, 281 | dataset=dataset, 282 | num_images=num_images, 283 | device=device, 284 | calculate_intrinsics=calculate_intrinsics, 285 | additional_timesteps=additional_timesteps, 286 | max_num_images=max_num_images, 287 | mode="segment", 288 | ) 289 | 290 | os.makedirs(os.path.dirname(eval_path), exist_ok=True) 291 | with open(eval_path, "w") as f: 292 | json.dump(results, f) -------------------------------------------------------------------------------- /diffusionsfm/eval/eval_jobs.py: -------------------------------------------------------------------------------- 1 | """ 2 | python -m diffusionsfm.eval.eval_jobs --eval_path output/multi_diffusionsfm_dense --use_submitit 3 | """ 4 | 5 | import os 6 | import json 7 | import submitit 8 | import argparse 9 | import itertools 10 | from glob import glob 11 | 12 | import numpy as np 13 | from tqdm.auto import tqdm 14 | 15 | from diffusionsfm.dataset.co3d_v2 import TEST_CATEGORIES, TRAINING_CATEGORIES 16 | from diffusionsfm.eval.eval_category import save_results 17 | from diffusionsfm.utils.slurm import submitit_job_watcher 18 | 19 | 20 | def evaluate_diffusionsfm(eval_path, use_submitit, mode): 21 | JOB_PARAMS = { 22 | "output_dir": [eval_path], 23 | "checkpoint": [800_000], 24 | "num_images": [2, 3, 4, 5, 6, 7, 8], 25 | "sample_num": [0, 1, 2, 3, 4], 26 | "category": TEST_CATEGORIES, # TRAINING_CATEGORIES + TEST_CATEGORIES, 27 | "calculate_additional_timesteps": [True], 28 | } 29 | if mode == "test": 30 | JOB_PARAMS["category"] = TEST_CATEGORIES 31 | elif mode == "train1": 32 | JOB_PARAMS["category"] = TRAINING_CATEGORIES[:len(TRAINING_CATEGORIES) // 2] 33 | elif mode == "train2": 34 | JOB_PARAMS["category"] = TRAINING_CATEGORIES[len(TRAINING_CATEGORIES) // 2:] 35 | keys, values = zip(*JOB_PARAMS.items()) 36 | job_configs = [dict(zip(keys, p)) for p in itertools.product(*values)] 37 | 38 | if use_submitit: 39 | log_output = "./slurm_logs" 40 | executor = submitit.AutoExecutor( 41 | cluster=None, folder=log_output, slurm_max_num_timeout=10 42 | ) 43 | # Use your own parameters 44 | executor.update_parameters( 45 | slurm_additional_parameters={ 46 | "nodes": 1, 47 | "cpus-per-task": 5, 48 | "gpus": 1, 49 | "time": "6:00:00", 50 | "partition": "all", 51 | "exclude": "grogu-1-9, grogu-1-14," 52 | } 53 | ) 54 | jobs = [] 55 | with executor.batch(): 56 | # This context manager submits all jobs at once at the end. 57 | for params in job_configs: 58 | job = executor.submit(save_results, **params) 59 | job_param = f"{params['category']}_N{params['num_images']}_{params['sample_num']}" 60 | jobs.append((job_param, job)) 61 | jobs = {f"{job_param}_{job.job_id}": job for job_param, job in jobs} 62 | submitit_job_watcher(jobs) 63 | else: 64 | for job_config in tqdm(job_configs): 65 | # This is much slower. 66 | save_results(**job_config) 67 | 68 | 69 | def process_predictions(eval_path, pred_index, checkpoint=800_000, threshold_R=15, threshold_CC=0.1): 70 | """ 71 | pred_index should be 1 (corresponding to T=90) 72 | """ 73 | def aggregate_per_category(categories, metric_key, num_images, sample_num, threshold=None): 74 | """ 75 | Aggregates one metric over all data points in a prediction file and then across categories. 76 | - For R_error and CC_error: use mean to threshold-based accuracy 77 | - For CD and CD_Object: use median to reduce the effect of outliers 78 | """ 79 | per_category_values = [] 80 | 81 | for category in tqdm(categories, desc=f"Sample {sample_num}, N={num_images}, {metric_key}"): 82 | per_pred_values = [] 83 | 84 | data_path = glob( 85 | os.path.join(eval_path, "eval", f"{category}_{num_images}_{sample_num}_ckpt{checkpoint}*.json") 86 | )[0] 87 | 88 | with open(data_path) as f: 89 | eval_data = json.load(f) 90 | 91 | for preds in eval_data.values(): 92 | if metric_key in ["R_error", "CC_error"]: 93 | vals = np.array(preds[pred_index][metric_key]) 94 | per_pred_values.append(np.mean(vals < threshold)) 95 | else: 96 | per_pred_values.append(preds[pred_index][metric_key]) 97 | 98 | # Aggregate over all predictions within this category 99 | per_category_values.append( 100 | np.mean(per_pred_values) if metric_key in ["R_error", "CC_error"] 101 | else np.median(per_pred_values) # CD or CD_Object — use median to filter outliers 102 | ) 103 | 104 | if metric_key in ["R_error", "CC_error"]: 105 | return np.mean(per_category_values) 106 | else: 107 | return np.median(per_category_values) 108 | 109 | def aggregate_metric(categories, metric_key, num_images, threshold=None): 110 | """Aggregates one metric over 5 random samples per category and returns the final mean""" 111 | return np.mean([ 112 | aggregate_per_category(categories, metric_key, num_images, sample_num, threshold=threshold) 113 | for sample_num in range(5) 114 | ]) 115 | 116 | # Output containers 117 | all_seen_acc_R, all_seen_acc_CC = [], [] 118 | all_seen_CD, all_seen_CD_Object = [], [] 119 | all_unseen_acc_R, all_unseen_acc_CC = [], [] 120 | all_unseen_CD, all_unseen_CD_Object = [], [] 121 | 122 | for num_images in range(2, 9): 123 | # Seen categories 124 | all_seen_acc_R.append( 125 | aggregate_metric(TRAINING_CATEGORIES, "R_error", num_images, threshold=threshold_R) 126 | ) 127 | all_seen_acc_CC.append( 128 | aggregate_metric(TRAINING_CATEGORIES, "CC_error", num_images, threshold=threshold_CC) 129 | ) 130 | all_seen_CD.append( 131 | aggregate_metric(TRAINING_CATEGORIES, "CD", num_images) 132 | ) 133 | all_seen_CD_Object.append( 134 | aggregate_metric(TRAINING_CATEGORIES, "CD_Object", num_images) 135 | ) 136 | 137 | # Unseen categories 138 | all_unseen_acc_R.append( 139 | aggregate_metric(TEST_CATEGORIES, "R_error", num_images, threshold=threshold_R) 140 | ) 141 | all_unseen_acc_CC.append( 142 | aggregate_metric(TEST_CATEGORIES, "CC_error", num_images, threshold=threshold_CC) 143 | ) 144 | all_unseen_CD.append( 145 | aggregate_metric(TEST_CATEGORIES, "CD", num_images) 146 | ) 147 | all_unseen_CD_Object.append( 148 | aggregate_metric(TEST_CATEGORIES, "CD_Object", num_images) 149 | ) 150 | 151 | # Print the results in formatted rows 152 | print("N= ", " ".join(f"{i: 5}" for i in range(2, 9))) 153 | print("Seen R ", " ".join([f"{x:0.3f}" for x in all_seen_acc_R])) 154 | print("Seen CC ", " ".join([f"{x:0.3f}" for x in all_seen_acc_CC])) 155 | print("Seen CD ", " ".join([f"{x:0.3f}" for x in all_seen_CD])) 156 | print("Seen CD_Obj ", " ".join([f"{x:0.3f}" for x in all_seen_CD_Object])) 157 | print("Unseen R ", " ".join([f"{x:0.3f}" for x in all_unseen_acc_R])) 158 | print("Unseen CC ", " ".join([f"{x:0.3f}" for x in all_unseen_acc_CC])) 159 | print("Unseen CD ", " ".join([f"{x:0.3f}" for x in all_unseen_CD])) 160 | print("Unseen CD_Obj", " ".join([f"{x:0.3f}" for x in all_unseen_CD_Object])) 161 | 162 | 163 | if __name__ == "__main__": 164 | parser = argparse.ArgumentParser() 165 | parser.add_argument("--eval_path", type=str, default=None) 166 | parser.add_argument("--use_submitit", action="store_true") 167 | parser.add_argument("--mode", type=str, default="test") 168 | args = parser.parse_args() 169 | 170 | eval_path = "output/multi_diffusionsfm_dense" if args.eval_path is None else args.eval_path 171 | use_submitit = args.use_submitit 172 | mode = args.mode 173 | 174 | evaluate_diffusionsfm(eval_path, use_submitit, mode) 175 | process_predictions(eval_path, 1) -------------------------------------------------------------------------------- /diffusionsfm/inference/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QitaoZhao/DiffusionSfM/4bb08800721bdcf46b0c823586a2fab4702ff282/diffusionsfm/inference/__init__.py -------------------------------------------------------------------------------- /diffusionsfm/inference/ddim.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import numpy as np 4 | from tqdm.auto import tqdm 5 | 6 | from diffusionsfm.utils.rays import compute_ndc_coordinates 7 | 8 | 9 | def inference_ddim( 10 | model, 11 | images, 12 | device, 13 | crop_parameters=None, 14 | eta=0, 15 | num_inference_steps=100, 16 | pbar=True, 17 | stop_iteration=None, 18 | num_patches_x=16, 19 | num_patches_y=16, 20 | visualize=False, 21 | max_num_images=8, 22 | seed=0, 23 | ): 24 | """ 25 | Implements DDIM-style inference. 26 | 27 | To get multiple samples, batch the images multiple times. 28 | 29 | Args: 30 | model: Ray Diffuser. 31 | images (torch.Tensor): (B, N, C, H, W). 32 | patch_rays_gt (torch.Tensor): If provided, the patch rays which are ground 33 | truth (B, N, P, 6). 34 | eta (float, optional): Stochasticity coefficient. 0 is completely deterministic, 35 | 1 is equivalent to DDPM. (Default: 0) 36 | num_inference_steps (int, optional): Number of inference steps. (Default: 100) 37 | pbar (bool, optional): Whether to show progress bar. (Default: True) 38 | """ 39 | timesteps = model.noise_scheduler.compute_inference_timesteps(num_inference_steps) 40 | batch_size = images.shape[0] 41 | num_images = images.shape[1] 42 | 43 | if isinstance(eta, list): 44 | eta_0, eta_1 = float(eta[0]), float(eta[1]) 45 | else: 46 | eta_0, eta_1 = 0, 0 47 | 48 | # Fixing seed 49 | if seed is not None: 50 | torch.manual_seed(seed) 51 | random.seed(seed) 52 | np.random.seed(seed) 53 | 54 | with torch.no_grad(): 55 | x_tau = torch.randn( 56 | batch_size, 57 | num_images, 58 | model.ray_out if hasattr(model, "ray_out") else model.ray_dim, 59 | num_patches_x, 60 | num_patches_y, 61 | device=device, 62 | ) 63 | 64 | if visualize: 65 | x_taus = [x_tau] 66 | all_pred = [] 67 | noise_samples = [] 68 | 69 | image_features = model.feature_extractor(images, autoresize=True) 70 | 71 | if model.append_ndc: 72 | ndc_coordinates = compute_ndc_coordinates( 73 | crop_parameters=crop_parameters, 74 | no_crop_param_device="cpu", 75 | num_patches_x=model.width, 76 | num_patches_y=model.width, 77 | distortion_coeffs=None, 78 | )[..., :2].to(device) 79 | ndc_coordinates = ndc_coordinates.permute(0, 1, 4, 2, 3) 80 | else: 81 | ndc_coordinates = None 82 | 83 | if stop_iteration is None: 84 | loop = range(len(timesteps)) 85 | else: 86 | loop = range(len(timesteps) - stop_iteration + 1) 87 | loop = tqdm(loop) if pbar else loop 88 | 89 | for t in loop: 90 | tau = timesteps[t] 91 | 92 | if tau > 0 and eta_1 > 0: 93 | z = torch.randn( 94 | batch_size, 95 | num_images, 96 | model.ray_out if hasattr(model, "ray_out") else model.ray_dim, 97 | num_patches_x, 98 | num_patches_y, 99 | device=device, 100 | ) 101 | else: 102 | z = 0 103 | 104 | alpha = model.noise_scheduler.alphas_cumprod[tau] 105 | if tau > 0: 106 | tau_prev = timesteps[t + 1] 107 | alpha_prev = model.noise_scheduler.alphas_cumprod[tau_prev] 108 | else: 109 | alpha_prev = torch.tensor(1.0, device=device).float() 110 | 111 | sigma_t = ( 112 | torch.sqrt((1 - alpha_prev) / (1 - alpha)) 113 | * torch.sqrt(1 - alpha / alpha_prev) 114 | ) 115 | 116 | if num_images > max_num_images: 117 | eps_pred = torch.zeros_like(x_tau) 118 | noise_sample = torch.zeros_like(x_tau) 119 | 120 | # Randomly split image indices (excluding index 0), then prepend 0 to each split 121 | indices_split = torch.split( 122 | torch.randperm(num_images - 1) + 1, max_num_images - 1 123 | ) 124 | 125 | for indices in indices_split: 126 | indices = torch.cat((torch.tensor([0]), indices)) # Ensure index 0 is always included 127 | 128 | eps_pred_ind, noise_sample_ind = model( 129 | features=image_features[:, indices], 130 | rays_noisy=x_tau[:, indices], 131 | t=int(tau), 132 | ndc_coordinates=ndc_coordinates[:, indices], 133 | indices=indices, 134 | ) 135 | 136 | eps_pred[:, indices] += eps_pred_ind 137 | 138 | if noise_sample_ind is not None: 139 | noise_sample[:, indices] += noise_sample_ind 140 | 141 | # Average over splits for the shared reference index (0) 142 | eps_pred[:, 0] /= len(indices_split) 143 | noise_sample[:, 0] /= len(indices_split) 144 | else: 145 | eps_pred, noise_sample = model( 146 | features=image_features, 147 | rays_noisy=x_tau, 148 | t=int(tau), 149 | ndc_coordinates=ndc_coordinates, 150 | ) 151 | 152 | if model.use_homogeneous: 153 | p1 = eps_pred[:, :, :4] 154 | p2 = eps_pred[:, :, 4:] 155 | 156 | c1 = torch.linalg.norm(p1, dim=2, keepdim=True) 157 | c2 = torch.linalg.norm(p2, dim=2, keepdim=True) 158 | eps_pred[:, :, :4] = p1 / c1 159 | eps_pred[:, :, 4:] = p2 / c2 160 | 161 | if visualize: 162 | all_pred.append(eps_pred.clone()) 163 | noise_samples.append(noise_sample) 164 | 165 | # TODO: Can simplify this a lot 166 | x0_pred = eps_pred.clone() 167 | eps_pred = (x_tau - torch.sqrt(alpha) * eps_pred) / torch.sqrt( 168 | 1 - alpha 169 | ) 170 | 171 | dir_x_tau = torch.sqrt(1 - alpha_prev - eta_0*sigma_t**2) * eps_pred 172 | noise = eta_1 * sigma_t * z 173 | 174 | new_x_tau = torch.sqrt(alpha_prev) * x0_pred + dir_x_tau + noise 175 | x_tau = new_x_tau 176 | 177 | if visualize: 178 | x_taus.append(x_tau.detach().clone()) 179 | if visualize: 180 | return x_tau, x_taus, all_pred, noise_samples 181 | return x_tau 182 | -------------------------------------------------------------------------------- /diffusionsfm/inference/load_model.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | from glob import glob 3 | 4 | import torch 5 | from omegaconf import OmegaConf 6 | 7 | from diffusionsfm.model.diffuser import RayDiffuser 8 | from diffusionsfm.model.diffuser_dpt import RayDiffuserDPT 9 | from diffusionsfm.model.scheduler import NoiseScheduler 10 | 11 | 12 | def load_model( 13 | output_dir, checkpoint=None, device="cuda:0", custom_keys=None, ignore_keys=() 14 | ): 15 | """ 16 | Loads a model and config from an output directory. 17 | 18 | E.g. to load with different number of images, 19 | ``` 20 | custom_keys={"model.num_images": 15}, ignore_keys=["pos_table"] 21 | ``` 22 | 23 | Args: 24 | output_dir (str): Path to the output directory. 25 | checkpoint (str or int): Path to the checkpoint to load. If None, loads the 26 | latest checkpoint. 27 | device (str): Device to load the model on. 28 | custom_keys (dict): Dictionary of custom keys to override in the config. 29 | """ 30 | if checkpoint is None: 31 | checkpoint_path = sorted(glob(osp.join(output_dir, "checkpoints", "*.pth")))[-1] 32 | else: 33 | if isinstance(checkpoint, int): 34 | checkpoint_name = f"ckpt_{checkpoint:08d}.pth" 35 | else: 36 | checkpoint_name = checkpoint 37 | checkpoint_path = osp.join(output_dir, "checkpoints", checkpoint_name) 38 | print("Loading checkpoint", osp.basename(checkpoint_path)) 39 | 40 | cfg = OmegaConf.load(osp.join(output_dir, "hydra", "config.yaml")) 41 | if custom_keys is not None: 42 | for k, v in custom_keys.items(): 43 | OmegaConf.update(cfg, k, v) 44 | noise_scheduler = NoiseScheduler( 45 | type=cfg.noise_scheduler.type, 46 | max_timesteps=cfg.noise_scheduler.max_timesteps, 47 | beta_start=cfg.noise_scheduler.beta_start, 48 | beta_end=cfg.noise_scheduler.beta_end, 49 | ) 50 | 51 | if not cfg.training.get("dpt_head", False): 52 | model = RayDiffuser( 53 | depth=cfg.model.depth, 54 | width=cfg.model.num_patches_x, 55 | P=1, 56 | max_num_images=cfg.model.num_images, 57 | noise_scheduler=noise_scheduler, 58 | feature_extractor=cfg.model.feature_extractor, 59 | append_ndc=cfg.model.append_ndc, 60 | diffuse_depths=cfg.training.get("diffuse_depths", False), 61 | depth_resolution=cfg.training.get("depth_resolution", 1), 62 | use_homogeneous=cfg.model.get("use_homogeneous", False), 63 | cond_depth_mask=cfg.model.get("cond_depth_mask", False), 64 | ).to(device) 65 | else: 66 | model = RayDiffuserDPT( 67 | depth=cfg.model.depth, 68 | width=cfg.model.num_patches_x, 69 | P=1, 70 | max_num_images=cfg.model.num_images, 71 | noise_scheduler=noise_scheduler, 72 | feature_extractor=cfg.model.feature_extractor, 73 | append_ndc=cfg.model.append_ndc, 74 | diffuse_depths=cfg.training.get("diffuse_depths", False), 75 | depth_resolution=cfg.training.get("depth_resolution", 1), 76 | encoder_features=cfg.training.get("dpt_encoder_features", False), 77 | use_homogeneous=cfg.model.get("use_homogeneous", False), 78 | cond_depth_mask=cfg.model.get("cond_depth_mask", False), 79 | ).to(device) 80 | 81 | data = torch.load(checkpoint_path) 82 | state_dict = {} 83 | for k, v in data["state_dict"].items(): 84 | include = True 85 | for ignore_key in ignore_keys: 86 | if ignore_key in k: 87 | include = False 88 | if include: 89 | state_dict[k] = v 90 | 91 | missing, unexpected = model.load_state_dict(state_dict, strict=False) 92 | if len(missing) > 0: 93 | print("Missing keys:", missing) 94 | if len(unexpected) > 0: 95 | print("Unexpected keys:", unexpected) 96 | model = model.eval() 97 | return model, cfg 98 | -------------------------------------------------------------------------------- /diffusionsfm/inference/predict.py: -------------------------------------------------------------------------------- 1 | from diffusionsfm.inference.ddim import inference_ddim 2 | from diffusionsfm.utils.rays import ( 3 | Rays, 4 | rays_to_cameras, 5 | rays_to_cameras_homography, 6 | ) 7 | 8 | 9 | def predict_cameras( 10 | model, 11 | images, 12 | device, 13 | crop_parameters=None, 14 | stop_iteration=None, 15 | num_patches_x=16, 16 | num_patches_y=16, 17 | additional_timesteps=(), 18 | calculate_intrinsics=False, 19 | max_num_images=8, 20 | mode=None, 21 | return_rays=False, 22 | use_homogeneous=False, 23 | seed=0, 24 | ): 25 | """ 26 | Args: 27 | images (torch.Tensor): (N, C, H, W) 28 | crop_parameters (torch.Tensor): (N, 4) or None 29 | """ 30 | if calculate_intrinsics: 31 | ray_to_cam = rays_to_cameras_homography 32 | else: 33 | ray_to_cam = rays_to_cameras 34 | 35 | get_spatial_rays = Rays.from_spatial 36 | 37 | rays_final, rays_intermediate, pred_intermediate, _ = inference_ddim( 38 | model, 39 | images.unsqueeze(0), 40 | device, 41 | crop_parameters=crop_parameters.unsqueeze(0), 42 | pbar=False, 43 | stop_iteration=stop_iteration, 44 | eta=[1, 0], 45 | num_inference_steps=100, 46 | num_patches_x=num_patches_x, 47 | num_patches_y=num_patches_y, 48 | visualize=True, 49 | max_num_images=max_num_images, 50 | ) 51 | 52 | spatial_rays = get_spatial_rays( 53 | rays_final[0], 54 | mode=mode, 55 | num_patches_x=num_patches_x, 56 | num_patches_y=num_patches_y, 57 | use_homogeneous=use_homogeneous, 58 | ) 59 | 60 | pred_cam = ray_to_cam( 61 | spatial_rays, 62 | crop_parameters, 63 | num_patches_x=num_patches_x, 64 | num_patches_y=num_patches_y, 65 | depth_resolution=model.depth_resolution, 66 | average_centers=True, 67 | directions_from_averaged_center=True, 68 | ) 69 | 70 | additional_predictions = [] 71 | for t in additional_timesteps: 72 | ray = pred_intermediate[t] 73 | 74 | ray = get_spatial_rays( 75 | ray[0], 76 | mode=mode, 77 | num_patches_x=num_patches_x, 78 | num_patches_y=num_patches_y, 79 | use_homogeneous=use_homogeneous, 80 | ) 81 | 82 | cam = ray_to_cam( 83 | ray, 84 | crop_parameters, 85 | num_patches_x=num_patches_x, 86 | num_patches_y=num_patches_y, 87 | average_centers=True, 88 | directions_from_averaged_center=True, 89 | ) 90 | if return_rays: 91 | cam = (cam, ray) 92 | additional_predictions.append(cam) 93 | 94 | if return_rays: 95 | return (pred_cam, spatial_rays), additional_predictions 96 | return pred_cam, additional_predictions, spatial_rays -------------------------------------------------------------------------------- /diffusionsfm/model/base_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class BaseModel(torch.nn.Module): 5 | def load(self, path): 6 | """Load model from file. 7 | 8 | Args: 9 | path (str): file path 10 | """ 11 | parameters = torch.load(path, map_location=torch.device("cpu")) 12 | 13 | if "optimizer" in parameters: 14 | parameters = parameters["model"] 15 | 16 | self.load_state_dict(parameters) 17 | -------------------------------------------------------------------------------- /diffusionsfm/model/blocks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from diffusionsfm.model.dit import TimestepEmbedder 4 | import ipdb 5 | 6 | 7 | def modulate(x, shift, scale): 8 | return x * (1 + scale.unsqueeze(-1).unsqueeze(-1)) + shift.unsqueeze(-1).unsqueeze( 9 | -1 10 | ) 11 | 12 | 13 | def _make_fusion_block(features, use_bn, use_ln, dpt_time, resolution): 14 | return FeatureFusionBlock_custom( 15 | features, 16 | nn.ReLU(False), 17 | deconv=False, 18 | bn=use_bn, 19 | expand=False, 20 | align_corners=True, 21 | dpt_time=dpt_time, 22 | ln=use_ln, 23 | resolution=resolution 24 | ) 25 | 26 | 27 | def _make_scratch(in_shape, out_shape, groups=1, expand=False): 28 | scratch = nn.Module() 29 | 30 | out_shape1 = out_shape 31 | out_shape2 = out_shape 32 | out_shape3 = out_shape 33 | out_shape4 = out_shape 34 | if expand == True: 35 | out_shape1 = out_shape 36 | out_shape2 = out_shape * 2 37 | out_shape3 = out_shape * 4 38 | out_shape4 = out_shape * 8 39 | 40 | scratch.layer1_rn = nn.Conv2d( 41 | in_shape[0], 42 | out_shape1, 43 | kernel_size=3, 44 | stride=1, 45 | padding=1, 46 | bias=False, 47 | groups=groups, 48 | ) 49 | scratch.layer2_rn = nn.Conv2d( 50 | in_shape[1], 51 | out_shape2, 52 | kernel_size=3, 53 | stride=1, 54 | padding=1, 55 | bias=False, 56 | groups=groups, 57 | ) 58 | scratch.layer3_rn = nn.Conv2d( 59 | in_shape[2], 60 | out_shape3, 61 | kernel_size=3, 62 | stride=1, 63 | padding=1, 64 | bias=False, 65 | groups=groups, 66 | ) 67 | scratch.layer4_rn = nn.Conv2d( 68 | in_shape[3], 69 | out_shape4, 70 | kernel_size=3, 71 | stride=1, 72 | padding=1, 73 | bias=False, 74 | groups=groups, 75 | ) 76 | 77 | return scratch 78 | 79 | 80 | class ResidualConvUnit_custom(nn.Module): 81 | """Residual convolution module.""" 82 | 83 | def __init__(self, features, activation, bn, ln, dpt_time=False, resolution=16): 84 | """Init. 85 | 86 | Args: 87 | features (int): number of features 88 | """ 89 | super().__init__() 90 | 91 | self.bn = bn 92 | self.ln = ln 93 | 94 | self.groups = 1 95 | 96 | self.conv1 = nn.Conv2d( 97 | features, 98 | features, 99 | kernel_size=3, 100 | stride=1, 101 | padding=1, 102 | bias=not self.bn, 103 | groups=self.groups, 104 | ) 105 | 106 | self.conv2 = nn.Conv2d( 107 | features, 108 | features, 109 | kernel_size=3, 110 | stride=1, 111 | padding=1, 112 | bias=not self.bn, 113 | groups=self.groups, 114 | ) 115 | 116 | nn.init.kaiming_uniform_(self.conv1.weight) 117 | nn.init.kaiming_uniform_(self.conv2.weight) 118 | 119 | if self.bn == True: 120 | self.bn1 = nn.BatchNorm2d(features) 121 | self.bn2 = nn.BatchNorm2d(features) 122 | 123 | if self.ln == True: 124 | self.bn1 = nn.LayerNorm((features, resolution, resolution)) 125 | self.bn2 = nn.LayerNorm((features, resolution, resolution)) 126 | 127 | self.activation = activation 128 | 129 | if dpt_time: 130 | self.t_embedder = TimestepEmbedder(hidden_size=features) 131 | self.adaLN_modulation = nn.Sequential( 132 | nn.SiLU(), nn.Linear(features, 3 * features, bias=True) 133 | ) 134 | 135 | def forward(self, x, t=None): 136 | """Forward pass. 137 | 138 | Args: 139 | x (tensor): input 140 | 141 | Returns: 142 | tensor: output 143 | """ 144 | if t is not None: 145 | # Embed timestamp & calculate shift parameters 146 | t = self.t_embedder(t) # (B*N) 147 | shift, scale, gate = self.adaLN_modulation(t).chunk(3, dim=1) # (B * N, T) 148 | 149 | # Shift & scale x 150 | x = modulate(x, shift, scale) # (B * N, T, H, W) 151 | 152 | out = self.activation(x) 153 | out = self.conv1(out) 154 | if self.bn or self.ln: 155 | out = self.bn1(out) 156 | 157 | out = self.activation(out) 158 | out = self.conv2(out) 159 | if self.bn or self.ln: 160 | out = self.bn2(out) 161 | 162 | if self.groups > 1: 163 | out = self.conv_merge(out) 164 | 165 | if t is not None: 166 | out = gate.unsqueeze(-1).unsqueeze(-1) * out 167 | 168 | return out + x 169 | 170 | 171 | class FeatureFusionBlock_custom(nn.Module): 172 | """Feature fusion block.""" 173 | 174 | def __init__( 175 | self, 176 | features, 177 | activation, 178 | deconv=False, 179 | bn=False, 180 | ln=False, 181 | expand=False, 182 | align_corners=True, 183 | dpt_time=False, 184 | resolution=16, 185 | ): 186 | """Init. 187 | 188 | Args: 189 | features (int): number of features 190 | """ 191 | super(FeatureFusionBlock_custom, self).__init__() 192 | 193 | self.deconv = deconv 194 | self.align_corners = align_corners 195 | 196 | self.groups = 1 197 | 198 | self.expand = expand 199 | out_features = features 200 | if self.expand == True: 201 | out_features = features // 2 202 | 203 | self.out_conv = nn.Conv2d( 204 | features, 205 | out_features, 206 | kernel_size=1, 207 | stride=1, 208 | padding=0, 209 | bias=True, 210 | groups=1, 211 | ) 212 | 213 | nn.init.kaiming_uniform_(self.out_conv.weight) 214 | 215 | # The second block sees time 216 | self.resConfUnit1 = ResidualConvUnit_custom( 217 | features, activation, bn=bn, ln=ln, dpt_time=False, resolution=resolution 218 | ) 219 | self.resConfUnit2 = ResidualConvUnit_custom( 220 | features, activation, bn=bn, ln=ln, dpt_time=dpt_time, resolution=resolution 221 | ) 222 | 223 | def forward(self, input, activation=None, t=None): 224 | """Forward pass. 225 | 226 | Returns: 227 | tensor: output 228 | """ 229 | output = input 230 | 231 | if activation is not None: 232 | res = self.resConfUnit1(activation) 233 | 234 | output += res 235 | 236 | output = self.resConfUnit2(output, t) 237 | 238 | output = torch.nn.functional.interpolate( 239 | output.float(), 240 | scale_factor=2, 241 | mode="bilinear", 242 | align_corners=self.align_corners, 243 | ) 244 | 245 | output = self.out_conv(output) 246 | 247 | return output 248 | -------------------------------------------------------------------------------- /diffusionsfm/model/diffuser.py: -------------------------------------------------------------------------------- 1 | import ipdb # noqa: F401 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | 6 | from diffusionsfm.model.dit import DiT 7 | from diffusionsfm.model.feature_extractors import PretrainedVAE, SpatialDino 8 | from diffusionsfm.model.scheduler import NoiseScheduler 9 | 10 | from huggingface_hub import PyTorchModelHubMixin 11 | 12 | 13 | class RayDiffuser(nn.Module, PyTorchModelHubMixin, 14 | repo_url="https://github.com/QitaoZhao/DiffusionSfM", 15 | paper_url="https://huggingface.co/papers/2505.05473", 16 | pipeline_tag="image-to-3d", 17 | license="mit"): 18 | def __init__( 19 | self, 20 | model_type="dit", 21 | depth=8, 22 | width=16, 23 | hidden_size=1152, 24 | P=1, 25 | max_num_images=1, 26 | noise_scheduler=None, 27 | freeze_encoder=True, 28 | feature_extractor="dino", 29 | append_ndc=True, 30 | use_unconditional=False, 31 | diffuse_depths=False, 32 | depth_resolution=1, 33 | use_homogeneous=False, 34 | cond_depth_mask=False, 35 | ): 36 | super().__init__() 37 | if noise_scheduler is None: 38 | self.noise_scheduler = NoiseScheduler() 39 | else: 40 | self.noise_scheduler = noise_scheduler 41 | 42 | self.diffuse_depths = diffuse_depths 43 | self.depth_resolution = depth_resolution 44 | self.use_homogeneous = use_homogeneous 45 | 46 | self.ray_dim = 3 47 | if self.use_homogeneous: 48 | self.ray_dim += 1 49 | 50 | self.ray_dim += self.ray_dim * self.depth_resolution**2 51 | 52 | if self.diffuse_depths: 53 | self.ray_dim += 1 54 | 55 | self.append_ndc = append_ndc 56 | self.width = width 57 | 58 | self.max_num_images = max_num_images 59 | self.model_type = model_type 60 | self.use_unconditional = use_unconditional 61 | self.cond_depth_mask = cond_depth_mask 62 | 63 | if feature_extractor == "dino": 64 | self.feature_extractor = SpatialDino( 65 | freeze_weights=freeze_encoder, num_patches_x=width, num_patches_y=width 66 | ) 67 | self.feature_dim = self.feature_extractor.feature_dim 68 | elif feature_extractor == "vae": 69 | self.feature_extractor = PretrainedVAE( 70 | freeze_weights=freeze_encoder, num_patches_x=width, num_patches_y=width 71 | ) 72 | self.feature_dim = self.feature_extractor.feature_dim 73 | else: 74 | raise Exception(f"Unknown feature extractor {feature_extractor}") 75 | 76 | if self.use_unconditional: 77 | self.register_parameter( 78 | "null_token", nn.Parameter(torch.randn(self.feature_dim, 1, 1)) 79 | ) 80 | 81 | self.input_dim = self.feature_dim * 2 82 | 83 | if self.append_ndc: 84 | self.input_dim += 2 85 | 86 | if model_type == "dit": 87 | self.ray_predictor = DiT( 88 | in_channels=self.input_dim, 89 | out_channels=self.ray_dim, 90 | width=width, 91 | depth=depth, 92 | hidden_size=hidden_size, 93 | max_num_images=max_num_images, 94 | P=P, 95 | ) 96 | 97 | self.scratch = nn.Module() 98 | self.scratch.input_conv = nn.Linear(self.ray_dim + int(self.cond_depth_mask), self.feature_dim) 99 | 100 | def forward_noise( 101 | self, x, t, epsilon=None, zero_out_mask=None 102 | ): 103 | """ 104 | Applies forward diffusion (adds noise) to the input. 105 | 106 | If a mask is provided, the noise is only applied to the masked inputs. 107 | """ 108 | t = t.reshape(-1, 1, 1, 1, 1) 109 | 110 | if epsilon is None: 111 | epsilon = torch.randn_like(x) 112 | else: 113 | epsilon = epsilon.reshape(x.shape) 114 | 115 | alpha_bar = self.noise_scheduler.alphas_cumprod[t] 116 | x_noise = torch.sqrt(alpha_bar) * x + torch.sqrt(1 - alpha_bar) * epsilon 117 | 118 | if zero_out_mask is not None and self.cond_depth_mask: 119 | x_noise = x_noise * zero_out_mask 120 | 121 | return x_noise, epsilon 122 | 123 | def forward( 124 | self, 125 | features=None, 126 | images=None, 127 | rays=None, 128 | rays_noisy=None, 129 | t=None, 130 | ndc_coordinates=None, 131 | unconditional_mask=None, 132 | return_dpt_activations=False, 133 | depth_mask=None, 134 | ): 135 | """ 136 | Args: 137 | images: (B, N, 3, H, W). 138 | t: (B,). 139 | rays: (B, N, 6, H, W). 140 | rays_noisy: (B, N, 6, H, W). 141 | ndc_coordinates: (B, N, 2, H, W). 142 | unconditional_mask: (B, N) or (B,). Should be 1 for unconditional samples 143 | and 0 else. 144 | """ 145 | 146 | if features is None: 147 | # VAE expects 256x256 images while DINO expects 224x224 images. 148 | # Both feature extractors support autoresize=True, but ideally we should 149 | # set this to be false and handle in the dataloader. 150 | features = self.feature_extractor(images, autoresize=True) 151 | 152 | B = features.shape[0] 153 | 154 | if ( 155 | unconditional_mask is not None 156 | and self.use_unconditional 157 | ): 158 | null_token = self.null_token.reshape(1, 1, self.feature_dim, 1, 1) 159 | unconditional_mask = unconditional_mask.reshape(B, -1, 1, 1, 1) 160 | features = ( 161 | features * (1 - unconditional_mask) + null_token * unconditional_mask 162 | ) 163 | 164 | if isinstance(t, int) or isinstance(t, np.int64): 165 | t = torch.ones(1, dtype=int).to(features.device) * t 166 | else: 167 | t = t.reshape(B) 168 | 169 | if rays_noisy is None: 170 | if self.cond_depth_mask: 171 | rays_noisy, epsilon = self.forward_noise(rays, t, zero_out_mask=depth_mask.unsqueeze(2)) 172 | else: 173 | rays_noisy, epsilon = self.forward_noise(rays, t) 174 | else: 175 | epsilon = None 176 | 177 | if self.cond_depth_mask: 178 | if depth_mask is None: 179 | depth_mask = torch.ones_like(rays_noisy[:, :, 0]) 180 | ray_repr = torch.cat([rays_noisy, depth_mask.unsqueeze(2)], dim=2) 181 | else: 182 | ray_repr = rays_noisy 183 | 184 | ray_repr = ray_repr.permute(0, 1, 3, 4, 2) 185 | ray_repr = self.scratch.input_conv(ray_repr).permute(0, 1, 4, 2, 3).contiguous() 186 | 187 | scene_features = torch.cat([features, ray_repr], dim=2) 188 | 189 | if self.append_ndc: 190 | scene_features = torch.cat([scene_features, ndc_coordinates], dim=2) 191 | 192 | epsilon_pred = self.ray_predictor( 193 | scene_features, 194 | t, 195 | return_dpt_activations=return_dpt_activations, 196 | ) 197 | 198 | if return_dpt_activations: 199 | return epsilon_pred, rays_noisy, epsilon 200 | 201 | return epsilon_pred, epsilon 202 | -------------------------------------------------------------------------------- /diffusionsfm/model/diffuser_dpt.py: -------------------------------------------------------------------------------- 1 | import ipdb # noqa: F401 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | 6 | from diffusionsfm.model.dit import DiT 7 | from diffusionsfm.model.feature_extractors import PretrainedVAE, SpatialDino 8 | from diffusionsfm.model.blocks import _make_fusion_block, _make_scratch 9 | from diffusionsfm.model.scheduler import NoiseScheduler 10 | 11 | from huggingface_hub import PyTorchModelHubMixin 12 | 13 | 14 | # functional implementation 15 | def nearest_neighbor_upsample(x: torch.Tensor, scale_factor: int): 16 | """Upsample {x} (NCHW) by scale factor {scale_factor} using nearest neighbor interpolation.""" 17 | s = scale_factor 18 | return ( 19 | x.reshape(*x.shape, 1, 1) 20 | .expand(*x.shape, s, s) 21 | .transpose(-2, -3) 22 | .reshape(*x.shape[:2], *(s * hw for hw in x.shape[2:])) 23 | ) 24 | 25 | 26 | class ProjectReadout(nn.Module): 27 | def __init__(self, in_features, start_index=1): 28 | super(ProjectReadout, self).__init__() 29 | self.start_index = start_index 30 | 31 | self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU()) 32 | 33 | def forward(self, x): 34 | readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :]) 35 | features = torch.cat((x[:, self.start_index :], readout), -1) 36 | 37 | return self.project(features) 38 | 39 | 40 | class RayDiffuserDPT(nn.Module, PyTorchModelHubMixin, 41 | repo_url="https://github.com/QitaoZhao/DiffusionSfM", 42 | paper_url="https://huggingface.co/papers/2505.05473", 43 | pipeline_tag="image-to-3d", 44 | license="mit"): 45 | def __init__( 46 | self, 47 | model_type="dit", 48 | depth=8, 49 | width=16, 50 | hidden_size=1152, 51 | P=1, 52 | max_num_images=1, 53 | noise_scheduler=None, 54 | freeze_encoder=True, 55 | feature_extractor="dino", 56 | append_ndc=True, 57 | use_unconditional=False, 58 | diffuse_depths=False, 59 | depth_resolution=1, 60 | encoder_features=False, 61 | use_homogeneous=False, 62 | freeze_transformer=False, 63 | cond_depth_mask=False, 64 | ): 65 | super().__init__() 66 | if noise_scheduler is None: 67 | self.noise_scheduler = NoiseScheduler() 68 | else: 69 | self.noise_scheduler = noise_scheduler 70 | 71 | self.diffuse_depths = diffuse_depths 72 | self.depth_resolution = depth_resolution 73 | self.use_homogeneous = use_homogeneous 74 | 75 | self.ray_dim = 3 76 | 77 | if self.use_homogeneous: 78 | self.ray_dim += 1 79 | self.ray_dim += self.ray_dim * self.depth_resolution**2 80 | 81 | if self.diffuse_depths: 82 | self.ray_dim += 1 83 | 84 | self.append_ndc = append_ndc 85 | self.width = width 86 | 87 | self.max_num_images = max_num_images 88 | self.model_type = model_type 89 | self.use_unconditional = use_unconditional 90 | self.cond_depth_mask = cond_depth_mask 91 | self.encoder_features = encoder_features 92 | 93 | if feature_extractor == "dino": 94 | self.feature_extractor = SpatialDino( 95 | freeze_weights=freeze_encoder, 96 | num_patches_x=width, 97 | num_patches_y=width, 98 | activation_hooks=self.encoder_features, 99 | ) 100 | self.feature_dim = self.feature_extractor.feature_dim 101 | elif feature_extractor == "vae": 102 | self.feature_extractor = PretrainedVAE( 103 | freeze_weights=freeze_encoder, num_patches_x=width, num_patches_y=width 104 | ) 105 | self.feature_dim = self.feature_extractor.feature_dim 106 | else: 107 | raise Exception(f"Unknown feature extractor {feature_extractor}") 108 | 109 | if self.use_unconditional: 110 | self.register_parameter( 111 | "null_token", nn.Parameter(torch.randn(self.feature_dim, 1, 1)) 112 | ) 113 | 114 | self.input_dim = self.feature_dim * 2 115 | 116 | if self.append_ndc: 117 | self.input_dim += 2 118 | 119 | if model_type == "dit": 120 | self.ray_predictor = DiT( 121 | in_channels=self.input_dim, 122 | out_channels=self.ray_dim, 123 | width=width, 124 | depth=depth, 125 | hidden_size=hidden_size, 126 | max_num_images=max_num_images, 127 | P=P, 128 | ) 129 | 130 | if freeze_transformer: 131 | for param in self.ray_predictor.parameters(): 132 | param.requires_grad = False 133 | 134 | # Fusion blocks 135 | self.f = 256 136 | 137 | if self.encoder_features: 138 | feature_lens = [ 139 | self.feature_extractor.feature_dim, 140 | self.feature_extractor.feature_dim, 141 | self.ray_predictor.hidden_size, 142 | self.ray_predictor.hidden_size, 143 | ] 144 | else: 145 | feature_lens = [self.ray_predictor.hidden_size] * 4 146 | 147 | self.scratch = _make_scratch(feature_lens, 256, groups=1, expand=False) 148 | self.scratch.refinenet1 = _make_fusion_block( 149 | self.f, use_bn=False, use_ln=False, dpt_time=True, resolution=128 150 | ) 151 | self.scratch.refinenet2 = _make_fusion_block( 152 | self.f, use_bn=False, use_ln=False, dpt_time=True, resolution=64 153 | ) 154 | self.scratch.refinenet3 = _make_fusion_block( 155 | self.f, use_bn=False, use_ln=False, dpt_time=True, resolution=32 156 | ) 157 | self.scratch.refinenet4 = _make_fusion_block( 158 | self.f, use_bn=False, use_ln=False, dpt_time=True, resolution=16 159 | ) 160 | 161 | self.scratch.input_conv = nn.Conv2d( 162 | self.ray_dim + int(self.cond_depth_mask), 163 | self.feature_dim, 164 | kernel_size=16, 165 | stride=16, 166 | padding=0 167 | ) 168 | 169 | self.scratch.output_conv = nn.Sequential( 170 | nn.Conv2d(self.f, self.f // 2, kernel_size=3, stride=1, padding=1), 171 | nn.LeakyReLU(), 172 | nn.Conv2d(self.f // 2, 32, kernel_size=3, stride=1, padding=1), 173 | nn.LeakyReLU(), 174 | nn.Conv2d(32, self.ray_dim, kernel_size=1, stride=1, padding=0), 175 | nn.Identity(), 176 | ) 177 | 178 | if self.encoder_features: 179 | self.project_opers = nn.ModuleList([ 180 | ProjectReadout(in_features=self.feature_extractor.feature_dim), 181 | ProjectReadout(in_features=self.feature_extractor.feature_dim), 182 | ]) 183 | 184 | def forward_noise( 185 | self, x, t, epsilon=None, zero_out_mask=None 186 | ): 187 | """ 188 | Applies forward diffusion (adds noise) to the input. 189 | 190 | If a mask is provided, the noise is only applied to the masked inputs. 191 | """ 192 | t = t.reshape(-1, 1, 1, 1, 1) 193 | if epsilon is None: 194 | epsilon = torch.randn_like(x) 195 | else: 196 | epsilon = epsilon.reshape(x.shape) 197 | 198 | alpha_bar = self.noise_scheduler.alphas_cumprod[t] 199 | x_noise = torch.sqrt(alpha_bar) * x + torch.sqrt(1 - alpha_bar) * epsilon 200 | 201 | if zero_out_mask is not None and self.cond_depth_mask: 202 | x_noise = zero_out_mask * x_noise 203 | 204 | return x_noise, epsilon 205 | 206 | def forward( 207 | self, 208 | features=None, 209 | images=None, 210 | rays=None, 211 | rays_noisy=None, 212 | t=None, 213 | ndc_coordinates=None, 214 | unconditional_mask=None, 215 | encoder_patches=16, 216 | depth_mask=None, 217 | multiview_unconditional=False, 218 | indices=None, 219 | ): 220 | """ 221 | Args: 222 | images: (B, N, 3, H, W). 223 | t: (B,). 224 | rays: (B, N, 6, H, W). 225 | rays_noisy: (B, N, 6, H, W). 226 | ndc_coordinates: (B, N, 2, H, W). 227 | unconditional_mask: (B, N) or (B,). Should be 1 for unconditional samples 228 | and 0 else. 229 | """ 230 | 231 | if features is None: 232 | # VAE expects 256x256 images while DINO expects 224x224 images. 233 | # Both feature extractors support autoresize=True, but ideally we should 234 | # set this to be false and handle in the dataloader. 235 | features = self.feature_extractor(images, autoresize=True) 236 | 237 | B = features.shape[0] 238 | 239 | if unconditional_mask is not None and self.use_unconditional: 240 | null_token = self.null_token.reshape(1, 1, self.feature_dim, 1, 1) 241 | unconditional_mask = unconditional_mask.reshape(B, -1, 1, 1, 1) 242 | features = ( 243 | features * (1 - unconditional_mask) + null_token * unconditional_mask 244 | ) 245 | 246 | if isinstance(t, int) or isinstance(t, np.int64): 247 | t = torch.ones(1, dtype=int).to(features.device) * t 248 | else: 249 | t = t.reshape(B) 250 | 251 | if rays_noisy is None: 252 | if self.cond_depth_mask: 253 | rays_noisy, epsilon = self.forward_noise( 254 | rays, t, zero_out_mask=depth_mask.unsqueeze(2) 255 | ) 256 | else: 257 | rays_noisy, epsilon = self.forward_noise( 258 | rays, t 259 | ) 260 | else: 261 | epsilon = None 262 | 263 | # DOWNSAMPLE RAYS 264 | B, N, C, H, W = rays_noisy.shape 265 | 266 | if self.cond_depth_mask: 267 | if depth_mask is None: 268 | depth_mask = torch.ones_like(rays_noisy[:, :, 0]) 269 | ray_repr = torch.cat([rays_noisy, depth_mask.unsqueeze(2)], dim=2) 270 | else: 271 | ray_repr = rays_noisy 272 | 273 | ray_repr = self.scratch.input_conv(ray_repr.reshape(B * N, -1, H, W)) 274 | _, CP, HP, WP = ray_repr.shape 275 | ray_repr = ray_repr.reshape(B, N, CP, HP, WP) 276 | scene_features = torch.cat([features, ray_repr], dim=2) 277 | 278 | if self.append_ndc: 279 | scene_features = torch.cat([scene_features, ndc_coordinates], dim=2) 280 | 281 | # DIT FORWARD PASS 282 | activations = self.ray_predictor( 283 | scene_features, 284 | t, 285 | return_dpt_activations=True, 286 | multiview_unconditional=multiview_unconditional, 287 | ) 288 | 289 | # PROJECT ENCODER ACTIVATIONS & RESHAPE 290 | if self.encoder_features: 291 | for i in range(2): 292 | name = f"encoder{i+1}" 293 | 294 | if indices is not None: 295 | act = self.feature_extractor.activations[name][indices] 296 | else: 297 | act = self.feature_extractor.activations[name] 298 | 299 | act = self.project_opers[i](act).permute(0, 2, 1) 300 | act = act.reshape( 301 | ( 302 | B * N, 303 | self.feature_extractor.feature_dim, 304 | encoder_patches, 305 | encoder_patches, 306 | ) 307 | ) 308 | activations[i] = act 309 | 310 | # UPSAMPLE ACTIVATIONS 311 | for i, act in enumerate(activations): 312 | k = 3 - i 313 | activations[i] = nearest_neighbor_upsample(act, 2**k) 314 | 315 | # FUSION BLOCKS 316 | layer_1_rn = self.scratch.layer1_rn(activations[0]) 317 | layer_2_rn = self.scratch.layer2_rn(activations[1]) 318 | layer_3_rn = self.scratch.layer3_rn(activations[2]) 319 | layer_4_rn = self.scratch.layer4_rn(activations[3]) 320 | 321 | # RESHAPE TIMESTEPS 322 | if t.shape[0] == B: 323 | t = t.unsqueeze(-1).repeat((1, N)).reshape(B * N) 324 | elif t.shape[0] == 1 and B > 1: 325 | t = t.repeat((B * N)) 326 | else: 327 | assert False 328 | 329 | path_4 = self.scratch.refinenet4(layer_4_rn, t=t) 330 | path_3 = self.scratch.refinenet3(path_4, activation=layer_3_rn, t=t) 331 | path_2 = self.scratch.refinenet2(path_3, activation=layer_2_rn, t=t) 332 | path_1 = self.scratch.refinenet1(path_2, activation=layer_1_rn, t=t) 333 | 334 | epsilon_pred = self.scratch.output_conv(path_1) 335 | epsilon_pred = epsilon_pred.reshape((B, N, C, H, W)) 336 | 337 | return epsilon_pred, epsilon 338 | -------------------------------------------------------------------------------- /diffusionsfm/model/dit.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # GLIDE: https://github.com/openai/glide-text2im 9 | # MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py 10 | # -------------------------------------------------------- 11 | 12 | import math 13 | 14 | import ipdb # noqa: F401 15 | import numpy as np 16 | import torch 17 | import torch.nn as nn 18 | from timm.models.vision_transformer import Attention, Mlp, PatchEmbed 19 | from diffusionsfm.model.memory_efficient_attention import MEAttention 20 | 21 | 22 | def modulate(x, shift, scale): 23 | return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) 24 | 25 | 26 | ################################################################################# 27 | # Embedding Layers for Timesteps and Class Labels # 28 | ################################################################################# 29 | 30 | 31 | class TimestepEmbedder(nn.Module): 32 | """ 33 | Embeds scalar timesteps into vector representations. 34 | """ 35 | 36 | def __init__(self, hidden_size, frequency_embedding_size=256): 37 | super().__init__() 38 | self.mlp = nn.Sequential( 39 | nn.Linear(frequency_embedding_size, hidden_size, bias=True), 40 | nn.SiLU(), 41 | nn.Linear(hidden_size, hidden_size, bias=True), 42 | ) 43 | self.frequency_embedding_size = frequency_embedding_size 44 | 45 | @staticmethod 46 | def timestep_embedding(t, dim, max_period=10000): 47 | """ 48 | Create sinusoidal timestep embeddings. 49 | :param t: a 1-D Tensor of N indices, one per batch element. 50 | These may be fractional. 51 | :param dim: the dimension of the output. 52 | :param max_period: controls the minimum frequency of the embeddings. 53 | :return: an (N, D) Tensor of positional embeddings. 54 | """ 55 | # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py 56 | half = dim // 2 57 | freqs = torch.exp( 58 | -math.log(max_period) 59 | * torch.arange(start=0, end=half, dtype=torch.float32) 60 | / half 61 | ).to(device=t.device) 62 | args = t[:, None].float() * freqs[None] 63 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 64 | if dim % 2: 65 | embedding = torch.cat( 66 | [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 67 | ) 68 | return embedding 69 | 70 | def forward(self, t): 71 | t_freq = self.timestep_embedding(t, self.frequency_embedding_size) 72 | t_emb = self.mlp(t_freq) 73 | return t_emb 74 | 75 | 76 | ################################################################################# 77 | # Core DiT Model # 78 | ################################################################################# 79 | 80 | 81 | class DiTBlock(nn.Module): 82 | """ 83 | A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning. 84 | """ 85 | 86 | def __init__( 87 | self, 88 | hidden_size, 89 | num_heads, 90 | mlp_ratio=4.0, 91 | use_xformers_attention=False, 92 | **block_kwargs 93 | ): 94 | super().__init__() 95 | self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 96 | attn = MEAttention if use_xformers_attention else Attention 97 | self.attn = attn( 98 | hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs 99 | ) 100 | self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 101 | mlp_hidden_dim = int(hidden_size * mlp_ratio) 102 | 103 | def approx_gelu(): 104 | return nn.GELU(approximate="tanh") 105 | 106 | self.mlp = Mlp( 107 | in_features=hidden_size, 108 | hidden_features=mlp_hidden_dim, 109 | act_layer=approx_gelu, 110 | drop=0, 111 | ) 112 | self.adaLN_modulation = nn.Sequential( 113 | nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True) 114 | ) 115 | 116 | def forward(self, x, c): 117 | ( 118 | shift_msa, 119 | scale_msa, 120 | gate_msa, 121 | shift_mlp, 122 | scale_mlp, 123 | gate_mlp, 124 | ) = self.adaLN_modulation(c).chunk(6, dim=1) 125 | x = x + gate_msa.unsqueeze(1) * self.attn( 126 | modulate(self.norm1(x), shift_msa, scale_msa) 127 | ) 128 | x = x + gate_mlp.unsqueeze(1) * self.mlp( 129 | modulate(self.norm2(x), shift_mlp, scale_mlp) 130 | ) 131 | return x 132 | 133 | 134 | class FinalLayer(nn.Module): 135 | """ 136 | The final layer of DiT. 137 | """ 138 | 139 | def __init__(self, hidden_size, patch_size, out_channels): 140 | super().__init__() 141 | self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 142 | self.linear = nn.Linear( 143 | hidden_size, patch_size * patch_size * out_channels, bias=True 144 | ) 145 | self.adaLN_modulation = nn.Sequential( 146 | nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True) 147 | ) 148 | 149 | def forward(self, x, c): 150 | shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) 151 | x = modulate(self.norm_final(x), shift, scale) 152 | x = self.linear(x) 153 | return x 154 | 155 | 156 | class DiT(nn.Module): 157 | """ 158 | Diffusion model with a Transformer backbone. 159 | """ 160 | 161 | def __init__( 162 | self, 163 | in_channels=442, 164 | out_channels=6, 165 | width=16, 166 | hidden_size=1152, 167 | depth=8, 168 | num_heads=16, 169 | mlp_ratio=4.0, 170 | max_num_images=8, 171 | P=1, 172 | within_image=False, 173 | ): 174 | super().__init__() 175 | self.num_heads = num_heads 176 | self.in_channels = in_channels 177 | self.out_channels = out_channels 178 | self.width = width 179 | self.hidden_size = hidden_size 180 | self.max_num_images = max_num_images 181 | self.P = P 182 | self.within_image = within_image 183 | 184 | # self.x_embedder = nn.Linear(in_channels, hidden_size) 185 | # self.x_embedder = PatchEmbed(in_channels, hidden_size, kernel_size=P, hidden_size=P) 186 | self.x_embedder = PatchEmbed( 187 | img_size=self.width, 188 | patch_size=self.P, 189 | in_chans=in_channels, 190 | embed_dim=hidden_size, 191 | bias=True, 192 | flatten=False, 193 | ) 194 | self.x_pos_enc = FeaturePositionalEncoding( 195 | max_num_images, hidden_size, width**2, P=self.P 196 | ) 197 | self.t_embedder = TimestepEmbedder(hidden_size) 198 | 199 | try: 200 | import xformers 201 | 202 | use_xformers_attention = True 203 | except ImportError: 204 | # xformers not available 205 | use_xformers_attention = False 206 | 207 | self.blocks = nn.ModuleList( 208 | [ 209 | DiTBlock( 210 | hidden_size, 211 | num_heads, 212 | mlp_ratio=mlp_ratio, 213 | use_xformers_attention=use_xformers_attention, 214 | ) 215 | for _ in range(depth) 216 | ] 217 | ) 218 | self.final_layer = FinalLayer(hidden_size, P, out_channels) 219 | self.initialize_weights() 220 | 221 | def initialize_weights(self): 222 | # Initialize transformer layers: 223 | def _basic_init(module): 224 | if isinstance(module, nn.Linear): 225 | torch.nn.init.xavier_uniform_(module.weight) 226 | if module.bias is not None: 227 | nn.init.constant_(module.bias, 0) 228 | 229 | self.apply(_basic_init) 230 | 231 | # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): 232 | w = self.x_embedder.proj.weight.data 233 | nn.init.xavier_uniform_(w.view([w.shape[0], -1])) 234 | nn.init.constant_(self.x_embedder.proj.bias, 0) 235 | 236 | # Initialize timestep embedding MLP: 237 | nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) 238 | nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) 239 | 240 | # Zero-out adaLN modulation layers in DiT blocks: 241 | for block in self.blocks: 242 | nn.init.constant_(block.adaLN_modulation[-1].weight, 0) 243 | nn.init.constant_(block.adaLN_modulation[-1].bias, 0) 244 | 245 | # Zero-out output layers: 246 | # nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) 247 | # nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) 248 | # nn.init.constant_(self.final_layer.linear.weight, 0) 249 | # nn.init.constant_(self.final_layer.linear.bias, 0) 250 | 251 | def unpatchify(self, x): 252 | """ 253 | x: (N, T, patch_size**2 * C) 254 | imgs: (N, H, W, C) 255 | """ 256 | c = self.out_channels 257 | p = self.x_embedder.patch_size[0] 258 | h = w = int(x.shape[1] ** 0.5) 259 | 260 | # print("unpatchify", c, p, h, w, x.shape) 261 | # assert h * w == x.shape[2] 262 | 263 | x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) 264 | x = torch.einsum("nhwpqc->nhpwqc", x) 265 | imgs = x.reshape(shape=(x.shape[0], h * p, h * p, c)) 266 | return imgs 267 | 268 | def forward( 269 | self, 270 | x, 271 | t, 272 | return_dpt_activations=False, 273 | multiview_unconditional=False, 274 | ): 275 | """ 276 | 277 | Args: 278 | x: Image/Ray features (B, N, C, H, W). 279 | t: Timesteps (N,). 280 | 281 | Returns: 282 | (B, N, D, H, W) 283 | """ 284 | B, N, c, h, w = x.shape 285 | P = self.P 286 | 287 | x = x.reshape((B * N, c, h, w)) # (B * N, C, H, W) 288 | x = self.x_embedder(x) # (B * N, C, H / P, W / P) 289 | 290 | x = x.permute(0, 2, 3, 1) # (B * N, H / P, W / P, C) 291 | # (B, N, H / P, W / P, C) 292 | x = x.reshape((B, N, h // P, w // P, self.hidden_size)) 293 | x = self.x_pos_enc(x) # (B, N, H * W / P ** 2, C) 294 | # TODO: fix positional encoding to work with (N, C, H, W) format. 295 | 296 | # Eval time, we get a scalar t 297 | if x.shape[0] != t.shape[0] and t.shape[0] == 1: 298 | t = t.repeat_interleave(B) 299 | 300 | if self.within_image or multiview_unconditional: 301 | t_within = t.repeat_interleave(N) 302 | t_within = self.t_embedder(t_within) 303 | 304 | t = self.t_embedder(t) 305 | 306 | dpt_activations = [] 307 | for i, block in enumerate(self.blocks): 308 | # Within image block 309 | if (self.within_image and i % 2 == 0) or multiview_unconditional: 310 | x = x.reshape((B * N, h * w // P**2, self.hidden_size)) 311 | x = block(x, t_within) 312 | 313 | # All patches block 314 | # Final layer is an all patches layer 315 | else: 316 | x = x.reshape((B, N * h * w // P**2, self.hidden_size)) 317 | x = block(x, t) # (N, T, D) 318 | 319 | if return_dpt_activations and i % 4 == 3: 320 | x_prime = x.reshape(B, N, h, w, self.hidden_size) 321 | x_prime = x.reshape(B * N, h, w, self.hidden_size) 322 | x_prime = x_prime.permute((0, 3, 1, 2)) 323 | dpt_activations.append(x_prime) 324 | 325 | # Reshape the output back to original shape 326 | if multiview_unconditional: 327 | x = x.reshape((B, N * h * w // P**2, self.hidden_size)) 328 | 329 | # (B, N * H * W / P ** 2, D) 330 | x = self.final_layer( 331 | x, t 332 | ) # (B, N * H * W / P ** 2, 6 * P ** 2) or (N, T, patch_size ** 2 * out_channels) 333 | 334 | x = x.reshape((B * N, w * w // P**2, self.out_channels * P**2)) 335 | x = self.unpatchify(x) # (B * N, H, W, C) 336 | x = x.reshape((B, N) + x.shape[1:]) 337 | x = x.permute(0, 1, 4, 2, 3) # (B, N, C, H, W) 338 | 339 | if return_dpt_activations: 340 | return dpt_activations[:4] 341 | 342 | return x 343 | 344 | 345 | class FeaturePositionalEncoding(nn.Module): 346 | def _get_sinusoid_encoding_table(self, n_position, d_hid, base): 347 | """Sinusoid position encoding table""" 348 | 349 | def get_position_angle_vec(position): 350 | return [ 351 | position / np.power(base, 2 * (hid_j // 2) / d_hid) 352 | for hid_j in range(d_hid) 353 | ] 354 | 355 | sinusoid_table = np.array( 356 | [get_position_angle_vec(pos_i) for pos_i in range(n_position)] 357 | ) 358 | sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i 359 | sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 360 | 361 | return torch.FloatTensor(sinusoid_table).unsqueeze(0) 362 | 363 | def __init__(self, max_num_images=8, feature_dim=1152, num_patches=256, P=1): 364 | super().__init__() 365 | self.max_num_images = max_num_images 366 | self.feature_dim = feature_dim 367 | self.P = P 368 | self.num_patches = num_patches // self.P**2 369 | 370 | self.register_buffer( 371 | "image_pos_table", 372 | self._get_sinusoid_encoding_table( 373 | self.max_num_images, self.feature_dim, 10000 374 | ), 375 | ) 376 | 377 | self.register_buffer( 378 | "token_pos_table", 379 | self._get_sinusoid_encoding_table( 380 | self.num_patches, self.feature_dim, 70007 381 | ), 382 | ) 383 | 384 | def forward(self, x): 385 | batch_size = x.shape[0] 386 | num_images = x.shape[1] 387 | 388 | x = x.reshape(batch_size, num_images, self.num_patches, self.feature_dim) 389 | 390 | # To encode image index 391 | pe1 = self.image_pos_table[:, :num_images].clone().detach() 392 | pe1 = pe1.reshape((1, num_images, 1, self.feature_dim)) 393 | pe1 = pe1.repeat((batch_size, 1, self.num_patches, 1)) 394 | 395 | # To encode patch index 396 | pe2 = self.token_pos_table.clone().detach() 397 | pe2 = pe2.reshape((1, 1, self.num_patches, self.feature_dim)) 398 | pe2 = pe2.repeat((batch_size, num_images, 1, 1)) 399 | 400 | x_pe = x + pe1 + pe2 401 | x_pe = x_pe.reshape( 402 | (batch_size, num_images * self.num_patches, self.feature_dim) 403 | ) 404 | 405 | return x_pe 406 | 407 | def forward_unet(self, x, B, N): 408 | D = int(self.num_patches**0.5) 409 | 410 | # x should be (B, N, T, D, D) 411 | x = x.permute((0, 2, 3, 1)) 412 | x = x.reshape(B, N, self.num_patches, self.feature_dim) 413 | 414 | # To encode image index 415 | pe1 = self.image_pos_table[:, :N].clone().detach() 416 | pe1 = pe1.reshape((1, N, 1, self.feature_dim)) 417 | pe1 = pe1.repeat((B, 1, self.num_patches, 1)) 418 | 419 | # To encode patch index 420 | pe2 = self.token_pos_table.clone().detach() 421 | pe2 = pe2.reshape((1, 1, self.num_patches, self.feature_dim)) 422 | pe2 = pe2.repeat((B, N, 1, 1)) 423 | 424 | x_pe = x + pe1 + pe2 425 | x_pe = x_pe.reshape((B * N, D, D, self.feature_dim)) 426 | x_pe = x_pe.permute((0, 3, 1, 2)) 427 | 428 | return x_pe 429 | -------------------------------------------------------------------------------- /diffusionsfm/model/feature_extractors.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import os 3 | import socket 4 | import sys 5 | 6 | import ipdb # noqa: F401 7 | import torch 8 | import torch.nn as nn 9 | from omegaconf import OmegaConf 10 | 11 | HOSTNAME = socket.gethostname() 12 | 13 | if "trinity" in HOSTNAME: 14 | # Might be outdated 15 | config_path = "/home/amylin2/latent-diffusion/configs/autoencoder/autoencoder_kl_16x16x16.yaml" 16 | weights_path = "/home/amylin2/latent-diffusion/model.ckpt" 17 | elif "grogu" in HOSTNAME: 18 | # Might be outdated 19 | config_path = "/home/jasonzh2/code/latent-diffusion/configs/autoencoder/autoencoder_kl_16x16x16.yaml" 20 | weights_path = "/home/jasonzh2/code/latent-diffusion/model.ckpt" 21 | elif "ender" in HOSTNAME: 22 | config_path = "/home/jason/ray_diffusion/external/latent-diffusion/configs/autoencoder/autoencoder_kl_16x16x16.yaml" 23 | weights_path = "/home/jason/ray_diffusion/external/latent-diffusion/model.ckpt" 24 | else: 25 | config_path = None 26 | weights_path = None 27 | 28 | 29 | if weights_path is not None: 30 | LDM_PATH = os.path.dirname(weights_path) 31 | if LDM_PATH not in sys.path: 32 | sys.path.append(LDM_PATH) 33 | 34 | 35 | def resize(image, size=None, scale_factor=None): 36 | return nn.functional.interpolate( 37 | image, 38 | size=size, 39 | scale_factor=scale_factor, 40 | mode="bilinear", 41 | align_corners=False, 42 | ) 43 | 44 | 45 | def instantiate_from_config(config): 46 | if "target" not in config: 47 | if config == "__is_first_stage__": 48 | return None 49 | elif config == "__is_unconditional__": 50 | return None 51 | raise KeyError("Expected key `target` to instantiate.") 52 | return get_obj_from_str(config["target"])(**config.get("params", dict())) 53 | 54 | 55 | def get_obj_from_str(string, reload=False): 56 | module, cls = string.rsplit(".", 1) 57 | if reload: 58 | module_imp = importlib.import_module(module) 59 | importlib.reload(module_imp) 60 | return getattr(importlib.import_module(module, package=None), cls) 61 | 62 | 63 | class PretrainedVAE(nn.Module): 64 | def __init__(self, freeze_weights=True, num_patches_x=16, num_patches_y=16): 65 | super().__init__() 66 | config = OmegaConf.load(config_path) 67 | self.model = instantiate_from_config(config.model) 68 | self.model.init_from_ckpt(weights_path) 69 | self.model.eval() 70 | self.feature_dim = 16 71 | self.num_patches_x = num_patches_x 72 | self.num_patches_y = num_patches_y 73 | 74 | if freeze_weights: 75 | for param in self.model.parameters(): 76 | param.requires_grad = False 77 | 78 | def forward(self, x, autoresize=False): 79 | """ 80 | Spatial dimensions of output will be H // 16, W // 16. If autoresize is True, 81 | then the input will be resized such that the output feature map is the correct 82 | dimensions. 83 | 84 | Args: 85 | x (torch.Tensor): Images (B, C, H, W). Should be normalized to be [-1, 1]. 86 | autoresize (bool): Whether to resize the input to match the num_patch 87 | dimensions. 88 | 89 | Returns: 90 | torch.Tensor: Latent sample (B, 16, h, w) 91 | """ 92 | 93 | *B, c, h, w = x.shape 94 | x = x.reshape(-1, c, h, w) 95 | if autoresize: 96 | new_w = self.num_patches_x * 16 97 | new_h = self.num_patches_y * 16 98 | x = resize(x, size=(new_h, new_w)) 99 | 100 | decoded, latent = self.model(x) 101 | # A little ambiguous bc it's all 16, but it is (c, h, w) 102 | latent_sample = latent.sample().reshape( 103 | *B, self.feature_dim, self.num_patches_y, self.num_patches_x 104 | ) 105 | return latent_sample 106 | 107 | 108 | activations = {} 109 | 110 | 111 | def get_activation(name): 112 | def hook(model, input, output): 113 | activations[name] = output 114 | 115 | return hook 116 | 117 | 118 | class SpatialDino(nn.Module): 119 | def __init__( 120 | self, 121 | freeze_weights=True, 122 | model_type="dinov2_vits14", 123 | num_patches_x=16, 124 | num_patches_y=16, 125 | activation_hooks=False, 126 | ): 127 | super().__init__() 128 | self.model = torch.hub.load("facebookresearch/dinov2", model_type) 129 | self.feature_dim = self.model.embed_dim 130 | self.num_patches_x = num_patches_x 131 | self.num_patches_y = num_patches_y 132 | if freeze_weights: 133 | for param in self.model.parameters(): 134 | param.requires_grad = False 135 | 136 | self.activation_hooks = activation_hooks 137 | 138 | if self.activation_hooks: 139 | self.model.blocks[5].register_forward_hook(get_activation("encoder1")) 140 | self.model.blocks[11].register_forward_hook(get_activation("encoder2")) 141 | self.activations = activations 142 | 143 | def forward(self, x, autoresize=False): 144 | """ 145 | Spatial dimensions of output will be H // 14, W // 14. If autoresize is True, 146 | then the output will be resized to the correct dimensions. 147 | 148 | Args: 149 | x (torch.Tensor): Images (B, C, H, W). Should be ImageNet normalized. 150 | autoresize (bool): Whether to resize the input to match the num_patch 151 | dimensions. 152 | 153 | Returns: 154 | feature_map (torch.tensor): (B, C, h, w) 155 | """ 156 | *B, c, h, w = x.shape 157 | 158 | x = x.reshape(-1, c, h, w) 159 | # if autoresize: 160 | # new_w = self.num_patches_x * 14 161 | # new_h = self.num_patches_y * 14 162 | # x = resize(x, size=(new_h, new_w)) 163 | 164 | # Output will be (B, H * W, C) 165 | features = self.model.forward_features(x)["x_norm_patchtokens"] 166 | features = features.permute(0, 2, 1) 167 | features = features.reshape( # (B, C, H, W) 168 | -1, self.feature_dim, h // 14, w // 14 169 | ) 170 | if autoresize: 171 | features = resize(features, size=(self.num_patches_y, self.num_patches_x)) 172 | 173 | features = features.reshape( 174 | *B, self.feature_dim, self.num_patches_y, self.num_patches_x 175 | ) 176 | return features 177 | -------------------------------------------------------------------------------- /diffusionsfm/model/memory_efficient_attention.py: -------------------------------------------------------------------------------- 1 | import ipdb 2 | import torch.nn as nn 3 | from xformers.ops import memory_efficient_attention 4 | 5 | 6 | class MEAttention(nn.Module): 7 | def __init__( 8 | self, 9 | dim, 10 | num_heads=8, 11 | qkv_bias=False, 12 | qk_norm=False, 13 | attn_drop=0.0, 14 | proj_drop=0.0, 15 | norm_layer=nn.LayerNorm, 16 | ): 17 | super().__init__() 18 | assert dim % num_heads == 0, "dim should be divisible by num_heads" 19 | self.num_heads = num_heads 20 | self.head_dim = dim // num_heads 21 | self.scale = self.head_dim**-0.5 22 | 23 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 24 | self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() 25 | self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() 26 | self.attn_drop = nn.Dropout(attn_drop) 27 | self.proj = nn.Linear(dim, dim) 28 | self.proj_drop = nn.Dropout(proj_drop) 29 | 30 | def forward(self, x): 31 | B, N, C = x.shape 32 | qkv = ( 33 | self.qkv(x) 34 | .reshape(B, N, 3, self.num_heads, self.head_dim) 35 | .permute(2, 0, 3, 1, 4) 36 | ) 37 | q, k, v = qkv.unbind(0) 38 | q, k = self.q_norm(q), self.k_norm(k) 39 | 40 | # MEA expects [B, N, H, D], whereas timm uses [B, H, N, D] 41 | x = memory_efficient_attention( 42 | q.transpose(1, 2), 43 | k.transpose(1, 2), 44 | v.transpose(1, 2), 45 | scale=self.scale, 46 | ) 47 | x = x.reshape(B, N, C) 48 | 49 | x = self.proj(x) 50 | x = self.proj_drop(x) 51 | return x 52 | -------------------------------------------------------------------------------- /diffusionsfm/model/scheduler.py: -------------------------------------------------------------------------------- 1 | import ipdb # noqa: F401 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | 7 | from diffusionsfm.utils.visualization import plot_to_image 8 | 9 | 10 | class NoiseScheduler(nn.Module): 11 | def __init__( 12 | self, 13 | max_timesteps=1000, 14 | beta_start=0.0001, 15 | beta_end=0.02, 16 | cos_power=2, 17 | num_inference_steps=100, 18 | type="linear", 19 | ): 20 | super().__init__() 21 | self.max_timesteps = max_timesteps 22 | self.num_inference_steps = num_inference_steps 23 | self.beta_start = beta_start 24 | self.beta_end = beta_end 25 | self.cos_power = cos_power 26 | self.type = type 27 | 28 | if type == "linear": 29 | self.register_linear_schedule() 30 | elif type == "cosine": 31 | self.register_cosine_schedule(cos_power) 32 | elif type == "scaled_linear": 33 | self.register_scaled_linear_schedule() 34 | 35 | self.inference_timesteps = self.compute_inference_timesteps() 36 | 37 | def register_linear_schedule(self): 38 | # zero terminal SNR (https://arxiv.org/pdf/2305.08891) 39 | betas = torch.linspace( 40 | self.beta_start, 41 | self.beta_end, 42 | self.max_timesteps, 43 | dtype=torch.float32, 44 | ) 45 | alphas = 1.0 - betas 46 | alphas_cumprod = torch.cumprod(alphas, dim=0) 47 | alphas_bar_sqrt = alphas_cumprod.sqrt() 48 | 49 | alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() 50 | alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() 51 | 52 | alphas_bar_sqrt -= alphas_bar_sqrt_T 53 | alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) 54 | 55 | alphas_bar = alphas_bar_sqrt**2 56 | alphas = alphas_bar[1:] / alphas_bar[:-1] 57 | alphas = torch.cat([alphas_bar[0:1], alphas]) 58 | betas = 1 - alphas 59 | 60 | self.register_buffer( 61 | "betas", 62 | betas, 63 | ) 64 | self.register_buffer("alphas", 1.0 - self.betas) 65 | self.register_buffer("alphas_cumprod", torch.cumprod(self.alphas, dim=0)) 66 | 67 | def register_cosine_schedule(self, cos_power, s=0.008): 68 | timesteps = ( 69 | torch.arange(self.max_timesteps + 1, dtype=torch.float32) 70 | / self.max_timesteps 71 | ) 72 | alpha_bars = (timesteps + s) / (1 + s) * np.pi / 2 73 | alpha_bars = torch.cos(alpha_bars).pow(cos_power) 74 | alpha_bars = alpha_bars / alpha_bars[0] 75 | betas = 1 - alpha_bars[1:] / alpha_bars[:-1] 76 | betas = np.clip(betas, a_min=0, a_max=0.999) 77 | 78 | self.register_buffer( 79 | "betas", 80 | betas, 81 | ) 82 | self.register_buffer("alphas", 1.0 - betas) 83 | self.register_buffer("alphas_cumprod", torch.cumprod(self.alphas, dim=0)) 84 | 85 | def register_scaled_linear_schedule(self): 86 | self.register_buffer( 87 | "betas", 88 | torch.linspace( 89 | self.beta_start**0.5, 90 | self.beta_end**0.5, 91 | self.max_timesteps, 92 | dtype=torch.float32, 93 | ) 94 | ** 2, 95 | ) 96 | self.register_buffer("alphas", 1.0 - self.betas) 97 | self.register_buffer("alphas_cumprod", torch.cumprod(self.alphas, dim=0)) 98 | 99 | def compute_inference_timesteps( 100 | self, num_inference_steps=None, num_train_steps=None 101 | ): 102 | # based on diffusers's scheduling code 103 | if num_inference_steps is None: 104 | num_inference_steps = self.num_inference_steps 105 | if num_train_steps is None: 106 | num_train_steps = self.max_timesteps 107 | step_ratio = num_train_steps // num_inference_steps 108 | timesteps = ( 109 | (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].astype(int) 110 | ) 111 | return timesteps 112 | 113 | def plot_schedule(self, return_image=False): 114 | fig = plt.figure(figsize=(6, 4), dpi=100) 115 | alpha_bars = self.alphas_cumprod.cpu().numpy() 116 | plt.plot(np.sqrt(alpha_bars)) 117 | plt.grid() 118 | if self.type == "linear": 119 | plt.title( 120 | f"Linear (T={self.max_timesteps}, S={self.beta_start}, E={self.beta_end})" 121 | ) 122 | else: 123 | self.type == "cosine" 124 | plt.title(f"Cosine (T={self.max_timesteps}, P={self.cos_power})") 125 | if return_image: 126 | image = plot_to_image(fig) 127 | plt.close(fig) 128 | return image 129 | -------------------------------------------------------------------------------- /diffusionsfm/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QitaoZhao/DiffusionSfM/4bb08800721bdcf46b0c823586a2fab4702ff282/diffusionsfm/utils/__init__.py -------------------------------------------------------------------------------- /diffusionsfm/utils/configs.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | from omegaconf import OmegaConf 5 | 6 | 7 | def load_cfg(config_path): 8 | """ 9 | Loads a yaml configuration file. 10 | 11 | Follows the chain of yaml configuration files that have a `_BASE` key, and updates 12 | the new keys accordingly. _BASE configurations can be specified using relative 13 | paths. 14 | """ 15 | config_dir = os.path.dirname(config_path) 16 | config_path = os.path.basename(config_path) 17 | return load_cfg_recursive(config_dir, config_path) 18 | 19 | 20 | def load_cfg_recursive(config_dir, config_path): 21 | """ 22 | Recursively loads config files. 23 | 24 | Follows the chain of yaml configuration files that have a `_BASE` key, and updates 25 | the new keys accordingly. _BASE configurations can be specified using relative 26 | paths. 27 | """ 28 | cfg = OmegaConf.load(os.path.join(config_dir, config_path)) 29 | base_path = OmegaConf.select(cfg, "_BASE", default=None) 30 | if base_path is not None: 31 | base_cfg = load_cfg_recursive(config_dir, base_path) 32 | cfg = OmegaConf.merge(base_cfg, cfg) 33 | return cfg 34 | 35 | 36 | def get_cfg(): 37 | parser = argparse.ArgumentParser() 38 | parser.add_argument("--config-path", type=str, required=True) 39 | args = parser.parse_args() 40 | cfg = load_cfg(args.config_path) 41 | print(OmegaConf.to_yaml(cfg)) 42 | 43 | exp_dir = os.path.join(cfg.training.runs_dir, cfg.training.exp_tag) 44 | os.makedirs(exp_dir, exist_ok=True) 45 | to_path = os.path.join(exp_dir, os.path.basename(args.config_path)) 46 | if not os.path.exists(to_path): 47 | OmegaConf.save(config=cfg, f=to_path) 48 | return cfg 49 | 50 | 51 | def get_cfg_from_path(config_path): 52 | """ 53 | args: 54 | config_path - get config from path 55 | """ 56 | print("getting config from path") 57 | 58 | cfg = load_cfg(config_path) 59 | print(OmegaConf.to_yaml(cfg)) 60 | 61 | exp_dir = os.path.join(cfg.training.runs_dir, cfg.training.exp_tag) 62 | os.makedirs(exp_dir, exist_ok=True) 63 | to_path = os.path.join(exp_dir, os.path.basename(config_path)) 64 | if not os.path.exists(to_path): 65 | OmegaConf.save(config=cfg, f=to_path) 66 | return cfg 67 | -------------------------------------------------------------------------------- /diffusionsfm/utils/distortion.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import ipdb 3 | import numpy as np 4 | from PIL import Image 5 | import torch 6 | 7 | 8 | # https://gist.github.com/davegreenwood/820d51ac5ec88a2aeda28d3079e7d9eb 9 | def apply_distortion(pts, k1, k2): 10 | """ 11 | Arguments: 12 | pts (N x 2): numpy array in NDC coordinates 13 | k1, k2 distortion coefficients 14 | Return: 15 | pts (N x 2): distorted points in NDC coordinates 16 | """ 17 | r2 = np.square(pts).sum(-1) 18 | f = 1 + k1 * r2 + k2 * r2**2 19 | return f[..., None] * pts 20 | 21 | 22 | # https://gist.github.com/davegreenwood/820d51ac5ec88a2aeda28d3079e7d9eb 23 | def apply_distortion_tensor(pts, k1, k2): 24 | """ 25 | Arguments: 26 | pts (N x 2): numpy array in NDC coordinates 27 | k1, k2 distortion coefficients 28 | Return: 29 | pts (N x 2): distorted points in NDC coordinates 30 | """ 31 | r2 = torch.square(pts).sum(-1) 32 | f = 1 + k1 * r2 + k2 * r2**2 33 | return f[..., None] * pts 34 | 35 | 36 | # https://gist.github.com/davegreenwood/820d51ac5ec88a2aeda28d3079e7d9eb 37 | def remove_distortion_iter(points, k1, k2): 38 | """ 39 | Arguments: 40 | pts (N x 2): numpy array in NDC coordinates 41 | k1, k2 distortion coefficients 42 | Return: 43 | pts (N x 2): distorted points in NDC coordinates 44 | """ 45 | pts = ptsd = points 46 | for _ in range(5): 47 | r2 = np.square(pts).sum(-1) 48 | f = 1 + k1 * r2 + k2 * r2**2 49 | pts = ptsd / f[..., None] 50 | 51 | return pts 52 | 53 | 54 | def make_square(im, fill_color=(0, 0, 0)): 55 | x, y = im.size 56 | size = max(x, y) 57 | new_im = Image.new("RGB", (size, size), fill_color) 58 | corner = (int((size - x) / 2), int((size - y) / 2)) 59 | new_im.paste(im, corner) 60 | return new_im, corner 61 | 62 | 63 | def pixel_to_ndc(coords, image_size): 64 | """ 65 | Converts pixel coordinates to normalized device coordinates (Pytorch3D convention 66 | with upper left = (1, 1)) for a square image. 67 | 68 | Args: 69 | coords: Pixel coordinates UL=(0, 0), LR=(image_size, image_size). 70 | image_size (int): Image size. 71 | 72 | Returns: 73 | NDC coordinates UL=(1, 1) LR=(-1, -1). 74 | """ 75 | coords = np.array(coords) 76 | return 1 - coords / image_size * 2 77 | 78 | 79 | def ndc_to_pixel(coords, image_size): 80 | """ 81 | Converts normalized device coordinates to pixel coordinates for a square image. 82 | """ 83 | num_points = coords.shape[0] 84 | sizes = np.tile(np.array(image_size, dtype=np.float32)[None, ...], (num_points, 1)) 85 | 86 | coords = np.array(coords, dtype=np.float32) 87 | return (1 - coords) * sizes / 2 88 | 89 | 90 | def distort_image(image, bbox, k1, k2, modify_bbox=False): 91 | # We want to operate in -1 to 1 space using the padded square of the original image 92 | image, corner = make_square(image) 93 | bbox[:2] += np.array(corner) 94 | bbox[2:] += np.array(corner) 95 | 96 | # Construct grid points 97 | x = np.linspace(1, -1, image.width, dtype=np.float32) 98 | y = np.linspace(1, -1, image.height, dtype=np.float32) 99 | x, y = np.meshgrid(x, y, indexing="xy") 100 | xy_grid = np.stack((x, y), axis=-1) 101 | points = xy_grid.reshape((image.height * image.width, 2)) 102 | new_points = ndc_to_pixel(apply_distortion(points, k1, k2), image.size) 103 | 104 | # Distort image by remapping 105 | map_x = new_points[:, 0].reshape((image.height, image.width)) 106 | map_y = new_points[:, 1].reshape((image.height, image.width)) 107 | distorted = cv2.remap( 108 | np.asarray(image), 109 | map_x, 110 | map_y, 111 | cv2.INTER_LINEAR, 112 | ) 113 | distorted = Image.fromarray(distorted) 114 | 115 | # Find distorted crop bounds - inverse process of above 116 | if modify_bbox: 117 | center = (bbox[:2] + bbox[2:]) / 2 118 | top, bottom = (bbox[0], center[1]), (bbox[2], center[1]) 119 | left, right = (center[0], bbox[1]), (center[0], bbox[3]) 120 | bbox_points = np.array( 121 | [ 122 | pixel_to_ndc(top, image.size), 123 | pixel_to_ndc(left, image.size), 124 | pixel_to_ndc(bottom, image.size), 125 | pixel_to_ndc(right, image.size), 126 | ], 127 | dtype=np.float32, 128 | ) 129 | else: 130 | bbox_points = np.array( 131 | [pixel_to_ndc(bbox[:2], image.size), pixel_to_ndc(bbox[2:], image.size)], 132 | dtype=np.float32, 133 | ) 134 | 135 | # Inverse mapping 136 | distorted_bbox = remove_distortion_iter(bbox_points, k1, k2) 137 | 138 | if modify_bbox: 139 | p = ndc_to_pixel(distorted_bbox, image.size) 140 | distorted_bbox = np.array([p[0][0], p[1][1], p[2][0], p[3][1]]) 141 | else: 142 | distorted_bbox = ndc_to_pixel(distorted_bbox, image.size).reshape(4) 143 | 144 | return distorted, distorted_bbox 145 | -------------------------------------------------------------------------------- /diffusionsfm/utils/distributed.py: -------------------------------------------------------------------------------- 1 | import os 2 | import socket 3 | from contextlib import closing 4 | 5 | import torch.distributed as dist 6 | 7 | 8 | def get_open_port(): 9 | with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: 10 | s.bind(("", 0)) 11 | s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 12 | return s.getsockname()[1] 13 | 14 | 15 | # Distributed process group 16 | def ddp_setup(rank, world_size, port="12345"): 17 | """ 18 | Args: 19 | rank: Unique Identifier 20 | world_size: number of processes 21 | """ 22 | os.environ["MASTER_ADDR"] = "localhost" 23 | print(f"MasterPort: {str(port)}") 24 | os.environ["MASTER_PORT"] = str(port) 25 | 26 | # initialize the process group 27 | dist.init_process_group(backend="nccl", rank=rank, world_size=world_size) 28 | 29 | 30 | def cleanup(): 31 | dist.destroy_process_group() 32 | -------------------------------------------------------------------------------- /diffusionsfm/utils/geometry.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from pytorch3d.renderer import FoVPerspectiveCameras 4 | from pytorch3d.transforms import quaternion_to_matrix 5 | 6 | 7 | def generate_random_rotations(N=1, device="cpu"): 8 | q = torch.randn(N, 4, device=device) 9 | q = q / q.norm(dim=-1, keepdim=True) 10 | return quaternion_to_matrix(q) 11 | 12 | 13 | def symmetric_orthogonalization(x): 14 | """Maps 9D input vectors onto SO(3) via symmetric orthogonalization. 15 | 16 | x: should have size [batch_size, 9] 17 | 18 | Output has size [batch_size, 3, 3], where each inner 3x3 matrix is in SO(3). 19 | """ 20 | m = x.view(-1, 3, 3) 21 | u, s, v = torch.svd(m) 22 | vt = torch.transpose(v, 1, 2) 23 | det = torch.det(torch.matmul(u, vt)) 24 | det = det.view(-1, 1, 1) 25 | vt = torch.cat((vt[:, :2, :], vt[:, -1:, :] * det), 1) 26 | r = torch.matmul(u, vt) 27 | return r 28 | 29 | 30 | def get_permutations(num_images): 31 | permutations = [] 32 | for i in range(0, num_images): 33 | for j in range(0, num_images): 34 | if i != j: 35 | permutations.append((j, i)) 36 | 37 | return permutations 38 | 39 | 40 | def n_to_np_rotations(num_frames, n_rots): 41 | R_pred_rel = [] 42 | permutations = get_permutations(num_frames) 43 | for i, j in permutations: 44 | R_pred_rel.append(n_rots[i].T @ n_rots[j]) 45 | R_pred_rel = torch.stack(R_pred_rel) 46 | 47 | return R_pred_rel 48 | 49 | 50 | def compute_angular_error_batch(rotation1, rotation2): 51 | R_rel = np.einsum("Bij,Bjk ->Bik", rotation2, rotation1.transpose(0, 2, 1)) 52 | t = (np.trace(R_rel, axis1=1, axis2=2) - 1) / 2 53 | theta = np.arccos(np.clip(t, -1, 1)) 54 | return theta * 180 / np.pi 55 | 56 | 57 | # A should be GT, B should be predicted 58 | def compute_optimal_alignment(A, B): 59 | """ 60 | Compute the optimal scale s, rotation R, and translation t that minimizes: 61 | || A - (s * B @ R + T) || ^ 2 62 | 63 | Reference: Umeyama (TPAMI 91) 64 | 65 | Args: 66 | A (torch.Tensor): (N, 3). 67 | B (torch.Tensor): (N, 3). 68 | 69 | Returns: 70 | s (float): scale. 71 | R (torch.Tensor): rotation matrix (3, 3). 72 | t (torch.Tensor): translation (3,). 73 | """ 74 | A_bar = A.mean(0) 75 | B_bar = B.mean(0) 76 | # normally with R @ B, this would be A @ B.T 77 | H = (B - B_bar).T @ (A - A_bar) 78 | U, S, Vh = torch.linalg.svd(H, full_matrices=True) 79 | s = torch.linalg.det(U @ Vh) 80 | S_prime = torch.diag(torch.tensor([1, 1, torch.sign(s)], device=A.device)) 81 | variance = torch.sum((B - B_bar) ** 2) 82 | scale = 1 / variance * torch.trace(torch.diag(S) @ S_prime) 83 | R = U @ S_prime @ Vh 84 | t = A_bar - scale * B_bar @ R 85 | 86 | A_hat = scale * B @ R + t 87 | return A_hat, scale, R, t 88 | 89 | 90 | def compute_optimal_translation_alignment(T_A, T_B, R_B): 91 | """ 92 | Assuming right-multiplied rotation matrices. 93 | 94 | E.g., for world2cam R and T, a world coordinate is transformed to camera coordinate 95 | system using X_cam = X_world.T @ R + T = R.T @ X_world + T 96 | 97 | Finds s, t that minimizes || T_A - (s * T_B + R_B.T @ t) ||^2 98 | 99 | Args: 100 | T_A (torch.Tensor): Target translation (N, 3). 101 | T_B (torch.Tensor): Initial translation (N, 3). 102 | R_B (torch.Tensor): Initial rotation (N, 3, 3). 103 | 104 | Returns: 105 | T_A_hat (torch.Tensor): s * T_B + t @ R_B (N, 3). 106 | scale s (torch.Tensor): (1,). 107 | translation t (torch.Tensor): (1, 3). 108 | """ 109 | n = len(T_A) 110 | 111 | T_A = T_A.unsqueeze(2) 112 | T_B = T_B.unsqueeze(2) 113 | 114 | A = torch.sum(T_B * T_A) 115 | B = (T_B.transpose(1, 2) @ R_B.transpose(1, 2)).sum(0) @ (R_B @ T_A).sum(0) / n 116 | C = torch.sum(T_B * T_B) 117 | D = (T_B.transpose(1, 2) @ R_B.transpose(1, 2)).sum(0) 118 | E = (D * D).sum() / n 119 | 120 | s = (A - B.sum()) / (C - E.sum()) 121 | 122 | t = (R_B @ (T_A - s * T_B)).sum(0) / n 123 | 124 | T_A_hat = s * T_B + R_B.transpose(1, 2) @ t 125 | 126 | return T_A_hat.squeeze(2), s, t.transpose(1, 0) 127 | 128 | 129 | def get_error(predict_rotations, R_pred, T_pred, R_gt, T_gt, gt_scene_scale): 130 | if predict_rotations: 131 | cameras_gt = FoVPerspectiveCameras(R=R_gt, T=T_gt) 132 | cc_gt = cameras_gt.get_camera_center() 133 | cameras_pred = FoVPerspectiveCameras(R=R_pred, T=T_pred) 134 | cc_pred = cameras_pred.get_camera_center() 135 | 136 | A_hat, _, _, _ = compute_optimal_alignment(cc_gt, cc_pred) 137 | norm = torch.linalg.norm(cc_gt - A_hat, dim=1) / gt_scene_scale 138 | 139 | norms = np.ndarray.tolist(norm.detach().cpu().numpy()) 140 | return norms, A_hat 141 | else: 142 | T_A_hat, _, _ = compute_optimal_translation_alignment(T_gt, T_pred, R_pred) 143 | norm = torch.linalg.norm(T_gt - T_A_hat, dim=1) / gt_scene_scale 144 | norms = np.ndarray.tolist(norm.detach().cpu().numpy()) 145 | return norms, T_A_hat 146 | -------------------------------------------------------------------------------- /diffusionsfm/utils/normalize.py: -------------------------------------------------------------------------------- 1 | import ipdb # noqa: F401 2 | import torch 3 | from pytorch3d.transforms import Rotate, Translate 4 | 5 | 6 | def intersect_skew_line_groups(p, r, mask=None): 7 | # p, r both of shape (B, N, n_intersected_lines, 3) 8 | # mask of shape (B, N, n_intersected_lines) 9 | p_intersect, r = intersect_skew_lines_high_dim(p, r, mask=mask) 10 | if p_intersect is None: 11 | return None, None, None, None 12 | _, p_line_intersect = point_line_distance( 13 | p, r, p_intersect[..., None, :].expand_as(p) 14 | ) 15 | intersect_dist_squared = ((p_line_intersect - p_intersect[..., None, :]) ** 2).sum( 16 | dim=-1 17 | ) 18 | return p_intersect, p_line_intersect, intersect_dist_squared, r 19 | 20 | 21 | def intersect_skew_lines_high_dim(p, r, mask=None): 22 | # Implements https://en.wikipedia.org/wiki/Skew_lines In more than two dimensions 23 | dim = p.shape[-1] 24 | # make sure the heading vectors are l2-normed 25 | if mask is None: 26 | mask = torch.ones_like(p[..., 0]) 27 | r = torch.nn.functional.normalize(r, dim=-1) 28 | 29 | eye = torch.eye(dim, device=p.device, dtype=p.dtype)[None, None] 30 | I_min_cov = (eye - (r[..., None] * r[..., None, :])) * mask[..., None, None] 31 | sum_proj = I_min_cov.matmul(p[..., None]).sum(dim=-3) 32 | 33 | p_intersect = torch.linalg.lstsq(I_min_cov.sum(dim=-3), sum_proj).solution[..., 0] 34 | 35 | if torch.any(torch.isnan(p_intersect)): 36 | print(p_intersect) 37 | return None, None 38 | ipdb.set_trace() 39 | assert False 40 | return p_intersect, r 41 | 42 | 43 | def point_line_distance(p1, r1, p2): 44 | df = p2 - p1 45 | proj_vector = df - ((df * r1).sum(dim=-1, keepdim=True) * r1) 46 | line_pt_nearest = p2 - proj_vector 47 | d = (proj_vector).norm(dim=-1) 48 | return d, line_pt_nearest 49 | 50 | 51 | def compute_optical_axis_intersection(cameras): 52 | centers = cameras.get_camera_center() 53 | principal_points = cameras.principal_point 54 | 55 | one_vec = torch.ones((len(cameras), 1), device=centers.device) 56 | optical_axis = torch.cat((principal_points, one_vec), -1) 57 | 58 | pp = cameras.unproject_points(optical_axis, from_ndc=True, world_coordinates=True) 59 | pp2 = torch.diagonal(pp, dim1=0, dim2=1).T 60 | 61 | directions = pp2 - centers 62 | centers = centers.unsqueeze(0).unsqueeze(0) 63 | directions = directions.unsqueeze(0).unsqueeze(0) 64 | 65 | p_intersect, p_line_intersect, _, r = intersect_skew_line_groups( 66 | p=centers, r=directions, mask=None 67 | ) 68 | 69 | if p_intersect is None: 70 | dist = None 71 | else: 72 | p_intersect = p_intersect.squeeze().unsqueeze(0) 73 | dist = (p_intersect - centers).norm(dim=-1) 74 | 75 | return p_intersect, dist, p_line_intersect, pp2, r 76 | 77 | 78 | def first_camera_transform(cameras, rotation_only=True): 79 | new_cameras = cameras.clone() 80 | new_transform = new_cameras.get_world_to_view_transform() 81 | tR = Rotate(new_cameras.R[0].unsqueeze(0)) 82 | if rotation_only: 83 | t = tR.inverse() 84 | else: 85 | tT = Translate(new_cameras.T[0].unsqueeze(0)) 86 | t = tR.compose(tT).inverse() 87 | 88 | new_transform = t.compose(new_transform) 89 | new_cameras.R = new_transform.get_matrix()[:, :3, :3] 90 | new_cameras.T = new_transform.get_matrix()[:, 3, :3] 91 | 92 | return new_cameras 93 | -------------------------------------------------------------------------------- /diffusionsfm/utils/slurm.py: -------------------------------------------------------------------------------- 1 | import os 2 | import signal 3 | import subprocess 4 | import sys 5 | import time 6 | 7 | 8 | def submitit_job_watcher(jobs, check_period: int = 15): 9 | job_out = {} 10 | 11 | try: 12 | while True: 13 | job_states = [job.state for job in jobs.values()] 14 | state_counts = { 15 | state: len([j for j in job_states if j == state]) 16 | for state in set(job_states) 17 | } 18 | 19 | n_done = sum(job.done() for job in jobs.values()) 20 | 21 | for job_name, job in jobs.items(): 22 | if job_name not in job_out and job.done(): 23 | job_out[job_name] = { 24 | "stderr": job.stderr(), 25 | "stdout": job.stdout(), 26 | } 27 | 28 | exc = job.exception() 29 | if exc is not None: 30 | print(f"{job_name} crashed!!!") 31 | if job_out[job_name]["stderr"] is not None: 32 | print("===== STDERR =====") 33 | print(job_out[job_name]["stderr"]) 34 | else: 35 | print(f"{job_name} done!") 36 | 37 | print("Job states:") 38 | for state, count in state_counts.items(): 39 | print(f" {state:15s} {count:6d} ({100.*count/len(jobs):.1f}%)") 40 | 41 | if n_done == len(jobs): 42 | print("All done!") 43 | return 44 | 45 | time.sleep(check_period) 46 | 47 | except KeyboardInterrupt: 48 | for job_name, job in jobs.items(): 49 | if not job.done(): 50 | print(f"Killing {job_name}") 51 | job.cancel(check=False) 52 | 53 | 54 | def get_jid(): 55 | if "SLURM_ARRAY_TASK_ID" in os.environ: 56 | return f"{os.environ['SLURM_ARRAY_JOB_ID']}_{os.environ['SLURM_ARRAY_TASK_ID']}" 57 | return os.environ["SLURM_JOB_ID"] 58 | 59 | 60 | def signal_helper(signum, frame): 61 | print(f"Caught signal {signal.Signals(signum).name} on for the this job") 62 | jid = get_jid() 63 | cmd = ["scontrol", "requeue", jid] 64 | try: 65 | print("calling", cmd) 66 | rtn = subprocess.check_call(cmd) 67 | print("subprocc", rtn) 68 | except: 69 | print("subproc call failed") 70 | return sys.exit(10) 71 | 72 | 73 | def bypass(signum, frame): 74 | print(f"Ignoring signal {signal.Signals(signum).name} on for the this job") 75 | 76 | 77 | def init_slurm_signals(): 78 | signal.signal(signal.SIGCONT, bypass) 79 | signal.signal(signal.SIGCHLD, bypass) 80 | signal.signal(signal.SIGTERM, bypass) 81 | signal.signal(signal.SIGUSR2, signal_helper) 82 | print("SLURM signal installed", flush=True) 83 | 84 | 85 | def init_slurm_signals_if_slurm(): 86 | if "SLURM_JOB_ID" in os.environ: 87 | init_slurm_signals() 88 | -------------------------------------------------------------------------------- /diffusionsfm/utils/visualization.py: -------------------------------------------------------------------------------- 1 | from http.client import MOVED_PERMANENTLY 2 | import io 3 | 4 | import ipdb # noqa: F401 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | import trimesh 8 | import torch 9 | import torchvision 10 | from pytorch3d.loss import chamfer_distance 11 | from scipy.spatial.transform import Rotation 12 | 13 | from diffusionsfm.inference.ddim import inference_ddim 14 | from diffusionsfm.utils.rays import ( 15 | Rays, 16 | cameras_to_rays, 17 | rays_to_cameras, 18 | rays_to_cameras_homography, 19 | ) 20 | from diffusionsfm.utils.geometry import ( 21 | compute_optimal_alignment, 22 | ) 23 | 24 | cmap = plt.get_cmap("hsv") 25 | 26 | 27 | def create_training_visualizations( 28 | model, 29 | images, 30 | device, 31 | cameras_gt, 32 | num_images, 33 | crop_parameters, 34 | pred_x0=False, 35 | no_crop_param_device="cpu", 36 | visualize_pred=False, 37 | return_first=False, 38 | calculate_intrinsics=False, 39 | mode=None, 40 | depths=None, 41 | scale_min=-1, 42 | scale_max=1, 43 | diffuse_depths=False, 44 | vis_mode=None, 45 | average_centers=True, 46 | full_num_patches_x=16, 47 | full_num_patches_y=16, 48 | use_homogeneous=False, 49 | distortion_coefficients=None, 50 | ): 51 | 52 | if model.depth_resolution == 1: 53 | W_in = W_out = full_num_patches_x 54 | H_in = H_out = full_num_patches_y 55 | else: 56 | W_in = H_in = model.width 57 | W_out = model.width * model.depth_resolution 58 | H_out = model.width * model.depth_resolution 59 | 60 | rays_final, rays_intermediate, pred_intermediate, _ = inference_ddim( 61 | model, 62 | images, 63 | device, 64 | crop_parameters=crop_parameters, 65 | eta=[1, 0], 66 | num_patches_x=W_in, 67 | num_patches_y=H_in, 68 | visualize=True, 69 | ) 70 | 71 | if vis_mode is None: 72 | vis_mode = mode 73 | 74 | T = model.noise_scheduler.max_timesteps 75 | if T == 1000: 76 | ts = [0, 100, 200, 300, 400, 500, 600, 700, 800, 900, 999] 77 | else: 78 | ts = [0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 99] 79 | 80 | # Get predicted cameras from rays 81 | pred_cameras_batched = [] 82 | vis_images = [] 83 | pred_rays = [] 84 | for index in range(len(images)): 85 | pred_cameras = [] 86 | per_sample_images = [] 87 | for ii in range(num_images): 88 | rays_gt = cameras_to_rays( 89 | cameras_gt[index], 90 | crop_parameters[index], 91 | no_crop_param_device=no_crop_param_device, 92 | num_patches_x=W_in, 93 | num_patches_y=H_in, 94 | depths=None if depths is None else depths[index], 95 | mode=mode, 96 | depth_resolution=model.depth_resolution, 97 | distortion_coefficients=( 98 | None 99 | if distortion_coefficients is None 100 | else distortion_coefficients[index] 101 | ), 102 | ) 103 | image_vis = (images[index, ii].cpu().permute(1, 2, 0).numpy() + 1) / 2 104 | 105 | if diffuse_depths: 106 | fig, axs = plt.subplots(3, 13, figsize=(15, 4.5), dpi=100) 107 | else: 108 | fig, axs = plt.subplots(3, 9, figsize=(12, 4.5), dpi=100) 109 | 110 | for i, t in enumerate(ts): 111 | r, c = i // 4, i % 4 112 | if visualize_pred: 113 | curr = pred_intermediate[t][index] 114 | else: 115 | curr = rays_intermediate[t][index] 116 | rays = Rays.from_spatial( 117 | curr, 118 | mode=mode, 119 | num_patches_x=H_in, 120 | num_patches_y=W_in, 121 | use_homogeneous=use_homogeneous, 122 | ) 123 | 124 | if vis_mode == "segment": 125 | vis = ( 126 | torch.clip( 127 | rays.get_segments()[ii], min=scale_min, max=scale_max 128 | ) 129 | - scale_min 130 | ) / (scale_max - scale_min) 131 | 132 | else: 133 | vis = ( 134 | torch.nn.functional.normalize(rays.get_moments()[ii], dim=-1) 135 | + 1 136 | ) / 2 137 | 138 | axs[r, c].imshow(vis.reshape(W_out, H_out, 3).cpu()) 139 | axs[r, c].set_title(f"T={T - t}") 140 | 141 | i += 1 142 | r, c = i // 4, i % 4 143 | 144 | if vis_mode == "segment": 145 | vis = ( 146 | torch.clip(rays_gt.get_segments()[ii], min=scale_min, max=scale_max) 147 | - scale_min 148 | ) / (scale_max - scale_min) 149 | else: 150 | vis = ( 151 | torch.nn.functional.normalize(rays_gt.get_moments()[ii], dim=-1) + 1 152 | ) / 2 153 | 154 | axs[r, c].imshow(vis.reshape(W_out, H_out, 3).cpu()) 155 | 156 | type_str = "Endpoints" if vis_mode == "segment" else "Moments" 157 | axs[r, c].set_title(f"GT {type_str}") 158 | 159 | for i, t in enumerate(ts): 160 | r, c = i // 4, i % 4 + 4 161 | if visualize_pred: 162 | curr = pred_intermediate[t][index] 163 | else: 164 | curr = rays_intermediate[t][index] 165 | rays = Rays.from_spatial( 166 | curr, 167 | mode, 168 | num_patches_x=H_in, 169 | num_patches_y=W_in, 170 | use_homogeneous=use_homogeneous, 171 | ) 172 | 173 | if vis_mode == "segment": 174 | vis = ( 175 | torch.clip( 176 | rays.get_origins(high_res=True)[ii], 177 | min=scale_min, 178 | max=scale_max, 179 | ) 180 | - scale_min 181 | ) / (scale_max - scale_min) 182 | else: 183 | vis = ( 184 | torch.nn.functional.normalize(rays.get_directions()[ii], dim=-1) 185 | + 1 186 | ) / 2 187 | 188 | axs[r, c].imshow(vis.reshape(W_out, H_out, 3).cpu()) 189 | axs[r, c].set_title(f"T={T - t}") 190 | 191 | i += 1 192 | r, c = i // 4, i % 4 + 4 193 | 194 | if vis_mode == "segment": 195 | vis = ( 196 | torch.clip( 197 | rays_gt.get_origins(high_res=True)[ii], 198 | min=scale_min, 199 | max=scale_max, 200 | ) 201 | - scale_min 202 | ) / (scale_max - scale_min) 203 | else: 204 | vis = ( 205 | torch.nn.functional.normalize(rays_gt.get_directions()[ii], dim=-1) 206 | + 1 207 | ) / 2 208 | axs[r, c].imshow(vis.reshape(W_out, H_out, 3).cpu()) 209 | type_str = "Origins" if vis_mode == "segment" else "Directions" 210 | axs[r, c].set_title(f"GT {type_str}") 211 | 212 | if diffuse_depths: 213 | for i, t in enumerate(ts): 214 | r, c = i // 4, i % 4 + 8 215 | if visualize_pred: 216 | curr = pred_intermediate[t][index] 217 | else: 218 | curr = rays_intermediate[t][index] 219 | rays = Rays.from_spatial( 220 | curr, 221 | mode, 222 | num_patches_x=H_in, 223 | num_patches_y=W_in, 224 | use_homogeneous=use_homogeneous, 225 | ) 226 | 227 | vis = rays.depths[ii] 228 | if len(rays.depths[ii].shape) < 2: 229 | vis = rays.depths[ii].reshape(H_out, W_out) 230 | 231 | axs[r, c].imshow(vis.cpu()) 232 | axs[r, c].set_title(f"T={T - t}") 233 | 234 | i += 1 235 | r, c = i // 4, i % 4 + 8 236 | 237 | vis = depths[index][ii] 238 | if len(rays.depths[ii].shape) < 2: 239 | vis = depths[index][ii].reshape(256, 256) 240 | 241 | axs[r, c].imshow(vis.cpu()) 242 | axs[r, c].set_title(f"GT Depths") 243 | 244 | axs[2, -1].imshow(image_vis) 245 | axs[2, -1].set_title("Input Image") 246 | for s in ["bottom", "top", "left", "right"]: 247 | axs[2, -1].spines[s].set_color(cmap(ii / (num_images))) 248 | axs[2, -1].spines[s].set_linewidth(5) 249 | 250 | for ax in axs.flatten(): 251 | ax.set_xticks([]) 252 | ax.set_yticks([]) 253 | plt.tight_layout() 254 | img = plot_to_image(fig) 255 | plt.close() 256 | per_sample_images.append(img) 257 | 258 | if return_first: 259 | rays_camera = pred_intermediate[0][index] 260 | elif pred_x0: 261 | rays_camera = pred_intermediate[-1][index] 262 | else: 263 | rays_camera = rays_final[index] 264 | rays = Rays.from_spatial( 265 | rays_camera, 266 | mode=mode, 267 | num_patches_x=H_in, 268 | num_patches_y=W_in, 269 | use_homogeneous=use_homogeneous, 270 | ) 271 | if calculate_intrinsics: 272 | pred_camera = rays_to_cameras_homography( 273 | rays=rays[ii, None], 274 | crop_parameters=crop_parameters[index], 275 | num_patches_x=W_in, 276 | num_patches_y=H_in, 277 | average_centers=average_centers, 278 | depth_resolution=model.depth_resolution, 279 | ) 280 | else: 281 | pred_camera = rays_to_cameras( 282 | rays=rays[ii, None], 283 | crop_parameters=crop_parameters[index], 284 | no_crop_param_device=no_crop_param_device, 285 | num_patches_x=W_in, 286 | num_patches_y=H_in, 287 | depth_resolution=model.depth_resolution, 288 | average_centers=average_centers, 289 | ) 290 | pred_cameras.append(pred_camera[0]) 291 | pred_rays.append(rays) 292 | 293 | pred_cameras_batched.append(pred_cameras) 294 | vis_images.append(np.vstack(per_sample_images)) 295 | 296 | return vis_images, pred_cameras_batched, pred_rays 297 | 298 | 299 | def plot_to_image(figure, dpi=100): 300 | """Converts matplotlib fig to a png for logging with tf.summary.image.""" 301 | buffer = io.BytesIO() 302 | figure.savefig(buffer, format="raw", dpi=dpi) 303 | plt.close(figure) 304 | buffer.seek(0) 305 | image = np.reshape( 306 | np.frombuffer(buffer.getvalue(), dtype=np.uint8), 307 | newshape=(int(figure.bbox.bounds[3]), int(figure.bbox.bounds[2]), -1), 308 | ) 309 | return image[..., :3] 310 | 311 | 312 | def view_color_coded_images_from_tensor(images, depth=False): 313 | num_frames = images.shape[0] 314 | cmap = plt.get_cmap("hsv") 315 | num_rows = 3 316 | num_cols = 3 317 | figsize = (num_cols * 2, num_rows * 2) 318 | fig, axs = plt.subplots(num_rows, num_cols, figsize=figsize) 319 | axs = axs.flatten() 320 | for i in range(num_rows * num_cols): 321 | if i < num_frames: 322 | if images[i].shape[0] == 3: 323 | image = images[i].permute(1, 2, 0) 324 | else: 325 | image = images[i].unsqueeze(-1) 326 | 327 | if not depth: 328 | image = image * 0.5 + 0.5 329 | else: 330 | image = image.repeat((1, 1, 3)) / torch.max(image) 331 | 332 | axs[i].imshow(image) 333 | for s in ["bottom", "top", "left", "right"]: 334 | axs[i].spines[s].set_color(cmap(i / (num_frames))) 335 | axs[i].spines[s].set_linewidth(5) 336 | axs[i].set_xticks([]) 337 | axs[i].set_yticks([]) 338 | else: 339 | axs[i].axis("off") 340 | plt.tight_layout() 341 | return fig 342 | 343 | 344 | def color_and_filter_points(points, images, mask, num_show, resolution): 345 | # Resize images 346 | resize = torchvision.transforms.Resize(resolution) 347 | images = resize(images) * 0.5 + 0.5 348 | 349 | # Reshape points and calculate mask 350 | points = points.reshape(num_show * resolution * resolution, 3) 351 | mask = mask.reshape(num_show * resolution * resolution) 352 | depth_mask = torch.argwhere(mask > 0.5)[:, 0] 353 | points = points[depth_mask] 354 | 355 | # Mask and reshape colors 356 | colors = images.permute(0, 2, 3, 1).reshape(num_show * resolution * resolution, 3) 357 | colors = colors[depth_mask] 358 | 359 | return points, colors 360 | 361 | 362 | def filter_and_align_point_clouds( 363 | num_frames, 364 | gt_points, 365 | pred_points, 366 | gt_masks, 367 | pred_masks, 368 | images, 369 | metrics=False, 370 | num_patches_x=16, 371 | ): 372 | 373 | # Filter and color points 374 | gt_points, gt_colors = color_and_filter_points( 375 | gt_points, images, gt_masks, num_show=num_frames, resolution=num_patches_x 376 | ) 377 | pred_points, pred_colors = color_and_filter_points( 378 | pred_points, images, pred_masks, num_show=num_frames, resolution=num_patches_x 379 | ) 380 | 381 | pred_points, _, _, _ = compute_optimal_alignment( 382 | gt_points.float(), pred_points.float() 383 | ) 384 | 385 | # Scale PCL so that furthest point from centroid is distance 1 386 | centroid = torch.mean(gt_points, dim=0) 387 | dists = torch.norm(gt_points - centroid.unsqueeze(0), dim=-1) 388 | scale = torch.mean(dists) 389 | gt_points_scaled = (gt_points - centroid) / scale 390 | pred_points_scaled = (pred_points - centroid) / scale 391 | 392 | if metrics: 393 | 394 | cd, _ = chamfer_distance( 395 | pred_points_scaled.unsqueeze(0), gt_points_scaled.unsqueeze(0) 396 | ) 397 | cd = cd.item() 398 | mse = torch.mean( 399 | torch.norm(pred_points_scaled - gt_points_scaled, dim=-1), dim=-1 400 | ).item() 401 | else: 402 | mse, cd = None, None 403 | 404 | return ( 405 | gt_points, 406 | pred_points, 407 | gt_colors, 408 | pred_colors, 409 | [mse, cd, None], 410 | ) 411 | 412 | 413 | def add_scene_cam(scene, c2w, edge_color, image=None, focal=None, imsize=None, screen_width=0.03): 414 | OPENGL = np.array([ 415 | [1, 0, 0, 0], 416 | [0, -1, 0, 0], 417 | [0, 0, -1, 0], 418 | [0, 0, 0, 1] 419 | ]) 420 | 421 | if image is not None: 422 | H, W, THREE = image.shape 423 | assert THREE == 3 424 | if image.dtype != np.uint8: 425 | image = np.uint8(255*image) 426 | elif imsize is not None: 427 | W, H = imsize 428 | elif focal is not None: 429 | H = W = focal / 1.1 430 | else: 431 | H = W = 1 432 | 433 | if focal is None: 434 | focal = min(H, W) * 1.1 # default value 435 | elif isinstance(focal, np.ndarray): 436 | focal = focal[0] 437 | 438 | # create fake camera 439 | height = focal * screen_width / H 440 | width = screen_width * 0.5**0.5 441 | rot45 = np.eye(4) 442 | rot45[:3, :3] = Rotation.from_euler('z', np.deg2rad(45)).as_matrix() 443 | rot45[2, 3] = -height # set the tip of the cone = optical center 444 | aspect_ratio = np.eye(4) 445 | aspect_ratio[0, 0] = W/H 446 | transform = c2w @ OPENGL @ aspect_ratio @ rot45 447 | cam = trimesh.creation.cone(width, height, sections=4) 448 | 449 | # this is the camera mesh 450 | rot2 = np.eye(4) 451 | rot2[:3, :3] = Rotation.from_euler('z', np.deg2rad(4)).as_matrix() 452 | vertices = cam.vertices 453 | vertices_offset = 0.9 * cam.vertices 454 | vertices = np.r_[vertices, vertices_offset, geotrf(rot2, cam.vertices)] 455 | vertices = geotrf(transform, vertices) 456 | faces = [] 457 | for face in cam.faces: 458 | if 0 in face: 459 | continue 460 | a, b, c = face 461 | a2, b2, c2 = face + len(cam.vertices) 462 | 463 | # add 3 pseudo-edges 464 | faces.append((a, b, b2)) 465 | faces.append((a, a2, c)) 466 | faces.append((c2, b, c)) 467 | 468 | faces.append((a, b2, a2)) 469 | faces.append((a2, c, c2)) 470 | faces.append((c2, b2, b)) 471 | 472 | # no culling 473 | faces += [(c, b, a) for a, b, c in faces] 474 | 475 | for i,face in enumerate(cam.faces): 476 | if 0 in face: 477 | continue 478 | 479 | if i == 1 or i == 5: 480 | a, b, c = face 481 | faces.append((a, b, c)) 482 | 483 | cam = trimesh.Trimesh(vertices=vertices, faces=faces) 484 | cam.visual.face_colors[:, :3] = edge_color 485 | 486 | scene.add_geometry(cam) 487 | 488 | 489 | def geotrf(Trf, pts, ncol=None, norm=False): 490 | """ Apply a geometric transformation to a list of 3-D points. 491 | 492 | H: 3x3 or 4x4 projection matrix (typically a Homography) 493 | p: numpy/torch/tuple of coordinates. Shape must be (...,2) or (...,3) 494 | 495 | ncol: int. number of columns of the result (2 or 3) 496 | norm: float. if != 0, the resut is projected on the z=norm plane. 497 | 498 | Returns an array of projected 2d points. 499 | """ 500 | assert Trf.ndim >= 2 501 | if isinstance(Trf, np.ndarray): 502 | pts = np.asarray(pts) 503 | elif isinstance(Trf, torch.Tensor): 504 | pts = torch.as_tensor(pts, dtype=Trf.dtype) 505 | 506 | # adapt shape if necessary 507 | output_reshape = pts.shape[:-1] 508 | ncol = ncol or pts.shape[-1] 509 | 510 | # optimized code 511 | if (isinstance(Trf, torch.Tensor) and isinstance(pts, torch.Tensor) and 512 | Trf.ndim == 3 and pts.ndim == 4): 513 | d = pts.shape[3] 514 | if Trf.shape[-1] == d: 515 | pts = torch.einsum("bij, bhwj -> bhwi", Trf, pts) 516 | elif Trf.shape[-1] == d+1: 517 | pts = torch.einsum("bij, bhwj -> bhwi", Trf[:, :d, :d], pts) + Trf[:, None, None, :d, d] 518 | else: 519 | raise ValueError(f'bad shape, not ending with 3 or 4, for {pts.shape=}') 520 | else: 521 | if Trf.ndim >= 3: 522 | n = Trf.ndim-2 523 | assert Trf.shape[:n] == pts.shape[:n], 'batch size does not match' 524 | Trf = Trf.reshape(-1, Trf.shape[-2], Trf.shape[-1]) 525 | 526 | if pts.ndim > Trf.ndim: 527 | # Trf == (B,d,d) & pts == (B,H,W,d) --> (B, H*W, d) 528 | pts = pts.reshape(Trf.shape[0], -1, pts.shape[-1]) 529 | elif pts.ndim == 2: 530 | # Trf == (B,d,d) & pts == (B,d) --> (B, 1, d) 531 | pts = pts[:, None, :] 532 | 533 | if pts.shape[-1]+1 == Trf.shape[-1]: 534 | Trf = Trf.swapaxes(-1, -2) # transpose Trf 535 | pts = pts @ Trf[..., :-1, :] + Trf[..., -1:, :] 536 | elif pts.shape[-1] == Trf.shape[-1]: 537 | Trf = Trf.swapaxes(-1, -2) # transpose Trf 538 | pts = pts @ Trf 539 | else: 540 | pts = Trf @ pts.T 541 | if pts.ndim >= 2: 542 | pts = pts.swapaxes(-1, -2) 543 | 544 | if norm: 545 | pts = pts / pts[..., -1:] # DONT DO /= BECAUSE OF WEIRD PYTORCH BUG 546 | if norm != 1: 547 | pts *= norm 548 | 549 | res = pts[..., :ncol].reshape(*output_reshape, ncol) 550 | return res -------------------------------------------------------------------------------- /docs/eval.md: -------------------------------------------------------------------------------- 1 | ## Evaluation Directions 2 | 3 | Use the scripts from `diffusionsfm/eval` to evaluate the performance of the dense model on the CO3D dataset: 4 | 5 | ``` 6 | python -m diffusionsfm.eval.eval_jobs --eval_path output/multi_diffusionsfm_dense --use_submitit 7 | ``` 8 | 9 | **Note:** The `use_submitit` flag is optional. If you have a SLURM system available, enabling it will dispatch jobs in parallel across available GPUs, significantly accelerating the evaluation process. 10 | 11 | The expected output at the end of evaluating the dense model is: 12 | 13 | ``` 14 | N= 2 3 4 5 6 7 8 15 | Seen R 0.926 0.941 0.946 0.950 0.953 0.955 0.955 16 | Seen CC 1.000 0.956 0.934 0.924 0.917 0.911 0.907 17 | Seen CD 0.023 0.023 0.026 0.026 0.028 0.031 0.030 18 | Seen CD_Obj 0.040 0.037 0.033 0.032 0.032 0.032 0.033 19 | Unseen R 0.913 0.928 0.938 0.945 0.950 0.951 0.953 20 | Unseen CC 1.000 0.926 0.884 0.870 0.864 0.851 0.847 21 | Unseen CD 0.024 0.024 0.025 0.024 0.025 0.026 0.027 22 | Unseen CD_Obj 0.028 0.023 0.022 0.022 0.023 0.021 0.020 23 | ``` 24 | 25 | This reports rotation and camera center accuracy, as well as Chamfer Distance on both all points (CD) and foreground points (CD_Obj), evaluated on held-out sequences from both seen and unseen object categories using varying numbers of input images. Performance is averaged over five runs to reduce variance. 26 | 27 | Note that minor variations in the reported numbers may occur due to randomness in the evaluation and inference processes. -------------------------------------------------------------------------------- /docs/train.md: -------------------------------------------------------------------------------- 1 | ## Training Directions 2 | 3 | ### Prepare CO3D Dataset 4 | 5 | Please refer to the instructions from [RayDiffusion](https://github.com/jasonyzhang/RayDiffusion/blob/main/docs/train.md#training-directions) to set up the CO3D dataset. 6 | 7 | ### Setting up `accelerate` 8 | 9 | Use `accelerate config` to set up `accelerate`. We recommend using multiple GPUs without any mixed precision (we handle AMP ourselves). 10 | 11 | ### Training models 12 | 13 | Our model is trained in two stages. In the first stage, we train a *sparse model* that predicts ray origins and endpoints at a low resolution (16×16). In the second stage, we initialize the dense model using the DiT weights from the sparse model and append a DPT decoder to produce high-resolution outputs (256×256 ray origins and endpoints). 14 | 15 | To train the sparse model, run: 16 | 17 | ``` 18 | accelerate launch --multi_gpu --gpu_ids 0,1,2,3,4,5,6,7 --num_processes 8 train.py \ 19 | training.batch_size=8 \ 20 | training.max_iterations=400000 \ 21 | model.num_images=8 \ 22 | dataset.name=co3d \ 23 | debug.project_name=diffusionsfm_co3d \ 24 | debug.run_name=co3d_diffusionsfm_sparse 25 | ``` 26 | 27 | To train the dense model (initialized from the sparse model weights), run: 28 | 29 | ``` 30 | accelerate launch --multi_gpu --gpu_ids 0,1,2,3,4,5,6,7 --num_processes 8 train.py \ 31 | training.batch_size=4 \ 32 | training.max_iterations=800000 \ 33 | model.num_images=8 \ 34 | dataset.name=co3d \ 35 | debug.project_name=diffusionsfm_co3d \ 36 | debug.run_name=co3d_diffusionsfm_dense \ 37 | training.dpt_head=True \ 38 | training.full_num_patches_x=256 \ 39 | training.full_num_patches_y=256 \ 40 | training.gradient_clipping=True \ 41 | training.reinit=True \ 42 | training.freeze_encoder=True \ 43 | model.freeze_transformer=True \ 44 | training.pretrain_path=.pth 45 | ``` 46 | 47 | Some notes: 48 | 49 | - `batch_size` refers to the batch size per GPU. The total batch size will be `batch_size * num_gpu`. 50 | - Depending on your setup, you can adjust the number of GPUs and batch size. You may also need to adjust the number of training iterations accordingly. 51 | - You can resume training from a checkpoint by specifying `train.resume=True hydra.run.dir=/path/to/your/output_dir` 52 | - If you are getting NaNs, try turning off mixed precision. This will increase the amount of memory used. 53 | 54 | For debugging, we recommend using a single-GPU job with a single category: 55 | 56 | ``` 57 | accelerate launch train.py training.batch_size=4 dataset.category=apple debug.wandb=False hydra.run.dir=output_debug 58 | ``` -------------------------------------------------------------------------------- /gradio_app.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import shutil 4 | import argparse 5 | import functools 6 | import torch 7 | import torchvision 8 | from PIL import Image 9 | import gradio as gr 10 | import numpy as np 11 | import matplotlib.pyplot as plt 12 | import trimesh 13 | 14 | from diffusionsfm.dataset.custom import CustomDataset 15 | from diffusionsfm.dataset.co3d_v2 import unnormalize_image 16 | from diffusionsfm.inference.load_model import load_model 17 | from diffusionsfm.inference.predict import predict_cameras 18 | from diffusionsfm.utils.visualization import add_scene_cam 19 | 20 | 21 | def info_fn(): 22 | gr.Info("Data preprocessing completed!") 23 | 24 | 25 | def get_select_index(evt: gr.SelectData): 26 | selected = evt.index 27 | return examples_full[selected][0], selected 28 | 29 | 30 | def check_img_input(control_image): 31 | if control_image is None: 32 | raise gr.Error("Please select or upload an input image.") 33 | 34 | 35 | def preprocess(args, image_block, selected): 36 | cate_name = time.strftime("%m%d_%H%M%S") if selected is None else examples_list[selected] 37 | 38 | demo_dir = os.path.join(args.output_dir, f'demo/{cate_name}') 39 | shutil.rmtree(demo_dir, ignore_errors=True) 40 | 41 | os.makedirs(os.path.join(demo_dir, 'source'), exist_ok=True) 42 | os.makedirs(os.path.join(demo_dir, 'processed'), exist_ok=True) 43 | 44 | dataset = CustomDataset(image_block) 45 | batch = dataset.get_data() 46 | batch['cate_name'] = cate_name 47 | 48 | processed_image_block = [] 49 | for i, file_path in enumerate(image_block): 50 | file_name = os.path.basename(file_path) 51 | raw_img = Image.open(file_path) 52 | try: 53 | raw_img.save(os.path.join(demo_dir, 'source', file_name)) 54 | except OSError: 55 | raw_img.convert('RGB').save(os.path.join(demo_dir, 'source', file_name)) 56 | 57 | batch['image_for_vis'][i].save(os.path.join(demo_dir, 'processed', file_name)) 58 | processed_image_block.append(os.path.join(demo_dir, 'processed', file_name)) 59 | 60 | return processed_image_block, batch 61 | 62 | 63 | def transform_cameras(pred_cameras): 64 | num_cameras = pred_cameras.R.shape[0] 65 | Rs = pred_cameras.R.transpose(1, 2).detach() 66 | ts = pred_cameras.T.unsqueeze(-1).detach() 67 | c2ws = torch.zeros(num_cameras, 4, 4) 68 | c2ws[:, :3, :3] = Rs 69 | c2ws[:, :3, -1:] = ts 70 | c2ws[:, 3, 3] = 1 71 | c2ws[:, :2] *= -1 # PyTorch3D to OpenCV 72 | c2ws = torch.linalg.inv(c2ws).numpy() 73 | 74 | return c2ws 75 | 76 | 77 | def run_inference(args, cfg, model, batch): 78 | device = args.device 79 | images = batch["image"].to(device) 80 | crop_parameters = batch["crop_parameters"].to(device) 81 | 82 | (pred_cameras, pred_rays), _ = predict_cameras( 83 | model=model, 84 | images=images, 85 | device=device, 86 | crop_parameters=crop_parameters, 87 | stop_iteration=90, 88 | num_patches_x=cfg.training.full_num_patches_x, 89 | num_patches_y=cfg.training.full_num_patches_y, 90 | calculate_intrinsics=True, 91 | max_num_images=8, 92 | mode="segment", 93 | return_rays=True, 94 | use_homogeneous=True, 95 | seed=0, 96 | ) 97 | 98 | # Unnormalize and resize input images 99 | images = unnormalize_image(images, return_numpy=False, return_int=False) 100 | images = torchvision.transforms.Resize(256)(images) 101 | rgbs = images.permute(0, 2, 3, 1).contiguous().view(-1, 3) 102 | xyzs = pred_rays.get_segments().view(-1, 3).cpu() 103 | 104 | # Create point cloud and scene 105 | scene = trimesh.Scene() 106 | point_cloud = trimesh.points.PointCloud(xyzs, colors=rgbs) 107 | scene.add_geometry(point_cloud) 108 | 109 | # Add predicted cameras to the scene 110 | num_images = images.shape[0] 111 | c2ws = transform_cameras(pred_cameras) 112 | cmap = plt.get_cmap("hsv") 113 | 114 | for i, c2w in enumerate(c2ws): 115 | color_rgb = (np.array(cmap(i / num_images))[:3] * 255).astype(int) 116 | add_scene_cam( 117 | scene=scene, 118 | c2w=c2w, 119 | edge_color=color_rgb, 120 | image=None, 121 | focal=None, 122 | imsize=(256, 256), 123 | screen_width=0.1 124 | ) 125 | 126 | # Export GLB 127 | cate_name = batch['cate_name'] 128 | output_path = os.path.join(args.output_dir, f'demo/{cate_name}/{cate_name}.glb') 129 | scene.export(output_path) 130 | 131 | return output_path 132 | 133 | 134 | if __name__ == "__main__": 135 | parser = argparse.ArgumentParser() 136 | parser.add_argument('--output_dir', default='output/multi_diffusionsfm_dense', type=str, help='Output directory') 137 | parser.add_argument('--device', default='cuda', type=str, help='Device to run inference on') 138 | args = parser.parse_args() 139 | 140 | _TITLE = "DiffusionSfM: Predicting Structure and Motion via Ray Origin and Endpoint Diffusion" 141 | _DESCRIPTION = """ 142 |
143 | 144 | 145 |
146 | DiffusionSfM learns to predict scene geometry and camera poses as pixel-wise ray origins and endpoints using a denoising diffusion model. 147 | """ 148 | 149 | # Load demo examples 150 | examples_list = ["kew_gardens_ruined_arch", "jellycat", "kotor_cathedral", "jordan"] 151 | examples_full = [] 152 | for example in examples_list: 153 | folder = os.path.join(os.path.dirname(__file__), "data/demo", example) 154 | examples = sorted(os.path.join(folder, x) for x in os.listdir(folder)) 155 | examples_full.append([examples]) 156 | 157 | model, cfg = load_model(args.output_dir, device=args.device) 158 | print("Loaded DiffusionSfM model!") 159 | 160 | preprocess = functools.partial(preprocess, args) 161 | run_inference = functools.partial(run_inference, args, cfg, model) 162 | 163 | with gr.Blocks(title=_TITLE, theme=gr.themes.Soft()) as demo: 164 | gr.Markdown(f"# {_TITLE}") 165 | gr.Markdown(_DESCRIPTION) 166 | 167 | with gr.Row(variant='panel'): 168 | with gr.Column(scale=2): 169 | image_block = gr.Files(file_count="multiple", label="Upload Images") 170 | 171 | gr.Markdown( 172 | "You can run our model by either: (1) **Uploading images** above " 173 | "or (2) selecting a **pre-collected example** below." 174 | ) 175 | 176 | gallery = gr.Gallery( 177 | value=[example[0][0] for example in examples_full], 178 | label="Examples", 179 | show_label=True, 180 | columns=[4], 181 | rows=[1], 182 | object_fit="contain", 183 | height="256", 184 | ) 185 | 186 | selected = gr.State() 187 | batch = gr.State() 188 | 189 | preprocessed_data = gr.Gallery( 190 | label="Preprocessed Images", 191 | show_label=True, 192 | columns=[4], 193 | rows=[1], 194 | object_fit="contain", 195 | height="256", 196 | ) 197 | 198 | with gr.Row(variant='panel'): 199 | run_inference_btn = gr.Button("Run Inference") 200 | 201 | with gr.Column(scale=4): 202 | output_3D = gr.Model3D( 203 | clear_color=[0.0, 0.0, 0.0, 0.0], 204 | height=520, 205 | zoom_speed=0.5, 206 | pan_speed=0.5, 207 | label="3D Point Clouds and Recovered Cameras" 208 | ) 209 | 210 | # Link image gallery selection 211 | gallery.select( 212 | fn=get_select_index, 213 | inputs=None, 214 | outputs=[image_block, selected] 215 | ).success( 216 | fn=preprocess, 217 | inputs=[image_block, selected], 218 | outputs=[preprocessed_data, batch], 219 | queue=False, 220 | show_progress="full" 221 | ) 222 | 223 | # Handle user uploads 224 | image_block.upload( 225 | preprocess, 226 | inputs=[image_block], 227 | outputs=[preprocessed_data, batch], 228 | queue=False, 229 | show_progress="full" 230 | ).success(info_fn, None, None) 231 | 232 | # Run 3D reconstruction 233 | run_inference_btn.click( 234 | check_img_input, 235 | inputs=[image_block], 236 | queue=False 237 | ).success( 238 | run_inference, 239 | inputs=[batch], 240 | outputs=[output_3D] 241 | ) 242 | 243 | demo.queue().launch(share=True) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | black 2 | einops 3 | flake8 4 | hydra-core 5 | imageio 6 | imageio-ffmpeg 7 | ipdb 8 | ipython 9 | isort 10 | jupyter 11 | jupyter_black 12 | jupyter_nbextensions_configurator 13 | jupyterlab 14 | kaleido 15 | matplotlib 16 | numpy==1.26.4 17 | omegaconf 18 | opencv-python 19 | plotly 20 | scikit-learn 21 | scipy 22 | timm 23 | trimesh 24 | tqdm 25 | iopath 26 | pandas 27 | wandb 28 | submitit 29 | accelerate==0.32.1 30 | gradio==4.44.1 31 | pydantic==2.10.6 32 | gdown --------------------------------------------------------------------------------