├── .gitignore
├── LICENSE
├── README.md
├── assets
└── demo.png
├── conf
├── config.yaml
└── diffusion.yml
├── data
└── demo
│ ├── jellycat
│ ├── 001.jpg
│ ├── 002.jpg
│ ├── 003.jpg
│ └── 004.jpg
│ ├── jordan
│ ├── 001.png
│ ├── 002.png
│ ├── 003.png
│ ├── 004.png
│ ├── 005.png
│ ├── 006.png
│ ├── 007.png
│ └── 008.png
│ ├── kew_gardens_ruined_arch
│ ├── 001.jpeg
│ ├── 002.jpeg
│ └── 003.jpeg
│ └── kotor_cathedral
│ ├── 001.jpeg
│ ├── 002.jpeg
│ ├── 003.jpeg
│ ├── 004.jpeg
│ ├── 005.jpeg
│ └── 006.jpeg
├── diffusionsfm
├── __init__.py
├── dataset
│ ├── __init__.py
│ ├── co3d_v2.py
│ └── custom.py
├── eval
│ ├── __init__.py
│ ├── eval_category.py
│ └── eval_jobs.py
├── inference
│ ├── __init__.py
│ ├── ddim.py
│ ├── load_model.py
│ └── predict.py
├── model
│ ├── base_model.py
│ ├── blocks.py
│ ├── diffuser.py
│ ├── diffuser_dpt.py
│ ├── dit.py
│ ├── feature_extractors.py
│ ├── memory_efficient_attention.py
│ └── scheduler.py
└── utils
│ ├── __init__.py
│ ├── configs.py
│ ├── distortion.py
│ ├── distributed.py
│ ├── geometry.py
│ ├── normalize.py
│ ├── rays.py
│ ├── slurm.py
│ └── visualization.py
├── docs
├── eval.md
└── train.md
├── gradio_app.py
├── requirements.txt
└── train.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__/
2 | .DS_Store
3 | wandb/
4 | slurm_logs/
5 | output/
6 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2025 Qitao Zhao
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # DiffusionSfM
2 |
3 | This repository contains the official implementation for **DiffusionSfM: Predicting Structure and Motion**
4 | **via Ray Origin and Endpoint Diffusion**. The paper has been accepted to [CVPR 2025](https://cvpr.thecvf.com/Conferences/2025).
5 |
6 | [Project Page](https://qitaozhao.github.io/DiffusionSfM) | [arXiv](https://arxiv.org/abs/2505.05473) |
7 |
8 | ### News
9 |
10 | - 2025.05.04: Initial code release.
11 |
12 | ## Introduction
13 |
14 | **tl;dr** Given a set of multi-view images, **DiffusionSfM** represents scene geometry and cameras as pixel-wise ray origins and endpoints in a global frame. It learns a denoising diffusion model to infer these elements directly from multi-view inputs.
15 |
16 | 
17 |
18 | ## Install
19 |
20 | 1. Clone DiffusionSfM:
21 |
22 | ```bash
23 | git clone https://github.com/QitaoZhao/DiffusionSfM.git
24 | cd DiffusionSfM
25 | ```
26 |
27 | 2. Create the environment and install packages:
28 |
29 | ```bash
30 | conda create -n diffusionsfm python=3.9
31 | conda activate diffusionsfm
32 |
33 | # enable nvcc
34 | conda install -c conda-forge cudatoolkit-dev
35 |
36 | ### torch
37 | # CUDA 11.7
38 | conda install pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.7 -c pytorch -c nvidia
39 |
40 | pip install -r requirements.txt
41 |
42 | ### pytorch3D
43 | # CUDA 11.7
44 | conda install https://anaconda.org/pytorch3d/pytorch3d/0.7.7/download/linux-64/pytorch3d-0.7.7-py39_cu117_pyt201.tar.bz2
45 |
46 | # xformers
47 | conda install xformers -c xformers
48 | ```
49 |
50 | Tested on:
51 |
52 | - Springdale Linux 8.6 with torch 2.0.1 & CUDA 11.7 on A6000 GPUs.
53 |
54 | > **Note:** If you encounter the error
55 |
56 | > ImportError: .../libtorch_cpu.so: undefined symbol: iJIT_NotifyEvent
57 |
58 | > when importing PyTorch, refer to this [related issue](https://github.com/coleygroup/shepherd-score/issues/1) or try installing Intel MKL explicitly with:
59 |
60 | ```
61 | conda install mkl==2024.0
62 | ```
63 |
64 | ## Run Demo
65 |
66 | #### (1) Try the Online Demo
67 |
68 | Check out our interactive demo on Hugging Face:
69 |
70 | 👉 [DiffusionSfM Demo](https://huggingface.co/spaces/qitaoz/DiffusionSfM)
71 |
72 | #### (2) Run the Gradio Demo Locally
73 |
74 | Download the model weights manually from [Hugging Face](https://huggingface.co/qitaoz/DiffusionSfM):
75 |
76 | ```python
77 | from huggingface_hub import hf_hub_download
78 |
79 | filepath = hf_hub_download(repo_id="qitaoz/DiffusionSfM", filename="qitaoz/DiffusionSfM")
80 | ```
81 |
82 | or [Google Drive](https://drive.google.com/file/d/1NBdq7A1QMFGhIbpK1HT3ATv2S1jXWr2h/view?usp=drive_link):
83 |
84 | ```bash
85 | gdown https://drive.google.com/uc\?id\=1NBdq7A1QMFGhIbpK1HT3ATv2S1jXWr2h
86 | unzip models.zip
87 | ```
88 | Next run the demo like so:
89 |
90 | ```bash
91 | # first-time running may take a longer time
92 | python gradio_app.py
93 | ```
94 |
95 | 
96 |
97 | You can run our model in two ways:
98 |
99 | 1. **Upload Images** — Upload your own multi-view images above.
100 | 2. **Use a Preprocessed Example** — Select one of the pre-collected examples below.
101 |
102 | ## Training
103 |
104 | Set up wandb:
105 |
106 | ```bash
107 | wandb login
108 | ```
109 |
110 | See [docs/train.md](https://github.com/QitaoZhao/DiffusionSfM/blob/main/docs/train.md) for more detailed instructions on training.
111 |
112 | ## Evaluation
113 |
114 | See [docs/eval.md](https://github.com/QitaoZhao/DiffusionSfM/blob/main/docs/eval.md) for instructions on how to run evaluation code.
115 |
116 | ## Acknowledgments
117 |
118 | This project builds upon [RayDiffusion](https://github.com/jasonyzhang/RayDiffusion). [Amy Lin](https://amyxlase.github.io/) and [Jason Y. Zhang](https://jasonyzhang.com/) developed the initial codebase during the early stages of this project.
119 |
120 | ## Cite DiffusionSfM
121 |
122 | If you find this code helpful, please cite:
123 |
124 | ```
125 | @inproceedings{zhao2025diffusionsfm,
126 | title={DiffusionSfM: Predicting Structure and Motion via Ray Origin and Endpoint Diffusion},
127 | author={Qitao Zhao and Amy Lin and Jeff Tan and Jason Y. Zhang and Deva Ramanan and Shubham Tulsiani},
128 | booktitle={CVPR},
129 | year={2025}
130 | }
131 | ```
--------------------------------------------------------------------------------
/assets/demo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/QitaoZhao/DiffusionSfM/4bb08800721bdcf46b0c823586a2fab4702ff282/assets/demo.png
--------------------------------------------------------------------------------
/conf/config.yaml:
--------------------------------------------------------------------------------
1 | training:
2 | resume: False # If True, must set hydra.run.dir accordingly
3 | pretrain_path: ""
4 | interval_visualize: 1000
5 | interval_save_checkpoint: 5000
6 | interval_delete_checkpoint: 10000
7 | interval_evaluate: 5000
8 | delete_all_checkpoints_after_training: False
9 | lr: 1e-4
10 | mixed_precision: True
11 | matmul_precision: high
12 | max_iterations: 100000
13 | batch_size: 64
14 | num_workers: 8
15 | gpu_id: 0
16 | freeze_encoder: True
17 | seed: 0
18 | job_key: "" # Use this for submitit sweeps where timestamps might collide
19 | translation_scale: 1.0
20 | regression: False
21 | prob_unconditional: 0
22 | load_extra_cameras: False
23 | calculate_intrinsics: False
24 | distort: False
25 | normalize_first_camera: True
26 | diffuse_origins_and_endpoints: True
27 | diffuse_depths: False
28 | depth_resolution: 1
29 | dpt_head: False
30 | full_num_patches_x: 16
31 | full_num_patches_y: 16
32 | dpt_encoder_features: True
33 | nearest_neighbor: True
34 | no_bg_targets: True
35 | unit_normalize_scene: False
36 | sd_scale: 2
37 | bfloat: True
38 | first_cam_mediod: True
39 | gradient_clipping: False
40 | l1_loss: False
41 | grad_accumulation: False
42 | reinit: False
43 |
44 | model:
45 | pred_x0: True
46 | model_type: dit
47 | num_patches_x: 16
48 | num_patches_y: 16
49 | depth: 16
50 | num_images: 1
51 | random_num_images: True
52 | feature_extractor: dino
53 | append_ndc: True
54 | within_image: False
55 | use_homogeneous: True
56 | freeze_transformer: False
57 | cond_depth_mask: True
58 |
59 | noise_scheduler:
60 | type: linear
61 | max_timesteps: 100
62 | beta_start: 0.0120
63 | beta_end: 0.00085
64 | marigold_ddim: False
65 |
66 | dataset:
67 | name: co3d
68 | shape: all_train
69 | apply_augmentation: True
70 | use_global_intrinsics: True
71 | mask_holes: True
72 | image_size: 224
73 |
74 | debug:
75 | wandb: True
76 | project_name: diffusionsfm
77 | run_name:
78 | anomaly_detection: False
79 |
80 | hydra:
81 | run:
82 | dir: ./output/${now:%m%d_%H%M%S_%f}${training.job_key}
83 | output_subdir: hydra
84 |
--------------------------------------------------------------------------------
/conf/diffusion.yml:
--------------------------------------------------------------------------------
1 | name: diffusion
2 | channels:
3 | - conda-forge
4 | - iopath
5 | - nvidia
6 | - pkgs/main
7 | - pytorch
8 | - xformers
9 | dependencies:
10 | - _libgcc_mutex=0.1=conda_forge
11 | - _openmp_mutex=4.5=2_gnu
12 | - blas=1.0=mkl
13 | - brotli-python=1.0.9=py39h5a03fae_9
14 | - bzip2=1.0.8=h7f98852_4
15 | - ca-certificates=2023.7.22=hbcca054_0
16 | - certifi=2023.7.22=pyhd8ed1ab_0
17 | - charset-normalizer=3.2.0=pyhd8ed1ab_0
18 | - colorama=0.4.6=pyhd8ed1ab_0
19 | - cuda-cudart=11.7.99=0
20 | - cuda-cupti=11.7.101=0
21 | - cuda-libraries=11.7.1=0
22 | - cuda-nvrtc=11.7.99=0
23 | - cuda-nvtx=11.7.91=0
24 | - cuda-runtime=11.7.1=0
25 | - ffmpeg=4.3=hf484d3e_0
26 | - filelock=3.12.2=pyhd8ed1ab_0
27 | - freetype=2.12.1=hca18f0e_1
28 | - fvcore=0.1.5.post20221221=pyhd8ed1ab_0
29 | - gmp=6.2.1=h58526e2_0
30 | - gmpy2=2.1.2=py39h376b7d2_1
31 | - gnutls=3.6.13=h85f3911_1
32 | - idna=3.4=pyhd8ed1ab_0
33 | - intel-openmp=2022.1.0=h9e868ea_3769
34 | - iopath=0.1.9=py39
35 | - jinja2=3.1.2=pyhd8ed1ab_1
36 | - jpeg=9e=h0b41bf4_3
37 | - lame=3.100=h166bdaf_1003
38 | - lcms2=2.15=hfd0df8a_0
39 | - ld_impl_linux-64=2.40=h41732ed_0
40 | - lerc=4.0.0=h27087fc_0
41 | - libblas=3.9.0=16_linux64_mkl
42 | - libcblas=3.9.0=16_linux64_mkl
43 | - libcublas=11.10.3.66=0
44 | - libcufft=10.7.2.124=h4fbf590_0
45 | - libcufile=1.7.1.12=0
46 | - libcurand=10.3.3.129=0
47 | - libcusolver=11.4.0.1=0
48 | - libcusparse=11.7.4.91=0
49 | - libdeflate=1.17=h0b41bf4_0
50 | - libffi=3.3=h58526e2_2
51 | - libgcc-ng=13.1.0=he5830b7_0
52 | - libgomp=13.1.0=he5830b7_0
53 | - libiconv=1.17=h166bdaf_0
54 | - liblapack=3.9.0=16_linux64_mkl
55 | - libnpp=11.7.4.75=0
56 | - libnvjpeg=11.8.0.2=0
57 | - libpng=1.6.39=h753d276_0
58 | - libsqlite=3.42.0=h2797004_0
59 | - libstdcxx-ng=13.1.0=hfd8a6a1_0
60 | - libtiff=4.5.0=h6adf6a1_2
61 | - libwebp-base=1.3.1=hd590300_0
62 | - libxcb=1.13=h7f98852_1004
63 | - libzlib=1.2.13=hd590300_5
64 | - markupsafe=2.1.3=py39hd1e30aa_0
65 | - mkl=2022.1.0=hc2b9512_224
66 | - mpc=1.3.1=hfe3b2da_0
67 | - mpfr=4.2.0=hb012696_0
68 | - mpmath=1.3.0=pyhd8ed1ab_0
69 | - ncurses=6.4=hcb278e6_0
70 | - nettle=3.6=he412f7d_0
71 | - networkx=3.1=pyhd8ed1ab_0
72 | - numpy=1.25.2=py39h6183b62_0
73 | - openh264=2.1.1=h780b84a_0
74 | - openjpeg=2.5.0=hfec8fc6_2
75 | - openssl=1.1.1v=hd590300_0
76 | - pillow=9.4.0=py39h2320bf1_1
77 | - pip=23.2.1=pyhd8ed1ab_0
78 | - portalocker=2.7.0=py39hf3d152e_0
79 | - pthread-stubs=0.4=h36c2ea0_1001
80 | - pysocks=1.7.1=pyha2e5f31_6
81 | - python=3.9.0=hffdb5ce_5_cpython
82 | - python_abi=3.9=3_cp39
83 | - pytorch=2.0.1=py3.9_cuda11.7_cudnn8.5.0_0
84 | - pytorch-cuda=11.7=h778d358_5
85 | - pytorch-mutex=1.0=cuda
86 | - pyyaml=6.0=py39hb9d737c_5
87 | - readline=8.2=h8228510_1
88 | - requests=2.31.0=pyhd8ed1ab_0
89 | - setuptools=68.0.0=pyhd8ed1ab_0
90 | - sqlite=3.42.0=h2c6b66d_0
91 | - sympy=1.12=pypyh9d50eac_103
92 | - tabulate=0.9.0=pyhd8ed1ab_1
93 | - termcolor=2.3.0=pyhd8ed1ab_0
94 | - tk=8.6.12=h27826a3_0
95 | - torchaudio=2.0.2=py39_cu117
96 | - torchtriton=2.0.0=py39
97 | - torchvision=0.15.2=py39_cu117
98 | - tqdm=4.66.1=pyhd8ed1ab_0
99 | - typing_extensions=4.7.1=pyha770c72_0
100 | - tzdata=2023c=h71feb2d_0
101 | - urllib3=2.0.4=pyhd8ed1ab_0
102 | - wheel=0.41.1=pyhd8ed1ab_0
103 | - xformers=0.0.21=py39_cu11.8.0_pyt2.0.1
104 | - xorg-libxau=1.0.11=hd590300_0
105 | - xorg-libxdmcp=1.1.3=h7f98852_0
106 | - xz=5.2.6=h166bdaf_0
107 | - yacs=0.1.8=pyhd8ed1ab_0
108 | - yaml=0.2.5=h7f98852_2
109 | - zlib=1.2.13=hd590300_5
110 | - zstd=1.5.2=hfc55251_7
111 |
--------------------------------------------------------------------------------
/data/demo/jellycat/001.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/QitaoZhao/DiffusionSfM/4bb08800721bdcf46b0c823586a2fab4702ff282/data/demo/jellycat/001.jpg
--------------------------------------------------------------------------------
/data/demo/jellycat/002.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/QitaoZhao/DiffusionSfM/4bb08800721bdcf46b0c823586a2fab4702ff282/data/demo/jellycat/002.jpg
--------------------------------------------------------------------------------
/data/demo/jellycat/003.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/QitaoZhao/DiffusionSfM/4bb08800721bdcf46b0c823586a2fab4702ff282/data/demo/jellycat/003.jpg
--------------------------------------------------------------------------------
/data/demo/jellycat/004.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/QitaoZhao/DiffusionSfM/4bb08800721bdcf46b0c823586a2fab4702ff282/data/demo/jellycat/004.jpg
--------------------------------------------------------------------------------
/data/demo/jordan/001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/QitaoZhao/DiffusionSfM/4bb08800721bdcf46b0c823586a2fab4702ff282/data/demo/jordan/001.png
--------------------------------------------------------------------------------
/data/demo/jordan/002.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/QitaoZhao/DiffusionSfM/4bb08800721bdcf46b0c823586a2fab4702ff282/data/demo/jordan/002.png
--------------------------------------------------------------------------------
/data/demo/jordan/003.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/QitaoZhao/DiffusionSfM/4bb08800721bdcf46b0c823586a2fab4702ff282/data/demo/jordan/003.png
--------------------------------------------------------------------------------
/data/demo/jordan/004.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/QitaoZhao/DiffusionSfM/4bb08800721bdcf46b0c823586a2fab4702ff282/data/demo/jordan/004.png
--------------------------------------------------------------------------------
/data/demo/jordan/005.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/QitaoZhao/DiffusionSfM/4bb08800721bdcf46b0c823586a2fab4702ff282/data/demo/jordan/005.png
--------------------------------------------------------------------------------
/data/demo/jordan/006.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/QitaoZhao/DiffusionSfM/4bb08800721bdcf46b0c823586a2fab4702ff282/data/demo/jordan/006.png
--------------------------------------------------------------------------------
/data/demo/jordan/007.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/QitaoZhao/DiffusionSfM/4bb08800721bdcf46b0c823586a2fab4702ff282/data/demo/jordan/007.png
--------------------------------------------------------------------------------
/data/demo/jordan/008.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/QitaoZhao/DiffusionSfM/4bb08800721bdcf46b0c823586a2fab4702ff282/data/demo/jordan/008.png
--------------------------------------------------------------------------------
/data/demo/kew_gardens_ruined_arch/001.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/QitaoZhao/DiffusionSfM/4bb08800721bdcf46b0c823586a2fab4702ff282/data/demo/kew_gardens_ruined_arch/001.jpeg
--------------------------------------------------------------------------------
/data/demo/kew_gardens_ruined_arch/002.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/QitaoZhao/DiffusionSfM/4bb08800721bdcf46b0c823586a2fab4702ff282/data/demo/kew_gardens_ruined_arch/002.jpeg
--------------------------------------------------------------------------------
/data/demo/kew_gardens_ruined_arch/003.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/QitaoZhao/DiffusionSfM/4bb08800721bdcf46b0c823586a2fab4702ff282/data/demo/kew_gardens_ruined_arch/003.jpeg
--------------------------------------------------------------------------------
/data/demo/kotor_cathedral/001.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/QitaoZhao/DiffusionSfM/4bb08800721bdcf46b0c823586a2fab4702ff282/data/demo/kotor_cathedral/001.jpeg
--------------------------------------------------------------------------------
/data/demo/kotor_cathedral/002.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/QitaoZhao/DiffusionSfM/4bb08800721bdcf46b0c823586a2fab4702ff282/data/demo/kotor_cathedral/002.jpeg
--------------------------------------------------------------------------------
/data/demo/kotor_cathedral/003.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/QitaoZhao/DiffusionSfM/4bb08800721bdcf46b0c823586a2fab4702ff282/data/demo/kotor_cathedral/003.jpeg
--------------------------------------------------------------------------------
/data/demo/kotor_cathedral/004.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/QitaoZhao/DiffusionSfM/4bb08800721bdcf46b0c823586a2fab4702ff282/data/demo/kotor_cathedral/004.jpeg
--------------------------------------------------------------------------------
/data/demo/kotor_cathedral/005.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/QitaoZhao/DiffusionSfM/4bb08800721bdcf46b0c823586a2fab4702ff282/data/demo/kotor_cathedral/005.jpeg
--------------------------------------------------------------------------------
/data/demo/kotor_cathedral/006.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/QitaoZhao/DiffusionSfM/4bb08800721bdcf46b0c823586a2fab4702ff282/data/demo/kotor_cathedral/006.jpeg
--------------------------------------------------------------------------------
/diffusionsfm/__init__.py:
--------------------------------------------------------------------------------
1 | from .utils.rays import cameras_to_rays, rays_to_cameras, Rays
2 |
--------------------------------------------------------------------------------
/diffusionsfm/dataset/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/QitaoZhao/DiffusionSfM/4bb08800721bdcf46b0c823586a2fab4702ff282/diffusionsfm/dataset/__init__.py
--------------------------------------------------------------------------------
/diffusionsfm/dataset/co3d_v2.py:
--------------------------------------------------------------------------------
1 | import gzip
2 | import json
3 | import os.path as osp
4 | import random
5 | import socket
6 | import time
7 | import torch
8 | import warnings
9 |
10 | import numpy as np
11 | from PIL import Image, ImageFile
12 | from tqdm import tqdm
13 | from pytorch3d.renderer import PerspectiveCameras
14 | from torch.utils.data import Dataset
15 | from torchvision import transforms
16 | import matplotlib.pyplot as plt
17 | from scipy import ndimage as nd
18 |
19 | from diffusionsfm.utils.distortion import distort_image
20 |
21 |
22 | HOSTNAME = socket.gethostname()
23 |
24 | CO3D_DIR = "../co3d_data" # update this
25 | CO3D_ANNOTATION_DIR = osp.join(CO3D_DIR, "co3d_annotations")
26 | CO3D_DIR = CO3D_DEPTH_DIR = osp.join(CO3D_DIR, "co3d")
27 | order_path = osp.join(
28 | CO3D_DIR, "co3d_v2_random_order_{sample_num}/{category}.json"
29 | )
30 |
31 |
32 | TRAINING_CATEGORIES = [
33 | "apple",
34 | "backpack",
35 | "banana",
36 | "baseballbat",
37 | "baseballglove",
38 | "bench",
39 | "bicycle",
40 | "bottle",
41 | "bowl",
42 | "broccoli",
43 | "cake",
44 | "car",
45 | "carrot",
46 | "cellphone",
47 | "chair",
48 | "cup",
49 | "donut",
50 | "hairdryer",
51 | "handbag",
52 | "hydrant",
53 | "keyboard",
54 | "laptop",
55 | "microwave",
56 | "motorcycle",
57 | "mouse",
58 | "orange",
59 | "parkingmeter",
60 | "pizza",
61 | "plant",
62 | "stopsign",
63 | "teddybear",
64 | "toaster",
65 | "toilet",
66 | "toybus",
67 | "toyplane",
68 | "toytrain",
69 | "toytruck",
70 | "tv",
71 | "umbrella",
72 | "vase",
73 | "wineglass",
74 | ]
75 |
76 | TEST_CATEGORIES = [
77 | "ball",
78 | "book",
79 | "couch",
80 | "frisbee",
81 | "hotdog",
82 | "kite",
83 | "remote",
84 | "sandwich",
85 | "skateboard",
86 | "suitcase",
87 | ]
88 |
89 | assert len(TRAINING_CATEGORIES) + len(TEST_CATEGORIES) == 51
90 |
91 | Image.MAX_IMAGE_PIXELS = None
92 | ImageFile.LOAD_TRUNCATED_IMAGES = True
93 |
94 |
95 | def fill_depths(data, invalid=None):
96 | data_list = []
97 | for i in range(data.shape[0]):
98 | data_item = data[i].numpy()
99 | # Invalid must be 1 where stuff is invalid, 0 where valid
100 | ind = nd.distance_transform_edt(
101 | invalid[i], return_distances=False, return_indices=True
102 | )
103 | data_list.append(torch.tensor(data_item[tuple(ind)]))
104 | return torch.stack(data_list, dim=0)
105 |
106 |
107 | def full_scene_scale(batch):
108 | cameras = PerspectiveCameras(R=batch["R"], T=batch["T"], device="cuda")
109 | cc = cameras.get_camera_center()
110 | centroid = torch.mean(cc, dim=0)
111 |
112 | diffs = cc - centroid
113 | norms = torch.linalg.norm(diffs, dim=1)
114 |
115 | furthest_index = torch.argmax(norms).item()
116 | scale = norms[furthest_index].item()
117 | return scale
118 |
119 |
120 | def square_bbox(bbox, padding=0.0, astype=None, tight=False):
121 | """
122 | Computes a square bounding box, with optional padding parameters.
123 | Args:
124 | bbox: Bounding box in xyxy format (4,).
125 | Returns:
126 | square_bbox in xyxy format (4,).
127 | """
128 | if astype is None:
129 | astype = type(bbox[0])
130 | bbox = np.array(bbox)
131 | center = (bbox[:2] + bbox[2:]) / 2
132 | extents = (bbox[2:] - bbox[:2]) / 2
133 |
134 | # No black bars if tight
135 | if tight:
136 | s = min(extents) * (1 + padding)
137 | else:
138 | s = max(extents) * (1 + padding)
139 |
140 | square_bbox = np.array(
141 | [center[0] - s, center[1] - s, center[0] + s, center[1] + s],
142 | dtype=astype,
143 | )
144 | return square_bbox
145 |
146 |
147 | def unnormalize_image(image, return_numpy=True, return_int=True):
148 | if isinstance(image, torch.Tensor):
149 | image = image.detach().cpu().numpy()
150 |
151 | if image.ndim == 3:
152 | if image.shape[0] == 3:
153 | image = image[None, ...]
154 | elif image.shape[2] == 3:
155 | image = image.transpose(2, 0, 1)[None, ...]
156 | else:
157 | raise ValueError(f"Unexpected image shape: {image.shape}")
158 | elif image.ndim == 4:
159 | if image.shape[1] == 3:
160 | pass
161 | elif image.shape[3] == 3:
162 | image = image.transpose(0, 3, 1, 2)
163 | else:
164 | raise ValueError(f"Unexpected batch image shape: {image.shape}")
165 | else:
166 | raise ValueError(f"Unsupported input shape: {image.shape}")
167 |
168 | mean = np.array([0.485, 0.456, 0.406])[None, :, None, None]
169 | std = np.array([0.229, 0.224, 0.225])[None, :, None, None]
170 | image = image * std + mean
171 |
172 | if return_int:
173 | image = np.clip(image * 255.0, 0, 255).astype(np.uint8)
174 | else:
175 | image = np.clip(image, 0.0, 1.0)
176 |
177 | if image.shape[0] == 1:
178 | image = image[0]
179 |
180 | if return_numpy:
181 | return image
182 | else:
183 | return torch.from_numpy(image)
184 |
185 |
186 | def unnormalize_image_for_vis(image):
187 | assert len(image.shape) == 5 and image.shape[2] == 3
188 | mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 1, 3, 1, 1).to(image.device)
189 | std = torch.tensor([0.229, 0.224, 0.225]).view(1, 1, 3, 1, 1).to(image.device)
190 | image = image * std + mean
191 | image = (image - 0.5) / 0.5
192 | return image
193 |
194 |
195 | def _transform_intrinsic(image, bbox, principal_point, focal_length):
196 | # Rescale intrinsics to match bbox
197 | half_box = np.array([image.width, image.height]).astype(np.float32) / 2
198 | org_scale = min(half_box).astype(np.float32)
199 |
200 | # Pixel coordinates
201 | principal_point_px = half_box - (np.array(principal_point) * org_scale)
202 | focal_length_px = np.array(focal_length) * org_scale
203 | principal_point_px -= bbox[:2]
204 | new_bbox = (bbox[2:] - bbox[:2]) / 2
205 | new_scale = min(new_bbox)
206 |
207 | # NDC coordinates
208 | new_principal_ndc = (new_bbox - principal_point_px) / new_scale
209 | new_focal_ndc = focal_length_px / new_scale
210 |
211 | principal_point = torch.tensor(new_principal_ndc.astype(np.float32))
212 | focal_length = torch.tensor(new_focal_ndc.astype(np.float32))
213 |
214 | return principal_point, focal_length
215 |
216 |
217 | def construct_camera_from_batch(batch, device):
218 | if isinstance(device, int):
219 | device = f"cuda:{device}"
220 |
221 | return PerspectiveCameras(
222 | R=batch["R"].reshape(-1, 3, 3),
223 | T=batch["T"].reshape(-1, 3),
224 | focal_length=batch["focal_lengths"].reshape(-1, 2),
225 | principal_point=batch["principal_points"].reshape(-1, 2),
226 | image_size=batch["image_sizes"].reshape(-1, 2),
227 | device=device,
228 | )
229 |
230 |
231 | def save_batch_images(images, fname):
232 | cmap = plt.get_cmap("hsv")
233 | num_frames = len(images)
234 | num_rows = len(images)
235 | num_cols = 4
236 | figsize = (num_cols * 2, num_rows * 2)
237 | fig, axs = plt.subplots(num_rows, num_cols, figsize=figsize)
238 | axs = axs.flatten()
239 | for i in range(num_rows):
240 | for j in range(4):
241 | if i < num_frames:
242 | axs[i * 4 + j].imshow(unnormalize_image(images[i][j]))
243 | for s in ["bottom", "top", "left", "right"]:
244 | axs[i * 4 + j].spines[s].set_color(cmap(i / (num_frames)))
245 | axs[i * 4 + j].spines[s].set_linewidth(5)
246 | axs[i * 4 + j].set_xticks([])
247 | axs[i * 4 + j].set_yticks([])
248 | else:
249 | axs[i * 4 + j].axis("off")
250 | plt.tight_layout()
251 | plt.savefig(fname)
252 |
253 |
254 | def jitter_bbox(
255 | square_bbox,
256 | jitter_scale=(1.1, 1.2),
257 | jitter_trans=(-0.07, 0.07),
258 | direction_from_size=None,
259 | ):
260 |
261 | square_bbox = np.array(square_bbox.astype(float))
262 | s = np.random.uniform(jitter_scale[0], jitter_scale[1])
263 |
264 | # Jitter only one dimension if center cropping
265 | tx, ty = np.random.uniform(jitter_trans[0], jitter_trans[1], size=2)
266 | if direction_from_size is not None:
267 | if direction_from_size[0] > direction_from_size[1]:
268 | tx = 0
269 | else:
270 | ty = 0
271 |
272 | side_length = square_bbox[2] - square_bbox[0]
273 | center = (square_bbox[:2] + square_bbox[2:]) / 2 + np.array([tx, ty]) * side_length
274 | extent = side_length / 2 * s
275 | ul = center - extent
276 | lr = ul + 2 * extent
277 | return np.concatenate((ul, lr))
278 |
279 |
280 | class Co3dDataset(Dataset):
281 | def __init__(
282 | self,
283 | category=("all_train",),
284 | split="train",
285 | transform=None,
286 | num_images=2,
287 | img_size=224,
288 | mask_images=False,
289 | crop_images=True,
290 | co3d_dir=None,
291 | co3d_annotation_dir=None,
292 | precropped_images=False,
293 | apply_augmentation=True,
294 | normalize_cameras=True,
295 | no_images=False,
296 | sample_num=None,
297 | seed=0,
298 | load_extra_cameras=False,
299 | distort_image=False,
300 | load_depths=False,
301 | center_crop=False,
302 | depth_size=256,
303 | mask_holes=False,
304 | object_mask=True,
305 | ):
306 | """
307 | Args:
308 | num_images: Number of images in each batch.
309 | perspective_correction (str):
310 | "none": No perspective correction.
311 | "warp": Warp the image and label.
312 | "label_only": Correct the label only.
313 | """
314 | start_time = time.time()
315 |
316 | self.category = category
317 | self.split = split
318 | self.transform = transform
319 | self.num_images = num_images
320 | self.img_size = img_size
321 | self.mask_images = mask_images
322 | self.crop_images = crop_images
323 | self.precropped_images = precropped_images
324 | self.apply_augmentation = apply_augmentation
325 | self.normalize_cameras = normalize_cameras
326 | self.no_images = no_images
327 | self.sample_num = sample_num
328 | self.load_extra_cameras = load_extra_cameras
329 | self.distort = distort_image
330 | self.load_depths = load_depths
331 | self.center_crop = center_crop
332 | self.depth_size = depth_size
333 | self.mask_holes = mask_holes
334 | self.object_mask = object_mask
335 |
336 | if self.apply_augmentation:
337 | if self.center_crop:
338 | self.jitter_scale = (0.8, 1.1)
339 | self.jitter_trans = (0.0, 0.0)
340 | else:
341 | self.jitter_scale = (1.1, 1.2)
342 | self.jitter_trans = (-0.07, 0.07)
343 | else:
344 | # Note if trained with apply_augmentation, we should still use
345 | # apply_augmentation at test time.
346 | self.jitter_scale = (1, 1)
347 | self.jitter_trans = (0.0, 0.0)
348 |
349 | if self.distort:
350 | self.k1_max = 1.0
351 | self.k2_max = 1.0
352 |
353 | if co3d_dir is not None:
354 | self.co3d_dir = co3d_dir
355 | self.co3d_annotation_dir = co3d_annotation_dir
356 | else:
357 | self.co3d_dir = CO3D_DIR
358 | self.co3d_annotation_dir = CO3D_ANNOTATION_DIR
359 | self.co3d_depth_dir = CO3D_DEPTH_DIR
360 |
361 | if isinstance(self.category, str):
362 | self.category = [self.category]
363 |
364 | if "all_train" in self.category:
365 | self.category = TRAINING_CATEGORIES
366 | if "all_test" in self.category:
367 | self.category = TEST_CATEGORIES
368 | if "full" in self.category:
369 | self.category = TRAINING_CATEGORIES + TEST_CATEGORIES
370 | self.category = sorted(self.category)
371 | self.is_single_category = len(self.category) == 1
372 |
373 | # Fixing seed
374 | torch.manual_seed(seed)
375 | random.seed(seed)
376 | np.random.seed(seed)
377 |
378 | print(f"Co3d ({split}):")
379 |
380 | self.low_quality_translations = [
381 | "411_55952_107659",
382 | "427_59915_115716",
383 | "435_61970_121848",
384 | "112_13265_22828",
385 | "110_13069_25642",
386 | "165_18080_34378",
387 | "368_39891_78502",
388 | "391_47029_93665",
389 | "20_695_1450",
390 | "135_15556_31096",
391 | "417_57572_110680",
392 | ] # Initialized with sequences with poor depth masks
393 | self.rotations = {}
394 | self.category_map = {}
395 | for c in tqdm(self.category):
396 | annotation_file = osp.join(
397 | self.co3d_annotation_dir, f"{c}_{self.split}.jgz"
398 | )
399 | with gzip.open(annotation_file, "r") as fin:
400 | annotation = json.loads(fin.read())
401 |
402 | counter = 0
403 | for seq_name, seq_data in annotation.items():
404 | counter += 1
405 | if len(seq_data) < self.num_images:
406 | continue
407 |
408 | filtered_data = []
409 | self.category_map[seq_name] = c
410 | bad_seq = False
411 | for data in seq_data:
412 | # Make sure translations are not ridiculous and rotations are valid
413 | det = np.linalg.det(data["R"])
414 | if (np.abs(data["T"]) > 1e5).any() or det < 0.99 or det > 1.01:
415 | bad_seq = True
416 | self.low_quality_translations.append(seq_name)
417 | break
418 |
419 | # Ignore all unnecessary information.
420 | filtered_data.append(
421 | {
422 | "filepath": data["filepath"],
423 | "bbox": data["bbox"],
424 | "R": data["R"],
425 | "T": data["T"],
426 | "focal_length": data["focal_length"],
427 | "principal_point": data["principal_point"],
428 | },
429 | )
430 |
431 | if not bad_seq:
432 | self.rotations[seq_name] = filtered_data
433 |
434 | self.sequence_list = list(self.rotations.keys())
435 |
436 | IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
437 | IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
438 |
439 | if self.transform is None:
440 | self.transform = transforms.Compose(
441 | [
442 | transforms.ToTensor(),
443 | transforms.Resize(self.img_size, antialias=True),
444 | transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
445 | ]
446 | )
447 |
448 | self.transform_depth = transforms.Compose(
449 | [
450 | transforms.Resize(
451 | self.depth_size,
452 | antialias=False,
453 | interpolation=transforms.InterpolationMode.NEAREST_EXACT,
454 | ),
455 | ]
456 | )
457 |
458 | print(
459 | f"Low quality translation sequences, not used: {self.low_quality_translations}"
460 | )
461 | print(f"Data size: {len(self)}")
462 | print(f"Data loading took {(time.time()-start_time)} seconds.")
463 |
464 | def __len__(self):
465 | return len(self.sequence_list)
466 |
467 | def __getitem__(self, index):
468 | num_to_load = self.num_images if not self.load_extra_cameras else 8
469 |
470 | sequence_name = self.sequence_list[index % len(self.sequence_list)]
471 | metadata = self.rotations[sequence_name]
472 |
473 | if self.sample_num is not None:
474 | with open(
475 | order_path.format(sample_num=self.sample_num, category=self.category[0])
476 | ) as f:
477 | order = json.load(f)
478 | ids = order[sequence_name][:num_to_load]
479 | else:
480 | replace = len(metadata) < 8
481 | ids = np.random.choice(len(metadata), num_to_load, replace=replace)
482 |
483 | return self.get_data(index=index, ids=ids, num_valid_frames=num_to_load)
484 |
485 | def _get_scene_scale(self, sequence_name):
486 | n = len(self.rotations[sequence_name])
487 |
488 | R = torch.zeros(n, 3, 3)
489 | T = torch.zeros(n, 3)
490 |
491 | for i, ann in enumerate(self.rotations[sequence_name]):
492 | R[i, ...] = torch.tensor(self.rotations[sequence_name][i]["R"])
493 | T[i, ...] = torch.tensor(self.rotations[sequence_name][i]["T"])
494 |
495 | cameras = PerspectiveCameras(R=R, T=T)
496 | cc = cameras.get_camera_center()
497 | centeroid = torch.mean(cc, dim=0)
498 | diff = cc - centeroid
499 |
500 | norm = torch.norm(diff, dim=1)
501 | scale = torch.max(norm).item()
502 |
503 | return scale
504 |
505 | def _crop_image(self, image, bbox):
506 | image_crop = transforms.functional.crop(
507 | image,
508 | top=bbox[1],
509 | left=bbox[0],
510 | height=bbox[3] - bbox[1],
511 | width=bbox[2] - bbox[0],
512 | )
513 | return image_crop
514 |
515 | def _transform_intrinsic(self, image, bbox, principal_point, focal_length):
516 | half_box = np.array([image.width, image.height]).astype(np.float32) / 2
517 | org_scale = min(half_box).astype(np.float32)
518 |
519 | # Pixel coordinates
520 | principal_point_px = half_box - (np.array(principal_point) * org_scale)
521 | focal_length_px = np.array(focal_length) * org_scale
522 | principal_point_px -= bbox[:2]
523 | new_bbox = (bbox[2:] - bbox[:2]) / 2
524 | new_scale = min(new_bbox)
525 |
526 | # NDC coordinates
527 | new_principal_ndc = (new_bbox - principal_point_px) / new_scale
528 | new_focal_ndc = focal_length_px / new_scale
529 |
530 | return new_principal_ndc.astype(np.float32), new_focal_ndc.astype(np.float32)
531 |
532 | def get_data(
533 | self,
534 | index=None,
535 | sequence_name=None,
536 | ids=(0, 1),
537 | no_images=False,
538 | num_valid_frames=None,
539 | load_using_order=None,
540 | ):
541 | if load_using_order is not None:
542 | with open(
543 | order_path.format(sample_num=self.sample_num, category=self.category[0])
544 | ) as f:
545 | order = json.load(f)
546 | ids = order[sequence_name][:load_using_order]
547 |
548 | if sequence_name is None:
549 | index = index % len(self.sequence_list)
550 | sequence_name = self.sequence_list[index]
551 | metadata = self.rotations[sequence_name]
552 | category = self.category_map[sequence_name]
553 |
554 | # Read image & camera information from annotations
555 | annos = [metadata[i] for i in ids]
556 | images = []
557 | image_sizes = []
558 | PP = []
559 | FL = []
560 | crop_parameters = []
561 | filenames = []
562 | distortion_parameters = []
563 | depths = []
564 | depth_masks = []
565 | object_masks = []
566 | dino_images = []
567 | for anno in annos:
568 | filepath = anno["filepath"]
569 |
570 | if not no_images:
571 | image = Image.open(osp.join(self.co3d_dir, filepath)).convert("RGB")
572 | image_size = image.size
573 |
574 | # Optionally mask images with black background
575 | if self.mask_images:
576 | black_image = Image.new("RGB", image_size, (0, 0, 0))
577 | mask_name = osp.basename(filepath.replace(".jpg", ".png"))
578 |
579 | mask_path = osp.join(
580 | self.co3d_dir, category, sequence_name, "masks", mask_name
581 | )
582 | mask = Image.open(mask_path).convert("L")
583 |
584 | if mask.size != image_size:
585 | mask = mask.resize(image_size)
586 | mask = Image.fromarray(np.array(mask) > 125)
587 | image = Image.composite(image, black_image, mask)
588 |
589 | if self.object_mask:
590 | mask_name = osp.basename(filepath.replace(".jpg", ".png"))
591 | mask_path = osp.join(
592 | self.co3d_dir, category, sequence_name, "masks", mask_name
593 | )
594 | mask = Image.open(mask_path).convert("L")
595 |
596 | if mask.size != image_size:
597 | mask = mask.resize(image_size)
598 | mask = torch.from_numpy(np.array(mask) > 125)
599 |
600 | # Determine crop, Resnet wants square images
601 | bbox = np.array(anno["bbox"])
602 | good_bbox = ((bbox[2:] - bbox[:2]) > 30).all()
603 | bbox = (
604 | anno["bbox"]
605 | if not self.center_crop and good_bbox
606 | else [0, 0, image.width, image.height]
607 | )
608 |
609 | # Distort image and bbox if desired
610 | if self.distort:
611 | k1 = random.uniform(0, self.k1_max)
612 | k2 = random.uniform(0, self.k2_max)
613 |
614 | try:
615 | image, bbox = distort_image(
616 | image, np.array(bbox), k1, k2, modify_bbox=True
617 | )
618 |
619 | except:
620 | print("INFO:")
621 | print(sequence_name)
622 | print(index)
623 | print(ids)
624 | print(k1)
625 | print(k2)
626 |
627 | distortion_parameters.append(torch.FloatTensor([k1, k2]))
628 |
629 | bbox = square_bbox(np.array(bbox), tight=self.center_crop)
630 | if self.apply_augmentation:
631 | bbox = jitter_bbox(
632 | bbox,
633 | jitter_scale=self.jitter_scale,
634 | jitter_trans=self.jitter_trans,
635 | direction_from_size=image.size if self.center_crop else None,
636 | )
637 | bbox = np.around(bbox).astype(int)
638 |
639 | # Crop parameters
640 | crop_center = (bbox[:2] + bbox[2:]) / 2
641 | principal_point = torch.tensor(anno["principal_point"])
642 | focal_length = torch.tensor(anno["focal_length"])
643 |
644 | # convert crop center to correspond to a "square" image
645 | width, height = image.size
646 | length = max(width, height)
647 | s = length / min(width, height)
648 | crop_center = crop_center + (length - np.array([width, height])) / 2
649 |
650 | # convert to NDC
651 | cc = s - 2 * s * crop_center / length
652 | crop_width = 2 * s * (bbox[2] - bbox[0]) / length
653 | crop_params = torch.tensor([-cc[0], -cc[1], crop_width, s])
654 |
655 | # Crop and normalize image
656 | if not self.precropped_images:
657 | image = self._crop_image(image, bbox)
658 |
659 | try:
660 | image = self.transform(image)
661 | except:
662 | print("INFO:")
663 | print(sequence_name)
664 | print(index)
665 | print(ids)
666 | print(k1)
667 | print(k2)
668 |
669 | images.append(image[:, : self.img_size, : self.img_size])
670 | crop_parameters.append(crop_params)
671 |
672 | if self.load_depths:
673 | # Open depth map
674 | depth_name = osp.basename(
675 | filepath.replace(".jpg", ".jpg.geometric.png")
676 | )
677 | depth_path = osp.join(
678 | self.co3d_depth_dir,
679 | category,
680 | sequence_name,
681 | "depths",
682 | depth_name,
683 | )
684 | depth_pil = Image.open(depth_path)
685 |
686 | # 16 bit float type casting
687 | depth = torch.tensor(
688 | np.frombuffer(
689 | np.array(depth_pil, dtype=np.uint16), dtype=np.float16
690 | )
691 | .astype(np.float32)
692 | .reshape((depth_pil.size[1], depth_pil.size[0]))
693 | )
694 |
695 | # Crop and resize as with images
696 | if depth_pil.size != image_size:
697 | # bbox may have the wrong scale
698 | bbox = depth_pil.size[0] * bbox / image_size[0]
699 |
700 | if self.object_mask:
701 | assert mask.shape == depth.shape
702 |
703 | bbox = np.around(bbox).astype(int)
704 | depth = self._crop_image(depth, bbox)
705 |
706 | # Resize
707 | depth = self.transform_depth(depth.unsqueeze(0))[
708 | 0, : self.depth_size, : self.depth_size
709 | ]
710 | depths.append(depth)
711 |
712 | if self.object_mask:
713 | mask = self._crop_image(mask, bbox)
714 | mask = self.transform_depth(mask.unsqueeze(0))[
715 | 0, : self.depth_size, : self.depth_size
716 | ]
717 | object_masks.append(mask)
718 |
719 | PP.append(principal_point)
720 | FL.append(focal_length)
721 | image_sizes.append(torch.tensor([self.img_size, self.img_size]))
722 | filenames.append(filepath)
723 |
724 | if not no_images:
725 | if self.load_depths:
726 | depths = torch.stack(depths)
727 |
728 | depth_masks = torch.logical_or(depths <= 0, depths.isinf())
729 | depth_masks = (~depth_masks).long()
730 |
731 | if self.object_mask:
732 | object_masks = torch.stack(object_masks, dim=0)
733 |
734 | if self.mask_holes:
735 | depths = fill_depths(depths, depth_masks == 0)
736 |
737 | # Sometimes mask_holes misses stuff
738 | new_masks = torch.logical_or(depths <= 0, depths.isinf())
739 | new_masks = (~new_masks).long()
740 | depths[new_masks == 0] = -1
741 |
742 | assert torch.logical_or(depths > 0, depths == -1).all()
743 | assert not (depths.isinf()).any()
744 | assert not (depths.isnan()).any()
745 |
746 | if self.load_extra_cameras:
747 | # Remove the extra loaded image, for saving space
748 | images = images[: self.num_images]
749 |
750 | if self.distort:
751 | distortion_parameters = torch.stack(distortion_parameters)
752 |
753 | images = torch.stack(images)
754 | crop_parameters = torch.stack(crop_parameters)
755 | focal_lengths = torch.stack(FL)
756 | principal_points = torch.stack(PP)
757 | image_sizes = torch.stack(image_sizes)
758 | else:
759 | images = None
760 | crop_parameters = None
761 | distortion_parameters = None
762 | focal_lengths = []
763 | principal_points = []
764 | image_sizes = []
765 |
766 | # Assemble batch info to send back
767 | R = torch.stack([torch.tensor(anno["R"]) for anno in annos])
768 | T = torch.stack([torch.tensor(anno["T"]) for anno in annos])
769 |
770 | batch = {
771 | "model_id": sequence_name,
772 | "category": category,
773 | "n": len(metadata),
774 | "num_valid_frames": num_valid_frames,
775 | "ind": torch.tensor(ids),
776 | "image": images,
777 | "depth": depths,
778 | "depth_masks": depth_masks,
779 | "object_masks": object_masks,
780 | "R": R,
781 | "T": T,
782 | "focal_length": focal_lengths,
783 | "principal_point": principal_points,
784 | "image_size": image_sizes,
785 | "crop_parameters": crop_parameters,
786 | "distortion_parameters": torch.zeros(4),
787 | "filename": filenames,
788 | "category": category,
789 | "dataset": "co3d",
790 | }
791 |
792 | return batch
793 |
--------------------------------------------------------------------------------
/diffusionsfm/dataset/custom.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import numpy as np
4 | import matplotlib.pyplot as plt
5 |
6 | from PIL import Image, ImageOps
7 | from torch.utils.data import Dataset
8 | from torchvision import transforms
9 |
10 | from diffusionsfm.dataset.co3d_v2 import square_bbox
11 |
12 |
13 | class CustomDataset(Dataset):
14 | def __init__(
15 | self,
16 | image_list,
17 | ):
18 | self.images = []
19 |
20 | for image_path in sorted(image_list):
21 | img = Image.open(image_path)
22 | img = ImageOps.exif_transpose(img).convert("RGB") # Apply EXIF rotation
23 | self.images.append(img)
24 |
25 | self.n = len(self.images)
26 | self.jitter_scale = [1, 1]
27 | self.jitter_trans = [0, 0]
28 | self.transform = transforms.Compose(
29 | [
30 | transforms.ToTensor(),
31 | transforms.Resize(224),
32 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
33 | ]
34 | )
35 | self.transform_for_vis = transforms.Compose(
36 | [
37 | transforms.Resize(224),
38 | ]
39 | )
40 |
41 | def __len__(self):
42 | return 1
43 |
44 | def _crop_image(self, image, bbox, white_bg=False):
45 | if white_bg:
46 | # Only support PIL Images
47 | image_crop = Image.new(
48 | "RGB", (bbox[2] - bbox[0], bbox[3] - bbox[1]), (255, 255, 255)
49 | )
50 | image_crop.paste(image, (-bbox[0], -bbox[1]))
51 | else:
52 | image_crop = transforms.functional.crop(
53 | image,
54 | top=bbox[1],
55 | left=bbox[0],
56 | height=bbox[3] - bbox[1],
57 | width=bbox[2] - bbox[0],
58 | )
59 | return image_crop
60 |
61 | def __getitem__(self):
62 | return self.get_data()
63 |
64 | def get_data(self):
65 | cmap = plt.get_cmap("hsv")
66 | ids = [i for i in range(len(self.images))]
67 | images = [self.images[i] for i in ids]
68 | images_transformed = []
69 | images_for_vis = []
70 | crop_parameters = []
71 |
72 | for i, image in enumerate(images):
73 | bbox = np.array([0, 0, image.width, image.height])
74 | bbox = square_bbox(bbox, tight=True)
75 | bbox = np.around(bbox).astype(int)
76 | image = self._crop_image(image, bbox)
77 | images_transformed.append(self.transform(image))
78 | image_for_vis = self.transform_for_vis(image)
79 | color_float = cmap(i / len(images))
80 | color_rgb = tuple(int(255 * c) for c in color_float[:3])
81 | image_for_vis = ImageOps.expand(image_for_vis, border=3, fill=color_rgb)
82 | images_for_vis.append(image_for_vis)
83 |
84 | width, height = image.size
85 | length = max(width, height)
86 | s = length / min(width, height)
87 | crop_center = (bbox[:2] + bbox[2:]) / 2
88 | crop_center = crop_center + (length - np.array([width, height])) / 2
89 | # convert to NDC
90 | cc = s - 2 * s * crop_center / length
91 | crop_width = 2 * s * (bbox[2] - bbox[0]) / length
92 | crop_params = torch.tensor([-cc[0], -cc[1], crop_width, s])
93 |
94 | crop_parameters.append(crop_params)
95 | images = images_transformed
96 |
97 | batch = {}
98 | batch["image"] = torch.stack(images)
99 | batch["image_for_vis"] = images_for_vis
100 | batch["n"] = len(images)
101 | batch["ind"] = torch.tensor(ids),
102 | batch["crop_parameters"] = torch.stack(crop_parameters)
103 | batch["distortion_parameters"] = torch.zeros(4)
104 |
105 | return batch
106 |
--------------------------------------------------------------------------------
/diffusionsfm/eval/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/QitaoZhao/DiffusionSfM/4bb08800721bdcf46b0c823586a2fab4702ff282/diffusionsfm/eval/__init__.py
--------------------------------------------------------------------------------
/diffusionsfm/eval/eval_category.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import torch
4 | import torchvision
5 | import numpy as np
6 | from tqdm.auto import tqdm
7 |
8 | from diffusionsfm.dataset.co3d_v2 import (
9 | Co3dDataset,
10 | full_scene_scale,
11 | )
12 | from pytorch3d.renderer import PerspectiveCameras
13 | from diffusionsfm.utils.visualization import filter_and_align_point_clouds
14 | from diffusionsfm.inference.load_model import load_model
15 | from diffusionsfm.inference.predict import predict_cameras
16 | from diffusionsfm.utils.geometry import (
17 | compute_angular_error_batch,
18 | get_error,
19 | n_to_np_rotations,
20 | )
21 | from diffusionsfm.utils.slurm import init_slurm_signals_if_slurm
22 | from diffusionsfm.utils.rays import cameras_to_rays
23 | from diffusionsfm.utils.rays import normalize_cameras_batch
24 |
25 |
26 | @torch.no_grad()
27 | def evaluate(
28 | cfg,
29 | model,
30 | dataset,
31 | num_images,
32 | device,
33 | use_pbar=True,
34 | calculate_intrinsics=True,
35 | additional_timesteps=(),
36 | num_evaluate=None,
37 | max_num_images=None,
38 | mode=None,
39 | metrics=True,
40 | load_depth=True,
41 | ):
42 | if cfg.training.get("dpt_head", False):
43 | H_in = W_in = 224
44 | H_out = W_out = cfg.training.full_num_patches_y
45 | else:
46 | H_in = H_out = cfg.model.num_patches_x
47 | W_in = W_out = cfg.model.num_patches_y
48 |
49 | results = {}
50 | instances = np.arange(0, len(dataset)) if num_evaluate is None else np.linspace(0, len(dataset) - 1, num_evaluate, endpoint=True, dtype=int)
51 | instances = tqdm(instances) if use_pbar else instances
52 |
53 | for counter, idx in enumerate(instances):
54 | batch = dataset[idx]
55 | instance = batch["model_id"]
56 | images = batch["image"].to(device)
57 | focal_length = batch["focal_length"].to(device)[:num_images]
58 | R = batch["R"].to(device)[:num_images]
59 | T = batch["T"].to(device)[:num_images]
60 | crop_parameters = batch["crop_parameters"].to(device)[:num_images]
61 |
62 | if load_depth:
63 | depths = batch["depth"].to(device)[:num_images]
64 | depth_masks = batch["depth_masks"].to(device)[:num_images]
65 | try:
66 | object_masks = batch["object_masks"].to(device)[:num_images]
67 | except KeyError:
68 | object_masks = depth_masks.clone()
69 |
70 | # Normalize cameras and scale depths for output resolution
71 | cameras_gt = PerspectiveCameras(
72 | R=R, T=T, focal_length=focal_length, device=device
73 | )
74 | cameras_gt, _, _ = normalize_cameras_batch(
75 | [cameras_gt],
76 | first_cam_mediod=cfg.training.first_cam_mediod,
77 | normalize_first_camera=cfg.training.normalize_first_camera,
78 | depths=depths.unsqueeze(0),
79 | crop_parameters=crop_parameters.unsqueeze(0),
80 | num_patches_x=H_in,
81 | num_patches_y=W_in,
82 | return_scales=True,
83 | )
84 | cameras_gt = cameras_gt[0]
85 |
86 | gt_rays = cameras_to_rays(
87 | cameras=cameras_gt,
88 | num_patches_x=H_in,
89 | num_patches_y=W_in,
90 | crop_parameters=crop_parameters,
91 | depths=depths,
92 | mode=mode,
93 | )
94 | gt_points = gt_rays.get_segments().view(num_images, -1, 3)
95 |
96 | resize = torchvision.transforms.Resize(
97 | 224,
98 | antialias=False,
99 | interpolation=torchvision.transforms.InterpolationMode.NEAREST_EXACT,
100 | )
101 | else:
102 | cameras_gt = PerspectiveCameras(
103 | R=R, T=T, focal_length=focal_length, device=device
104 | )
105 |
106 | pred_cameras, additional_cams = predict_cameras(
107 | model,
108 | images,
109 | device,
110 | crop_parameters=crop_parameters,
111 | num_patches_x=H_out,
112 | num_patches_y=W_out,
113 | max_num_images=max_num_images,
114 | additional_timesteps=additional_timesteps,
115 | calculate_intrinsics=calculate_intrinsics,
116 | mode=mode,
117 | return_rays=True,
118 | use_homogeneous=cfg.model.get("use_homogeneous", False),
119 | )
120 | cameras_to_evaluate = additional_cams + [pred_cameras]
121 |
122 | all_cams_batch = dataset.get_data(
123 | sequence_name=instance, ids=np.arange(0, batch["n"]), no_images=True
124 | )
125 | gt_scene_scale = full_scene_scale(all_cams_batch)
126 | R_gt = R
127 | T_gt = T
128 |
129 | errors = []
130 | for _, (camera, pred_rays) in enumerate(cameras_to_evaluate):
131 | R_pred = camera.R
132 | T_pred = camera.T
133 | f_pred = camera.focal_length
134 |
135 | R_pred_rel = n_to_np_rotations(num_images, R_pred).cpu().numpy()
136 | R_gt_rel = n_to_np_rotations(num_images, batch["R"]).cpu().numpy()
137 | R_error = compute_angular_error_batch(R_pred_rel, R_gt_rel)
138 |
139 | CC_error, _ = get_error(True, R_pred, T_pred, R_gt, T_gt, gt_scene_scale)
140 |
141 | if load_depth and metrics:
142 | # Evaluate outputs at the same resolution as DUSt3R
143 | pred_points = pred_rays.get_segments().view(num_images, H_out, H_out, 3)
144 | pred_points = pred_points.permute(0, 3, 1, 2)
145 | pred_points = resize(pred_points).permute(0, 2, 3, 1).view(num_images, H_in*W_in, 3)
146 |
147 | (
148 | _,
149 | _,
150 | _,
151 | _,
152 | metric_values,
153 | ) = filter_and_align_point_clouds(
154 | num_images,
155 | gt_points,
156 | pred_points,
157 | depth_masks,
158 | depth_masks,
159 | images,
160 | metrics=metrics,
161 | num_patches_x=H_in,
162 | )
163 |
164 | (
165 | _,
166 | _,
167 | _,
168 | _,
169 | object_metric_values,
170 | ) = filter_and_align_point_clouds(
171 | num_images,
172 | gt_points,
173 | pred_points,
174 | depth_masks * object_masks,
175 | depth_masks * object_masks,
176 | images,
177 | metrics=metrics,
178 | num_patches_x=H_in,
179 | )
180 |
181 | result = {
182 | "R_pred": R_pred.detach().cpu().numpy().tolist(),
183 | "T_pred": T_pred.detach().cpu().numpy().tolist(),
184 | "f_pred": f_pred.detach().cpu().numpy().tolist(),
185 | "R_gt": R_gt.detach().cpu().numpy().tolist(),
186 | "T_gt": T_gt.detach().cpu().numpy().tolist(),
187 | "f_gt": focal_length.detach().cpu().numpy().tolist(),
188 | "scene_scale": gt_scene_scale,
189 | "R_error": R_error.tolist(),
190 | "CC_error": CC_error,
191 | }
192 |
193 | if load_depth and metrics:
194 | result["CD"] = metric_values[1]
195 | result["CD_Object"] = object_metric_values[1]
196 | else:
197 | result["CD"] = 0
198 | result["CD_Object"] = 0
199 |
200 | errors.append(result)
201 |
202 | results[instance] = errors
203 |
204 | if counter == len(dataset) - 1:
205 | break
206 | return results
207 |
208 |
209 | def save_results(
210 | output_dir,
211 | checkpoint=800_000,
212 | category="hydrant",
213 | num_images=None,
214 | calculate_additional_timesteps=True,
215 | calculate_intrinsics=True,
216 | split="test",
217 | force=False,
218 | sample_num=1,
219 | max_num_images=None,
220 | dataset="co3d",
221 | ):
222 | init_slurm_signals_if_slurm()
223 | os.umask(000) # Default to 777 permissions
224 | eval_path = os.path.join(
225 | output_dir,
226 | f"eval_{dataset}",
227 | f"{category}_{num_images}_{sample_num}_ckpt{checkpoint}.json",
228 | )
229 |
230 | if os.path.exists(eval_path) and not force:
231 | print(f"File {eval_path} already exists. Skipping.")
232 | return
233 |
234 | if num_images is not None and num_images > 8:
235 | custom_keys = {"model.num_images": num_images}
236 | ignore_keys = ["pos_table"]
237 | else:
238 | custom_keys = None
239 | ignore_keys = []
240 |
241 | device = torch.device("cuda")
242 | model, cfg = load_model(
243 | output_dir,
244 | checkpoint=checkpoint,
245 | device=device,
246 | custom_keys=custom_keys,
247 | ignore_keys=ignore_keys,
248 | )
249 | if num_images is None:
250 | num_images = cfg.dataset.num_images
251 |
252 | if cfg.training.dpt_head:
253 | # Evaluate outputs at the same resolution as DUSt3R
254 | depth_size = 224
255 | else:
256 | depth_size = cfg.model.num_patches_x
257 |
258 | dataset = Co3dDataset(
259 | category=category,
260 | split=split,
261 | num_images=num_images,
262 | apply_augmentation=False,
263 | sample_num=None if split == "train" else sample_num,
264 | use_global_intrinsics=cfg.dataset.use_global_intrinsics,
265 | load_depths=True,
266 | center_crop=True,
267 | depth_size=depth_size,
268 | mask_holes=not cfg.training.regression,
269 | img_size=256 if cfg.model.unet_diffuser else 224,
270 | )
271 | print(f"Category {category} {len(dataset)}")
272 |
273 | if calculate_additional_timesteps:
274 | additional_timesteps = [0, 10, 20, 30, 40, 50, 60, 70, 80, 90]
275 | else:
276 | additional_timesteps = []
277 |
278 | results = evaluate(
279 | cfg=cfg,
280 | model=model,
281 | dataset=dataset,
282 | num_images=num_images,
283 | device=device,
284 | calculate_intrinsics=calculate_intrinsics,
285 | additional_timesteps=additional_timesteps,
286 | max_num_images=max_num_images,
287 | mode="segment",
288 | )
289 |
290 | os.makedirs(os.path.dirname(eval_path), exist_ok=True)
291 | with open(eval_path, "w") as f:
292 | json.dump(results, f)
--------------------------------------------------------------------------------
/diffusionsfm/eval/eval_jobs.py:
--------------------------------------------------------------------------------
1 | """
2 | python -m diffusionsfm.eval.eval_jobs --eval_path output/multi_diffusionsfm_dense --use_submitit
3 | """
4 |
5 | import os
6 | import json
7 | import submitit
8 | import argparse
9 | import itertools
10 | from glob import glob
11 |
12 | import numpy as np
13 | from tqdm.auto import tqdm
14 |
15 | from diffusionsfm.dataset.co3d_v2 import TEST_CATEGORIES, TRAINING_CATEGORIES
16 | from diffusionsfm.eval.eval_category import save_results
17 | from diffusionsfm.utils.slurm import submitit_job_watcher
18 |
19 |
20 | def evaluate_diffusionsfm(eval_path, use_submitit, mode):
21 | JOB_PARAMS = {
22 | "output_dir": [eval_path],
23 | "checkpoint": [800_000],
24 | "num_images": [2, 3, 4, 5, 6, 7, 8],
25 | "sample_num": [0, 1, 2, 3, 4],
26 | "category": TEST_CATEGORIES, # TRAINING_CATEGORIES + TEST_CATEGORIES,
27 | "calculate_additional_timesteps": [True],
28 | }
29 | if mode == "test":
30 | JOB_PARAMS["category"] = TEST_CATEGORIES
31 | elif mode == "train1":
32 | JOB_PARAMS["category"] = TRAINING_CATEGORIES[:len(TRAINING_CATEGORIES) // 2]
33 | elif mode == "train2":
34 | JOB_PARAMS["category"] = TRAINING_CATEGORIES[len(TRAINING_CATEGORIES) // 2:]
35 | keys, values = zip(*JOB_PARAMS.items())
36 | job_configs = [dict(zip(keys, p)) for p in itertools.product(*values)]
37 |
38 | if use_submitit:
39 | log_output = "./slurm_logs"
40 | executor = submitit.AutoExecutor(
41 | cluster=None, folder=log_output, slurm_max_num_timeout=10
42 | )
43 | # Use your own parameters
44 | executor.update_parameters(
45 | slurm_additional_parameters={
46 | "nodes": 1,
47 | "cpus-per-task": 5,
48 | "gpus": 1,
49 | "time": "6:00:00",
50 | "partition": "all",
51 | "exclude": "grogu-1-9, grogu-1-14,"
52 | }
53 | )
54 | jobs = []
55 | with executor.batch():
56 | # This context manager submits all jobs at once at the end.
57 | for params in job_configs:
58 | job = executor.submit(save_results, **params)
59 | job_param = f"{params['category']}_N{params['num_images']}_{params['sample_num']}"
60 | jobs.append((job_param, job))
61 | jobs = {f"{job_param}_{job.job_id}": job for job_param, job in jobs}
62 | submitit_job_watcher(jobs)
63 | else:
64 | for job_config in tqdm(job_configs):
65 | # This is much slower.
66 | save_results(**job_config)
67 |
68 |
69 | def process_predictions(eval_path, pred_index, checkpoint=800_000, threshold_R=15, threshold_CC=0.1):
70 | """
71 | pred_index should be 1 (corresponding to T=90)
72 | """
73 | def aggregate_per_category(categories, metric_key, num_images, sample_num, threshold=None):
74 | """
75 | Aggregates one metric over all data points in a prediction file and then across categories.
76 | - For R_error and CC_error: use mean to threshold-based accuracy
77 | - For CD and CD_Object: use median to reduce the effect of outliers
78 | """
79 | per_category_values = []
80 |
81 | for category in tqdm(categories, desc=f"Sample {sample_num}, N={num_images}, {metric_key}"):
82 | per_pred_values = []
83 |
84 | data_path = glob(
85 | os.path.join(eval_path, "eval", f"{category}_{num_images}_{sample_num}_ckpt{checkpoint}*.json")
86 | )[0]
87 |
88 | with open(data_path) as f:
89 | eval_data = json.load(f)
90 |
91 | for preds in eval_data.values():
92 | if metric_key in ["R_error", "CC_error"]:
93 | vals = np.array(preds[pred_index][metric_key])
94 | per_pred_values.append(np.mean(vals < threshold))
95 | else:
96 | per_pred_values.append(preds[pred_index][metric_key])
97 |
98 | # Aggregate over all predictions within this category
99 | per_category_values.append(
100 | np.mean(per_pred_values) if metric_key in ["R_error", "CC_error"]
101 | else np.median(per_pred_values) # CD or CD_Object — use median to filter outliers
102 | )
103 |
104 | if metric_key in ["R_error", "CC_error"]:
105 | return np.mean(per_category_values)
106 | else:
107 | return np.median(per_category_values)
108 |
109 | def aggregate_metric(categories, metric_key, num_images, threshold=None):
110 | """Aggregates one metric over 5 random samples per category and returns the final mean"""
111 | return np.mean([
112 | aggregate_per_category(categories, metric_key, num_images, sample_num, threshold=threshold)
113 | for sample_num in range(5)
114 | ])
115 |
116 | # Output containers
117 | all_seen_acc_R, all_seen_acc_CC = [], []
118 | all_seen_CD, all_seen_CD_Object = [], []
119 | all_unseen_acc_R, all_unseen_acc_CC = [], []
120 | all_unseen_CD, all_unseen_CD_Object = [], []
121 |
122 | for num_images in range(2, 9):
123 | # Seen categories
124 | all_seen_acc_R.append(
125 | aggregate_metric(TRAINING_CATEGORIES, "R_error", num_images, threshold=threshold_R)
126 | )
127 | all_seen_acc_CC.append(
128 | aggregate_metric(TRAINING_CATEGORIES, "CC_error", num_images, threshold=threshold_CC)
129 | )
130 | all_seen_CD.append(
131 | aggregate_metric(TRAINING_CATEGORIES, "CD", num_images)
132 | )
133 | all_seen_CD_Object.append(
134 | aggregate_metric(TRAINING_CATEGORIES, "CD_Object", num_images)
135 | )
136 |
137 | # Unseen categories
138 | all_unseen_acc_R.append(
139 | aggregate_metric(TEST_CATEGORIES, "R_error", num_images, threshold=threshold_R)
140 | )
141 | all_unseen_acc_CC.append(
142 | aggregate_metric(TEST_CATEGORIES, "CC_error", num_images, threshold=threshold_CC)
143 | )
144 | all_unseen_CD.append(
145 | aggregate_metric(TEST_CATEGORIES, "CD", num_images)
146 | )
147 | all_unseen_CD_Object.append(
148 | aggregate_metric(TEST_CATEGORIES, "CD_Object", num_images)
149 | )
150 |
151 | # Print the results in formatted rows
152 | print("N= ", " ".join(f"{i: 5}" for i in range(2, 9)))
153 | print("Seen R ", " ".join([f"{x:0.3f}" for x in all_seen_acc_R]))
154 | print("Seen CC ", " ".join([f"{x:0.3f}" for x in all_seen_acc_CC]))
155 | print("Seen CD ", " ".join([f"{x:0.3f}" for x in all_seen_CD]))
156 | print("Seen CD_Obj ", " ".join([f"{x:0.3f}" for x in all_seen_CD_Object]))
157 | print("Unseen R ", " ".join([f"{x:0.3f}" for x in all_unseen_acc_R]))
158 | print("Unseen CC ", " ".join([f"{x:0.3f}" for x in all_unseen_acc_CC]))
159 | print("Unseen CD ", " ".join([f"{x:0.3f}" for x in all_unseen_CD]))
160 | print("Unseen CD_Obj", " ".join([f"{x:0.3f}" for x in all_unseen_CD_Object]))
161 |
162 |
163 | if __name__ == "__main__":
164 | parser = argparse.ArgumentParser()
165 | parser.add_argument("--eval_path", type=str, default=None)
166 | parser.add_argument("--use_submitit", action="store_true")
167 | parser.add_argument("--mode", type=str, default="test")
168 | args = parser.parse_args()
169 |
170 | eval_path = "output/multi_diffusionsfm_dense" if args.eval_path is None else args.eval_path
171 | use_submitit = args.use_submitit
172 | mode = args.mode
173 |
174 | evaluate_diffusionsfm(eval_path, use_submitit, mode)
175 | process_predictions(eval_path, 1)
--------------------------------------------------------------------------------
/diffusionsfm/inference/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/QitaoZhao/DiffusionSfM/4bb08800721bdcf46b0c823586a2fab4702ff282/diffusionsfm/inference/__init__.py
--------------------------------------------------------------------------------
/diffusionsfm/inference/ddim.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import random
3 | import numpy as np
4 | from tqdm.auto import tqdm
5 |
6 | from diffusionsfm.utils.rays import compute_ndc_coordinates
7 |
8 |
9 | def inference_ddim(
10 | model,
11 | images,
12 | device,
13 | crop_parameters=None,
14 | eta=0,
15 | num_inference_steps=100,
16 | pbar=True,
17 | stop_iteration=None,
18 | num_patches_x=16,
19 | num_patches_y=16,
20 | visualize=False,
21 | max_num_images=8,
22 | seed=0,
23 | ):
24 | """
25 | Implements DDIM-style inference.
26 |
27 | To get multiple samples, batch the images multiple times.
28 |
29 | Args:
30 | model: Ray Diffuser.
31 | images (torch.Tensor): (B, N, C, H, W).
32 | patch_rays_gt (torch.Tensor): If provided, the patch rays which are ground
33 | truth (B, N, P, 6).
34 | eta (float, optional): Stochasticity coefficient. 0 is completely deterministic,
35 | 1 is equivalent to DDPM. (Default: 0)
36 | num_inference_steps (int, optional): Number of inference steps. (Default: 100)
37 | pbar (bool, optional): Whether to show progress bar. (Default: True)
38 | """
39 | timesteps = model.noise_scheduler.compute_inference_timesteps(num_inference_steps)
40 | batch_size = images.shape[0]
41 | num_images = images.shape[1]
42 |
43 | if isinstance(eta, list):
44 | eta_0, eta_1 = float(eta[0]), float(eta[1])
45 | else:
46 | eta_0, eta_1 = 0, 0
47 |
48 | # Fixing seed
49 | if seed is not None:
50 | torch.manual_seed(seed)
51 | random.seed(seed)
52 | np.random.seed(seed)
53 |
54 | with torch.no_grad():
55 | x_tau = torch.randn(
56 | batch_size,
57 | num_images,
58 | model.ray_out if hasattr(model, "ray_out") else model.ray_dim,
59 | num_patches_x,
60 | num_patches_y,
61 | device=device,
62 | )
63 |
64 | if visualize:
65 | x_taus = [x_tau]
66 | all_pred = []
67 | noise_samples = []
68 |
69 | image_features = model.feature_extractor(images, autoresize=True)
70 |
71 | if model.append_ndc:
72 | ndc_coordinates = compute_ndc_coordinates(
73 | crop_parameters=crop_parameters,
74 | no_crop_param_device="cpu",
75 | num_patches_x=model.width,
76 | num_patches_y=model.width,
77 | distortion_coeffs=None,
78 | )[..., :2].to(device)
79 | ndc_coordinates = ndc_coordinates.permute(0, 1, 4, 2, 3)
80 | else:
81 | ndc_coordinates = None
82 |
83 | if stop_iteration is None:
84 | loop = range(len(timesteps))
85 | else:
86 | loop = range(len(timesteps) - stop_iteration + 1)
87 | loop = tqdm(loop) if pbar else loop
88 |
89 | for t in loop:
90 | tau = timesteps[t]
91 |
92 | if tau > 0 and eta_1 > 0:
93 | z = torch.randn(
94 | batch_size,
95 | num_images,
96 | model.ray_out if hasattr(model, "ray_out") else model.ray_dim,
97 | num_patches_x,
98 | num_patches_y,
99 | device=device,
100 | )
101 | else:
102 | z = 0
103 |
104 | alpha = model.noise_scheduler.alphas_cumprod[tau]
105 | if tau > 0:
106 | tau_prev = timesteps[t + 1]
107 | alpha_prev = model.noise_scheduler.alphas_cumprod[tau_prev]
108 | else:
109 | alpha_prev = torch.tensor(1.0, device=device).float()
110 |
111 | sigma_t = (
112 | torch.sqrt((1 - alpha_prev) / (1 - alpha))
113 | * torch.sqrt(1 - alpha / alpha_prev)
114 | )
115 |
116 | if num_images > max_num_images:
117 | eps_pred = torch.zeros_like(x_tau)
118 | noise_sample = torch.zeros_like(x_tau)
119 |
120 | # Randomly split image indices (excluding index 0), then prepend 0 to each split
121 | indices_split = torch.split(
122 | torch.randperm(num_images - 1) + 1, max_num_images - 1
123 | )
124 |
125 | for indices in indices_split:
126 | indices = torch.cat((torch.tensor([0]), indices)) # Ensure index 0 is always included
127 |
128 | eps_pred_ind, noise_sample_ind = model(
129 | features=image_features[:, indices],
130 | rays_noisy=x_tau[:, indices],
131 | t=int(tau),
132 | ndc_coordinates=ndc_coordinates[:, indices],
133 | indices=indices,
134 | )
135 |
136 | eps_pred[:, indices] += eps_pred_ind
137 |
138 | if noise_sample_ind is not None:
139 | noise_sample[:, indices] += noise_sample_ind
140 |
141 | # Average over splits for the shared reference index (0)
142 | eps_pred[:, 0] /= len(indices_split)
143 | noise_sample[:, 0] /= len(indices_split)
144 | else:
145 | eps_pred, noise_sample = model(
146 | features=image_features,
147 | rays_noisy=x_tau,
148 | t=int(tau),
149 | ndc_coordinates=ndc_coordinates,
150 | )
151 |
152 | if model.use_homogeneous:
153 | p1 = eps_pred[:, :, :4]
154 | p2 = eps_pred[:, :, 4:]
155 |
156 | c1 = torch.linalg.norm(p1, dim=2, keepdim=True)
157 | c2 = torch.linalg.norm(p2, dim=2, keepdim=True)
158 | eps_pred[:, :, :4] = p1 / c1
159 | eps_pred[:, :, 4:] = p2 / c2
160 |
161 | if visualize:
162 | all_pred.append(eps_pred.clone())
163 | noise_samples.append(noise_sample)
164 |
165 | # TODO: Can simplify this a lot
166 | x0_pred = eps_pred.clone()
167 | eps_pred = (x_tau - torch.sqrt(alpha) * eps_pred) / torch.sqrt(
168 | 1 - alpha
169 | )
170 |
171 | dir_x_tau = torch.sqrt(1 - alpha_prev - eta_0*sigma_t**2) * eps_pred
172 | noise = eta_1 * sigma_t * z
173 |
174 | new_x_tau = torch.sqrt(alpha_prev) * x0_pred + dir_x_tau + noise
175 | x_tau = new_x_tau
176 |
177 | if visualize:
178 | x_taus.append(x_tau.detach().clone())
179 | if visualize:
180 | return x_tau, x_taus, all_pred, noise_samples
181 | return x_tau
182 |
--------------------------------------------------------------------------------
/diffusionsfm/inference/load_model.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 | from glob import glob
3 |
4 | import torch
5 | from omegaconf import OmegaConf
6 |
7 | from diffusionsfm.model.diffuser import RayDiffuser
8 | from diffusionsfm.model.diffuser_dpt import RayDiffuserDPT
9 | from diffusionsfm.model.scheduler import NoiseScheduler
10 |
11 |
12 | def load_model(
13 | output_dir, checkpoint=None, device="cuda:0", custom_keys=None, ignore_keys=()
14 | ):
15 | """
16 | Loads a model and config from an output directory.
17 |
18 | E.g. to load with different number of images,
19 | ```
20 | custom_keys={"model.num_images": 15}, ignore_keys=["pos_table"]
21 | ```
22 |
23 | Args:
24 | output_dir (str): Path to the output directory.
25 | checkpoint (str or int): Path to the checkpoint to load. If None, loads the
26 | latest checkpoint.
27 | device (str): Device to load the model on.
28 | custom_keys (dict): Dictionary of custom keys to override in the config.
29 | """
30 | if checkpoint is None:
31 | checkpoint_path = sorted(glob(osp.join(output_dir, "checkpoints", "*.pth")))[-1]
32 | else:
33 | if isinstance(checkpoint, int):
34 | checkpoint_name = f"ckpt_{checkpoint:08d}.pth"
35 | else:
36 | checkpoint_name = checkpoint
37 | checkpoint_path = osp.join(output_dir, "checkpoints", checkpoint_name)
38 | print("Loading checkpoint", osp.basename(checkpoint_path))
39 |
40 | cfg = OmegaConf.load(osp.join(output_dir, "hydra", "config.yaml"))
41 | if custom_keys is not None:
42 | for k, v in custom_keys.items():
43 | OmegaConf.update(cfg, k, v)
44 | noise_scheduler = NoiseScheduler(
45 | type=cfg.noise_scheduler.type,
46 | max_timesteps=cfg.noise_scheduler.max_timesteps,
47 | beta_start=cfg.noise_scheduler.beta_start,
48 | beta_end=cfg.noise_scheduler.beta_end,
49 | )
50 |
51 | if not cfg.training.get("dpt_head", False):
52 | model = RayDiffuser(
53 | depth=cfg.model.depth,
54 | width=cfg.model.num_patches_x,
55 | P=1,
56 | max_num_images=cfg.model.num_images,
57 | noise_scheduler=noise_scheduler,
58 | feature_extractor=cfg.model.feature_extractor,
59 | append_ndc=cfg.model.append_ndc,
60 | diffuse_depths=cfg.training.get("diffuse_depths", False),
61 | depth_resolution=cfg.training.get("depth_resolution", 1),
62 | use_homogeneous=cfg.model.get("use_homogeneous", False),
63 | cond_depth_mask=cfg.model.get("cond_depth_mask", False),
64 | ).to(device)
65 | else:
66 | model = RayDiffuserDPT(
67 | depth=cfg.model.depth,
68 | width=cfg.model.num_patches_x,
69 | P=1,
70 | max_num_images=cfg.model.num_images,
71 | noise_scheduler=noise_scheduler,
72 | feature_extractor=cfg.model.feature_extractor,
73 | append_ndc=cfg.model.append_ndc,
74 | diffuse_depths=cfg.training.get("diffuse_depths", False),
75 | depth_resolution=cfg.training.get("depth_resolution", 1),
76 | encoder_features=cfg.training.get("dpt_encoder_features", False),
77 | use_homogeneous=cfg.model.get("use_homogeneous", False),
78 | cond_depth_mask=cfg.model.get("cond_depth_mask", False),
79 | ).to(device)
80 |
81 | data = torch.load(checkpoint_path)
82 | state_dict = {}
83 | for k, v in data["state_dict"].items():
84 | include = True
85 | for ignore_key in ignore_keys:
86 | if ignore_key in k:
87 | include = False
88 | if include:
89 | state_dict[k] = v
90 |
91 | missing, unexpected = model.load_state_dict(state_dict, strict=False)
92 | if len(missing) > 0:
93 | print("Missing keys:", missing)
94 | if len(unexpected) > 0:
95 | print("Unexpected keys:", unexpected)
96 | model = model.eval()
97 | return model, cfg
98 |
--------------------------------------------------------------------------------
/diffusionsfm/inference/predict.py:
--------------------------------------------------------------------------------
1 | from diffusionsfm.inference.ddim import inference_ddim
2 | from diffusionsfm.utils.rays import (
3 | Rays,
4 | rays_to_cameras,
5 | rays_to_cameras_homography,
6 | )
7 |
8 |
9 | def predict_cameras(
10 | model,
11 | images,
12 | device,
13 | crop_parameters=None,
14 | stop_iteration=None,
15 | num_patches_x=16,
16 | num_patches_y=16,
17 | additional_timesteps=(),
18 | calculate_intrinsics=False,
19 | max_num_images=8,
20 | mode=None,
21 | return_rays=False,
22 | use_homogeneous=False,
23 | seed=0,
24 | ):
25 | """
26 | Args:
27 | images (torch.Tensor): (N, C, H, W)
28 | crop_parameters (torch.Tensor): (N, 4) or None
29 | """
30 | if calculate_intrinsics:
31 | ray_to_cam = rays_to_cameras_homography
32 | else:
33 | ray_to_cam = rays_to_cameras
34 |
35 | get_spatial_rays = Rays.from_spatial
36 |
37 | rays_final, rays_intermediate, pred_intermediate, _ = inference_ddim(
38 | model,
39 | images.unsqueeze(0),
40 | device,
41 | crop_parameters=crop_parameters.unsqueeze(0),
42 | pbar=False,
43 | stop_iteration=stop_iteration,
44 | eta=[1, 0],
45 | num_inference_steps=100,
46 | num_patches_x=num_patches_x,
47 | num_patches_y=num_patches_y,
48 | visualize=True,
49 | max_num_images=max_num_images,
50 | )
51 |
52 | spatial_rays = get_spatial_rays(
53 | rays_final[0],
54 | mode=mode,
55 | num_patches_x=num_patches_x,
56 | num_patches_y=num_patches_y,
57 | use_homogeneous=use_homogeneous,
58 | )
59 |
60 | pred_cam = ray_to_cam(
61 | spatial_rays,
62 | crop_parameters,
63 | num_patches_x=num_patches_x,
64 | num_patches_y=num_patches_y,
65 | depth_resolution=model.depth_resolution,
66 | average_centers=True,
67 | directions_from_averaged_center=True,
68 | )
69 |
70 | additional_predictions = []
71 | for t in additional_timesteps:
72 | ray = pred_intermediate[t]
73 |
74 | ray = get_spatial_rays(
75 | ray[0],
76 | mode=mode,
77 | num_patches_x=num_patches_x,
78 | num_patches_y=num_patches_y,
79 | use_homogeneous=use_homogeneous,
80 | )
81 |
82 | cam = ray_to_cam(
83 | ray,
84 | crop_parameters,
85 | num_patches_x=num_patches_x,
86 | num_patches_y=num_patches_y,
87 | average_centers=True,
88 | directions_from_averaged_center=True,
89 | )
90 | if return_rays:
91 | cam = (cam, ray)
92 | additional_predictions.append(cam)
93 |
94 | if return_rays:
95 | return (pred_cam, spatial_rays), additional_predictions
96 | return pred_cam, additional_predictions, spatial_rays
--------------------------------------------------------------------------------
/diffusionsfm/model/base_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class BaseModel(torch.nn.Module):
5 | def load(self, path):
6 | """Load model from file.
7 |
8 | Args:
9 | path (str): file path
10 | """
11 | parameters = torch.load(path, map_location=torch.device("cpu"))
12 |
13 | if "optimizer" in parameters:
14 | parameters = parameters["model"]
15 |
16 | self.load_state_dict(parameters)
17 |
--------------------------------------------------------------------------------
/diffusionsfm/model/blocks.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from diffusionsfm.model.dit import TimestepEmbedder
4 | import ipdb
5 |
6 |
7 | def modulate(x, shift, scale):
8 | return x * (1 + scale.unsqueeze(-1).unsqueeze(-1)) + shift.unsqueeze(-1).unsqueeze(
9 | -1
10 | )
11 |
12 |
13 | def _make_fusion_block(features, use_bn, use_ln, dpt_time, resolution):
14 | return FeatureFusionBlock_custom(
15 | features,
16 | nn.ReLU(False),
17 | deconv=False,
18 | bn=use_bn,
19 | expand=False,
20 | align_corners=True,
21 | dpt_time=dpt_time,
22 | ln=use_ln,
23 | resolution=resolution
24 | )
25 |
26 |
27 | def _make_scratch(in_shape, out_shape, groups=1, expand=False):
28 | scratch = nn.Module()
29 |
30 | out_shape1 = out_shape
31 | out_shape2 = out_shape
32 | out_shape3 = out_shape
33 | out_shape4 = out_shape
34 | if expand == True:
35 | out_shape1 = out_shape
36 | out_shape2 = out_shape * 2
37 | out_shape3 = out_shape * 4
38 | out_shape4 = out_shape * 8
39 |
40 | scratch.layer1_rn = nn.Conv2d(
41 | in_shape[0],
42 | out_shape1,
43 | kernel_size=3,
44 | stride=1,
45 | padding=1,
46 | bias=False,
47 | groups=groups,
48 | )
49 | scratch.layer2_rn = nn.Conv2d(
50 | in_shape[1],
51 | out_shape2,
52 | kernel_size=3,
53 | stride=1,
54 | padding=1,
55 | bias=False,
56 | groups=groups,
57 | )
58 | scratch.layer3_rn = nn.Conv2d(
59 | in_shape[2],
60 | out_shape3,
61 | kernel_size=3,
62 | stride=1,
63 | padding=1,
64 | bias=False,
65 | groups=groups,
66 | )
67 | scratch.layer4_rn = nn.Conv2d(
68 | in_shape[3],
69 | out_shape4,
70 | kernel_size=3,
71 | stride=1,
72 | padding=1,
73 | bias=False,
74 | groups=groups,
75 | )
76 |
77 | return scratch
78 |
79 |
80 | class ResidualConvUnit_custom(nn.Module):
81 | """Residual convolution module."""
82 |
83 | def __init__(self, features, activation, bn, ln, dpt_time=False, resolution=16):
84 | """Init.
85 |
86 | Args:
87 | features (int): number of features
88 | """
89 | super().__init__()
90 |
91 | self.bn = bn
92 | self.ln = ln
93 |
94 | self.groups = 1
95 |
96 | self.conv1 = nn.Conv2d(
97 | features,
98 | features,
99 | kernel_size=3,
100 | stride=1,
101 | padding=1,
102 | bias=not self.bn,
103 | groups=self.groups,
104 | )
105 |
106 | self.conv2 = nn.Conv2d(
107 | features,
108 | features,
109 | kernel_size=3,
110 | stride=1,
111 | padding=1,
112 | bias=not self.bn,
113 | groups=self.groups,
114 | )
115 |
116 | nn.init.kaiming_uniform_(self.conv1.weight)
117 | nn.init.kaiming_uniform_(self.conv2.weight)
118 |
119 | if self.bn == True:
120 | self.bn1 = nn.BatchNorm2d(features)
121 | self.bn2 = nn.BatchNorm2d(features)
122 |
123 | if self.ln == True:
124 | self.bn1 = nn.LayerNorm((features, resolution, resolution))
125 | self.bn2 = nn.LayerNorm((features, resolution, resolution))
126 |
127 | self.activation = activation
128 |
129 | if dpt_time:
130 | self.t_embedder = TimestepEmbedder(hidden_size=features)
131 | self.adaLN_modulation = nn.Sequential(
132 | nn.SiLU(), nn.Linear(features, 3 * features, bias=True)
133 | )
134 |
135 | def forward(self, x, t=None):
136 | """Forward pass.
137 |
138 | Args:
139 | x (tensor): input
140 |
141 | Returns:
142 | tensor: output
143 | """
144 | if t is not None:
145 | # Embed timestamp & calculate shift parameters
146 | t = self.t_embedder(t) # (B*N)
147 | shift, scale, gate = self.adaLN_modulation(t).chunk(3, dim=1) # (B * N, T)
148 |
149 | # Shift & scale x
150 | x = modulate(x, shift, scale) # (B * N, T, H, W)
151 |
152 | out = self.activation(x)
153 | out = self.conv1(out)
154 | if self.bn or self.ln:
155 | out = self.bn1(out)
156 |
157 | out = self.activation(out)
158 | out = self.conv2(out)
159 | if self.bn or self.ln:
160 | out = self.bn2(out)
161 |
162 | if self.groups > 1:
163 | out = self.conv_merge(out)
164 |
165 | if t is not None:
166 | out = gate.unsqueeze(-1).unsqueeze(-1) * out
167 |
168 | return out + x
169 |
170 |
171 | class FeatureFusionBlock_custom(nn.Module):
172 | """Feature fusion block."""
173 |
174 | def __init__(
175 | self,
176 | features,
177 | activation,
178 | deconv=False,
179 | bn=False,
180 | ln=False,
181 | expand=False,
182 | align_corners=True,
183 | dpt_time=False,
184 | resolution=16,
185 | ):
186 | """Init.
187 |
188 | Args:
189 | features (int): number of features
190 | """
191 | super(FeatureFusionBlock_custom, self).__init__()
192 |
193 | self.deconv = deconv
194 | self.align_corners = align_corners
195 |
196 | self.groups = 1
197 |
198 | self.expand = expand
199 | out_features = features
200 | if self.expand == True:
201 | out_features = features // 2
202 |
203 | self.out_conv = nn.Conv2d(
204 | features,
205 | out_features,
206 | kernel_size=1,
207 | stride=1,
208 | padding=0,
209 | bias=True,
210 | groups=1,
211 | )
212 |
213 | nn.init.kaiming_uniform_(self.out_conv.weight)
214 |
215 | # The second block sees time
216 | self.resConfUnit1 = ResidualConvUnit_custom(
217 | features, activation, bn=bn, ln=ln, dpt_time=False, resolution=resolution
218 | )
219 | self.resConfUnit2 = ResidualConvUnit_custom(
220 | features, activation, bn=bn, ln=ln, dpt_time=dpt_time, resolution=resolution
221 | )
222 |
223 | def forward(self, input, activation=None, t=None):
224 | """Forward pass.
225 |
226 | Returns:
227 | tensor: output
228 | """
229 | output = input
230 |
231 | if activation is not None:
232 | res = self.resConfUnit1(activation)
233 |
234 | output += res
235 |
236 | output = self.resConfUnit2(output, t)
237 |
238 | output = torch.nn.functional.interpolate(
239 | output.float(),
240 | scale_factor=2,
241 | mode="bilinear",
242 | align_corners=self.align_corners,
243 | )
244 |
245 | output = self.out_conv(output)
246 |
247 | return output
248 |
--------------------------------------------------------------------------------
/diffusionsfm/model/diffuser.py:
--------------------------------------------------------------------------------
1 | import ipdb # noqa: F401
2 | import numpy as np
3 | import torch
4 | import torch.nn as nn
5 |
6 | from diffusionsfm.model.dit import DiT
7 | from diffusionsfm.model.feature_extractors import PretrainedVAE, SpatialDino
8 | from diffusionsfm.model.scheduler import NoiseScheduler
9 |
10 | from huggingface_hub import PyTorchModelHubMixin
11 |
12 |
13 | class RayDiffuser(nn.Module, PyTorchModelHubMixin,
14 | repo_url="https://github.com/QitaoZhao/DiffusionSfM",
15 | paper_url="https://huggingface.co/papers/2505.05473",
16 | pipeline_tag="image-to-3d",
17 | license="mit"):
18 | def __init__(
19 | self,
20 | model_type="dit",
21 | depth=8,
22 | width=16,
23 | hidden_size=1152,
24 | P=1,
25 | max_num_images=1,
26 | noise_scheduler=None,
27 | freeze_encoder=True,
28 | feature_extractor="dino",
29 | append_ndc=True,
30 | use_unconditional=False,
31 | diffuse_depths=False,
32 | depth_resolution=1,
33 | use_homogeneous=False,
34 | cond_depth_mask=False,
35 | ):
36 | super().__init__()
37 | if noise_scheduler is None:
38 | self.noise_scheduler = NoiseScheduler()
39 | else:
40 | self.noise_scheduler = noise_scheduler
41 |
42 | self.diffuse_depths = diffuse_depths
43 | self.depth_resolution = depth_resolution
44 | self.use_homogeneous = use_homogeneous
45 |
46 | self.ray_dim = 3
47 | if self.use_homogeneous:
48 | self.ray_dim += 1
49 |
50 | self.ray_dim += self.ray_dim * self.depth_resolution**2
51 |
52 | if self.diffuse_depths:
53 | self.ray_dim += 1
54 |
55 | self.append_ndc = append_ndc
56 | self.width = width
57 |
58 | self.max_num_images = max_num_images
59 | self.model_type = model_type
60 | self.use_unconditional = use_unconditional
61 | self.cond_depth_mask = cond_depth_mask
62 |
63 | if feature_extractor == "dino":
64 | self.feature_extractor = SpatialDino(
65 | freeze_weights=freeze_encoder, num_patches_x=width, num_patches_y=width
66 | )
67 | self.feature_dim = self.feature_extractor.feature_dim
68 | elif feature_extractor == "vae":
69 | self.feature_extractor = PretrainedVAE(
70 | freeze_weights=freeze_encoder, num_patches_x=width, num_patches_y=width
71 | )
72 | self.feature_dim = self.feature_extractor.feature_dim
73 | else:
74 | raise Exception(f"Unknown feature extractor {feature_extractor}")
75 |
76 | if self.use_unconditional:
77 | self.register_parameter(
78 | "null_token", nn.Parameter(torch.randn(self.feature_dim, 1, 1))
79 | )
80 |
81 | self.input_dim = self.feature_dim * 2
82 |
83 | if self.append_ndc:
84 | self.input_dim += 2
85 |
86 | if model_type == "dit":
87 | self.ray_predictor = DiT(
88 | in_channels=self.input_dim,
89 | out_channels=self.ray_dim,
90 | width=width,
91 | depth=depth,
92 | hidden_size=hidden_size,
93 | max_num_images=max_num_images,
94 | P=P,
95 | )
96 |
97 | self.scratch = nn.Module()
98 | self.scratch.input_conv = nn.Linear(self.ray_dim + int(self.cond_depth_mask), self.feature_dim)
99 |
100 | def forward_noise(
101 | self, x, t, epsilon=None, zero_out_mask=None
102 | ):
103 | """
104 | Applies forward diffusion (adds noise) to the input.
105 |
106 | If a mask is provided, the noise is only applied to the masked inputs.
107 | """
108 | t = t.reshape(-1, 1, 1, 1, 1)
109 |
110 | if epsilon is None:
111 | epsilon = torch.randn_like(x)
112 | else:
113 | epsilon = epsilon.reshape(x.shape)
114 |
115 | alpha_bar = self.noise_scheduler.alphas_cumprod[t]
116 | x_noise = torch.sqrt(alpha_bar) * x + torch.sqrt(1 - alpha_bar) * epsilon
117 |
118 | if zero_out_mask is not None and self.cond_depth_mask:
119 | x_noise = x_noise * zero_out_mask
120 |
121 | return x_noise, epsilon
122 |
123 | def forward(
124 | self,
125 | features=None,
126 | images=None,
127 | rays=None,
128 | rays_noisy=None,
129 | t=None,
130 | ndc_coordinates=None,
131 | unconditional_mask=None,
132 | return_dpt_activations=False,
133 | depth_mask=None,
134 | ):
135 | """
136 | Args:
137 | images: (B, N, 3, H, W).
138 | t: (B,).
139 | rays: (B, N, 6, H, W).
140 | rays_noisy: (B, N, 6, H, W).
141 | ndc_coordinates: (B, N, 2, H, W).
142 | unconditional_mask: (B, N) or (B,). Should be 1 for unconditional samples
143 | and 0 else.
144 | """
145 |
146 | if features is None:
147 | # VAE expects 256x256 images while DINO expects 224x224 images.
148 | # Both feature extractors support autoresize=True, but ideally we should
149 | # set this to be false and handle in the dataloader.
150 | features = self.feature_extractor(images, autoresize=True)
151 |
152 | B = features.shape[0]
153 |
154 | if (
155 | unconditional_mask is not None
156 | and self.use_unconditional
157 | ):
158 | null_token = self.null_token.reshape(1, 1, self.feature_dim, 1, 1)
159 | unconditional_mask = unconditional_mask.reshape(B, -1, 1, 1, 1)
160 | features = (
161 | features * (1 - unconditional_mask) + null_token * unconditional_mask
162 | )
163 |
164 | if isinstance(t, int) or isinstance(t, np.int64):
165 | t = torch.ones(1, dtype=int).to(features.device) * t
166 | else:
167 | t = t.reshape(B)
168 |
169 | if rays_noisy is None:
170 | if self.cond_depth_mask:
171 | rays_noisy, epsilon = self.forward_noise(rays, t, zero_out_mask=depth_mask.unsqueeze(2))
172 | else:
173 | rays_noisy, epsilon = self.forward_noise(rays, t)
174 | else:
175 | epsilon = None
176 |
177 | if self.cond_depth_mask:
178 | if depth_mask is None:
179 | depth_mask = torch.ones_like(rays_noisy[:, :, 0])
180 | ray_repr = torch.cat([rays_noisy, depth_mask.unsqueeze(2)], dim=2)
181 | else:
182 | ray_repr = rays_noisy
183 |
184 | ray_repr = ray_repr.permute(0, 1, 3, 4, 2)
185 | ray_repr = self.scratch.input_conv(ray_repr).permute(0, 1, 4, 2, 3).contiguous()
186 |
187 | scene_features = torch.cat([features, ray_repr], dim=2)
188 |
189 | if self.append_ndc:
190 | scene_features = torch.cat([scene_features, ndc_coordinates], dim=2)
191 |
192 | epsilon_pred = self.ray_predictor(
193 | scene_features,
194 | t,
195 | return_dpt_activations=return_dpt_activations,
196 | )
197 |
198 | if return_dpt_activations:
199 | return epsilon_pred, rays_noisy, epsilon
200 |
201 | return epsilon_pred, epsilon
202 |
--------------------------------------------------------------------------------
/diffusionsfm/model/diffuser_dpt.py:
--------------------------------------------------------------------------------
1 | import ipdb # noqa: F401
2 | import numpy as np
3 | import torch
4 | import torch.nn as nn
5 |
6 | from diffusionsfm.model.dit import DiT
7 | from diffusionsfm.model.feature_extractors import PretrainedVAE, SpatialDino
8 | from diffusionsfm.model.blocks import _make_fusion_block, _make_scratch
9 | from diffusionsfm.model.scheduler import NoiseScheduler
10 |
11 | from huggingface_hub import PyTorchModelHubMixin
12 |
13 |
14 | # functional implementation
15 | def nearest_neighbor_upsample(x: torch.Tensor, scale_factor: int):
16 | """Upsample {x} (NCHW) by scale factor {scale_factor} using nearest neighbor interpolation."""
17 | s = scale_factor
18 | return (
19 | x.reshape(*x.shape, 1, 1)
20 | .expand(*x.shape, s, s)
21 | .transpose(-2, -3)
22 | .reshape(*x.shape[:2], *(s * hw for hw in x.shape[2:]))
23 | )
24 |
25 |
26 | class ProjectReadout(nn.Module):
27 | def __init__(self, in_features, start_index=1):
28 | super(ProjectReadout, self).__init__()
29 | self.start_index = start_index
30 |
31 | self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU())
32 |
33 | def forward(self, x):
34 | readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :])
35 | features = torch.cat((x[:, self.start_index :], readout), -1)
36 |
37 | return self.project(features)
38 |
39 |
40 | class RayDiffuserDPT(nn.Module, PyTorchModelHubMixin,
41 | repo_url="https://github.com/QitaoZhao/DiffusionSfM",
42 | paper_url="https://huggingface.co/papers/2505.05473",
43 | pipeline_tag="image-to-3d",
44 | license="mit"):
45 | def __init__(
46 | self,
47 | model_type="dit",
48 | depth=8,
49 | width=16,
50 | hidden_size=1152,
51 | P=1,
52 | max_num_images=1,
53 | noise_scheduler=None,
54 | freeze_encoder=True,
55 | feature_extractor="dino",
56 | append_ndc=True,
57 | use_unconditional=False,
58 | diffuse_depths=False,
59 | depth_resolution=1,
60 | encoder_features=False,
61 | use_homogeneous=False,
62 | freeze_transformer=False,
63 | cond_depth_mask=False,
64 | ):
65 | super().__init__()
66 | if noise_scheduler is None:
67 | self.noise_scheduler = NoiseScheduler()
68 | else:
69 | self.noise_scheduler = noise_scheduler
70 |
71 | self.diffuse_depths = diffuse_depths
72 | self.depth_resolution = depth_resolution
73 | self.use_homogeneous = use_homogeneous
74 |
75 | self.ray_dim = 3
76 |
77 | if self.use_homogeneous:
78 | self.ray_dim += 1
79 | self.ray_dim += self.ray_dim * self.depth_resolution**2
80 |
81 | if self.diffuse_depths:
82 | self.ray_dim += 1
83 |
84 | self.append_ndc = append_ndc
85 | self.width = width
86 |
87 | self.max_num_images = max_num_images
88 | self.model_type = model_type
89 | self.use_unconditional = use_unconditional
90 | self.cond_depth_mask = cond_depth_mask
91 | self.encoder_features = encoder_features
92 |
93 | if feature_extractor == "dino":
94 | self.feature_extractor = SpatialDino(
95 | freeze_weights=freeze_encoder,
96 | num_patches_x=width,
97 | num_patches_y=width,
98 | activation_hooks=self.encoder_features,
99 | )
100 | self.feature_dim = self.feature_extractor.feature_dim
101 | elif feature_extractor == "vae":
102 | self.feature_extractor = PretrainedVAE(
103 | freeze_weights=freeze_encoder, num_patches_x=width, num_patches_y=width
104 | )
105 | self.feature_dim = self.feature_extractor.feature_dim
106 | else:
107 | raise Exception(f"Unknown feature extractor {feature_extractor}")
108 |
109 | if self.use_unconditional:
110 | self.register_parameter(
111 | "null_token", nn.Parameter(torch.randn(self.feature_dim, 1, 1))
112 | )
113 |
114 | self.input_dim = self.feature_dim * 2
115 |
116 | if self.append_ndc:
117 | self.input_dim += 2
118 |
119 | if model_type == "dit":
120 | self.ray_predictor = DiT(
121 | in_channels=self.input_dim,
122 | out_channels=self.ray_dim,
123 | width=width,
124 | depth=depth,
125 | hidden_size=hidden_size,
126 | max_num_images=max_num_images,
127 | P=P,
128 | )
129 |
130 | if freeze_transformer:
131 | for param in self.ray_predictor.parameters():
132 | param.requires_grad = False
133 |
134 | # Fusion blocks
135 | self.f = 256
136 |
137 | if self.encoder_features:
138 | feature_lens = [
139 | self.feature_extractor.feature_dim,
140 | self.feature_extractor.feature_dim,
141 | self.ray_predictor.hidden_size,
142 | self.ray_predictor.hidden_size,
143 | ]
144 | else:
145 | feature_lens = [self.ray_predictor.hidden_size] * 4
146 |
147 | self.scratch = _make_scratch(feature_lens, 256, groups=1, expand=False)
148 | self.scratch.refinenet1 = _make_fusion_block(
149 | self.f, use_bn=False, use_ln=False, dpt_time=True, resolution=128
150 | )
151 | self.scratch.refinenet2 = _make_fusion_block(
152 | self.f, use_bn=False, use_ln=False, dpt_time=True, resolution=64
153 | )
154 | self.scratch.refinenet3 = _make_fusion_block(
155 | self.f, use_bn=False, use_ln=False, dpt_time=True, resolution=32
156 | )
157 | self.scratch.refinenet4 = _make_fusion_block(
158 | self.f, use_bn=False, use_ln=False, dpt_time=True, resolution=16
159 | )
160 |
161 | self.scratch.input_conv = nn.Conv2d(
162 | self.ray_dim + int(self.cond_depth_mask),
163 | self.feature_dim,
164 | kernel_size=16,
165 | stride=16,
166 | padding=0
167 | )
168 |
169 | self.scratch.output_conv = nn.Sequential(
170 | nn.Conv2d(self.f, self.f // 2, kernel_size=3, stride=1, padding=1),
171 | nn.LeakyReLU(),
172 | nn.Conv2d(self.f // 2, 32, kernel_size=3, stride=1, padding=1),
173 | nn.LeakyReLU(),
174 | nn.Conv2d(32, self.ray_dim, kernel_size=1, stride=1, padding=0),
175 | nn.Identity(),
176 | )
177 |
178 | if self.encoder_features:
179 | self.project_opers = nn.ModuleList([
180 | ProjectReadout(in_features=self.feature_extractor.feature_dim),
181 | ProjectReadout(in_features=self.feature_extractor.feature_dim),
182 | ])
183 |
184 | def forward_noise(
185 | self, x, t, epsilon=None, zero_out_mask=None
186 | ):
187 | """
188 | Applies forward diffusion (adds noise) to the input.
189 |
190 | If a mask is provided, the noise is only applied to the masked inputs.
191 | """
192 | t = t.reshape(-1, 1, 1, 1, 1)
193 | if epsilon is None:
194 | epsilon = torch.randn_like(x)
195 | else:
196 | epsilon = epsilon.reshape(x.shape)
197 |
198 | alpha_bar = self.noise_scheduler.alphas_cumprod[t]
199 | x_noise = torch.sqrt(alpha_bar) * x + torch.sqrt(1 - alpha_bar) * epsilon
200 |
201 | if zero_out_mask is not None and self.cond_depth_mask:
202 | x_noise = zero_out_mask * x_noise
203 |
204 | return x_noise, epsilon
205 |
206 | def forward(
207 | self,
208 | features=None,
209 | images=None,
210 | rays=None,
211 | rays_noisy=None,
212 | t=None,
213 | ndc_coordinates=None,
214 | unconditional_mask=None,
215 | encoder_patches=16,
216 | depth_mask=None,
217 | multiview_unconditional=False,
218 | indices=None,
219 | ):
220 | """
221 | Args:
222 | images: (B, N, 3, H, W).
223 | t: (B,).
224 | rays: (B, N, 6, H, W).
225 | rays_noisy: (B, N, 6, H, W).
226 | ndc_coordinates: (B, N, 2, H, W).
227 | unconditional_mask: (B, N) or (B,). Should be 1 for unconditional samples
228 | and 0 else.
229 | """
230 |
231 | if features is None:
232 | # VAE expects 256x256 images while DINO expects 224x224 images.
233 | # Both feature extractors support autoresize=True, but ideally we should
234 | # set this to be false and handle in the dataloader.
235 | features = self.feature_extractor(images, autoresize=True)
236 |
237 | B = features.shape[0]
238 |
239 | if unconditional_mask is not None and self.use_unconditional:
240 | null_token = self.null_token.reshape(1, 1, self.feature_dim, 1, 1)
241 | unconditional_mask = unconditional_mask.reshape(B, -1, 1, 1, 1)
242 | features = (
243 | features * (1 - unconditional_mask) + null_token * unconditional_mask
244 | )
245 |
246 | if isinstance(t, int) or isinstance(t, np.int64):
247 | t = torch.ones(1, dtype=int).to(features.device) * t
248 | else:
249 | t = t.reshape(B)
250 |
251 | if rays_noisy is None:
252 | if self.cond_depth_mask:
253 | rays_noisy, epsilon = self.forward_noise(
254 | rays, t, zero_out_mask=depth_mask.unsqueeze(2)
255 | )
256 | else:
257 | rays_noisy, epsilon = self.forward_noise(
258 | rays, t
259 | )
260 | else:
261 | epsilon = None
262 |
263 | # DOWNSAMPLE RAYS
264 | B, N, C, H, W = rays_noisy.shape
265 |
266 | if self.cond_depth_mask:
267 | if depth_mask is None:
268 | depth_mask = torch.ones_like(rays_noisy[:, :, 0])
269 | ray_repr = torch.cat([rays_noisy, depth_mask.unsqueeze(2)], dim=2)
270 | else:
271 | ray_repr = rays_noisy
272 |
273 | ray_repr = self.scratch.input_conv(ray_repr.reshape(B * N, -1, H, W))
274 | _, CP, HP, WP = ray_repr.shape
275 | ray_repr = ray_repr.reshape(B, N, CP, HP, WP)
276 | scene_features = torch.cat([features, ray_repr], dim=2)
277 |
278 | if self.append_ndc:
279 | scene_features = torch.cat([scene_features, ndc_coordinates], dim=2)
280 |
281 | # DIT FORWARD PASS
282 | activations = self.ray_predictor(
283 | scene_features,
284 | t,
285 | return_dpt_activations=True,
286 | multiview_unconditional=multiview_unconditional,
287 | )
288 |
289 | # PROJECT ENCODER ACTIVATIONS & RESHAPE
290 | if self.encoder_features:
291 | for i in range(2):
292 | name = f"encoder{i+1}"
293 |
294 | if indices is not None:
295 | act = self.feature_extractor.activations[name][indices]
296 | else:
297 | act = self.feature_extractor.activations[name]
298 |
299 | act = self.project_opers[i](act).permute(0, 2, 1)
300 | act = act.reshape(
301 | (
302 | B * N,
303 | self.feature_extractor.feature_dim,
304 | encoder_patches,
305 | encoder_patches,
306 | )
307 | )
308 | activations[i] = act
309 |
310 | # UPSAMPLE ACTIVATIONS
311 | for i, act in enumerate(activations):
312 | k = 3 - i
313 | activations[i] = nearest_neighbor_upsample(act, 2**k)
314 |
315 | # FUSION BLOCKS
316 | layer_1_rn = self.scratch.layer1_rn(activations[0])
317 | layer_2_rn = self.scratch.layer2_rn(activations[1])
318 | layer_3_rn = self.scratch.layer3_rn(activations[2])
319 | layer_4_rn = self.scratch.layer4_rn(activations[3])
320 |
321 | # RESHAPE TIMESTEPS
322 | if t.shape[0] == B:
323 | t = t.unsqueeze(-1).repeat((1, N)).reshape(B * N)
324 | elif t.shape[0] == 1 and B > 1:
325 | t = t.repeat((B * N))
326 | else:
327 | assert False
328 |
329 | path_4 = self.scratch.refinenet4(layer_4_rn, t=t)
330 | path_3 = self.scratch.refinenet3(path_4, activation=layer_3_rn, t=t)
331 | path_2 = self.scratch.refinenet2(path_3, activation=layer_2_rn, t=t)
332 | path_1 = self.scratch.refinenet1(path_2, activation=layer_1_rn, t=t)
333 |
334 | epsilon_pred = self.scratch.output_conv(path_1)
335 | epsilon_pred = epsilon_pred.reshape((B, N, C, H, W))
336 |
337 | return epsilon_pred, epsilon
338 |
--------------------------------------------------------------------------------
/diffusionsfm/model/dit.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | # --------------------------------------------------------
7 | # References:
8 | # GLIDE: https://github.com/openai/glide-text2im
9 | # MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
10 | # --------------------------------------------------------
11 |
12 | import math
13 |
14 | import ipdb # noqa: F401
15 | import numpy as np
16 | import torch
17 | import torch.nn as nn
18 | from timm.models.vision_transformer import Attention, Mlp, PatchEmbed
19 | from diffusionsfm.model.memory_efficient_attention import MEAttention
20 |
21 |
22 | def modulate(x, shift, scale):
23 | return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
24 |
25 |
26 | #################################################################################
27 | # Embedding Layers for Timesteps and Class Labels #
28 | #################################################################################
29 |
30 |
31 | class TimestepEmbedder(nn.Module):
32 | """
33 | Embeds scalar timesteps into vector representations.
34 | """
35 |
36 | def __init__(self, hidden_size, frequency_embedding_size=256):
37 | super().__init__()
38 | self.mlp = nn.Sequential(
39 | nn.Linear(frequency_embedding_size, hidden_size, bias=True),
40 | nn.SiLU(),
41 | nn.Linear(hidden_size, hidden_size, bias=True),
42 | )
43 | self.frequency_embedding_size = frequency_embedding_size
44 |
45 | @staticmethod
46 | def timestep_embedding(t, dim, max_period=10000):
47 | """
48 | Create sinusoidal timestep embeddings.
49 | :param t: a 1-D Tensor of N indices, one per batch element.
50 | These may be fractional.
51 | :param dim: the dimension of the output.
52 | :param max_period: controls the minimum frequency of the embeddings.
53 | :return: an (N, D) Tensor of positional embeddings.
54 | """
55 | # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
56 | half = dim // 2
57 | freqs = torch.exp(
58 | -math.log(max_period)
59 | * torch.arange(start=0, end=half, dtype=torch.float32)
60 | / half
61 | ).to(device=t.device)
62 | args = t[:, None].float() * freqs[None]
63 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
64 | if dim % 2:
65 | embedding = torch.cat(
66 | [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
67 | )
68 | return embedding
69 |
70 | def forward(self, t):
71 | t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
72 | t_emb = self.mlp(t_freq)
73 | return t_emb
74 |
75 |
76 | #################################################################################
77 | # Core DiT Model #
78 | #################################################################################
79 |
80 |
81 | class DiTBlock(nn.Module):
82 | """
83 | A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
84 | """
85 |
86 | def __init__(
87 | self,
88 | hidden_size,
89 | num_heads,
90 | mlp_ratio=4.0,
91 | use_xformers_attention=False,
92 | **block_kwargs
93 | ):
94 | super().__init__()
95 | self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
96 | attn = MEAttention if use_xformers_attention else Attention
97 | self.attn = attn(
98 | hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs
99 | )
100 | self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
101 | mlp_hidden_dim = int(hidden_size * mlp_ratio)
102 |
103 | def approx_gelu():
104 | return nn.GELU(approximate="tanh")
105 |
106 | self.mlp = Mlp(
107 | in_features=hidden_size,
108 | hidden_features=mlp_hidden_dim,
109 | act_layer=approx_gelu,
110 | drop=0,
111 | )
112 | self.adaLN_modulation = nn.Sequential(
113 | nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True)
114 | )
115 |
116 | def forward(self, x, c):
117 | (
118 | shift_msa,
119 | scale_msa,
120 | gate_msa,
121 | shift_mlp,
122 | scale_mlp,
123 | gate_mlp,
124 | ) = self.adaLN_modulation(c).chunk(6, dim=1)
125 | x = x + gate_msa.unsqueeze(1) * self.attn(
126 | modulate(self.norm1(x), shift_msa, scale_msa)
127 | )
128 | x = x + gate_mlp.unsqueeze(1) * self.mlp(
129 | modulate(self.norm2(x), shift_mlp, scale_mlp)
130 | )
131 | return x
132 |
133 |
134 | class FinalLayer(nn.Module):
135 | """
136 | The final layer of DiT.
137 | """
138 |
139 | def __init__(self, hidden_size, patch_size, out_channels):
140 | super().__init__()
141 | self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
142 | self.linear = nn.Linear(
143 | hidden_size, patch_size * patch_size * out_channels, bias=True
144 | )
145 | self.adaLN_modulation = nn.Sequential(
146 | nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)
147 | )
148 |
149 | def forward(self, x, c):
150 | shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
151 | x = modulate(self.norm_final(x), shift, scale)
152 | x = self.linear(x)
153 | return x
154 |
155 |
156 | class DiT(nn.Module):
157 | """
158 | Diffusion model with a Transformer backbone.
159 | """
160 |
161 | def __init__(
162 | self,
163 | in_channels=442,
164 | out_channels=6,
165 | width=16,
166 | hidden_size=1152,
167 | depth=8,
168 | num_heads=16,
169 | mlp_ratio=4.0,
170 | max_num_images=8,
171 | P=1,
172 | within_image=False,
173 | ):
174 | super().__init__()
175 | self.num_heads = num_heads
176 | self.in_channels = in_channels
177 | self.out_channels = out_channels
178 | self.width = width
179 | self.hidden_size = hidden_size
180 | self.max_num_images = max_num_images
181 | self.P = P
182 | self.within_image = within_image
183 |
184 | # self.x_embedder = nn.Linear(in_channels, hidden_size)
185 | # self.x_embedder = PatchEmbed(in_channels, hidden_size, kernel_size=P, hidden_size=P)
186 | self.x_embedder = PatchEmbed(
187 | img_size=self.width,
188 | patch_size=self.P,
189 | in_chans=in_channels,
190 | embed_dim=hidden_size,
191 | bias=True,
192 | flatten=False,
193 | )
194 | self.x_pos_enc = FeaturePositionalEncoding(
195 | max_num_images, hidden_size, width**2, P=self.P
196 | )
197 | self.t_embedder = TimestepEmbedder(hidden_size)
198 |
199 | try:
200 | import xformers
201 |
202 | use_xformers_attention = True
203 | except ImportError:
204 | # xformers not available
205 | use_xformers_attention = False
206 |
207 | self.blocks = nn.ModuleList(
208 | [
209 | DiTBlock(
210 | hidden_size,
211 | num_heads,
212 | mlp_ratio=mlp_ratio,
213 | use_xformers_attention=use_xformers_attention,
214 | )
215 | for _ in range(depth)
216 | ]
217 | )
218 | self.final_layer = FinalLayer(hidden_size, P, out_channels)
219 | self.initialize_weights()
220 |
221 | def initialize_weights(self):
222 | # Initialize transformer layers:
223 | def _basic_init(module):
224 | if isinstance(module, nn.Linear):
225 | torch.nn.init.xavier_uniform_(module.weight)
226 | if module.bias is not None:
227 | nn.init.constant_(module.bias, 0)
228 |
229 | self.apply(_basic_init)
230 |
231 | # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
232 | w = self.x_embedder.proj.weight.data
233 | nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
234 | nn.init.constant_(self.x_embedder.proj.bias, 0)
235 |
236 | # Initialize timestep embedding MLP:
237 | nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
238 | nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
239 |
240 | # Zero-out adaLN modulation layers in DiT blocks:
241 | for block in self.blocks:
242 | nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
243 | nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
244 |
245 | # Zero-out output layers:
246 | # nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
247 | # nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
248 | # nn.init.constant_(self.final_layer.linear.weight, 0)
249 | # nn.init.constant_(self.final_layer.linear.bias, 0)
250 |
251 | def unpatchify(self, x):
252 | """
253 | x: (N, T, patch_size**2 * C)
254 | imgs: (N, H, W, C)
255 | """
256 | c = self.out_channels
257 | p = self.x_embedder.patch_size[0]
258 | h = w = int(x.shape[1] ** 0.5)
259 |
260 | # print("unpatchify", c, p, h, w, x.shape)
261 | # assert h * w == x.shape[2]
262 |
263 | x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
264 | x = torch.einsum("nhwpqc->nhpwqc", x)
265 | imgs = x.reshape(shape=(x.shape[0], h * p, h * p, c))
266 | return imgs
267 |
268 | def forward(
269 | self,
270 | x,
271 | t,
272 | return_dpt_activations=False,
273 | multiview_unconditional=False,
274 | ):
275 | """
276 |
277 | Args:
278 | x: Image/Ray features (B, N, C, H, W).
279 | t: Timesteps (N,).
280 |
281 | Returns:
282 | (B, N, D, H, W)
283 | """
284 | B, N, c, h, w = x.shape
285 | P = self.P
286 |
287 | x = x.reshape((B * N, c, h, w)) # (B * N, C, H, W)
288 | x = self.x_embedder(x) # (B * N, C, H / P, W / P)
289 |
290 | x = x.permute(0, 2, 3, 1) # (B * N, H / P, W / P, C)
291 | # (B, N, H / P, W / P, C)
292 | x = x.reshape((B, N, h // P, w // P, self.hidden_size))
293 | x = self.x_pos_enc(x) # (B, N, H * W / P ** 2, C)
294 | # TODO: fix positional encoding to work with (N, C, H, W) format.
295 |
296 | # Eval time, we get a scalar t
297 | if x.shape[0] != t.shape[0] and t.shape[0] == 1:
298 | t = t.repeat_interleave(B)
299 |
300 | if self.within_image or multiview_unconditional:
301 | t_within = t.repeat_interleave(N)
302 | t_within = self.t_embedder(t_within)
303 |
304 | t = self.t_embedder(t)
305 |
306 | dpt_activations = []
307 | for i, block in enumerate(self.blocks):
308 | # Within image block
309 | if (self.within_image and i % 2 == 0) or multiview_unconditional:
310 | x = x.reshape((B * N, h * w // P**2, self.hidden_size))
311 | x = block(x, t_within)
312 |
313 | # All patches block
314 | # Final layer is an all patches layer
315 | else:
316 | x = x.reshape((B, N * h * w // P**2, self.hidden_size))
317 | x = block(x, t) # (N, T, D)
318 |
319 | if return_dpt_activations and i % 4 == 3:
320 | x_prime = x.reshape(B, N, h, w, self.hidden_size)
321 | x_prime = x.reshape(B * N, h, w, self.hidden_size)
322 | x_prime = x_prime.permute((0, 3, 1, 2))
323 | dpt_activations.append(x_prime)
324 |
325 | # Reshape the output back to original shape
326 | if multiview_unconditional:
327 | x = x.reshape((B, N * h * w // P**2, self.hidden_size))
328 |
329 | # (B, N * H * W / P ** 2, D)
330 | x = self.final_layer(
331 | x, t
332 | ) # (B, N * H * W / P ** 2, 6 * P ** 2) or (N, T, patch_size ** 2 * out_channels)
333 |
334 | x = x.reshape((B * N, w * w // P**2, self.out_channels * P**2))
335 | x = self.unpatchify(x) # (B * N, H, W, C)
336 | x = x.reshape((B, N) + x.shape[1:])
337 | x = x.permute(0, 1, 4, 2, 3) # (B, N, C, H, W)
338 |
339 | if return_dpt_activations:
340 | return dpt_activations[:4]
341 |
342 | return x
343 |
344 |
345 | class FeaturePositionalEncoding(nn.Module):
346 | def _get_sinusoid_encoding_table(self, n_position, d_hid, base):
347 | """Sinusoid position encoding table"""
348 |
349 | def get_position_angle_vec(position):
350 | return [
351 | position / np.power(base, 2 * (hid_j // 2) / d_hid)
352 | for hid_j in range(d_hid)
353 | ]
354 |
355 | sinusoid_table = np.array(
356 | [get_position_angle_vec(pos_i) for pos_i in range(n_position)]
357 | )
358 | sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
359 | sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
360 |
361 | return torch.FloatTensor(sinusoid_table).unsqueeze(0)
362 |
363 | def __init__(self, max_num_images=8, feature_dim=1152, num_patches=256, P=1):
364 | super().__init__()
365 | self.max_num_images = max_num_images
366 | self.feature_dim = feature_dim
367 | self.P = P
368 | self.num_patches = num_patches // self.P**2
369 |
370 | self.register_buffer(
371 | "image_pos_table",
372 | self._get_sinusoid_encoding_table(
373 | self.max_num_images, self.feature_dim, 10000
374 | ),
375 | )
376 |
377 | self.register_buffer(
378 | "token_pos_table",
379 | self._get_sinusoid_encoding_table(
380 | self.num_patches, self.feature_dim, 70007
381 | ),
382 | )
383 |
384 | def forward(self, x):
385 | batch_size = x.shape[0]
386 | num_images = x.shape[1]
387 |
388 | x = x.reshape(batch_size, num_images, self.num_patches, self.feature_dim)
389 |
390 | # To encode image index
391 | pe1 = self.image_pos_table[:, :num_images].clone().detach()
392 | pe1 = pe1.reshape((1, num_images, 1, self.feature_dim))
393 | pe1 = pe1.repeat((batch_size, 1, self.num_patches, 1))
394 |
395 | # To encode patch index
396 | pe2 = self.token_pos_table.clone().detach()
397 | pe2 = pe2.reshape((1, 1, self.num_patches, self.feature_dim))
398 | pe2 = pe2.repeat((batch_size, num_images, 1, 1))
399 |
400 | x_pe = x + pe1 + pe2
401 | x_pe = x_pe.reshape(
402 | (batch_size, num_images * self.num_patches, self.feature_dim)
403 | )
404 |
405 | return x_pe
406 |
407 | def forward_unet(self, x, B, N):
408 | D = int(self.num_patches**0.5)
409 |
410 | # x should be (B, N, T, D, D)
411 | x = x.permute((0, 2, 3, 1))
412 | x = x.reshape(B, N, self.num_patches, self.feature_dim)
413 |
414 | # To encode image index
415 | pe1 = self.image_pos_table[:, :N].clone().detach()
416 | pe1 = pe1.reshape((1, N, 1, self.feature_dim))
417 | pe1 = pe1.repeat((B, 1, self.num_patches, 1))
418 |
419 | # To encode patch index
420 | pe2 = self.token_pos_table.clone().detach()
421 | pe2 = pe2.reshape((1, 1, self.num_patches, self.feature_dim))
422 | pe2 = pe2.repeat((B, N, 1, 1))
423 |
424 | x_pe = x + pe1 + pe2
425 | x_pe = x_pe.reshape((B * N, D, D, self.feature_dim))
426 | x_pe = x_pe.permute((0, 3, 1, 2))
427 |
428 | return x_pe
429 |
--------------------------------------------------------------------------------
/diffusionsfm/model/feature_extractors.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | import os
3 | import socket
4 | import sys
5 |
6 | import ipdb # noqa: F401
7 | import torch
8 | import torch.nn as nn
9 | from omegaconf import OmegaConf
10 |
11 | HOSTNAME = socket.gethostname()
12 |
13 | if "trinity" in HOSTNAME:
14 | # Might be outdated
15 | config_path = "/home/amylin2/latent-diffusion/configs/autoencoder/autoencoder_kl_16x16x16.yaml"
16 | weights_path = "/home/amylin2/latent-diffusion/model.ckpt"
17 | elif "grogu" in HOSTNAME:
18 | # Might be outdated
19 | config_path = "/home/jasonzh2/code/latent-diffusion/configs/autoencoder/autoencoder_kl_16x16x16.yaml"
20 | weights_path = "/home/jasonzh2/code/latent-diffusion/model.ckpt"
21 | elif "ender" in HOSTNAME:
22 | config_path = "/home/jason/ray_diffusion/external/latent-diffusion/configs/autoencoder/autoencoder_kl_16x16x16.yaml"
23 | weights_path = "/home/jason/ray_diffusion/external/latent-diffusion/model.ckpt"
24 | else:
25 | config_path = None
26 | weights_path = None
27 |
28 |
29 | if weights_path is not None:
30 | LDM_PATH = os.path.dirname(weights_path)
31 | if LDM_PATH not in sys.path:
32 | sys.path.append(LDM_PATH)
33 |
34 |
35 | def resize(image, size=None, scale_factor=None):
36 | return nn.functional.interpolate(
37 | image,
38 | size=size,
39 | scale_factor=scale_factor,
40 | mode="bilinear",
41 | align_corners=False,
42 | )
43 |
44 |
45 | def instantiate_from_config(config):
46 | if "target" not in config:
47 | if config == "__is_first_stage__":
48 | return None
49 | elif config == "__is_unconditional__":
50 | return None
51 | raise KeyError("Expected key `target` to instantiate.")
52 | return get_obj_from_str(config["target"])(**config.get("params", dict()))
53 |
54 |
55 | def get_obj_from_str(string, reload=False):
56 | module, cls = string.rsplit(".", 1)
57 | if reload:
58 | module_imp = importlib.import_module(module)
59 | importlib.reload(module_imp)
60 | return getattr(importlib.import_module(module, package=None), cls)
61 |
62 |
63 | class PretrainedVAE(nn.Module):
64 | def __init__(self, freeze_weights=True, num_patches_x=16, num_patches_y=16):
65 | super().__init__()
66 | config = OmegaConf.load(config_path)
67 | self.model = instantiate_from_config(config.model)
68 | self.model.init_from_ckpt(weights_path)
69 | self.model.eval()
70 | self.feature_dim = 16
71 | self.num_patches_x = num_patches_x
72 | self.num_patches_y = num_patches_y
73 |
74 | if freeze_weights:
75 | for param in self.model.parameters():
76 | param.requires_grad = False
77 |
78 | def forward(self, x, autoresize=False):
79 | """
80 | Spatial dimensions of output will be H // 16, W // 16. If autoresize is True,
81 | then the input will be resized such that the output feature map is the correct
82 | dimensions.
83 |
84 | Args:
85 | x (torch.Tensor): Images (B, C, H, W). Should be normalized to be [-1, 1].
86 | autoresize (bool): Whether to resize the input to match the num_patch
87 | dimensions.
88 |
89 | Returns:
90 | torch.Tensor: Latent sample (B, 16, h, w)
91 | """
92 |
93 | *B, c, h, w = x.shape
94 | x = x.reshape(-1, c, h, w)
95 | if autoresize:
96 | new_w = self.num_patches_x * 16
97 | new_h = self.num_patches_y * 16
98 | x = resize(x, size=(new_h, new_w))
99 |
100 | decoded, latent = self.model(x)
101 | # A little ambiguous bc it's all 16, but it is (c, h, w)
102 | latent_sample = latent.sample().reshape(
103 | *B, self.feature_dim, self.num_patches_y, self.num_patches_x
104 | )
105 | return latent_sample
106 |
107 |
108 | activations = {}
109 |
110 |
111 | def get_activation(name):
112 | def hook(model, input, output):
113 | activations[name] = output
114 |
115 | return hook
116 |
117 |
118 | class SpatialDino(nn.Module):
119 | def __init__(
120 | self,
121 | freeze_weights=True,
122 | model_type="dinov2_vits14",
123 | num_patches_x=16,
124 | num_patches_y=16,
125 | activation_hooks=False,
126 | ):
127 | super().__init__()
128 | self.model = torch.hub.load("facebookresearch/dinov2", model_type)
129 | self.feature_dim = self.model.embed_dim
130 | self.num_patches_x = num_patches_x
131 | self.num_patches_y = num_patches_y
132 | if freeze_weights:
133 | for param in self.model.parameters():
134 | param.requires_grad = False
135 |
136 | self.activation_hooks = activation_hooks
137 |
138 | if self.activation_hooks:
139 | self.model.blocks[5].register_forward_hook(get_activation("encoder1"))
140 | self.model.blocks[11].register_forward_hook(get_activation("encoder2"))
141 | self.activations = activations
142 |
143 | def forward(self, x, autoresize=False):
144 | """
145 | Spatial dimensions of output will be H // 14, W // 14. If autoresize is True,
146 | then the output will be resized to the correct dimensions.
147 |
148 | Args:
149 | x (torch.Tensor): Images (B, C, H, W). Should be ImageNet normalized.
150 | autoresize (bool): Whether to resize the input to match the num_patch
151 | dimensions.
152 |
153 | Returns:
154 | feature_map (torch.tensor): (B, C, h, w)
155 | """
156 | *B, c, h, w = x.shape
157 |
158 | x = x.reshape(-1, c, h, w)
159 | # if autoresize:
160 | # new_w = self.num_patches_x * 14
161 | # new_h = self.num_patches_y * 14
162 | # x = resize(x, size=(new_h, new_w))
163 |
164 | # Output will be (B, H * W, C)
165 | features = self.model.forward_features(x)["x_norm_patchtokens"]
166 | features = features.permute(0, 2, 1)
167 | features = features.reshape( # (B, C, H, W)
168 | -1, self.feature_dim, h // 14, w // 14
169 | )
170 | if autoresize:
171 | features = resize(features, size=(self.num_patches_y, self.num_patches_x))
172 |
173 | features = features.reshape(
174 | *B, self.feature_dim, self.num_patches_y, self.num_patches_x
175 | )
176 | return features
177 |
--------------------------------------------------------------------------------
/diffusionsfm/model/memory_efficient_attention.py:
--------------------------------------------------------------------------------
1 | import ipdb
2 | import torch.nn as nn
3 | from xformers.ops import memory_efficient_attention
4 |
5 |
6 | class MEAttention(nn.Module):
7 | def __init__(
8 | self,
9 | dim,
10 | num_heads=8,
11 | qkv_bias=False,
12 | qk_norm=False,
13 | attn_drop=0.0,
14 | proj_drop=0.0,
15 | norm_layer=nn.LayerNorm,
16 | ):
17 | super().__init__()
18 | assert dim % num_heads == 0, "dim should be divisible by num_heads"
19 | self.num_heads = num_heads
20 | self.head_dim = dim // num_heads
21 | self.scale = self.head_dim**-0.5
22 |
23 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
24 | self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
25 | self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
26 | self.attn_drop = nn.Dropout(attn_drop)
27 | self.proj = nn.Linear(dim, dim)
28 | self.proj_drop = nn.Dropout(proj_drop)
29 |
30 | def forward(self, x):
31 | B, N, C = x.shape
32 | qkv = (
33 | self.qkv(x)
34 | .reshape(B, N, 3, self.num_heads, self.head_dim)
35 | .permute(2, 0, 3, 1, 4)
36 | )
37 | q, k, v = qkv.unbind(0)
38 | q, k = self.q_norm(q), self.k_norm(k)
39 |
40 | # MEA expects [B, N, H, D], whereas timm uses [B, H, N, D]
41 | x = memory_efficient_attention(
42 | q.transpose(1, 2),
43 | k.transpose(1, 2),
44 | v.transpose(1, 2),
45 | scale=self.scale,
46 | )
47 | x = x.reshape(B, N, C)
48 |
49 | x = self.proj(x)
50 | x = self.proj_drop(x)
51 | return x
52 |
--------------------------------------------------------------------------------
/diffusionsfm/model/scheduler.py:
--------------------------------------------------------------------------------
1 | import ipdb # noqa: F401
2 | import matplotlib.pyplot as plt
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 |
7 | from diffusionsfm.utils.visualization import plot_to_image
8 |
9 |
10 | class NoiseScheduler(nn.Module):
11 | def __init__(
12 | self,
13 | max_timesteps=1000,
14 | beta_start=0.0001,
15 | beta_end=0.02,
16 | cos_power=2,
17 | num_inference_steps=100,
18 | type="linear",
19 | ):
20 | super().__init__()
21 | self.max_timesteps = max_timesteps
22 | self.num_inference_steps = num_inference_steps
23 | self.beta_start = beta_start
24 | self.beta_end = beta_end
25 | self.cos_power = cos_power
26 | self.type = type
27 |
28 | if type == "linear":
29 | self.register_linear_schedule()
30 | elif type == "cosine":
31 | self.register_cosine_schedule(cos_power)
32 | elif type == "scaled_linear":
33 | self.register_scaled_linear_schedule()
34 |
35 | self.inference_timesteps = self.compute_inference_timesteps()
36 |
37 | def register_linear_schedule(self):
38 | # zero terminal SNR (https://arxiv.org/pdf/2305.08891)
39 | betas = torch.linspace(
40 | self.beta_start,
41 | self.beta_end,
42 | self.max_timesteps,
43 | dtype=torch.float32,
44 | )
45 | alphas = 1.0 - betas
46 | alphas_cumprod = torch.cumprod(alphas, dim=0)
47 | alphas_bar_sqrt = alphas_cumprod.sqrt()
48 |
49 | alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
50 | alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
51 |
52 | alphas_bar_sqrt -= alphas_bar_sqrt_T
53 | alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
54 |
55 | alphas_bar = alphas_bar_sqrt**2
56 | alphas = alphas_bar[1:] / alphas_bar[:-1]
57 | alphas = torch.cat([alphas_bar[0:1], alphas])
58 | betas = 1 - alphas
59 |
60 | self.register_buffer(
61 | "betas",
62 | betas,
63 | )
64 | self.register_buffer("alphas", 1.0 - self.betas)
65 | self.register_buffer("alphas_cumprod", torch.cumprod(self.alphas, dim=0))
66 |
67 | def register_cosine_schedule(self, cos_power, s=0.008):
68 | timesteps = (
69 | torch.arange(self.max_timesteps + 1, dtype=torch.float32)
70 | / self.max_timesteps
71 | )
72 | alpha_bars = (timesteps + s) / (1 + s) * np.pi / 2
73 | alpha_bars = torch.cos(alpha_bars).pow(cos_power)
74 | alpha_bars = alpha_bars / alpha_bars[0]
75 | betas = 1 - alpha_bars[1:] / alpha_bars[:-1]
76 | betas = np.clip(betas, a_min=0, a_max=0.999)
77 |
78 | self.register_buffer(
79 | "betas",
80 | betas,
81 | )
82 | self.register_buffer("alphas", 1.0 - betas)
83 | self.register_buffer("alphas_cumprod", torch.cumprod(self.alphas, dim=0))
84 |
85 | def register_scaled_linear_schedule(self):
86 | self.register_buffer(
87 | "betas",
88 | torch.linspace(
89 | self.beta_start**0.5,
90 | self.beta_end**0.5,
91 | self.max_timesteps,
92 | dtype=torch.float32,
93 | )
94 | ** 2,
95 | )
96 | self.register_buffer("alphas", 1.0 - self.betas)
97 | self.register_buffer("alphas_cumprod", torch.cumprod(self.alphas, dim=0))
98 |
99 | def compute_inference_timesteps(
100 | self, num_inference_steps=None, num_train_steps=None
101 | ):
102 | # based on diffusers's scheduling code
103 | if num_inference_steps is None:
104 | num_inference_steps = self.num_inference_steps
105 | if num_train_steps is None:
106 | num_train_steps = self.max_timesteps
107 | step_ratio = num_train_steps // num_inference_steps
108 | timesteps = (
109 | (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].astype(int)
110 | )
111 | return timesteps
112 |
113 | def plot_schedule(self, return_image=False):
114 | fig = plt.figure(figsize=(6, 4), dpi=100)
115 | alpha_bars = self.alphas_cumprod.cpu().numpy()
116 | plt.plot(np.sqrt(alpha_bars))
117 | plt.grid()
118 | if self.type == "linear":
119 | plt.title(
120 | f"Linear (T={self.max_timesteps}, S={self.beta_start}, E={self.beta_end})"
121 | )
122 | else:
123 | self.type == "cosine"
124 | plt.title(f"Cosine (T={self.max_timesteps}, P={self.cos_power})")
125 | if return_image:
126 | image = plot_to_image(fig)
127 | plt.close(fig)
128 | return image
129 |
--------------------------------------------------------------------------------
/diffusionsfm/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/QitaoZhao/DiffusionSfM/4bb08800721bdcf46b0c823586a2fab4702ff282/diffusionsfm/utils/__init__.py
--------------------------------------------------------------------------------
/diffusionsfm/utils/configs.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 |
4 | from omegaconf import OmegaConf
5 |
6 |
7 | def load_cfg(config_path):
8 | """
9 | Loads a yaml configuration file.
10 |
11 | Follows the chain of yaml configuration files that have a `_BASE` key, and updates
12 | the new keys accordingly. _BASE configurations can be specified using relative
13 | paths.
14 | """
15 | config_dir = os.path.dirname(config_path)
16 | config_path = os.path.basename(config_path)
17 | return load_cfg_recursive(config_dir, config_path)
18 |
19 |
20 | def load_cfg_recursive(config_dir, config_path):
21 | """
22 | Recursively loads config files.
23 |
24 | Follows the chain of yaml configuration files that have a `_BASE` key, and updates
25 | the new keys accordingly. _BASE configurations can be specified using relative
26 | paths.
27 | """
28 | cfg = OmegaConf.load(os.path.join(config_dir, config_path))
29 | base_path = OmegaConf.select(cfg, "_BASE", default=None)
30 | if base_path is not None:
31 | base_cfg = load_cfg_recursive(config_dir, base_path)
32 | cfg = OmegaConf.merge(base_cfg, cfg)
33 | return cfg
34 |
35 |
36 | def get_cfg():
37 | parser = argparse.ArgumentParser()
38 | parser.add_argument("--config-path", type=str, required=True)
39 | args = parser.parse_args()
40 | cfg = load_cfg(args.config_path)
41 | print(OmegaConf.to_yaml(cfg))
42 |
43 | exp_dir = os.path.join(cfg.training.runs_dir, cfg.training.exp_tag)
44 | os.makedirs(exp_dir, exist_ok=True)
45 | to_path = os.path.join(exp_dir, os.path.basename(args.config_path))
46 | if not os.path.exists(to_path):
47 | OmegaConf.save(config=cfg, f=to_path)
48 | return cfg
49 |
50 |
51 | def get_cfg_from_path(config_path):
52 | """
53 | args:
54 | config_path - get config from path
55 | """
56 | print("getting config from path")
57 |
58 | cfg = load_cfg(config_path)
59 | print(OmegaConf.to_yaml(cfg))
60 |
61 | exp_dir = os.path.join(cfg.training.runs_dir, cfg.training.exp_tag)
62 | os.makedirs(exp_dir, exist_ok=True)
63 | to_path = os.path.join(exp_dir, os.path.basename(config_path))
64 | if not os.path.exists(to_path):
65 | OmegaConf.save(config=cfg, f=to_path)
66 | return cfg
67 |
--------------------------------------------------------------------------------
/diffusionsfm/utils/distortion.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import ipdb
3 | import numpy as np
4 | from PIL import Image
5 | import torch
6 |
7 |
8 | # https://gist.github.com/davegreenwood/820d51ac5ec88a2aeda28d3079e7d9eb
9 | def apply_distortion(pts, k1, k2):
10 | """
11 | Arguments:
12 | pts (N x 2): numpy array in NDC coordinates
13 | k1, k2 distortion coefficients
14 | Return:
15 | pts (N x 2): distorted points in NDC coordinates
16 | """
17 | r2 = np.square(pts).sum(-1)
18 | f = 1 + k1 * r2 + k2 * r2**2
19 | return f[..., None] * pts
20 |
21 |
22 | # https://gist.github.com/davegreenwood/820d51ac5ec88a2aeda28d3079e7d9eb
23 | def apply_distortion_tensor(pts, k1, k2):
24 | """
25 | Arguments:
26 | pts (N x 2): numpy array in NDC coordinates
27 | k1, k2 distortion coefficients
28 | Return:
29 | pts (N x 2): distorted points in NDC coordinates
30 | """
31 | r2 = torch.square(pts).sum(-1)
32 | f = 1 + k1 * r2 + k2 * r2**2
33 | return f[..., None] * pts
34 |
35 |
36 | # https://gist.github.com/davegreenwood/820d51ac5ec88a2aeda28d3079e7d9eb
37 | def remove_distortion_iter(points, k1, k2):
38 | """
39 | Arguments:
40 | pts (N x 2): numpy array in NDC coordinates
41 | k1, k2 distortion coefficients
42 | Return:
43 | pts (N x 2): distorted points in NDC coordinates
44 | """
45 | pts = ptsd = points
46 | for _ in range(5):
47 | r2 = np.square(pts).sum(-1)
48 | f = 1 + k1 * r2 + k2 * r2**2
49 | pts = ptsd / f[..., None]
50 |
51 | return pts
52 |
53 |
54 | def make_square(im, fill_color=(0, 0, 0)):
55 | x, y = im.size
56 | size = max(x, y)
57 | new_im = Image.new("RGB", (size, size), fill_color)
58 | corner = (int((size - x) / 2), int((size - y) / 2))
59 | new_im.paste(im, corner)
60 | return new_im, corner
61 |
62 |
63 | def pixel_to_ndc(coords, image_size):
64 | """
65 | Converts pixel coordinates to normalized device coordinates (Pytorch3D convention
66 | with upper left = (1, 1)) for a square image.
67 |
68 | Args:
69 | coords: Pixel coordinates UL=(0, 0), LR=(image_size, image_size).
70 | image_size (int): Image size.
71 |
72 | Returns:
73 | NDC coordinates UL=(1, 1) LR=(-1, -1).
74 | """
75 | coords = np.array(coords)
76 | return 1 - coords / image_size * 2
77 |
78 |
79 | def ndc_to_pixel(coords, image_size):
80 | """
81 | Converts normalized device coordinates to pixel coordinates for a square image.
82 | """
83 | num_points = coords.shape[0]
84 | sizes = np.tile(np.array(image_size, dtype=np.float32)[None, ...], (num_points, 1))
85 |
86 | coords = np.array(coords, dtype=np.float32)
87 | return (1 - coords) * sizes / 2
88 |
89 |
90 | def distort_image(image, bbox, k1, k2, modify_bbox=False):
91 | # We want to operate in -1 to 1 space using the padded square of the original image
92 | image, corner = make_square(image)
93 | bbox[:2] += np.array(corner)
94 | bbox[2:] += np.array(corner)
95 |
96 | # Construct grid points
97 | x = np.linspace(1, -1, image.width, dtype=np.float32)
98 | y = np.linspace(1, -1, image.height, dtype=np.float32)
99 | x, y = np.meshgrid(x, y, indexing="xy")
100 | xy_grid = np.stack((x, y), axis=-1)
101 | points = xy_grid.reshape((image.height * image.width, 2))
102 | new_points = ndc_to_pixel(apply_distortion(points, k1, k2), image.size)
103 |
104 | # Distort image by remapping
105 | map_x = new_points[:, 0].reshape((image.height, image.width))
106 | map_y = new_points[:, 1].reshape((image.height, image.width))
107 | distorted = cv2.remap(
108 | np.asarray(image),
109 | map_x,
110 | map_y,
111 | cv2.INTER_LINEAR,
112 | )
113 | distorted = Image.fromarray(distorted)
114 |
115 | # Find distorted crop bounds - inverse process of above
116 | if modify_bbox:
117 | center = (bbox[:2] + bbox[2:]) / 2
118 | top, bottom = (bbox[0], center[1]), (bbox[2], center[1])
119 | left, right = (center[0], bbox[1]), (center[0], bbox[3])
120 | bbox_points = np.array(
121 | [
122 | pixel_to_ndc(top, image.size),
123 | pixel_to_ndc(left, image.size),
124 | pixel_to_ndc(bottom, image.size),
125 | pixel_to_ndc(right, image.size),
126 | ],
127 | dtype=np.float32,
128 | )
129 | else:
130 | bbox_points = np.array(
131 | [pixel_to_ndc(bbox[:2], image.size), pixel_to_ndc(bbox[2:], image.size)],
132 | dtype=np.float32,
133 | )
134 |
135 | # Inverse mapping
136 | distorted_bbox = remove_distortion_iter(bbox_points, k1, k2)
137 |
138 | if modify_bbox:
139 | p = ndc_to_pixel(distorted_bbox, image.size)
140 | distorted_bbox = np.array([p[0][0], p[1][1], p[2][0], p[3][1]])
141 | else:
142 | distorted_bbox = ndc_to_pixel(distorted_bbox, image.size).reshape(4)
143 |
144 | return distorted, distorted_bbox
145 |
--------------------------------------------------------------------------------
/diffusionsfm/utils/distributed.py:
--------------------------------------------------------------------------------
1 | import os
2 | import socket
3 | from contextlib import closing
4 |
5 | import torch.distributed as dist
6 |
7 |
8 | def get_open_port():
9 | with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
10 | s.bind(("", 0))
11 | s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
12 | return s.getsockname()[1]
13 |
14 |
15 | # Distributed process group
16 | def ddp_setup(rank, world_size, port="12345"):
17 | """
18 | Args:
19 | rank: Unique Identifier
20 | world_size: number of processes
21 | """
22 | os.environ["MASTER_ADDR"] = "localhost"
23 | print(f"MasterPort: {str(port)}")
24 | os.environ["MASTER_PORT"] = str(port)
25 |
26 | # initialize the process group
27 | dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
28 |
29 |
30 | def cleanup():
31 | dist.destroy_process_group()
32 |
--------------------------------------------------------------------------------
/diffusionsfm/utils/geometry.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from pytorch3d.renderer import FoVPerspectiveCameras
4 | from pytorch3d.transforms import quaternion_to_matrix
5 |
6 |
7 | def generate_random_rotations(N=1, device="cpu"):
8 | q = torch.randn(N, 4, device=device)
9 | q = q / q.norm(dim=-1, keepdim=True)
10 | return quaternion_to_matrix(q)
11 |
12 |
13 | def symmetric_orthogonalization(x):
14 | """Maps 9D input vectors onto SO(3) via symmetric orthogonalization.
15 |
16 | x: should have size [batch_size, 9]
17 |
18 | Output has size [batch_size, 3, 3], where each inner 3x3 matrix is in SO(3).
19 | """
20 | m = x.view(-1, 3, 3)
21 | u, s, v = torch.svd(m)
22 | vt = torch.transpose(v, 1, 2)
23 | det = torch.det(torch.matmul(u, vt))
24 | det = det.view(-1, 1, 1)
25 | vt = torch.cat((vt[:, :2, :], vt[:, -1:, :] * det), 1)
26 | r = torch.matmul(u, vt)
27 | return r
28 |
29 |
30 | def get_permutations(num_images):
31 | permutations = []
32 | for i in range(0, num_images):
33 | for j in range(0, num_images):
34 | if i != j:
35 | permutations.append((j, i))
36 |
37 | return permutations
38 |
39 |
40 | def n_to_np_rotations(num_frames, n_rots):
41 | R_pred_rel = []
42 | permutations = get_permutations(num_frames)
43 | for i, j in permutations:
44 | R_pred_rel.append(n_rots[i].T @ n_rots[j])
45 | R_pred_rel = torch.stack(R_pred_rel)
46 |
47 | return R_pred_rel
48 |
49 |
50 | def compute_angular_error_batch(rotation1, rotation2):
51 | R_rel = np.einsum("Bij,Bjk ->Bik", rotation2, rotation1.transpose(0, 2, 1))
52 | t = (np.trace(R_rel, axis1=1, axis2=2) - 1) / 2
53 | theta = np.arccos(np.clip(t, -1, 1))
54 | return theta * 180 / np.pi
55 |
56 |
57 | # A should be GT, B should be predicted
58 | def compute_optimal_alignment(A, B):
59 | """
60 | Compute the optimal scale s, rotation R, and translation t that minimizes:
61 | || A - (s * B @ R + T) || ^ 2
62 |
63 | Reference: Umeyama (TPAMI 91)
64 |
65 | Args:
66 | A (torch.Tensor): (N, 3).
67 | B (torch.Tensor): (N, 3).
68 |
69 | Returns:
70 | s (float): scale.
71 | R (torch.Tensor): rotation matrix (3, 3).
72 | t (torch.Tensor): translation (3,).
73 | """
74 | A_bar = A.mean(0)
75 | B_bar = B.mean(0)
76 | # normally with R @ B, this would be A @ B.T
77 | H = (B - B_bar).T @ (A - A_bar)
78 | U, S, Vh = torch.linalg.svd(H, full_matrices=True)
79 | s = torch.linalg.det(U @ Vh)
80 | S_prime = torch.diag(torch.tensor([1, 1, torch.sign(s)], device=A.device))
81 | variance = torch.sum((B - B_bar) ** 2)
82 | scale = 1 / variance * torch.trace(torch.diag(S) @ S_prime)
83 | R = U @ S_prime @ Vh
84 | t = A_bar - scale * B_bar @ R
85 |
86 | A_hat = scale * B @ R + t
87 | return A_hat, scale, R, t
88 |
89 |
90 | def compute_optimal_translation_alignment(T_A, T_B, R_B):
91 | """
92 | Assuming right-multiplied rotation matrices.
93 |
94 | E.g., for world2cam R and T, a world coordinate is transformed to camera coordinate
95 | system using X_cam = X_world.T @ R + T = R.T @ X_world + T
96 |
97 | Finds s, t that minimizes || T_A - (s * T_B + R_B.T @ t) ||^2
98 |
99 | Args:
100 | T_A (torch.Tensor): Target translation (N, 3).
101 | T_B (torch.Tensor): Initial translation (N, 3).
102 | R_B (torch.Tensor): Initial rotation (N, 3, 3).
103 |
104 | Returns:
105 | T_A_hat (torch.Tensor): s * T_B + t @ R_B (N, 3).
106 | scale s (torch.Tensor): (1,).
107 | translation t (torch.Tensor): (1, 3).
108 | """
109 | n = len(T_A)
110 |
111 | T_A = T_A.unsqueeze(2)
112 | T_B = T_B.unsqueeze(2)
113 |
114 | A = torch.sum(T_B * T_A)
115 | B = (T_B.transpose(1, 2) @ R_B.transpose(1, 2)).sum(0) @ (R_B @ T_A).sum(0) / n
116 | C = torch.sum(T_B * T_B)
117 | D = (T_B.transpose(1, 2) @ R_B.transpose(1, 2)).sum(0)
118 | E = (D * D).sum() / n
119 |
120 | s = (A - B.sum()) / (C - E.sum())
121 |
122 | t = (R_B @ (T_A - s * T_B)).sum(0) / n
123 |
124 | T_A_hat = s * T_B + R_B.transpose(1, 2) @ t
125 |
126 | return T_A_hat.squeeze(2), s, t.transpose(1, 0)
127 |
128 |
129 | def get_error(predict_rotations, R_pred, T_pred, R_gt, T_gt, gt_scene_scale):
130 | if predict_rotations:
131 | cameras_gt = FoVPerspectiveCameras(R=R_gt, T=T_gt)
132 | cc_gt = cameras_gt.get_camera_center()
133 | cameras_pred = FoVPerspectiveCameras(R=R_pred, T=T_pred)
134 | cc_pred = cameras_pred.get_camera_center()
135 |
136 | A_hat, _, _, _ = compute_optimal_alignment(cc_gt, cc_pred)
137 | norm = torch.linalg.norm(cc_gt - A_hat, dim=1) / gt_scene_scale
138 |
139 | norms = np.ndarray.tolist(norm.detach().cpu().numpy())
140 | return norms, A_hat
141 | else:
142 | T_A_hat, _, _ = compute_optimal_translation_alignment(T_gt, T_pred, R_pred)
143 | norm = torch.linalg.norm(T_gt - T_A_hat, dim=1) / gt_scene_scale
144 | norms = np.ndarray.tolist(norm.detach().cpu().numpy())
145 | return norms, T_A_hat
146 |
--------------------------------------------------------------------------------
/diffusionsfm/utils/normalize.py:
--------------------------------------------------------------------------------
1 | import ipdb # noqa: F401
2 | import torch
3 | from pytorch3d.transforms import Rotate, Translate
4 |
5 |
6 | def intersect_skew_line_groups(p, r, mask=None):
7 | # p, r both of shape (B, N, n_intersected_lines, 3)
8 | # mask of shape (B, N, n_intersected_lines)
9 | p_intersect, r = intersect_skew_lines_high_dim(p, r, mask=mask)
10 | if p_intersect is None:
11 | return None, None, None, None
12 | _, p_line_intersect = point_line_distance(
13 | p, r, p_intersect[..., None, :].expand_as(p)
14 | )
15 | intersect_dist_squared = ((p_line_intersect - p_intersect[..., None, :]) ** 2).sum(
16 | dim=-1
17 | )
18 | return p_intersect, p_line_intersect, intersect_dist_squared, r
19 |
20 |
21 | def intersect_skew_lines_high_dim(p, r, mask=None):
22 | # Implements https://en.wikipedia.org/wiki/Skew_lines In more than two dimensions
23 | dim = p.shape[-1]
24 | # make sure the heading vectors are l2-normed
25 | if mask is None:
26 | mask = torch.ones_like(p[..., 0])
27 | r = torch.nn.functional.normalize(r, dim=-1)
28 |
29 | eye = torch.eye(dim, device=p.device, dtype=p.dtype)[None, None]
30 | I_min_cov = (eye - (r[..., None] * r[..., None, :])) * mask[..., None, None]
31 | sum_proj = I_min_cov.matmul(p[..., None]).sum(dim=-3)
32 |
33 | p_intersect = torch.linalg.lstsq(I_min_cov.sum(dim=-3), sum_proj).solution[..., 0]
34 |
35 | if torch.any(torch.isnan(p_intersect)):
36 | print(p_intersect)
37 | return None, None
38 | ipdb.set_trace()
39 | assert False
40 | return p_intersect, r
41 |
42 |
43 | def point_line_distance(p1, r1, p2):
44 | df = p2 - p1
45 | proj_vector = df - ((df * r1).sum(dim=-1, keepdim=True) * r1)
46 | line_pt_nearest = p2 - proj_vector
47 | d = (proj_vector).norm(dim=-1)
48 | return d, line_pt_nearest
49 |
50 |
51 | def compute_optical_axis_intersection(cameras):
52 | centers = cameras.get_camera_center()
53 | principal_points = cameras.principal_point
54 |
55 | one_vec = torch.ones((len(cameras), 1), device=centers.device)
56 | optical_axis = torch.cat((principal_points, one_vec), -1)
57 |
58 | pp = cameras.unproject_points(optical_axis, from_ndc=True, world_coordinates=True)
59 | pp2 = torch.diagonal(pp, dim1=0, dim2=1).T
60 |
61 | directions = pp2 - centers
62 | centers = centers.unsqueeze(0).unsqueeze(0)
63 | directions = directions.unsqueeze(0).unsqueeze(0)
64 |
65 | p_intersect, p_line_intersect, _, r = intersect_skew_line_groups(
66 | p=centers, r=directions, mask=None
67 | )
68 |
69 | if p_intersect is None:
70 | dist = None
71 | else:
72 | p_intersect = p_intersect.squeeze().unsqueeze(0)
73 | dist = (p_intersect - centers).norm(dim=-1)
74 |
75 | return p_intersect, dist, p_line_intersect, pp2, r
76 |
77 |
78 | def first_camera_transform(cameras, rotation_only=True):
79 | new_cameras = cameras.clone()
80 | new_transform = new_cameras.get_world_to_view_transform()
81 | tR = Rotate(new_cameras.R[0].unsqueeze(0))
82 | if rotation_only:
83 | t = tR.inverse()
84 | else:
85 | tT = Translate(new_cameras.T[0].unsqueeze(0))
86 | t = tR.compose(tT).inverse()
87 |
88 | new_transform = t.compose(new_transform)
89 | new_cameras.R = new_transform.get_matrix()[:, :3, :3]
90 | new_cameras.T = new_transform.get_matrix()[:, 3, :3]
91 |
92 | return new_cameras
93 |
--------------------------------------------------------------------------------
/diffusionsfm/utils/slurm.py:
--------------------------------------------------------------------------------
1 | import os
2 | import signal
3 | import subprocess
4 | import sys
5 | import time
6 |
7 |
8 | def submitit_job_watcher(jobs, check_period: int = 15):
9 | job_out = {}
10 |
11 | try:
12 | while True:
13 | job_states = [job.state for job in jobs.values()]
14 | state_counts = {
15 | state: len([j for j in job_states if j == state])
16 | for state in set(job_states)
17 | }
18 |
19 | n_done = sum(job.done() for job in jobs.values())
20 |
21 | for job_name, job in jobs.items():
22 | if job_name not in job_out and job.done():
23 | job_out[job_name] = {
24 | "stderr": job.stderr(),
25 | "stdout": job.stdout(),
26 | }
27 |
28 | exc = job.exception()
29 | if exc is not None:
30 | print(f"{job_name} crashed!!!")
31 | if job_out[job_name]["stderr"] is not None:
32 | print("===== STDERR =====")
33 | print(job_out[job_name]["stderr"])
34 | else:
35 | print(f"{job_name} done!")
36 |
37 | print("Job states:")
38 | for state, count in state_counts.items():
39 | print(f" {state:15s} {count:6d} ({100.*count/len(jobs):.1f}%)")
40 |
41 | if n_done == len(jobs):
42 | print("All done!")
43 | return
44 |
45 | time.sleep(check_period)
46 |
47 | except KeyboardInterrupt:
48 | for job_name, job in jobs.items():
49 | if not job.done():
50 | print(f"Killing {job_name}")
51 | job.cancel(check=False)
52 |
53 |
54 | def get_jid():
55 | if "SLURM_ARRAY_TASK_ID" in os.environ:
56 | return f"{os.environ['SLURM_ARRAY_JOB_ID']}_{os.environ['SLURM_ARRAY_TASK_ID']}"
57 | return os.environ["SLURM_JOB_ID"]
58 |
59 |
60 | def signal_helper(signum, frame):
61 | print(f"Caught signal {signal.Signals(signum).name} on for the this job")
62 | jid = get_jid()
63 | cmd = ["scontrol", "requeue", jid]
64 | try:
65 | print("calling", cmd)
66 | rtn = subprocess.check_call(cmd)
67 | print("subprocc", rtn)
68 | except:
69 | print("subproc call failed")
70 | return sys.exit(10)
71 |
72 |
73 | def bypass(signum, frame):
74 | print(f"Ignoring signal {signal.Signals(signum).name} on for the this job")
75 |
76 |
77 | def init_slurm_signals():
78 | signal.signal(signal.SIGCONT, bypass)
79 | signal.signal(signal.SIGCHLD, bypass)
80 | signal.signal(signal.SIGTERM, bypass)
81 | signal.signal(signal.SIGUSR2, signal_helper)
82 | print("SLURM signal installed", flush=True)
83 |
84 |
85 | def init_slurm_signals_if_slurm():
86 | if "SLURM_JOB_ID" in os.environ:
87 | init_slurm_signals()
88 |
--------------------------------------------------------------------------------
/diffusionsfm/utils/visualization.py:
--------------------------------------------------------------------------------
1 | from http.client import MOVED_PERMANENTLY
2 | import io
3 |
4 | import ipdb # noqa: F401
5 | import matplotlib.pyplot as plt
6 | import numpy as np
7 | import trimesh
8 | import torch
9 | import torchvision
10 | from pytorch3d.loss import chamfer_distance
11 | from scipy.spatial.transform import Rotation
12 |
13 | from diffusionsfm.inference.ddim import inference_ddim
14 | from diffusionsfm.utils.rays import (
15 | Rays,
16 | cameras_to_rays,
17 | rays_to_cameras,
18 | rays_to_cameras_homography,
19 | )
20 | from diffusionsfm.utils.geometry import (
21 | compute_optimal_alignment,
22 | )
23 |
24 | cmap = plt.get_cmap("hsv")
25 |
26 |
27 | def create_training_visualizations(
28 | model,
29 | images,
30 | device,
31 | cameras_gt,
32 | num_images,
33 | crop_parameters,
34 | pred_x0=False,
35 | no_crop_param_device="cpu",
36 | visualize_pred=False,
37 | return_first=False,
38 | calculate_intrinsics=False,
39 | mode=None,
40 | depths=None,
41 | scale_min=-1,
42 | scale_max=1,
43 | diffuse_depths=False,
44 | vis_mode=None,
45 | average_centers=True,
46 | full_num_patches_x=16,
47 | full_num_patches_y=16,
48 | use_homogeneous=False,
49 | distortion_coefficients=None,
50 | ):
51 |
52 | if model.depth_resolution == 1:
53 | W_in = W_out = full_num_patches_x
54 | H_in = H_out = full_num_patches_y
55 | else:
56 | W_in = H_in = model.width
57 | W_out = model.width * model.depth_resolution
58 | H_out = model.width * model.depth_resolution
59 |
60 | rays_final, rays_intermediate, pred_intermediate, _ = inference_ddim(
61 | model,
62 | images,
63 | device,
64 | crop_parameters=crop_parameters,
65 | eta=[1, 0],
66 | num_patches_x=W_in,
67 | num_patches_y=H_in,
68 | visualize=True,
69 | )
70 |
71 | if vis_mode is None:
72 | vis_mode = mode
73 |
74 | T = model.noise_scheduler.max_timesteps
75 | if T == 1000:
76 | ts = [0, 100, 200, 300, 400, 500, 600, 700, 800, 900, 999]
77 | else:
78 | ts = [0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 99]
79 |
80 | # Get predicted cameras from rays
81 | pred_cameras_batched = []
82 | vis_images = []
83 | pred_rays = []
84 | for index in range(len(images)):
85 | pred_cameras = []
86 | per_sample_images = []
87 | for ii in range(num_images):
88 | rays_gt = cameras_to_rays(
89 | cameras_gt[index],
90 | crop_parameters[index],
91 | no_crop_param_device=no_crop_param_device,
92 | num_patches_x=W_in,
93 | num_patches_y=H_in,
94 | depths=None if depths is None else depths[index],
95 | mode=mode,
96 | depth_resolution=model.depth_resolution,
97 | distortion_coefficients=(
98 | None
99 | if distortion_coefficients is None
100 | else distortion_coefficients[index]
101 | ),
102 | )
103 | image_vis = (images[index, ii].cpu().permute(1, 2, 0).numpy() + 1) / 2
104 |
105 | if diffuse_depths:
106 | fig, axs = plt.subplots(3, 13, figsize=(15, 4.5), dpi=100)
107 | else:
108 | fig, axs = plt.subplots(3, 9, figsize=(12, 4.5), dpi=100)
109 |
110 | for i, t in enumerate(ts):
111 | r, c = i // 4, i % 4
112 | if visualize_pred:
113 | curr = pred_intermediate[t][index]
114 | else:
115 | curr = rays_intermediate[t][index]
116 | rays = Rays.from_spatial(
117 | curr,
118 | mode=mode,
119 | num_patches_x=H_in,
120 | num_patches_y=W_in,
121 | use_homogeneous=use_homogeneous,
122 | )
123 |
124 | if vis_mode == "segment":
125 | vis = (
126 | torch.clip(
127 | rays.get_segments()[ii], min=scale_min, max=scale_max
128 | )
129 | - scale_min
130 | ) / (scale_max - scale_min)
131 |
132 | else:
133 | vis = (
134 | torch.nn.functional.normalize(rays.get_moments()[ii], dim=-1)
135 | + 1
136 | ) / 2
137 |
138 | axs[r, c].imshow(vis.reshape(W_out, H_out, 3).cpu())
139 | axs[r, c].set_title(f"T={T - t}")
140 |
141 | i += 1
142 | r, c = i // 4, i % 4
143 |
144 | if vis_mode == "segment":
145 | vis = (
146 | torch.clip(rays_gt.get_segments()[ii], min=scale_min, max=scale_max)
147 | - scale_min
148 | ) / (scale_max - scale_min)
149 | else:
150 | vis = (
151 | torch.nn.functional.normalize(rays_gt.get_moments()[ii], dim=-1) + 1
152 | ) / 2
153 |
154 | axs[r, c].imshow(vis.reshape(W_out, H_out, 3).cpu())
155 |
156 | type_str = "Endpoints" if vis_mode == "segment" else "Moments"
157 | axs[r, c].set_title(f"GT {type_str}")
158 |
159 | for i, t in enumerate(ts):
160 | r, c = i // 4, i % 4 + 4
161 | if visualize_pred:
162 | curr = pred_intermediate[t][index]
163 | else:
164 | curr = rays_intermediate[t][index]
165 | rays = Rays.from_spatial(
166 | curr,
167 | mode,
168 | num_patches_x=H_in,
169 | num_patches_y=W_in,
170 | use_homogeneous=use_homogeneous,
171 | )
172 |
173 | if vis_mode == "segment":
174 | vis = (
175 | torch.clip(
176 | rays.get_origins(high_res=True)[ii],
177 | min=scale_min,
178 | max=scale_max,
179 | )
180 | - scale_min
181 | ) / (scale_max - scale_min)
182 | else:
183 | vis = (
184 | torch.nn.functional.normalize(rays.get_directions()[ii], dim=-1)
185 | + 1
186 | ) / 2
187 |
188 | axs[r, c].imshow(vis.reshape(W_out, H_out, 3).cpu())
189 | axs[r, c].set_title(f"T={T - t}")
190 |
191 | i += 1
192 | r, c = i // 4, i % 4 + 4
193 |
194 | if vis_mode == "segment":
195 | vis = (
196 | torch.clip(
197 | rays_gt.get_origins(high_res=True)[ii],
198 | min=scale_min,
199 | max=scale_max,
200 | )
201 | - scale_min
202 | ) / (scale_max - scale_min)
203 | else:
204 | vis = (
205 | torch.nn.functional.normalize(rays_gt.get_directions()[ii], dim=-1)
206 | + 1
207 | ) / 2
208 | axs[r, c].imshow(vis.reshape(W_out, H_out, 3).cpu())
209 | type_str = "Origins" if vis_mode == "segment" else "Directions"
210 | axs[r, c].set_title(f"GT {type_str}")
211 |
212 | if diffuse_depths:
213 | for i, t in enumerate(ts):
214 | r, c = i // 4, i % 4 + 8
215 | if visualize_pred:
216 | curr = pred_intermediate[t][index]
217 | else:
218 | curr = rays_intermediate[t][index]
219 | rays = Rays.from_spatial(
220 | curr,
221 | mode,
222 | num_patches_x=H_in,
223 | num_patches_y=W_in,
224 | use_homogeneous=use_homogeneous,
225 | )
226 |
227 | vis = rays.depths[ii]
228 | if len(rays.depths[ii].shape) < 2:
229 | vis = rays.depths[ii].reshape(H_out, W_out)
230 |
231 | axs[r, c].imshow(vis.cpu())
232 | axs[r, c].set_title(f"T={T - t}")
233 |
234 | i += 1
235 | r, c = i // 4, i % 4 + 8
236 |
237 | vis = depths[index][ii]
238 | if len(rays.depths[ii].shape) < 2:
239 | vis = depths[index][ii].reshape(256, 256)
240 |
241 | axs[r, c].imshow(vis.cpu())
242 | axs[r, c].set_title(f"GT Depths")
243 |
244 | axs[2, -1].imshow(image_vis)
245 | axs[2, -1].set_title("Input Image")
246 | for s in ["bottom", "top", "left", "right"]:
247 | axs[2, -1].spines[s].set_color(cmap(ii / (num_images)))
248 | axs[2, -1].spines[s].set_linewidth(5)
249 |
250 | for ax in axs.flatten():
251 | ax.set_xticks([])
252 | ax.set_yticks([])
253 | plt.tight_layout()
254 | img = plot_to_image(fig)
255 | plt.close()
256 | per_sample_images.append(img)
257 |
258 | if return_first:
259 | rays_camera = pred_intermediate[0][index]
260 | elif pred_x0:
261 | rays_camera = pred_intermediate[-1][index]
262 | else:
263 | rays_camera = rays_final[index]
264 | rays = Rays.from_spatial(
265 | rays_camera,
266 | mode=mode,
267 | num_patches_x=H_in,
268 | num_patches_y=W_in,
269 | use_homogeneous=use_homogeneous,
270 | )
271 | if calculate_intrinsics:
272 | pred_camera = rays_to_cameras_homography(
273 | rays=rays[ii, None],
274 | crop_parameters=crop_parameters[index],
275 | num_patches_x=W_in,
276 | num_patches_y=H_in,
277 | average_centers=average_centers,
278 | depth_resolution=model.depth_resolution,
279 | )
280 | else:
281 | pred_camera = rays_to_cameras(
282 | rays=rays[ii, None],
283 | crop_parameters=crop_parameters[index],
284 | no_crop_param_device=no_crop_param_device,
285 | num_patches_x=W_in,
286 | num_patches_y=H_in,
287 | depth_resolution=model.depth_resolution,
288 | average_centers=average_centers,
289 | )
290 | pred_cameras.append(pred_camera[0])
291 | pred_rays.append(rays)
292 |
293 | pred_cameras_batched.append(pred_cameras)
294 | vis_images.append(np.vstack(per_sample_images))
295 |
296 | return vis_images, pred_cameras_batched, pred_rays
297 |
298 |
299 | def plot_to_image(figure, dpi=100):
300 | """Converts matplotlib fig to a png for logging with tf.summary.image."""
301 | buffer = io.BytesIO()
302 | figure.savefig(buffer, format="raw", dpi=dpi)
303 | plt.close(figure)
304 | buffer.seek(0)
305 | image = np.reshape(
306 | np.frombuffer(buffer.getvalue(), dtype=np.uint8),
307 | newshape=(int(figure.bbox.bounds[3]), int(figure.bbox.bounds[2]), -1),
308 | )
309 | return image[..., :3]
310 |
311 |
312 | def view_color_coded_images_from_tensor(images, depth=False):
313 | num_frames = images.shape[0]
314 | cmap = plt.get_cmap("hsv")
315 | num_rows = 3
316 | num_cols = 3
317 | figsize = (num_cols * 2, num_rows * 2)
318 | fig, axs = plt.subplots(num_rows, num_cols, figsize=figsize)
319 | axs = axs.flatten()
320 | for i in range(num_rows * num_cols):
321 | if i < num_frames:
322 | if images[i].shape[0] == 3:
323 | image = images[i].permute(1, 2, 0)
324 | else:
325 | image = images[i].unsqueeze(-1)
326 |
327 | if not depth:
328 | image = image * 0.5 + 0.5
329 | else:
330 | image = image.repeat((1, 1, 3)) / torch.max(image)
331 |
332 | axs[i].imshow(image)
333 | for s in ["bottom", "top", "left", "right"]:
334 | axs[i].spines[s].set_color(cmap(i / (num_frames)))
335 | axs[i].spines[s].set_linewidth(5)
336 | axs[i].set_xticks([])
337 | axs[i].set_yticks([])
338 | else:
339 | axs[i].axis("off")
340 | plt.tight_layout()
341 | return fig
342 |
343 |
344 | def color_and_filter_points(points, images, mask, num_show, resolution):
345 | # Resize images
346 | resize = torchvision.transforms.Resize(resolution)
347 | images = resize(images) * 0.5 + 0.5
348 |
349 | # Reshape points and calculate mask
350 | points = points.reshape(num_show * resolution * resolution, 3)
351 | mask = mask.reshape(num_show * resolution * resolution)
352 | depth_mask = torch.argwhere(mask > 0.5)[:, 0]
353 | points = points[depth_mask]
354 |
355 | # Mask and reshape colors
356 | colors = images.permute(0, 2, 3, 1).reshape(num_show * resolution * resolution, 3)
357 | colors = colors[depth_mask]
358 |
359 | return points, colors
360 |
361 |
362 | def filter_and_align_point_clouds(
363 | num_frames,
364 | gt_points,
365 | pred_points,
366 | gt_masks,
367 | pred_masks,
368 | images,
369 | metrics=False,
370 | num_patches_x=16,
371 | ):
372 |
373 | # Filter and color points
374 | gt_points, gt_colors = color_and_filter_points(
375 | gt_points, images, gt_masks, num_show=num_frames, resolution=num_patches_x
376 | )
377 | pred_points, pred_colors = color_and_filter_points(
378 | pred_points, images, pred_masks, num_show=num_frames, resolution=num_patches_x
379 | )
380 |
381 | pred_points, _, _, _ = compute_optimal_alignment(
382 | gt_points.float(), pred_points.float()
383 | )
384 |
385 | # Scale PCL so that furthest point from centroid is distance 1
386 | centroid = torch.mean(gt_points, dim=0)
387 | dists = torch.norm(gt_points - centroid.unsqueeze(0), dim=-1)
388 | scale = torch.mean(dists)
389 | gt_points_scaled = (gt_points - centroid) / scale
390 | pred_points_scaled = (pred_points - centroid) / scale
391 |
392 | if metrics:
393 |
394 | cd, _ = chamfer_distance(
395 | pred_points_scaled.unsqueeze(0), gt_points_scaled.unsqueeze(0)
396 | )
397 | cd = cd.item()
398 | mse = torch.mean(
399 | torch.norm(pred_points_scaled - gt_points_scaled, dim=-1), dim=-1
400 | ).item()
401 | else:
402 | mse, cd = None, None
403 |
404 | return (
405 | gt_points,
406 | pred_points,
407 | gt_colors,
408 | pred_colors,
409 | [mse, cd, None],
410 | )
411 |
412 |
413 | def add_scene_cam(scene, c2w, edge_color, image=None, focal=None, imsize=None, screen_width=0.03):
414 | OPENGL = np.array([
415 | [1, 0, 0, 0],
416 | [0, -1, 0, 0],
417 | [0, 0, -1, 0],
418 | [0, 0, 0, 1]
419 | ])
420 |
421 | if image is not None:
422 | H, W, THREE = image.shape
423 | assert THREE == 3
424 | if image.dtype != np.uint8:
425 | image = np.uint8(255*image)
426 | elif imsize is not None:
427 | W, H = imsize
428 | elif focal is not None:
429 | H = W = focal / 1.1
430 | else:
431 | H = W = 1
432 |
433 | if focal is None:
434 | focal = min(H, W) * 1.1 # default value
435 | elif isinstance(focal, np.ndarray):
436 | focal = focal[0]
437 |
438 | # create fake camera
439 | height = focal * screen_width / H
440 | width = screen_width * 0.5**0.5
441 | rot45 = np.eye(4)
442 | rot45[:3, :3] = Rotation.from_euler('z', np.deg2rad(45)).as_matrix()
443 | rot45[2, 3] = -height # set the tip of the cone = optical center
444 | aspect_ratio = np.eye(4)
445 | aspect_ratio[0, 0] = W/H
446 | transform = c2w @ OPENGL @ aspect_ratio @ rot45
447 | cam = trimesh.creation.cone(width, height, sections=4)
448 |
449 | # this is the camera mesh
450 | rot2 = np.eye(4)
451 | rot2[:3, :3] = Rotation.from_euler('z', np.deg2rad(4)).as_matrix()
452 | vertices = cam.vertices
453 | vertices_offset = 0.9 * cam.vertices
454 | vertices = np.r_[vertices, vertices_offset, geotrf(rot2, cam.vertices)]
455 | vertices = geotrf(transform, vertices)
456 | faces = []
457 | for face in cam.faces:
458 | if 0 in face:
459 | continue
460 | a, b, c = face
461 | a2, b2, c2 = face + len(cam.vertices)
462 |
463 | # add 3 pseudo-edges
464 | faces.append((a, b, b2))
465 | faces.append((a, a2, c))
466 | faces.append((c2, b, c))
467 |
468 | faces.append((a, b2, a2))
469 | faces.append((a2, c, c2))
470 | faces.append((c2, b2, b))
471 |
472 | # no culling
473 | faces += [(c, b, a) for a, b, c in faces]
474 |
475 | for i,face in enumerate(cam.faces):
476 | if 0 in face:
477 | continue
478 |
479 | if i == 1 or i == 5:
480 | a, b, c = face
481 | faces.append((a, b, c))
482 |
483 | cam = trimesh.Trimesh(vertices=vertices, faces=faces)
484 | cam.visual.face_colors[:, :3] = edge_color
485 |
486 | scene.add_geometry(cam)
487 |
488 |
489 | def geotrf(Trf, pts, ncol=None, norm=False):
490 | """ Apply a geometric transformation to a list of 3-D points.
491 |
492 | H: 3x3 or 4x4 projection matrix (typically a Homography)
493 | p: numpy/torch/tuple of coordinates. Shape must be (...,2) or (...,3)
494 |
495 | ncol: int. number of columns of the result (2 or 3)
496 | norm: float. if != 0, the resut is projected on the z=norm plane.
497 |
498 | Returns an array of projected 2d points.
499 | """
500 | assert Trf.ndim >= 2
501 | if isinstance(Trf, np.ndarray):
502 | pts = np.asarray(pts)
503 | elif isinstance(Trf, torch.Tensor):
504 | pts = torch.as_tensor(pts, dtype=Trf.dtype)
505 |
506 | # adapt shape if necessary
507 | output_reshape = pts.shape[:-1]
508 | ncol = ncol or pts.shape[-1]
509 |
510 | # optimized code
511 | if (isinstance(Trf, torch.Tensor) and isinstance(pts, torch.Tensor) and
512 | Trf.ndim == 3 and pts.ndim == 4):
513 | d = pts.shape[3]
514 | if Trf.shape[-1] == d:
515 | pts = torch.einsum("bij, bhwj -> bhwi", Trf, pts)
516 | elif Trf.shape[-1] == d+1:
517 | pts = torch.einsum("bij, bhwj -> bhwi", Trf[:, :d, :d], pts) + Trf[:, None, None, :d, d]
518 | else:
519 | raise ValueError(f'bad shape, not ending with 3 or 4, for {pts.shape=}')
520 | else:
521 | if Trf.ndim >= 3:
522 | n = Trf.ndim-2
523 | assert Trf.shape[:n] == pts.shape[:n], 'batch size does not match'
524 | Trf = Trf.reshape(-1, Trf.shape[-2], Trf.shape[-1])
525 |
526 | if pts.ndim > Trf.ndim:
527 | # Trf == (B,d,d) & pts == (B,H,W,d) --> (B, H*W, d)
528 | pts = pts.reshape(Trf.shape[0], -1, pts.shape[-1])
529 | elif pts.ndim == 2:
530 | # Trf == (B,d,d) & pts == (B,d) --> (B, 1, d)
531 | pts = pts[:, None, :]
532 |
533 | if pts.shape[-1]+1 == Trf.shape[-1]:
534 | Trf = Trf.swapaxes(-1, -2) # transpose Trf
535 | pts = pts @ Trf[..., :-1, :] + Trf[..., -1:, :]
536 | elif pts.shape[-1] == Trf.shape[-1]:
537 | Trf = Trf.swapaxes(-1, -2) # transpose Trf
538 | pts = pts @ Trf
539 | else:
540 | pts = Trf @ pts.T
541 | if pts.ndim >= 2:
542 | pts = pts.swapaxes(-1, -2)
543 |
544 | if norm:
545 | pts = pts / pts[..., -1:] # DONT DO /= BECAUSE OF WEIRD PYTORCH BUG
546 | if norm != 1:
547 | pts *= norm
548 |
549 | res = pts[..., :ncol].reshape(*output_reshape, ncol)
550 | return res
--------------------------------------------------------------------------------
/docs/eval.md:
--------------------------------------------------------------------------------
1 | ## Evaluation Directions
2 |
3 | Use the scripts from `diffusionsfm/eval` to evaluate the performance of the dense model on the CO3D dataset:
4 |
5 | ```
6 | python -m diffusionsfm.eval.eval_jobs --eval_path output/multi_diffusionsfm_dense --use_submitit
7 | ```
8 |
9 | **Note:** The `use_submitit` flag is optional. If you have a SLURM system available, enabling it will dispatch jobs in parallel across available GPUs, significantly accelerating the evaluation process.
10 |
11 | The expected output at the end of evaluating the dense model is:
12 |
13 | ```
14 | N= 2 3 4 5 6 7 8
15 | Seen R 0.926 0.941 0.946 0.950 0.953 0.955 0.955
16 | Seen CC 1.000 0.956 0.934 0.924 0.917 0.911 0.907
17 | Seen CD 0.023 0.023 0.026 0.026 0.028 0.031 0.030
18 | Seen CD_Obj 0.040 0.037 0.033 0.032 0.032 0.032 0.033
19 | Unseen R 0.913 0.928 0.938 0.945 0.950 0.951 0.953
20 | Unseen CC 1.000 0.926 0.884 0.870 0.864 0.851 0.847
21 | Unseen CD 0.024 0.024 0.025 0.024 0.025 0.026 0.027
22 | Unseen CD_Obj 0.028 0.023 0.022 0.022 0.023 0.021 0.020
23 | ```
24 |
25 | This reports rotation and camera center accuracy, as well as Chamfer Distance on both all points (CD) and foreground points (CD_Obj), evaluated on held-out sequences from both seen and unseen object categories using varying numbers of input images. Performance is averaged over five runs to reduce variance.
26 |
27 | Note that minor variations in the reported numbers may occur due to randomness in the evaluation and inference processes.
--------------------------------------------------------------------------------
/docs/train.md:
--------------------------------------------------------------------------------
1 | ## Training Directions
2 |
3 | ### Prepare CO3D Dataset
4 |
5 | Please refer to the instructions from [RayDiffusion](https://github.com/jasonyzhang/RayDiffusion/blob/main/docs/train.md#training-directions) to set up the CO3D dataset.
6 |
7 | ### Setting up `accelerate`
8 |
9 | Use `accelerate config` to set up `accelerate`. We recommend using multiple GPUs without any mixed precision (we handle AMP ourselves).
10 |
11 | ### Training models
12 |
13 | Our model is trained in two stages. In the first stage, we train a *sparse model* that predicts ray origins and endpoints at a low resolution (16×16). In the second stage, we initialize the dense model using the DiT weights from the sparse model and append a DPT decoder to produce high-resolution outputs (256×256 ray origins and endpoints).
14 |
15 | To train the sparse model, run:
16 |
17 | ```
18 | accelerate launch --multi_gpu --gpu_ids 0,1,2,3,4,5,6,7 --num_processes 8 train.py \
19 | training.batch_size=8 \
20 | training.max_iterations=400000 \
21 | model.num_images=8 \
22 | dataset.name=co3d \
23 | debug.project_name=diffusionsfm_co3d \
24 | debug.run_name=co3d_diffusionsfm_sparse
25 | ```
26 |
27 | To train the dense model (initialized from the sparse model weights), run:
28 |
29 | ```
30 | accelerate launch --multi_gpu --gpu_ids 0,1,2,3,4,5,6,7 --num_processes 8 train.py \
31 | training.batch_size=4 \
32 | training.max_iterations=800000 \
33 | model.num_images=8 \
34 | dataset.name=co3d \
35 | debug.project_name=diffusionsfm_co3d \
36 | debug.run_name=co3d_diffusionsfm_dense \
37 | training.dpt_head=True \
38 | training.full_num_patches_x=256 \
39 | training.full_num_patches_y=256 \
40 | training.gradient_clipping=True \
41 | training.reinit=True \
42 | training.freeze_encoder=True \
43 | model.freeze_transformer=True \
44 | training.pretrain_path=.pth
45 | ```
46 |
47 | Some notes:
48 |
49 | - `batch_size` refers to the batch size per GPU. The total batch size will be `batch_size * num_gpu`.
50 | - Depending on your setup, you can adjust the number of GPUs and batch size. You may also need to adjust the number of training iterations accordingly.
51 | - You can resume training from a checkpoint by specifying `train.resume=True hydra.run.dir=/path/to/your/output_dir`
52 | - If you are getting NaNs, try turning off mixed precision. This will increase the amount of memory used.
53 |
54 | For debugging, we recommend using a single-GPU job with a single category:
55 |
56 | ```
57 | accelerate launch train.py training.batch_size=4 dataset.category=apple debug.wandb=False hydra.run.dir=output_debug
58 | ```
--------------------------------------------------------------------------------
/gradio_app.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import shutil
4 | import argparse
5 | import functools
6 | import torch
7 | import torchvision
8 | from PIL import Image
9 | import gradio as gr
10 | import numpy as np
11 | import matplotlib.pyplot as plt
12 | import trimesh
13 |
14 | from diffusionsfm.dataset.custom import CustomDataset
15 | from diffusionsfm.dataset.co3d_v2 import unnormalize_image
16 | from diffusionsfm.inference.load_model import load_model
17 | from diffusionsfm.inference.predict import predict_cameras
18 | from diffusionsfm.utils.visualization import add_scene_cam
19 |
20 |
21 | def info_fn():
22 | gr.Info("Data preprocessing completed!")
23 |
24 |
25 | def get_select_index(evt: gr.SelectData):
26 | selected = evt.index
27 | return examples_full[selected][0], selected
28 |
29 |
30 | def check_img_input(control_image):
31 | if control_image is None:
32 | raise gr.Error("Please select or upload an input image.")
33 |
34 |
35 | def preprocess(args, image_block, selected):
36 | cate_name = time.strftime("%m%d_%H%M%S") if selected is None else examples_list[selected]
37 |
38 | demo_dir = os.path.join(args.output_dir, f'demo/{cate_name}')
39 | shutil.rmtree(demo_dir, ignore_errors=True)
40 |
41 | os.makedirs(os.path.join(demo_dir, 'source'), exist_ok=True)
42 | os.makedirs(os.path.join(demo_dir, 'processed'), exist_ok=True)
43 |
44 | dataset = CustomDataset(image_block)
45 | batch = dataset.get_data()
46 | batch['cate_name'] = cate_name
47 |
48 | processed_image_block = []
49 | for i, file_path in enumerate(image_block):
50 | file_name = os.path.basename(file_path)
51 | raw_img = Image.open(file_path)
52 | try:
53 | raw_img.save(os.path.join(demo_dir, 'source', file_name))
54 | except OSError:
55 | raw_img.convert('RGB').save(os.path.join(demo_dir, 'source', file_name))
56 |
57 | batch['image_for_vis'][i].save(os.path.join(demo_dir, 'processed', file_name))
58 | processed_image_block.append(os.path.join(demo_dir, 'processed', file_name))
59 |
60 | return processed_image_block, batch
61 |
62 |
63 | def transform_cameras(pred_cameras):
64 | num_cameras = pred_cameras.R.shape[0]
65 | Rs = pred_cameras.R.transpose(1, 2).detach()
66 | ts = pred_cameras.T.unsqueeze(-1).detach()
67 | c2ws = torch.zeros(num_cameras, 4, 4)
68 | c2ws[:, :3, :3] = Rs
69 | c2ws[:, :3, -1:] = ts
70 | c2ws[:, 3, 3] = 1
71 | c2ws[:, :2] *= -1 # PyTorch3D to OpenCV
72 | c2ws = torch.linalg.inv(c2ws).numpy()
73 |
74 | return c2ws
75 |
76 |
77 | def run_inference(args, cfg, model, batch):
78 | device = args.device
79 | images = batch["image"].to(device)
80 | crop_parameters = batch["crop_parameters"].to(device)
81 |
82 | (pred_cameras, pred_rays), _ = predict_cameras(
83 | model=model,
84 | images=images,
85 | device=device,
86 | crop_parameters=crop_parameters,
87 | stop_iteration=90,
88 | num_patches_x=cfg.training.full_num_patches_x,
89 | num_patches_y=cfg.training.full_num_patches_y,
90 | calculate_intrinsics=True,
91 | max_num_images=8,
92 | mode="segment",
93 | return_rays=True,
94 | use_homogeneous=True,
95 | seed=0,
96 | )
97 |
98 | # Unnormalize and resize input images
99 | images = unnormalize_image(images, return_numpy=False, return_int=False)
100 | images = torchvision.transforms.Resize(256)(images)
101 | rgbs = images.permute(0, 2, 3, 1).contiguous().view(-1, 3)
102 | xyzs = pred_rays.get_segments().view(-1, 3).cpu()
103 |
104 | # Create point cloud and scene
105 | scene = trimesh.Scene()
106 | point_cloud = trimesh.points.PointCloud(xyzs, colors=rgbs)
107 | scene.add_geometry(point_cloud)
108 |
109 | # Add predicted cameras to the scene
110 | num_images = images.shape[0]
111 | c2ws = transform_cameras(pred_cameras)
112 | cmap = plt.get_cmap("hsv")
113 |
114 | for i, c2w in enumerate(c2ws):
115 | color_rgb = (np.array(cmap(i / num_images))[:3] * 255).astype(int)
116 | add_scene_cam(
117 | scene=scene,
118 | c2w=c2w,
119 | edge_color=color_rgb,
120 | image=None,
121 | focal=None,
122 | imsize=(256, 256),
123 | screen_width=0.1
124 | )
125 |
126 | # Export GLB
127 | cate_name = batch['cate_name']
128 | output_path = os.path.join(args.output_dir, f'demo/{cate_name}/{cate_name}.glb')
129 | scene.export(output_path)
130 |
131 | return output_path
132 |
133 |
134 | if __name__ == "__main__":
135 | parser = argparse.ArgumentParser()
136 | parser.add_argument('--output_dir', default='output/multi_diffusionsfm_dense', type=str, help='Output directory')
137 | parser.add_argument('--device', default='cuda', type=str, help='Device to run inference on')
138 | args = parser.parse_args()
139 |
140 | _TITLE = "DiffusionSfM: Predicting Structure and Motion via Ray Origin and Endpoint Diffusion"
141 | _DESCRIPTION = """
142 |