├── .gitignore
├── README.md
├── asset
└── teasure_figure.png
├── config
├── base_config.yaml
└── experiments
│ └── generation.yaml
├── download.sh
├── inference
└── generate.py
├── model
├── gsdecoder
│ ├── camera_embedding.py
│ ├── cuda_splatting.py
│ ├── decoder_splatting.py
│ ├── gs_decoder_architecture.py
│ └── load_gsdecoder.py
├── multiview_rf
│ ├── load_mv_sd3.py
│ ├── mv_sd3_architecture.py
│ └── text_embedding.py
├── refiner
│ ├── camera_util.py
│ ├── gs_util.py
│ └── sds_pp_refiner.py
└── util.py
├── requirements.txt
└── util
├── camera_visualization.py
└── dist_util.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .vscode
2 | baseline_results
3 | checkpoints
4 | final
5 | final_latest_4
6 | output
7 | point_clouds
8 | results
9 | sample_ablation
10 | sample_scene
11 | samples
12 | teasure
13 | tmp
14 | wandb
15 | __pycache__
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # [CVPR 2025] SplatFlow: Multi-View Rectified Flow Model for 3D Gaussian Splatting Synthesis
2 |
3 |
4 |
5 |
6 |
7 |
8 | This repository contains the official pytorch implementation of the paper: "SplatFlow: Multi-View Rectified Flow Model for 3D Gaussian Splatting Synthesis".
9 | 
10 |
11 |
12 | ## Installation
13 | Our code is tested with Python 3.10, Pytorch 2.4.0, and CUDA 11.8. To install the required packages, run the following command:
14 | ```bash
15 | pip3 install torch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 --index-url https://download.pytorch.org/whl/cu118
16 | pip install -r requirements.txt
17 | ```
18 |
19 | ## Model Checkpoints
20 | We have extended the model training period after submission to enhance its performance. Updated model checkpoints are now available, and benchmark results will be revised following the review process. You can download the model checkpoints with the following command:
21 | ```
22 | bash download.sh
23 | ```
24 |
25 |
26 |
27 | ## Inference
28 | ### 1. 3DGS Generation
29 |
30 | ```bash
31 | export PYTHONPATH=$(pwd)
32 | huggingface-cli login
33 | python inference/generate.py +experiments=generation inference.generate.prompt="Your prompt here"
34 | ```
35 |
36 | ## Citation
37 | If you find this repository helpful for your project, please consider citing our work. :)
38 | ```
39 | @article{go2024splatflow,
40 | title={SplatFlow: Multi-View Rectified Flow Model for 3D Gaussian Splatting Synthesis},
41 | author={Go, Hyojun and Park, Byeongjun and Jang, Jiho and Kim, Jin-Young and Kwon, Soonwoo and Kim, Changick},
42 | journal={arXiv preprint arXiv:2411.16443},
43 | year={2024}
44 | }
45 | ```
46 |
47 | ## Acknolwedgement
48 | We thank [director3d](https://github.com/imlixinyang/Director3D)
49 |
50 |
51 | ## TODO:
52 | - [ ] Update project page.
53 | - [x] Add model checkpoints
54 | - [x] Code verification.
55 | - [ ] Add more details on the README.md
56 | - [ ] Add the training script
57 | - [ ] Share text annotations for RealEstate10K, DL3DV-10K, MVImgNet, and ACID datasets. (Used in the next work, may be integrated with https://github.com/gohyojun15/VideoRFSplat)
58 |
--------------------------------------------------------------------------------
/asset/teasure_figure.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gohyojun15/SplatFlow/6964764ca8ac4ea2ca7c6e4e2f662457cec26040/asset/teasure_figure.png
--------------------------------------------------------------------------------
/config/base_config.yaml:
--------------------------------------------------------------------------------
1 |
2 | general:
3 | iterations: 0
4 | num_workers: 0
5 | global_batch_size: 0
6 | global_seed: 42
7 | gpu_offset: 0
8 | mixed_precision: true
9 | sampled_view: 8
10 | debug: true
11 |
12 | optim:
13 | lr: 0
14 | weight_decay: 0
15 | warmup_steps: 0
16 |
17 | gsdecoder:
18 | gan_loss:
19 | enable: true
20 | disc_start: 0
21 | disc_weight: 0.1
22 | loading_ckpt_path: Null
23 |
24 | mv_rf_model:
25 | resume_from_ckpt: Null
26 | weighting_scheme: "logit_normal"
27 | logit_mean: 0
28 | logit_std: 1.0
29 | mode_scale: 1.29
30 | precondition_outputs: true
31 | hf_path: "stabilityai/stable-diffusion-3-medium-diffusers"
32 |
33 | depth_encoder:
34 | hf_path: "depth-anything/Depth-Anything-V2-Small-hf"
35 |
36 | defaults:
37 | - _self_
38 | - override hydra/job_logging: disabled
39 | - override hydra/hydra_logging: disabled
40 |
41 |
42 | # Inference configuration.
43 | inference:
44 | mv_rf_ckpt: checkpoints/mv_rf_ema.pt
45 | gsdecoder_ckpt: checkpoints/gs_decoder.pt
46 | sample:
47 | num_steps: 200
48 | cfg: true
49 | cfg_scale: [7.0, 5.0, 1.0]
50 | stop_ray: 50
51 | sd3_guidance: true
52 | sd3_cfg: true
53 | sd3_scale: 3.0
54 |
55 | hydra:
56 | run:
57 | dir: .
58 | output_subdir: null
--------------------------------------------------------------------------------
/config/experiments/generation.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | inference:
4 | # generation
5 | generate:
6 | save_path: "samples"
7 | prompt: "A blue car on a clean street with buildings in the background"
8 | refiner:
9 | args:
10 | sd_model_key: 'stabilityai/stable-diffusion-2-1-base'
11 | num_views: 1
12 | img_size: 512
13 | guidance_scale: 7.5
14 | min_step_percent: 0.02
15 | max_step_percent: 0.5
16 | num_densifications: 4
17 | lr_scale: 0.25
18 | lrs: {'xyz': 2e-4, 'features': 1e-2, 'opacity': 5e-2, 'scales': 1e-3, 'rotations': 1e-2, 'embeddings': 1e-2}
19 | use_lods: True
20 | lambda_latent_sds: 1
21 | lambda_image_sds: 0.1
22 | lambda_image_variation: 0.001
23 | opacity_threshold: 0.001
24 | text_templete: $text$
25 | negative_text_templete: 'unclear. noisy. point cloud. low-res. low-quality. low-resolution. unrealistic.'
26 | total_iterations: 1000
--------------------------------------------------------------------------------
/download.sh:
--------------------------------------------------------------------------------
1 |
2 | # Via google drive
3 | # gdown 1Ch9YK0eA7-alMIKK8NxKaoyJ7rWZiqif
4 | # gdown 1BUXCmR7jDTiGfbV55NHf7dZp6GsV5LCf
5 |
6 |
7 | mkdir -p checkpoints
8 | cd checkpoints
9 | wget https://huggingface.co/HJGO/splatflow/resolve/main/gs_decoder.pt?download=true -O gs_decoder.pt
10 | wait
11 |
12 | wget https://huggingface.co/HJGO/splatflow/resolve/main/mv_rf_ema.pt?download=true -O mv_rf_ema.pt
13 | wait
--------------------------------------------------------------------------------
/inference/generate.py:
--------------------------------------------------------------------------------
1 | import os
2 | from pathlib import Path
3 |
4 | import hydra
5 | import matplotlib.pyplot as plt
6 | import torch
7 | import torchvision
8 | from diffusers import FlowMatchEulerDiscreteScheduler
9 | from einops import rearrange, repeat
10 | from pytorch3d.utils import cameras_from_opencv_projection
11 | from scipy.spatial.transform import Rotation as R
12 | from tqdm import tqdm
13 |
14 | from model.gsdecoder.camera_embedding import get_plucker_rays, optimize_plucker_ray
15 | from model.gsdecoder.load_gsdecoder import create_gsdecoder
16 | from model.multiview_rf.load_mv_sd3 import create_sd_multiview_rf_model
17 | from model.multiview_rf.text_embedding import compute_text_embeddings
18 | from model.refiner.camera_util import (
19 | export_mv,
20 | export_ply_for_gaussians,
21 | export_video,
22 | load_ply_for_gaussians,
23 | )
24 | from model.refiner.sds_pp_refiner import GSRefinerSDSPlusPlus
25 | from model.util import create_sd3_transformer, create_vae
26 | from util import camera_visualization, dist_util
27 |
28 |
29 | def visualize_camera_pose(cameras, path):
30 | # visualize cameras
31 | extrinsic = cameras[0][:, :, 3:]
32 | device = cameras.device
33 | num_view = len(extrinsic)
34 | homo_extrinsic = repeat(torch.eye(4).to(device, dtype=extrinsic.dtype), 'i j -> v i j', v=num_view).clone()
35 | homo_extrinsic[:, :3] = extrinsic
36 | w2c = homo_extrinsic.inverse()
37 | image_size = repeat(torch.tensor([256, 256], device=device), "i -> v i", v=num_view)
38 | cameras_ours = cameras_from_opencv_projection(
39 | R=w2c[:, :3, :3],
40 | tvec=w2c[:, :3, -1],
41 | camera_matrix=cameras[0, :, :, :3],
42 | image_size=image_size,
43 | )
44 | fig = plt.figure()
45 | ax = fig.add_subplot(projection="3d")
46 | ax.clear()
47 | points = camera_visualization.plot_cameras(ax, cameras_ours)
48 | cc = points[:, -1]
49 | max_scene = cc.max(dim=0)[0].cpu()
50 | min_scene = cc.min(dim=0)[0].cpu()
51 | ax.set_xticklabels([])
52 | ax.set_yticklabels([])
53 | ax.set_zticklabels([])
54 | ax.set_xlim3d([min_scene[0] - 0.1, max_scene[0] + 0.3])
55 | ax.set_ylim3d([min_scene[2] - 0.1, max_scene[2] + 0.1])
56 | ax.set_zlim3d([min_scene[1] - 0.1, max_scene[1] + 0.1])
57 | ax.invert_yaxis()
58 | plt.savefig(os.path.join(path, "pose.pdf"), bbox_inches="tight", pad_inches=0, transparent=True)
59 |
60 | def save_results(path, output, poses, Ks):
61 | means, covariance, opacity, rgb, rotation, scale = output
62 | _, v, h, w, _ = means.shape
63 |
64 | means = rearrange(means, "() v h w xyz -> (v h w) xyz")
65 | opacity = rearrange(opacity, "() v h w o -> (v h w) o")
66 | rgb = rearrange(rgb, "() v h w rgb -> (v h w) rgb")
67 | rotation = rearrange(rotation, "() v h w q -> (v h w) q")
68 | scale = rearrange(scale, "() v h w s -> (v h w) s")
69 | source_rotations = repeat(poses[..., :3, :3], "() v i j -> (v h w) i j", h=h, w=w)
70 |
71 | cam_rotation_matrix = R.from_quat(rotation.detach().cpu().numpy()).as_matrix()
72 | world_rotation_matrix = source_rotations.detach().cpu().numpy() @ cam_rotation_matrix
73 | world_rotations = R.from_matrix(world_rotation_matrix).as_quat()
74 | world_rotations = torch.from_numpy(world_rotations).to(source_rotations.device)
75 |
76 | export_ply_for_gaussians(os.path.join(path, 'gaussian'), (means, rgb[:, None], opacity, scale, world_rotations))
77 | cameras = torch.cat([Ks[0], poses[0, :, :3]], dim=-1)
78 | torch.save(cameras, path / "cameras.pt")
79 |
80 |
81 | def generate_sampling(
82 | text_prompt,
83 | text_encoders,
84 | tokenizers,
85 | model,
86 | gs_decoder,
87 | vae,
88 | noise_scheduler,
89 | stable_diffusion3_transformer,
90 | cfg,
91 | device,
92 | dtype,
93 | ):
94 | """
95 | Generating a sample, supposing that batch size is 1
96 | """
97 | batch_size = 1
98 |
99 | if isinstance(text_prompt, str):
100 | text_prompt = [text_prompt]
101 | negative_prompt = [""]
102 |
103 | noise_scheduler.set_timesteps(cfg.inference.sample.num_steps)
104 | timesteps = noise_scheduler.timesteps
105 | latents = torch.randn(batch_size, 8, 38, 32, 32, device=device, dtype=dtype)
106 |
107 | with torch.no_grad():
108 | prompt_embeds, pooled_prompt_embeds = compute_text_embeddings(
109 | text_prompt,
110 | text_encoders,
111 | tokenizers,
112 | 77,
113 | device,
114 | )
115 | if cfg.inference.sample.cfg:
116 | negative_prompt_embeds, pooled_negative_prompt_embeds = compute_text_embeddings(
117 | negative_prompt, text_encoders, tokenizers, 77, device=device
118 | )
119 | prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
120 | pooled_prompt_embeds = torch.cat([pooled_negative_prompt_embeds, pooled_prompt_embeds], dim=0)
121 | cfg_scales = []
122 | cfg_scales.append([cfg.inference.sample.cfg_scale[0]] * 16) # Image
123 | cfg_scales.append([cfg.inference.sample.cfg_scale[1]] * 16) # Depth
124 | cfg_scales.append([cfg.inference.sample.cfg_scale[2]] * 6) # Pose
125 | cfg_scales = torch.tensor(sum(cfg_scales, []), dtype=latents.dtype, device=latents.device)
126 | cfg_scales = repeat(cfg_scales, "i -> () () i () ()")
127 |
128 | for i, t in tqdm(enumerate(timesteps)):
129 | latent_model_input = torch.cat([latents] * 2) if cfg.inference.sample.cfg else latents
130 | timestep = t.expand(latent_model_input.shape[0]).to(device)
131 |
132 | noise_pred = model(
133 | hidden_states=latent_model_input,
134 | timestep=timestep,
135 | encoder_hidden_states=prompt_embeds,
136 | pooled_projections=pooled_prompt_embeds,
137 | return_dict=False,
138 | )[0]
139 |
140 | if cfg.inference.sample.cfg:
141 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
142 | noise_pred = noise_pred_uncond + cfg_scales * (noise_pred_text - noise_pred_uncond)
143 |
144 | if noise_scheduler.step_index is None:
145 | noise_scheduler._init_step_index(t)
146 |
147 | if noise_scheduler.step_index <= cfg.inference.sample.stop_ray:
148 | with torch.enable_grad():
149 | original_latent = latents - noise_pred * noise_scheduler.sigmas[noise_scheduler.step_index]
150 | poses, Ks, inv_poses = optimize_plucker_ray(original_latent[:, :, -6:])
151 | clean_ray_latent = get_plucker_rays(poses[:, :, :3], Ks, is_diffusion=True)
152 |
153 | if cfg.inference.sample.sd3_guidance and (i % 3 == 0) and (i < cfg.inference.sample.stop_ray):
154 | image_2d_flatten_latent = rearrange(latent_model_input, "b v c h w -> (b v) c h w")
155 | image_timestep = timestep.repeat(latent_model_input.shape[1])
156 |
157 | if cfg.inference.sample.sd3_cfg:
158 | positive_prompt_embed = prompt_embeds[1].repeat(latent_model_input.shape[1], 1, 1)
159 | negative_prompt_embed = prompt_embeds[0].repeat(latent_model_input.shape[1], 1, 1)
160 | image_prompt_embed = torch.cat([negative_prompt_embed, positive_prompt_embed], dim=0)
161 |
162 | positive_pooled_prompt_embed = pooled_prompt_embeds[1].repeat(latent_model_input.shape[1], 1)
163 | negative_pooled_prompt_embed = pooled_prompt_embeds[0].repeat(latent_model_input.shape[1], 1)
164 | image_pooled_prompt_embed = torch.cat(
165 | [negative_pooled_prompt_embed, positive_pooled_prompt_embed], dim=0
166 | )
167 | else:
168 | image_prompt_embed = prompt_embeds.repeat(latent_model_input.shape[1], 1, 1)
169 | image_pooled_prompt_embed = pooled_prompt_embeds.repeat(latent_model_input.shape[1], 1)
170 |
171 | sd_noise_pred = stable_diffusion3_transformer(
172 | hidden_states=image_2d_flatten_latent[:, :16, :, :],
173 | timestep=image_timestep,
174 | encoder_hidden_states=image_prompt_embed,
175 | pooled_projections=image_pooled_prompt_embed,
176 | return_dict=False,
177 | )[0]
178 |
179 | if cfg.inference.sample.sd3_cfg:
180 | sd_noise_pred_uncond, sd_noise_pred_text = sd_noise_pred.chunk(2)
181 | sd_noise_pred = sd_noise_pred_uncond + cfg.inference.sample.sd3_scale * (
182 | sd_noise_pred_text - sd_noise_pred_uncond
183 | )
184 | sd_noise_pred = rearrange(sd_noise_pred, "(b v) c h w -> b v c h w", v=latent_model_input.shape[1])
185 | noise_pred[:, :, :16, :, :] = sd_noise_pred
186 |
187 | latents = noise_scheduler.step(noise_pred, t, latents, return_dict=False)[0]
188 | # update ray latent
189 | sigmas = noise_scheduler.sigmas[noise_scheduler.step_index]
190 | noise = torch.randn_like(latents[:, :, 32:])
191 | noisy_ray_latent = (1.0 - sigmas) * clean_ray_latent + sigmas * noise
192 | latents[:, :, -6:] = noisy_ray_latent
193 |
194 | with torch.inference_mode():
195 | img_latent, depth_latent, _ = latents.split([16, 16, 6], dim=2)
196 | output = gs_decoder(img_latent, depth_latent, None, poses[:, :, :3], Ks, near=0.05, far=20)
197 | means, covariance, opacity, rgb, rotation, scale = output
198 | # save results
199 | path = Path(f"{cfg.inference.generate.save_path}/{text_prompt[0]}")
200 | os.makedirs(path, exist_ok=True)
201 |
202 | text_prompt_save = path / "text_prompt.txt"
203 | with open(text_prompt_save, "w", encoding="utf-8") as f:
204 | f.write(text_prompt[0])
205 |
206 | # Save 3DGS and camera parameters
207 | save_results(path, output, poses, Ks)
208 |
209 | ### Save images
210 | rendered_images = gs_decoder.eval_rendering(
211 | means.float(),
212 | covariance.float(),
213 | opacity.float(),
214 | rgb.float(),
215 | poses[:, :, :3].float(),
216 | Ks.float(),
217 | near=0.05,
218 | far=20,
219 | )
220 |
221 | rendered_images = rendered_images.clamp(min=0, max=1)
222 | transform = torchvision.transforms.ToPILImage()
223 |
224 | mv_results_path = path / "mv_results"
225 | os.makedirs(mv_results_path, exist_ok=True)
226 |
227 | for view, img in enumerate(rendered_images):
228 | transform(img).save(mv_results_path / f"render_img_{view}.png")
229 |
230 |
231 | @hydra.main(config_path="../config", config_name="base_config.yaml", version_base="1.1")
232 | def main(cfg):
233 | dist_util.setup_dist(cfg.general)
234 | device = dist_util.device()
235 | dtype = torch.float16
236 | print(f"Device: {device}")
237 |
238 | path = Path(f"{cfg.inference.generate.save_path}/{cfg.inference.generate.prompt}")
239 |
240 | if os.path.exists(path / "cameras.pt") and os.path.exists(path / "gaussian.ply"):
241 | print("Step 1 is already done") # Skip generating First step
242 | else:
243 | """
244 | First Step: Generate a initial 3DGS
245 | """
246 | # create vae part
247 | vae = create_vae(cfg)
248 | decoder = create_gsdecoder(cfg)
249 | vae, decoder = vae.to(device=device, dtype=dtype), decoder.to(device=device)
250 | decoder.load_state_dict(torch.load(cfg.inference.gsdecoder_ckpt, map_location="cpu", weights_only=False))
251 | vae.eval(), decoder.eval()
252 |
253 | # MV RF
254 | model, tokenizer, text_encoders = create_sd_multiview_rf_model()
255 | model = model.to(device=device, dtype=dtype)
256 | model.load_state_dict(torch.load(cfg.inference.mv_rf_ckpt, map_location="cpu", weights_only=False))
257 | model.eval()
258 |
259 | # Text encoders
260 | text_encoders_list = []
261 | for i, text_encoder in enumerate(text_encoders):
262 | text_encoder.requires_grad_(False)
263 | text_encoder.to(device, dtype=dtype)
264 | text_encoders_list.append(text_encoder.to(device, dtype=dtype))
265 | text_encoders = text_encoders_list
266 |
267 | noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
268 | cfg.mv_rf_model.hf_path,
269 | subfolder="scheduler",
270 | force_download=False,
271 | shift=1.0,
272 | )
273 |
274 | if cfg.inference.sample.sd3_guidance:
275 | stable_diffusion3_transformer = create_sd3_transformer()
276 | stable_diffusion3_transformer = stable_diffusion3_transformer.to(device, dtype=dtype)
277 | stable_diffusion3_transformer.eval()
278 | else:
279 | stable_diffusion3_transformer = None
280 |
281 | generate_sampling(
282 | cfg.inference.generate.prompt,
283 | text_encoders,
284 | tokenizer,
285 | model,
286 | decoder,
287 | vae,
288 | noise_scheduler,
289 | stable_diffusion3_transformer,
290 | cfg,
291 | device,
292 | dtype,
293 | )
294 |
295 | """
296 | Second Step: Refine the 3DGS
297 | """
298 | refiner = GSRefinerSDSPlusPlus(**cfg.inference.generate.refiner.args)
299 | refiner.to(device)
300 | gaussians = load_ply_for_gaussians(path / "gaussian.ply", device=device)
301 | cameras = torch.load(path / "cameras.pt", weights_only=False).unsqueeze(0)
302 | refined_gaussians = refiner.refine_gaussians(gaussians, cfg.inference.generate.prompt, dense_cameras=cameras)
303 |
304 | visualize_camera_pose(cameras, path)
305 | export_ply_for_gaussians(os.path.join(path, 'refined_gaussian'), [p[0] for p in refined_gaussians])
306 |
307 | def render_fn(cameras, h=512, w=512):
308 | return refiner.renderer(cameras, refined_gaussians, h=h, w=w, bg=None)[:2]
309 |
310 | export_mv(
311 | render_fn,
312 | path,
313 | cameras,
314 | )
315 |
316 | export_video(render_fn, path, "refined_video", cameras, device=device)
317 |
318 |
319 | if __name__ == "__main__":
320 | main()
321 |
--------------------------------------------------------------------------------
/model/gsdecoder/camera_embedding.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import torch
3 | from einops import rearrange, repeat
4 | from torch import nn
5 |
6 | # Reference: https://github.com/valeoai/LaRa/blob/main/semanticbev/models/components/LaRa_embeddings.py
7 | # We slightly modify the official code to fit in our setting.
8 |
9 |
10 | def meshgrid(spatial_shape, normalized=True, indexing="ij", device=None):
11 | """Create evenly spaced position coordinates for self.spatial_shape with values in [v_min, v_max].
12 | :param v_min: minimum coordinate value per dimension.
13 | :param v_max: maximum coordinate value per dimension.
14 | :return: position coordinates tensor of shape (*shape, len(shape)).
15 | """
16 | if normalized:
17 | axis_coords = [torch.linspace(-1.0, 1.0, steps=s, device=device) for s in spatial_shape]
18 | else:
19 | axis_coords = [torch.linspace(0, s - 1, steps=s, device=device) for s in spatial_shape]
20 |
21 | grid_coords = torch.meshgrid(*axis_coords, indexing=indexing)
22 |
23 | return torch.stack(grid_coords, dim=-1)
24 |
25 |
26 | def get_plucker_rays(extrinsics, intrinsics, h=32, w=32, stride=8, is_diffusion=False):
27 | b, v = extrinsics.shape[:2]
28 |
29 | # Adjust intrinsics scale due to downsizing by input_stride (we take feature maps as input not the raw images)
30 |
31 | updated_intrinsics = intrinsics.clone().unsqueeze(1) if len(intrinsics.shape) == 3 else intrinsics.clone()
32 | updated_intrinsics[..., 0, 0] *= 1 / stride
33 | updated_intrinsics[..., 0, 2] *= 1 / stride
34 | updated_intrinsics[..., 1, 1] *= 1 / stride
35 | updated_intrinsics[..., 1, 2] *= 1 / stride
36 |
37 | # create positionnal encodings
38 | pixel_coords = meshgrid((w, h), normalized=False, indexing="xy", device=extrinsics.device)
39 | ones = torch.ones((h, w, 1), device=extrinsics.device)
40 |
41 | pixel_coords = torch.cat([pixel_coords, ones], dim=-1) # [x, y, 1] vectors of pixel coordinates
42 | pixel_coords = rearrange(pixel_coords, "h w c -> c (h w)")
43 | pixel_coords = repeat(pixel_coords, "... -> b v ...", b=b, v=v)
44 |
45 | # Split extrinsics into rots and trans, rots == c2w, trans == camera center at world coordinate
46 | rots, trans = extrinsics.split([3, 1], dim=-1)
47 |
48 | # pixel_coords.shape = [B, N, 3, K] | N # of cams, K # of pixels
49 | directions = rots @ updated_intrinsics.inverse() @ pixel_coords
50 | directions = directions / directions.norm(dim=2, keepdim=True)
51 | directions = rearrange(directions, "b v c (h w) -> b v c h w", h=h, w=w)
52 | cam_origins = repeat(trans.squeeze(-1), "b v c -> b v c h w", h=h, w=w)
53 | moments = torch.cross(cam_origins, directions, dim=2)
54 |
55 | output = [] if is_diffusion else [cam_origins]
56 | output.append(directions)
57 | output.append(moments)
58 |
59 | return torch.cat(output, dim=2)
60 |
61 |
62 | def optimize_plucker_ray(ray_latent):
63 | b, v, _, h, w = ray_latent.shape
64 | ray_latent = ray_latent.float() # this function does not support bfloat16 type..
65 | directions, moments = ray_latent.split(3, dim=2)
66 |
67 | # Reverse Process
68 | c = torch.linalg.norm(directions, dim=2, keepdim=True)
69 | origins = torch.cross(directions, moments / c, dim=2)
70 |
71 | new_trans = intersect_skew_lines_high_dim(
72 | rearrange(origins, "b n c h w -> b n (h w) c"), rearrange(directions, "b n c h w -> b n (h w) c")
73 | )
74 |
75 | # Retrieve target rays
76 | I_intrinsic_ = torch.tensor([[1, 0, h // 2], [0, 1, w // 2], [0, 0, 1]], dtype=ray_latent.dtype, device=c.device)
77 | I_intrinsic = repeat(I_intrinsic_, "i j -> b v i j", b=b, v=v)
78 | I_rot = repeat(torch.eye(3, dtype=ray_latent.dtype, device=c.device), "i j -> b v i j", b=b, v=v)
79 |
80 | # create positionnal encodings
81 | pixel_coords = meshgrid((w, h), normalized=False, indexing="xy", device=ray_latent.device)
82 | ones = torch.ones((h, w, 1), device=ray_latent.device)
83 |
84 | pixel_coords = torch.cat([pixel_coords, ones], dim=-1) # [x, y, 1] vectors of pixel coordinates
85 | pixel_coords = rearrange(pixel_coords, "h w c -> c (h w)")
86 | pixel_coords = repeat(pixel_coords, "... -> b v ...", b=b, v=v)
87 |
88 | I_directions = I_rot @ I_intrinsic.inverse() @ pixel_coords
89 | I_directions = I_directions / I_directions.norm(dim=2, keepdim=True)
90 | I_directions = rearrange(I_directions, "b v c (h w) -> b v c h w", h=h, w=w)
91 |
92 | new_rots, new_intrinsics = [], []
93 | for bb in range(b):
94 | Rs, Ks = [], []
95 | for vv in range(v):
96 | R, f, pp = compute_optimal_rotation_intrinsics(
97 | I_directions[bb, vv], directions[bb, vv], reproj_threshold=0.2
98 | )
99 | Rs.append(R)
100 | K = I_intrinsic_.clone()
101 | K[:2, :2] = torch.diag(1 / f)
102 | K[:, -1][:2] += pp
103 | Ks.append(K)
104 |
105 | new_rots.append(torch.stack(Rs))
106 | new_intrinsics.append(torch.stack(Ks))
107 |
108 | new_rots = torch.stack(new_rots)
109 | new_intrinsics = torch.stack(new_intrinsics)
110 |
111 | ff = nn.Parameter(new_intrinsics[..., [0, 1], [0, 1]].mean(1)[0], requires_grad=True)
112 | optimizer = torch.optim.Adam([ff], lr=0.01)
113 | X = torch.tensor(
114 | [[1, 0, h // 2], [0, 1, w // 2], [0, 0, 1]], dtype=ray_latent.dtype, device=c.device, requires_grad=True
115 | )
116 |
117 | # Optimization loop
118 | num_iterations = 10
119 | for i in range(num_iterations + 1):
120 | optimizer.zero_grad()
121 | scale_matrix = torch.diag(torch.cat([ff, torch.ones_like(ff[:1])], dim=0))
122 |
123 | triu_intrinsics = X @ scale_matrix
124 | re_X = repeat(triu_intrinsics, "i j -> b v i j", b=b, v=v)
125 |
126 | I_directions = I_rot @ re_X.inverse() @ pixel_coords
127 | I_directions = I_directions / I_directions.norm(dim=2, keepdim=True)
128 | I_directions = rearrange(I_directions, "b v c (h w) -> b v c h w", h=h, w=w)
129 |
130 | new_rots = []
131 | for bb in range(b):
132 | Rs = []
133 | for vv in range(v):
134 | R = compute_optimal_rotation_alignment(I_directions[bb, vv], directions[bb, vv])
135 | Rs.append(R)
136 |
137 | new_rots.append(torch.stack(Rs))
138 |
139 | new_rots = torch.stack(new_rots)
140 |
141 | if i == num_iterations:
142 | break
143 |
144 | I_directions = new_rots.clone() @ re_X.inverse() @ pixel_coords.clone()
145 | I_directions = I_directions / I_directions.norm(dim=2, keepdim=True)
146 | I_directions = rearrange(I_directions, "b v c (h w) -> b v c h w", h=h, w=w)
147 | loss = torch.norm(I_directions - directions, dim=2).mean()
148 |
149 | loss.backward()
150 | optimizer.step()
151 |
152 | new_rots = new_rots.detach().clone()
153 | new_intrinsics = re_X.detach().clone()
154 |
155 | stride = 8
156 | new_intrinsics[..., 0, 0] *= stride
157 | new_intrinsics[..., 0, 2] *= stride
158 | new_intrinsics[..., 1, 1] *= stride
159 | new_intrinsics[..., 1, 2] *= stride
160 |
161 | # normalize camera pose
162 | new_poses = torch.cat([new_rots, new_trans.unsqueeze(-1)], dim=-1)
163 | bottom = torch.tensor([0, 0, 0, 1], dtype=new_poses.dtype, device=new_poses.device)
164 | homo_poses = torch.cat([new_poses, repeat(bottom, "i -> b v () i", b=b, v=v)], dim=-2)
165 | inv_poses = torch.cat([homo_poses.inverse()[:, :, :3], repeat(bottom, "i -> b v () i", b=b, v=v)], dim=-2)
166 | return homo_poses, new_intrinsics, inv_poses
167 |
168 |
169 | # Refer to RayDiffusion
170 | def intersect_skew_lines_high_dim(p, r):
171 | # p : num views x 3 x num points
172 | # Implements https://en.wikipedia.org/wiki/Skew_lines In more than two dimensions
173 | dim = p.shape[-1]
174 |
175 | eye = torch.eye(dim, device=p.device, dtype=p.dtype)[None, None, None]
176 | I_min_cov = eye - (r[..., None] * r[..., None, :])
177 | sum_proj = I_min_cov.matmul(p[..., None]).sum(dim=-3)
178 |
179 | p_intersect = torch.linalg.lstsq(I_min_cov.sum(dim=-3), sum_proj).solution[..., 0]
180 |
181 | return p_intersect
182 |
183 |
184 | def compute_optimal_rotation_intrinsics(rays_origin, rays_target, z_threshold=1e-8, reproj_threshold=0.2):
185 | """
186 | Note: for some reason, f seems to be 1/f.
187 |
188 | Args:
189 | rays_origin (torch.Tensor): (3, H, W)
190 | rays_target (torch.Tensor): (3, H, W)
191 | z_threshold (float): Threshold for z value to be considered valid.
192 |
193 | Returns:
194 | R (torch.tensor): (3, 3)
195 | focal_length (torch.tensor): (2,)
196 | principal_point (torch.tensor): (2,)
197 | """
198 | device = rays_origin.device
199 | _, h, w = rays_origin.shape
200 |
201 | rays_origin = rearrange(rays_origin, "c h w -> (h w) c")
202 | rays_target = rearrange(rays_target, "c h w -> (h w) c")
203 |
204 | z_mask = torch.logical_and(torch.abs(rays_target) > z_threshold, torch.abs(rays_origin) > z_threshold)[:, 2]
205 | rays_target = rays_target[z_mask]
206 | rays_origin = rays_origin[z_mask]
207 | rays_origin = rays_origin[:, :2] / rays_origin[:, -1:]
208 | rays_target = rays_target[:, :2] / rays_target[:, -1:]
209 |
210 | A, _ = cv2.findHomography(
211 | rays_origin.cpu().numpy(),
212 | rays_target.cpu().numpy(),
213 | cv2.RANSAC,
214 | reproj_threshold,
215 | )
216 | A = torch.from_numpy(A).float().to(device)
217 |
218 | if torch.linalg.det(A) < 0:
219 | A = -A
220 |
221 | R, L = ql_decomposition(A)
222 | L = L / L[2][2]
223 |
224 | f = torch.stack((L[0][0], L[1][1]))
225 | pp = torch.stack((L[2][0], L[2][1]))
226 | return R, f, pp
227 |
228 |
229 | def ql_decomposition(A):
230 | P = torch.tensor([[0, 0, 1], [0, 1, 0], [1, 0, 0]], device=A.device).float()
231 | A_tilde = torch.matmul(A, P)
232 | Q_tilde, R_tilde = torch.linalg.qr(A_tilde)
233 | Q = torch.matmul(Q_tilde, P)
234 | L = torch.matmul(torch.matmul(P, R_tilde), P)
235 | d = torch.diag(L)
236 | Q[:, 0] *= torch.sign(d[0])
237 | Q[:, 1] *= torch.sign(d[1])
238 | Q[:, 2] *= torch.sign(d[2])
239 | L[0] *= torch.sign(d[0])
240 | L[1] *= torch.sign(d[1])
241 | L[2] *= torch.sign(d[2])
242 | return Q, L
243 |
244 |
245 | def compute_optimal_rotation_alignment(A, B):
246 | """
247 | Compute optimal R that minimizes: || A - B @ R ||_F
248 |
249 | Args:
250 | A (torch.Tensor): (3, H, W)
251 | B (torch.Tensor): (3, H, W)
252 |
253 | Returns:
254 | R (torch.tensor): (3, 3)
255 | """
256 | A = rearrange(A, "c h w -> (h w) c")
257 | B = rearrange(B, "c h w -> (h w) c")
258 |
259 | # normally with R @ B, this would be A @ B.T
260 | H = B.T @ A
261 | U, _, Vh = torch.linalg.svd(H, full_matrices=True)
262 | s = torch.linalg.det(U @ Vh)
263 | S_prime = torch.diag(torch.tensor([1, 1, torch.sign(s)], device=A.device))
264 | return U @ S_prime @ Vh
265 |
--------------------------------------------------------------------------------
/model/gsdecoder/cuda_splatting.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from diff_gaussian_rasterization import (
3 | GaussianRasterizationSettings,
4 | GaussianRasterizer,
5 | )
6 | from einops import einsum, rearrange, repeat
7 |
8 |
9 | def get_fov(intrinsics):
10 | intrinsics_inv = intrinsics.inverse()
11 |
12 | def process_vector(vector):
13 | vector = torch.tensor(vector, dtype=torch.float32, device=intrinsics.device)
14 | vector = einsum(intrinsics_inv, vector, "b i j, j -> b i")
15 | return vector / vector.norm(dim=-1, keepdim=True)
16 |
17 | left = process_vector([0, 0.5, 1])
18 | right = process_vector([1, 0.5, 1])
19 | top = process_vector([0.5, 0, 1])
20 | bottom = process_vector([0.5, 1, 1])
21 | fov_x = (left * right).sum(dim=-1).acos()
22 | fov_y = (top * bottom).sum(dim=-1).acos()
23 | return torch.stack((fov_x, fov_y), dim=-1)
24 |
25 |
26 | def get_projection_matrix(near, far, fov_x, fov_y):
27 | """Maps points in the viewing frustum to (-1, 1) on the X/Y axes and (0, 1) on the Z
28 | axis. Differs from the OpenGL version in that Z doesn't have range (-1, 1) after
29 | transformation and that Z is flipped.
30 | """
31 | tan_fov_x = (0.5 * fov_x).tan()
32 | tan_fov_y = (0.5 * fov_y).tan()
33 |
34 | top = tan_fov_y * near
35 | bottom = -top
36 | right = tan_fov_x * near
37 | left = -right
38 |
39 | (b,) = fov_x.shape
40 | result = torch.zeros((b, 4, 4), dtype=torch.float32, device=fov_x.device)
41 | result[:, 0, 0] = 2 * near / (right - left)
42 | result[:, 1, 1] = 2 * near / (top - bottom)
43 | result[:, 0, 2] = (right + left) / (right - left)
44 | result[:, 1, 2] = (top + bottom) / (top - bottom)
45 | result[:, 3, 2] = 1
46 | result[:, 2, 2] = far / (far - near)
47 | result[:, 2, 3] = -(far * near) / (far - near)
48 | return result
49 |
50 |
51 | def render_cuda(
52 | extrinsics,
53 | intrinsics,
54 | near,
55 | far,
56 | image_shape,
57 | num_views,
58 | background_color,
59 | gaussian_means,
60 | gaussian_covariances,
61 | gaussians_rgb,
62 | gaussian_opacities,
63 | ):
64 | b, _, _ = extrinsics.shape
65 | h, w = image_shape
66 |
67 | update_intrinsics = intrinsics.clone()
68 | update_intrinsics[:, 0] *= 1 / w
69 | update_intrinsics[:, 1] *= 1 / h
70 |
71 | fov_x, fov_y = get_fov(update_intrinsics).unbind(dim=-1)
72 | tan_fov_x = (0.5 * fov_x).tan()
73 | tan_fov_y = (0.5 * fov_y).tan()
74 |
75 | projection_matrix = get_projection_matrix(near, far, fov_x, fov_y)
76 | projection_matrix = rearrange(projection_matrix, "b i j -> b j i")
77 | view_matrix = rearrange(extrinsics.inverse(), "b i j -> b j i")
78 | full_projection = view_matrix @ projection_matrix
79 |
80 | gaussian_means = repeat(gaussian_means, "b n c -> (b v) n c", v=num_views)
81 | gaussian_covariances = repeat(gaussian_covariances, "b n i j -> (b v) n i j", v=num_views)
82 | gaussians_rgb = repeat(gaussians_rgb, "b n c -> (b v) n c", v=num_views)
83 | gaussian_opacities = repeat(gaussian_opacities, "b n c -> (b v) n c", v=num_views)
84 |
85 | all_images = []
86 | all_depths = []
87 | for i in range(b):
88 | # Set up a tensor for the gradients of the screen-space means.
89 | mean_gradients = torch.zeros_like(gaussian_means[i], requires_grad=True)
90 | try:
91 | mean_gradients.retain_grad()
92 | except Exception:
93 | pass
94 |
95 | settings = GaussianRasterizationSettings(
96 | image_height=h,
97 | image_width=w,
98 | tanfovx=tan_fov_x[i].item(),
99 | tanfovy=tan_fov_y[i].item(),
100 | bg=background_color[i],
101 | scale_modifier=1.0,
102 | viewmatrix=view_matrix[i],
103 | projmatrix=full_projection[i],
104 | sh_degree=0,
105 | campos=extrinsics[i, :3, 3],
106 | prefiltered=False, # This matches the original usage.
107 | debug=False,
108 | )
109 | rasterizer = GaussianRasterizer(settings)
110 |
111 | row, col = torch.triu_indices(3, 3)
112 |
113 | image, _, depth, _ = rasterizer(
114 | means3D=gaussian_means[i],
115 | means2D=mean_gradients,
116 | shs=None,
117 | colors_precomp=gaussians_rgb[i],
118 | opacities=gaussian_opacities[i],
119 | cov3D_precomp=gaussian_covariances[i, :, row, col],
120 | )
121 | all_images.append(image)
122 | all_depths.append(depth)
123 |
124 | return torch.stack(all_images), torch.stack(all_depths)
125 |
--------------------------------------------------------------------------------
/model/gsdecoder/decoder_splatting.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from einops import einsum, rearrange, repeat
3 | from torch import nn
4 |
5 | from .camera_embedding import meshgrid
6 | from .cuda_splatting import render_cuda
7 |
8 | # Reference: https://github.com/dcharatan/pixelsplat/blob/main/src/model/encoder/common/gaussian_adapter.py
9 | # We slightly modify the official code to fit in our setting.
10 |
11 |
12 | # https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/transforms/rotation_conversions.py
13 | def quaternion_to_matrix(quaternions, eps: float = 1e-8):
14 | # Order changed to match scipy format!
15 | i, j, k, r = torch.unbind(quaternions, dim=-1)
16 | two_s = 2 / ((quaternions * quaternions).sum(dim=-1) + eps)
17 |
18 | o = torch.stack(
19 | (
20 | 1 - two_s * (j * j + k * k),
21 | two_s * (i * j - k * r),
22 | two_s * (i * k + j * r),
23 | two_s * (i * j + k * r),
24 | 1 - two_s * (i * i + k * k),
25 | two_s * (j * k - i * r),
26 | two_s * (i * k - j * r),
27 | two_s * (j * k + i * r),
28 | 1 - two_s * (i * i + j * j),
29 | ),
30 | -1,
31 | )
32 | return rearrange(o, "... (i j) -> ... i j", i=3, j=3)
33 |
34 |
35 | def build_covariance(scale, rotation_xyzw):
36 | scale = scale.diag_embed()
37 | rotation = quaternion_to_matrix(rotation_xyzw)
38 | return rotation @ scale @ rearrange(scale, "... i j -> ... j i") @ rearrange(rotation, "... i j -> ... j i")
39 |
40 |
41 | class DecoderSplatting(nn.Module):
42 | def __init__(self):
43 | super().__init__()
44 | self.register_buffer("background_color", torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32), persistent=False)
45 | self.act_scale = nn.Softplus()
46 | self.act_rgb = nn.Softplus()
47 |
48 | def get_scale_multiplier(self, intrinsics):
49 | pixel_size = torch.ones((2,), dtype=torch.float32, device=intrinsics.device)
50 | xy_multipliers = einsum(
51 | intrinsics[..., :2, :2].inverse(),
52 | pixel_size,
53 | "... i j, j -> ... i",
54 | )
55 | return xy_multipliers.sum(dim=-1)
56 |
57 | def gaussian_adapter(self, raw_gaussians, extrinsics, intrinsics, near=1, far=100):
58 | b, v, c, h, w = raw_gaussians.shape
59 | if len(intrinsics.shape) == 3:
60 | intrinsics = repeat(intrinsics, "b i j -> b v i j", v=v)
61 | extrinsics = repeat(extrinsics, "b v i j -> b v () () i j")
62 | intrinsics = repeat(intrinsics, "b v i j -> b v () () i j")
63 | raw_gaussians = rearrange(raw_gaussians, "b v c h w -> b v h w c")
64 |
65 | rgb, disp, opacity, scales, rotations, xy_offset = raw_gaussians.split((3, 1, 1, 3, 4, 2), dim=-1)
66 |
67 | # calculate xy_offset and origin/direction for each view.
68 | pixel_coords = meshgrid((w, h), normalized=False, indexing="xy", device=raw_gaussians.device)
69 | pixel_coords = repeat(pixel_coords, "h w c -> b v h w c", b=b, v=v)
70 |
71 | coordinates = pixel_coords + (xy_offset.sigmoid() - 0.5)
72 | coordinates = torch.cat([coordinates, torch.ones_like(coordinates[..., :1])], dim=-1)
73 |
74 | directions = einsum(intrinsics.inverse(), coordinates, "... i j, ... j -> ... i")
75 | directions = directions / directions.norm(dim=-1, keepdim=True)
76 | # directions = directions / directions[..., -1:]
77 | directions = torch.cat([directions, torch.zeros_like(directions[..., :1])], dim=-1)
78 | directions = einsum(extrinsics, directions, "... i j, ... j -> ... i")
79 | origins = extrinsics[..., -1].broadcast_to(directions.shape)
80 |
81 | # calculate depth from disparity
82 | depths = 1.0 / (disp.sigmoid() * (1.0 / near - 1.0 / far) + 1.0 / far)
83 |
84 | # calculate all parameters of gaussian splats
85 | means = origins + directions * depths
86 |
87 | multiplier = self.get_scale_multiplier(intrinsics)
88 | scales = self.act_scale(scales) * multiplier[..., None]
89 |
90 | if len(torch.where(scales > 0.05)[0]) > 0:
91 | big_gaussian_reg_loss = torch.mean(scales[torch.where(scales > 0.05)])
92 | else:
93 | big_gaussian_reg_loss = 0
94 |
95 | if len(torch.where(scales < 1e-6)[0]) > 0:
96 | small_gaussian_reg_loss = torch.mean(-torch.log(scales[torch.where(scales < 1e-6)]) * 0.1)
97 | else:
98 | small_gaussian_reg_loss = 0
99 |
100 | rotations = rotations / (rotations.norm(dim=-1, keepdim=True) + 1e-8)
101 | covariances = build_covariance(scales, rotations)
102 | c2w_rotations = extrinsics[..., :3, :3]
103 | covariances = c2w_rotations @ covariances @ c2w_rotations.transpose(-1, -2)
104 |
105 | opacity = opacity.sigmoid()
106 | rgb = self.act_rgb(rgb)
107 | if self.training:
108 | return means, covariances, opacity, rgb, big_gaussian_reg_loss, small_gaussian_reg_loss
109 | else:
110 | return means, covariances, opacity, rgb, rotations, scales
111 |
112 | def forward(self, raw_gaussians, target_extrinsics, source_extrinsics, intrinsics, near=1, far=100):
113 | b, v = target_extrinsics.shape[:2]
114 | h, w = raw_gaussians.shape[-2:]
115 |
116 | means, covariance, opacity, rgb, b_loss, s_loss = self.gaussian_adapter(
117 | raw_gaussians, source_extrinsics, intrinsics, near, far
118 | )
119 | lower_ext = torch.tensor([0.0, 0.0, 0.0, 1.0], device=target_extrinsics.device, dtype=torch.float32)
120 | lower_ext = repeat(lower_ext, "i -> b v () i", b=b, v=v)
121 | homo_extrinsics = torch.cat([target_extrinsics, lower_ext], dim=2)
122 | output_image, output_depth = render_cuda(
123 | rearrange(homo_extrinsics, "b v i j -> (b v) i j"),
124 | repeat(intrinsics, "b i j -> (b v) i j", v=v),
125 | near,
126 | far,
127 | (h, w),
128 | v,
129 | repeat(self.background_color, "i -> bv i", bv=b * v),
130 | rearrange(means, "b v h w xyz -> b (v h w) xyz"),
131 | rearrange(covariance, "b v h w i j -> b (v h w) i j"),
132 | rearrange(rgb, "b v h w c -> b (v h w) c"),
133 | rearrange(opacity, "b v h w o -> b (v h w) o"),
134 | )
135 | return output_image, output_depth, b_loss, s_loss
136 |
137 | def eval_forward(self, means, covariance, opacity, rgb, target_extrinsics, target_intrinsics, near, far):
138 | # TODO: merge eval and training forward logics
139 | _, _, h, w, _ = means.shape
140 | b, target_view_num, _, _ = target_extrinsics.shape
141 |
142 | if len(target_intrinsics.shape) == 3:
143 | target_intrinsics = repeat(target_intrinsics, "b i j -> (b v) i j", v=target_view_num)
144 | else:
145 | target_intrinsics = repeat(target_intrinsics, "b v i j -> (b v) i j")
146 |
147 | lower_ext = torch.tensor([0.0, 0.0, 0.0, 1.0], device=target_extrinsics.device, dtype=torch.float32)
148 | lower_ext = repeat(lower_ext, "i -> b v () i", b=b, v=target_view_num)
149 | homo_extrinsics = torch.cat([target_extrinsics, lower_ext], dim=2)
150 | output_image, output_depth = render_cuda(
151 | rearrange(homo_extrinsics, "b v i j -> (b v) i j"),
152 | target_intrinsics,
153 | near,
154 | far,
155 | (h, w),
156 | target_view_num,
157 | repeat(self.background_color, "i -> bv i", bv=b * target_view_num),
158 | rearrange(means, "b v h w xyz -> b (v h w) xyz"),
159 | rearrange(covariance, "b v h w i j -> b (v h w) i j"),
160 | rearrange(rgb, "b v h w rgb -> b (v h w) rgb"),
161 | rearrange(opacity, "b v h w o -> b (v h w) o"),
162 | )
163 | return output_image
164 |
--------------------------------------------------------------------------------
/model/gsdecoder/gs_decoder_architecture.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Tuple
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from diffusers.models.attention_processor import Attention
7 | from diffusers.models.unets.unet_2d_blocks import UNetMidBlock2D, get_up_block
8 | from einops import rearrange, repeat
9 |
10 | from .camera_embedding import get_plucker_rays
11 | from .decoder_splatting import DecoderSplatting
12 |
13 |
14 | class AttnProcessor2_0_modified:
15 | r"""
16 | We copied the AttnProcessor2_0 class from the diffusers library
17 | """
18 |
19 | def __init__(self):
20 | if not hasattr(F, "scaled_dot_product_attention"):
21 | raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
22 |
23 | def __call__(
24 | self,
25 | attn: Attention,
26 | hidden_states: torch.Tensor,
27 | encoder_hidden_states: Optional[torch.Tensor] = None,
28 | attention_mask: Optional[torch.Tensor] = None,
29 | temb: Optional[torch.Tensor] = None,
30 | num_views: Optional[torch.Tensor] = None,
31 | *args,
32 | **kwargs,
33 | ) -> torch.Tensor:
34 | residual = hidden_states
35 | if attn.spatial_norm is not None:
36 | hidden_states = attn.spatial_norm(hidden_states, temb)
37 |
38 | input_ndim = hidden_states.ndim
39 |
40 | if input_ndim == 4:
41 | batch_size, channel, height, width = hidden_states.shape
42 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
43 |
44 | batch_size, sequence_length, _ = (
45 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
46 | )
47 |
48 | if attention_mask is not None:
49 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
50 | # scaled_dot_product_attention expects attention_mask shape to be
51 | # (batch, heads, source_length, target_length)
52 | attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
53 |
54 | if attn.group_norm is not None:
55 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
56 |
57 | hidden_states = rearrange(hidden_states, "(b v) t c -> b (v t) c", v=num_views)
58 | batch_size = hidden_states.shape[0]
59 | query = attn.to_q(hidden_states)
60 |
61 | if encoder_hidden_states is None:
62 | encoder_hidden_states = hidden_states
63 | elif attn.norm_cross:
64 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
65 |
66 | key = attn.to_k(encoder_hidden_states)
67 | value = attn.to_v(encoder_hidden_states)
68 |
69 | inner_dim = key.shape[-1]
70 | head_dim = inner_dim // attn.heads
71 |
72 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
73 |
74 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
75 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
76 |
77 | # the output of sdp = (batch, num_heads, seq_len, head_dim)
78 | # TODO: add support for attn.scale when we move to Torch 2.1
79 | hidden_states = F.scaled_dot_product_attention(
80 | query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
81 | )
82 |
83 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
84 | hidden_states = hidden_states.to(query.dtype)
85 |
86 | # linear proj
87 | hidden_states = attn.to_out[0](hidden_states)
88 | # dropout
89 | hidden_states = attn.to_out[1](hidden_states)
90 |
91 | hidden_states = rearrange(hidden_states, "b (v t) c -> (b v) t c", v=num_views)
92 | batch_size = hidden_states.shape[0]
93 |
94 | if input_ndim == 4:
95 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
96 |
97 | if attn.residual_connection:
98 | hidden_states = hidden_states + residual
99 |
100 | hidden_states = hidden_states / attn.rescale_output_factor
101 |
102 | return hidden_states
103 |
104 |
105 | class GSDecoder(nn.Module):
106 | def __init__(
107 | self,
108 | in_channels: int = 41,
109 | out_channels: int = 14,
110 | up_block_types: Tuple[str, ...] = (
111 | "UpDecoderBlock2D",
112 | "UpDecoderBlock2D",
113 | "UpDecoderBlock2D",
114 | "UpDecoderBlock2D",
115 | ),
116 | block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
117 | layers_per_block: int = 2,
118 | norm_num_groups: int = 32,
119 | cfg=None,
120 | ):
121 | super(GSDecoder, self).__init__()
122 | self.cfg = cfg
123 | self.conv_in = nn.Conv2d(
124 | in_channels,
125 | block_out_channels[-1],
126 | kernel_size=3,
127 | stride=1,
128 | padding=1,
129 | )
130 |
131 | self.gn = nn.GroupNorm(norm_num_groups, block_out_channels[-1], affine=True)
132 | self.attn_block_pre = nn.ModuleList(
133 | [
134 | Attention(
135 | block_out_channels[-1],
136 | heads=1,
137 | dim_head=block_out_channels[-1],
138 | rescale_output_factor=1,
139 | eps=1e-6,
140 | norm_num_groups=32,
141 | spatial_norm_dim=None,
142 | residual_connection=True,
143 | processor=AttnProcessor2_0_modified(),
144 | bias=True,
145 | upcast_softmax=True,
146 | _from_deprecated_attn_block=True,
147 | dropout=0.0,
148 | )
149 | for _ in range(3)
150 | ]
151 | )
152 |
153 | self.mid_block = UNetMidBlock2D(
154 | in_channels=block_out_channels[-1],
155 | resnet_eps=1e-6,
156 | resnet_act_fn="silu",
157 | output_scale_factor=1,
158 | resnet_time_scale_shift="default",
159 | attention_head_dim=block_out_channels[-1],
160 | resnet_groups=norm_num_groups,
161 | temb_channels=None,
162 | add_attention=False,
163 | )
164 |
165 | self.attn_block_post = nn.ModuleList(
166 | [
167 | Attention(
168 | block_out_channels[-1],
169 | heads=1,
170 | dim_head=block_out_channels[-1],
171 | rescale_output_factor=1,
172 | eps=1e-6,
173 | norm_num_groups=32,
174 | spatial_norm_dim=None,
175 | residual_connection=True,
176 | processor=AttnProcessor2_0_modified(),
177 | bias=True,
178 | upcast_softmax=True,
179 | _from_deprecated_attn_block=True,
180 | dropout=0.0,
181 | )
182 | for _ in range(3)
183 | ]
184 | )
185 |
186 | # up
187 | self.up_blocks = nn.ModuleList([])
188 | reversed_block_out_channels = list(reversed(block_out_channels))
189 | output_channel = reversed_block_out_channels[0]
190 | for i, up_block_type in enumerate(up_block_types):
191 | prev_output_channel = output_channel
192 | output_channel = reversed_block_out_channels[i]
193 |
194 | is_final_block = i == len(block_out_channels) - 1
195 |
196 | up_block = get_up_block(
197 | up_block_type,
198 | num_layers=layers_per_block + 1,
199 | in_channels=prev_output_channel,
200 | out_channels=output_channel,
201 | prev_output_channel=None,
202 | add_upsample=not is_final_block,
203 | resnet_eps=1e-6,
204 | resnet_act_fn="silu",
205 | resnet_groups=norm_num_groups,
206 | attention_head_dim=output_channel,
207 | temb_channels=None,
208 | resnet_time_scale_shift="group",
209 | )
210 | self.up_blocks.append(up_block)
211 |
212 | self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
213 | self.conv_act = nn.SiLU()
214 | self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
215 |
216 | self.splatter = DecoderSplatting()
217 |
218 | def forward(
219 | self, latent_image, latent_depth, target_extrinsics, source_extrinsics, intrinsics, near, far, plucker_ray=None
220 | ):
221 | # calculate plucker rays
222 | lh, lw = latent_image.shape[-2:]
223 | if plucker_ray is None:
224 | ray_embeddings = get_plucker_rays(source_extrinsics, intrinsics, h=lh, w=lw, stride=8)
225 | else:
226 | ray_embeddings = plucker_ray
227 |
228 | if self.training:
229 | ray_embeddings += torch.randn_like(ray_embeddings) * 0.0001 # simulating some noises
230 |
231 | if latent_depth is not None:
232 | x = torch.cat([latent_image, latent_depth, ray_embeddings], dim=2) # channels: 16 + 16 + 9
233 | else:
234 | x = torch.cat([latent_image, ray_embeddings], dim=2)
235 |
236 | b, v, _, h, w = x.size()
237 | x = rearrange(x, "b v c h w -> (b v) c h w") # TODO: patchify?
238 | x = self.conv_in(x)
239 | x = self.gn(x)
240 |
241 | for attn_block in self.attn_block_pre:
242 | x = attn_block(x, num_views=v)
243 |
244 | x = self.mid_block(x)
245 |
246 | for attn_block in self.attn_block_post:
247 | x = attn_block(x, num_views=v)
248 |
249 | # up
250 | for up_block in self.up_blocks:
251 | x = up_block(x)
252 |
253 | x = self.conv_norm_out(x)
254 | x = self.conv_act(x)
255 | x = self.conv_out(x)
256 |
257 | x = rearrange(x, "(b v) c h w -> b v c h w", b=b, v=v) # (b v c_out h w)
258 |
259 | with torch.amp.autocast("cuda", enabled=False):
260 | if self.training:
261 | x_img, x_depth, b_loss, s_loss = self.splatter(
262 | x.float(), target_extrinsics, source_extrinsics, intrinsics, near, far
263 | )
264 | return x_img, repeat(x_depth, "b c h w -> b (c r) h w", r=3), b_loss, s_loss
265 | else:
266 | x = self.splatter.gaussian_adapter(
267 | x.float(), source_extrinsics, intrinsics, near, far
268 | ) # return encoded gaussian primitives
269 |
270 | return x
271 |
272 | def eval_rendering(self, means, covariance, opacity, rgb, target_extrinsics, target_intrinsics, near, far):
273 | # TODO: I should merge this eval rendering loop for this
274 | return self.splatter.eval_forward(
275 | means, covariance, opacity, rgb, target_extrinsics, target_intrinsics, near, far
276 | )
277 |
--------------------------------------------------------------------------------
/model/gsdecoder/load_gsdecoder.py:
--------------------------------------------------------------------------------
1 | from model.gsdecoder.gs_decoder_architecture import GSDecoder
2 |
3 |
4 | def create_gsdecoder(cfg):
5 | decoder = GSDecoder(cfg=cfg)
6 | return decoder
7 |
--------------------------------------------------------------------------------
/model/multiview_rf/load_mv_sd3.py:
--------------------------------------------------------------------------------
1 | from argparse import Namespace
2 |
3 | from transformers import CLIPTokenizer, PretrainedConfig, T5TokenizerFast
4 |
5 | from model.multiview_rf.mv_sd3_architecture import MultiViewSD3Transformer
6 |
7 |
8 | def import_model_class_from_model_name_or_path(
9 | pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
10 | ):
11 | text_encoder_config = PretrainedConfig.from_pretrained(
12 | pretrained_model_name_or_path,
13 | subfolder=subfolder,
14 | revision=revision,
15 | )
16 | model_class = text_encoder_config.architectures[0]
17 | if model_class == "CLIPTextModelWithProjection":
18 | from transformers import CLIPTextModelWithProjection
19 |
20 | return CLIPTextModelWithProjection
21 | elif model_class == "T5EncoderModel":
22 | from transformers import T5EncoderModel
23 |
24 | return T5EncoderModel
25 | else:
26 | raise ValueError(f"{model_class} is not supported.")
27 |
28 |
29 | # Load a sd multi-view rf model
30 | def create_sd_multiview_rf_model(num_input_output_channel=38):
31 | args = Namespace()
32 | args.pretrained_model_name_or_path = "stabilityai/stable-diffusion-3-medium-diffusers"
33 | args.revision = None
34 | args.variant = None
35 |
36 | # rf model
37 | rf_transformer = MultiViewSD3Transformer.from_pretrained(
38 | "stabilityai/stable-diffusion-3-medium-diffusers",
39 | subfolder="transformer",
40 | ignore_mismatched_sizes=True,
41 | strict=False,
42 | low_cpu_mem_usage=False,
43 | )
44 | rf_transformer.adjust_output_input_channel_size(num_input_output_channel)
45 |
46 | # tokenizers
47 | tokenizer_one = CLIPTokenizer.from_pretrained(
48 | args.pretrained_model_name_or_path,
49 | subfolder="tokenizer",
50 | revision=args.revision,
51 | low_cpu_mem_usage=True,
52 | )
53 | tokenizer_two = CLIPTokenizer.from_pretrained(
54 | args.pretrained_model_name_or_path,
55 | subfolder="tokenizer_2",
56 | revision=args.revision,
57 | low_cpu_mem_usage=True,
58 | )
59 | tokenizer_three = T5TokenizerFast.from_pretrained(
60 | args.pretrained_model_name_or_path,
61 | subfolder="tokenizer_3",
62 | revision=args.revision,
63 | low_cpu_mem_usage=True,
64 | )
65 |
66 | # Text encoders
67 | text_encoder_cls_one = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)
68 | text_encoder_cls_two = import_model_class_from_model_name_or_path(
69 | args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2"
70 | )
71 | text_encoder_cls_three = import_model_class_from_model_name_or_path(
72 | args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_3"
73 | )
74 |
75 | def load_text_encoders(class_one, class_two, class_three):
76 | text_encoder_one = class_one.from_pretrained(
77 | args.pretrained_model_name_or_path,
78 | subfolder="text_encoder",
79 | revision=args.revision,
80 | variant=args.variant,
81 | low_cpu_mem_usage=True,
82 | torch_dtype="auto",
83 | )
84 | text_encoder_two = class_two.from_pretrained(
85 | args.pretrained_model_name_or_path,
86 | subfolder="text_encoder_2",
87 | revision=args.revision,
88 | variant=args.variant,
89 | low_cpu_mem_usage=True,
90 | torch_dtype="auto",
91 | )
92 | text_encoder_three = class_three.from_pretrained(
93 | args.pretrained_model_name_or_path,
94 | subfolder="text_encoder_3",
95 | revision=args.revision,
96 | variant=args.variant,
97 | low_cpu_mem_usage=True,
98 | torch_dtype="auto",
99 | )
100 | return text_encoder_one, text_encoder_two, text_encoder_three
101 |
102 | text_encoder_one, text_encoder_two, text_encoder_three = load_text_encoders(
103 | text_encoder_cls_one, text_encoder_cls_two, text_encoder_cls_three
104 | )
105 | tokenizers = [tokenizer_one, tokenizer_two, tokenizer_three]
106 | text_encoders = [text_encoder_one, text_encoder_two, text_encoder_three]
107 |
108 | return rf_transformer, tokenizers, text_encoders
109 |
--------------------------------------------------------------------------------
/model/multiview_rf/mv_sd3_architecture.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict, List, Optional, Union
2 |
3 | import torch
4 | import torch.nn as nn
5 | from diffusers.configuration_utils import ConfigMixin, register_to_config
6 | from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
7 | from diffusers.models.attention import JointTransformerBlock, _chunked_feed_forward
8 | from diffusers.models.attention_processor import Attention, AttentionProcessor
9 | from diffusers.models.embeddings import (
10 | CombinedTimestepTextProjEmbeddings,
11 | PatchEmbed,
12 | get_3d_sincos_pos_embed,
13 | )
14 | from diffusers.models.modeling_outputs import Transformer2DModelOutput
15 | from diffusers.models.modeling_utils import ModelMixin
16 | from diffusers.models.normalization import AdaLayerNormContinuous
17 | from diffusers.utils import (
18 | USE_PEFT_BACKEND,
19 | is_torch_version,
20 | logging,
21 | scale_lora_layers,
22 | unscale_lora_layers,
23 | )
24 | from einops import rearrange
25 |
26 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name
27 |
28 |
29 | class JointTransformerBlockMultiView(JointTransformerBlock):
30 | def __init__(self, num_views, *args, **kwargs):
31 | super().__init__(*args, **kwargs)
32 | self.num_views = num_views
33 |
34 | def freeze(self):
35 | for param in self.parameters():
36 | param.requires_grad = False
37 |
38 | def unfreeze_view(self):
39 | for param in self.view_layernorm.parameters():
40 | param.requires_grad = True
41 | for param in self.view_adaln.parameters():
42 | param.requires_grad = True
43 | for param in self.view_attn.parameters():
44 | param.requires_grad = True
45 | for param in self.view_fc.parameters():
46 | param.requires_grad = True
47 |
48 | def forward(
49 | self,
50 | hidden_states: torch.FloatTensor,
51 | encoder_hidden_states: torch.FloatTensor,
52 | temb: torch.FloatTensor,
53 | ):
54 | norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
55 |
56 | # Normal block
57 | if self.context_pre_only:
58 | norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb)
59 | else:
60 | norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
61 | encoder_hidden_states, emb=temb
62 | )
63 |
64 | K, N, M = hidden_states.shape
65 | B = K // self.num_views
66 | norm_hidden_states = rearrange(norm_hidden_states, "(b v) n m -> b (v n) m", b=B, v=self.num_views, n=N, m=M)
67 | norm_encoder_hidden_states = rearrange(
68 | norm_encoder_hidden_states, "(b v) n m -> b (v n) m", b=B, v=self.num_views
69 | )
70 | # Attention.
71 | attn_output, context_attn_output = self.attn(
72 | hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states
73 | )
74 | attn_output = rearrange(attn_output, "b (v n) m -> (b v) n m", b=B, v=self.num_views)
75 | context_attn_output = rearrange(context_attn_output, "b (v n) m -> (b v) n m", b=B, v=self.num_views)
76 |
77 | # Process attention outputs for the `hidden_states`.
78 | attn_output = gate_msa.unsqueeze(1) * attn_output
79 | hidden_states = hidden_states + attn_output
80 | norm_hidden_states = self.norm2(hidden_states)
81 | norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
82 | if self._chunk_size is not None:
83 | # "feed_forward_chunk_size" can be used to save memory
84 | ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
85 | else:
86 | ff_output = self.ff(norm_hidden_states)
87 | ff_output = gate_mlp.unsqueeze(1) * ff_output
88 | hidden_states = hidden_states + ff_output
89 |
90 | # Process attention outputs for the `encoder_hidden_states`.
91 | if self.context_pre_only:
92 | encoder_hidden_states = None
93 | else:
94 | context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
95 | encoder_hidden_states = encoder_hidden_states + context_attn_output
96 |
97 | norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
98 | norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
99 | if self._chunk_size is not None:
100 | # "feed_forward_chunk_size" can be used to save memory
101 | context_ff_output = _chunked_feed_forward(
102 | self.ff_context, norm_encoder_hidden_states, self._chunk_dim, self._chunk_size
103 | )
104 | else:
105 | context_ff_output = self.ff_context(norm_encoder_hidden_states)
106 | encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
107 |
108 | return encoder_hidden_states, hidden_states
109 |
110 |
111 | class PatchEmbedMultiView(PatchEmbed):
112 | def __init__(
113 | self,
114 | height: int = 224,
115 | width: int = 224,
116 | patch_size=16,
117 | in_channels=3,
118 | embed_dim=768,
119 | layer_norm=False,
120 | flatten=True,
121 | bias=True,
122 | interpolation_scale=1,
123 | pos_embed_type="sincos",
124 | pos_embed_max_size=None, # For SD3 cropping
125 | num_views=8,
126 | ):
127 | super().__init__(
128 | height=height,
129 | width=width,
130 | patch_size=patch_size,
131 | in_channels=in_channels,
132 | embed_dim=embed_dim,
133 | layer_norm=layer_norm,
134 | flatten=flatten,
135 | bias=bias,
136 | interpolation_scale=interpolation_scale,
137 | pos_embed_type=pos_embed_type,
138 | pos_embed_max_size=pos_embed_max_size,
139 | )
140 | self.num_views = num_views
141 | pos_embed = get_3d_sincos_pos_embed(
142 | embed_dim=embed_dim,
143 | spatial_size=(16, 16),
144 | temporal_size=num_views,
145 | spatial_interpolation_scale=1.0,
146 | temporal_interpolation_scale=1.0,
147 | )
148 | self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0))
149 |
150 | def forward(self, latent):
151 | latent = self.proj(latent)
152 | if self.flatten:
153 | latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC
154 | if self.layer_norm:
155 | latent = self.norm(latent)
156 | if self.pos_embed is None:
157 | return latent.to(latent.dtype)
158 | pos_embed = self.pos_embed
159 | b = latent.shape[0] // self.num_views
160 | view = self.num_views
161 | latent = rearrange(latent, "(b v) n m -> b v n m", b=b, v=view)
162 | latent = latent + pos_embed.to(latent.dtype)
163 | latent = rearrange(latent, "b v n m -> (b v) n m", b=b, v=view)
164 | return latent
165 |
166 |
167 | class MultiViewSD3Transformer(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
168 | """
169 | The `MultiViewSD3Transformer` model is a multi-view extension of the `StableDiffusion3` model. The model is
170 | """
171 |
172 | _supports_gradient_checkpointing = True
173 |
174 | @register_to_config
175 | def __init__(
176 | self,
177 | sample_size: int = 128,
178 | patch_size: int = 2,
179 | in_channels: int = 16,
180 | num_layers: int = 18,
181 | attention_head_dim: int = 64,
182 | num_attention_heads: int = 18,
183 | joint_attention_dim: int = 4096,
184 | caption_projection_dim: int = 1152,
185 | pooled_projection_dim: int = 2048,
186 | out_channels: int = 16,
187 | pos_embed_max_size: int = 96,
188 | num_views: int = 8,
189 | ):
190 | super().__init__()
191 | default_out_channels = in_channels
192 | self.out_channels = out_channels if out_channels is not None else default_out_channels
193 | self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
194 |
195 | self.pos_embed = PatchEmbedMultiView(
196 | height=self.config.sample_size,
197 | width=self.config.sample_size,
198 | patch_size=self.config.patch_size,
199 | in_channels=16,
200 | embed_dim=self.inner_dim,
201 | pos_embed_max_size=pos_embed_max_size, # hard-code for now.
202 | num_views=num_views,
203 | )
204 | self.time_text_embed = CombinedTimestepTextProjEmbeddings(
205 | embedding_dim=self.inner_dim, pooled_projection_dim=self.config.pooled_projection_dim
206 | )
207 |
208 | self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.config.caption_projection_dim)
209 |
210 | # `attention_head_dim` is doubled to account for the mixing.
211 | # It needs to crafted when we get the actual checkpoints.
212 | self.transformer_blocks = nn.ModuleList(
213 | [
214 | JointTransformerBlockMultiView(
215 | dim=self.inner_dim,
216 | num_attention_heads=self.config.num_attention_heads,
217 | attention_head_dim=self.inner_dim,
218 | context_pre_only=i == num_layers - 1,
219 | num_views=num_views,
220 | )
221 | for i in range(self.config.num_layers)
222 | ]
223 | )
224 | self.patch_size = patch_size
225 | self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
226 | self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
227 | self.gradient_checkpointing = False
228 |
229 | def freeze(self):
230 | for param in self.parameters():
231 | param.requires_grad = False
232 |
233 | def unfreeze_view(self):
234 | for blk in self.transformer_blocks:
235 | for param in blk.parameters():
236 | param.requires_grad = True
237 |
238 | def adjust_output_input_channel_size(self, new_in_channels: int):
239 | self.config.in_channels = new_in_channels
240 | self.out_channels = new_in_channels
241 | old_conv_layer = self.pos_embed.proj
242 |
243 | # Calculate scaling factor
244 | scaling_factor_conv = (old_conv_layer.in_channels / new_in_channels) ** 0.5
245 |
246 | # Create a new convolutional layer with the desired number of input channels
247 | new_conv_layer = nn.Conv2d(
248 | in_channels=new_in_channels,
249 | out_channels=old_conv_layer.out_channels,
250 | kernel_size=old_conv_layer.kernel_size,
251 | stride=old_conv_layer.stride,
252 | padding=old_conv_layer.padding,
253 | dilation=old_conv_layer.dilation,
254 | groups=old_conv_layer.groups,
255 | bias=old_conv_layer.bias is not None,
256 | )
257 |
258 | with torch.no_grad():
259 | channels_to_copy = min(old_conv_layer.in_channels, new_in_channels)
260 |
261 | # Copy existing weights
262 | new_conv_layer.weight.data[:, :channels_to_copy, :, :] = old_conv_layer.weight.data[
263 | :, :channels_to_copy, :, :
264 | ]
265 |
266 | # Copy existing weights to new input channels via modulo indexing
267 | for i in range(channels_to_copy, new_in_channels):
268 | idx = i % old_conv_layer.in_channels
269 | new_conv_layer.weight.data[:, i : i + 1, :, :] = old_conv_layer.weight.data[:, idx : idx + 1, :, :]
270 |
271 | # Scale the weights
272 | new_conv_layer.weight.mul_(scaling_factor_conv)
273 |
274 | # Copy bias if it exists
275 | if old_conv_layer.bias is not None:
276 | new_conv_layer.bias.data = old_conv_layer.bias.data
277 |
278 | # Replace the old convolutional layer with the new one
279 | self.pos_embed.proj = new_conv_layer
280 |
281 | # Output layer modification
282 | old_linear_layer = self.proj_out # Get the original Linear layer
283 |
284 | # Calculate the new output features for the Linear layer
285 | new_out_features = self.patch_size * self.patch_size * new_in_channels
286 |
287 | # Calculate scaling factor for the linear layer
288 | scaling_factor_linear = (old_linear_layer.out_features / new_out_features) ** 0.5
289 |
290 | # Create a new Linear layer with the desired output channels
291 | new_linear_layer = nn.Linear(
292 | old_linear_layer.in_features, # Keep the input features the same
293 | new_out_features, # New number of output features
294 | bias=old_linear_layer.bias is not None,
295 | )
296 |
297 | with torch.no_grad():
298 | features_to_copy = min(old_linear_layer.out_features, new_out_features)
299 |
300 | # Copy existing weights
301 | new_linear_layer.weight.data[:features_to_copy, :] = old_linear_layer.weight.data[:features_to_copy, :]
302 |
303 | # Copy existing weights to new outputs via modulo indexing
304 | for i in range(features_to_copy, new_out_features):
305 | idx = i % old_linear_layer.out_features
306 | new_linear_layer.weight.data[i, :] = old_linear_layer.weight.data[idx, :]
307 |
308 | # Scale the weights
309 | new_linear_layer.weight.mul_(scaling_factor_linear)
310 |
311 | # Copy existing biases
312 | if old_linear_layer.bias is not None:
313 | new_linear_layer.bias.data[:features_to_copy] = old_linear_layer.bias.data[:features_to_copy]
314 |
315 | # Copy biases via modulo indexing for new outputs
316 | for i in range(features_to_copy, new_out_features):
317 | idx = i % old_linear_layer.out_features
318 | new_linear_layer.bias.data[i] = old_linear_layer.bias.data[idx]
319 |
320 | # Replace the old Linear layer with the new one
321 | self.proj_out = new_linear_layer
322 |
323 | # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
324 | def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
325 | """
326 | Sets the attention processor to use [feed forward
327 | chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
328 |
329 | Parameters:
330 | chunk_size (`int`, *optional*):
331 | The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
332 | over each tensor of dim=`dim`.
333 | dim (`int`, *optional*, defaults to `0`):
334 | The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
335 | or dim=1 (sequence length).
336 | """
337 | if dim not in [0, 1]:
338 | raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
339 |
340 | # By default chunk size is 1
341 | chunk_size = chunk_size or 1
342 |
343 | def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
344 | if hasattr(module, "set_chunk_feed_forward"):
345 | module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
346 |
347 | for child in module.children():
348 | fn_recursive_feed_forward(child, chunk_size, dim)
349 |
350 | for module in self.children():
351 | fn_recursive_feed_forward(module, chunk_size, dim)
352 |
353 | @property
354 | # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
355 | def attn_processors(self) -> Dict[str, AttentionProcessor]:
356 | r"""
357 | Returns:
358 | `dict` of attention processors: A dictionary containing all attention processors used in the model with
359 | indexed by its weight name.
360 | """
361 | # set recursively
362 | processors = {}
363 |
364 | def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
365 | if hasattr(module, "get_processor"):
366 | processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
367 |
368 | for sub_name, child in module.named_children():
369 | fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
370 |
371 | return processors
372 |
373 | for name, module in self.named_children():
374 | fn_recursive_add_processors(name, module, processors)
375 |
376 | return processors
377 |
378 | # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
379 | def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
380 | r"""
381 | Sets the attention processor to use to compute attention.
382 |
383 | Parameters:
384 | processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
385 | The instantiated processor class or a dictionary of processor classes that will be set as the processor
386 | for **all** `Attention` layers.
387 |
388 | If `processor` is a dict, the key needs to define the path to the corresponding cross attention
389 | processor. This is strongly recommended when setting trainable attention processors.
390 |
391 | """
392 | count = len(self.attn_processors.keys())
393 |
394 | if isinstance(processor, dict) and len(processor) != count:
395 | raise ValueError(
396 | f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
397 | f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
398 | )
399 |
400 | def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
401 | if hasattr(module, "set_processor"):
402 | if not isinstance(processor, dict):
403 | module.set_processor(processor)
404 | else:
405 | module.set_processor(processor.pop(f"{name}.processor"))
406 |
407 | for sub_name, child in module.named_children():
408 | fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
409 |
410 | for name, module in self.named_children():
411 | fn_recursive_attn_processor(name, module, processor)
412 |
413 | # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
414 | def fuse_qkv_projections(self):
415 | """
416 | Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
417 | are fused. For cross-attention modules, key and value projection matrices are fused.
418 |
419 |
420 |
421 | This API is 🧪 experimental.
422 |
423 |
424 | """
425 | self.original_attn_processors = None
426 |
427 | for _, attn_processor in self.attn_processors.items():
428 | if "Added" in str(attn_processor.__class__.__name__):
429 | raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
430 |
431 | self.original_attn_processors = self.attn_processors
432 |
433 | for module in self.modules():
434 | if isinstance(module, Attention):
435 | module.fuse_projections(fuse=True)
436 |
437 | # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
438 | def unfuse_qkv_projections(self):
439 | """Disables the fused QKV projection if enabled.
440 |
441 |
442 |
443 | This API is 🧪 experimental.
444 |
445 |
446 |
447 | """
448 | if self.original_attn_processors is not None:
449 | self.set_attn_processor(self.original_attn_processors)
450 |
451 | def _set_gradient_checkpointing(self, module, value=False):
452 | if hasattr(module, "gradient_checkpointing"):
453 | module.gradient_checkpointing = value
454 |
455 | def forward(
456 | self,
457 | hidden_states: torch.FloatTensor,
458 | encoder_hidden_states: torch.FloatTensor = None,
459 | pooled_projections: torch.FloatTensor = None,
460 | timestep: torch.LongTensor = None,
461 | block_controlnet_hidden_states: List = None,
462 | joint_attention_kwargs: Optional[Dict[str, Any]] = None,
463 | return_dict: bool = True,
464 | ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
465 | """
466 | The [`SD3Transformer2DModel`] forward method.
467 |
468 | Args:
469 | hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
470 | Input `hidden_states`.
471 | encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
472 | Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
473 | pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
474 | from the embeddings of input conditions.
475 | timestep ( `torch.LongTensor`):
476 | Used to indicate denoising step.
477 | block_controlnet_hidden_states: (`list` of `torch.Tensor`):
478 | A list of tensors that if specified are added to the residuals of transformer blocks.
479 | joint_attention_kwargs (`dict`, *optional*):
480 | A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
481 | `self.processor` in
482 | [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
483 | return_dict (`bool`, *optional*, defaults to `True`):
484 | Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
485 | tuple.
486 |
487 | Returns:
488 | If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
489 | `tuple` where the first element is the sample tensor.
490 | """
491 | if joint_attention_kwargs is not None:
492 | joint_attention_kwargs = joint_attention_kwargs.copy()
493 | lora_scale = joint_attention_kwargs.pop("scale", 1.0)
494 | else:
495 | lora_scale = 1.0
496 |
497 | if USE_PEFT_BACKEND:
498 | # weight the lora layers by setting `lora_scale` for each PEFT layer
499 | scale_lora_layers(self, lora_scale)
500 | else:
501 | if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
502 | logger.warning(
503 | "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
504 | )
505 |
506 | b, view, channel, height, width = hidden_states.shape
507 | # 2Dfy
508 | hidden_states = hidden_states.reshape(b * view, channel, height, width)
509 |
510 | hidden_states = self.pos_embed(hidden_states)
511 |
512 | temb = self.time_text_embed(timestep, pooled_projections)
513 | encoder_hidden_states = self.context_embedder(encoder_hidden_states)
514 |
515 | # Multi-view Patch
516 | temb = temb.unsqueeze(1).repeat(1, view, 1)
517 | encoder_hidden_states = encoder_hidden_states.unsqueeze(1).repeat(1, view, 1, 1)
518 |
519 | # 2dfy
520 | temb = temb.reshape(b * view, temb.shape[-1])
521 | encoder_hidden_states = encoder_hidden_states.reshape(
522 | b * view, encoder_hidden_states.shape[-2], encoder_hidden_states.shape[-1]
523 | )
524 |
525 | for index_block, block in enumerate(self.transformer_blocks):
526 | if self.training and self.gradient_checkpointing:
527 |
528 | def create_custom_forward(module, return_dict=None):
529 | def custom_forward(*inputs):
530 | if return_dict is not None:
531 | return module(*inputs, return_dict=return_dict)
532 | else:
533 | return module(*inputs)
534 |
535 | return custom_forward
536 |
537 | ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
538 | encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
539 | create_custom_forward(block),
540 | hidden_states,
541 | encoder_hidden_states,
542 | temb,
543 | **ckpt_kwargs,
544 | )
545 |
546 | else:
547 | encoder_hidden_states, hidden_states = block(
548 | hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb
549 | )
550 |
551 | # controlnet residual
552 | if block_controlnet_hidden_states is not None and block.context_pre_only is False:
553 | interval_control = len(self.transformer_blocks) // len(block_controlnet_hidden_states)
554 | hidden_states = hidden_states + block_controlnet_hidden_states[index_block // interval_control]
555 |
556 | hidden_states = self.norm_out(hidden_states, temb)
557 | hidden_states = self.proj_out(hidden_states)
558 |
559 | # unpatchify
560 | patch_size = self.config.patch_size
561 | height = height // patch_size
562 | width = width // patch_size
563 |
564 | hidden_states = hidden_states.reshape(
565 | shape=(hidden_states.shape[0], height, width, patch_size, patch_size, self.out_channels)
566 | )
567 | hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
568 | output = hidden_states.reshape(
569 | shape=(hidden_states.shape[0], self.out_channels, height * patch_size, width * patch_size)
570 | )
571 |
572 | output = output.reshape(b, view, self.out_channels, height * patch_size, width * patch_size)
573 | if USE_PEFT_BACKEND:
574 | # remove `lora_scale` from each PEFT layer
575 | unscale_lora_layers(self, lora_scale)
576 |
577 | if not return_dict:
578 | return (output,)
579 |
580 | return Transformer2DModelOutput(sample=output)
581 |
--------------------------------------------------------------------------------
/model/multiview_rf/text_embedding.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def tokenize_prompt(tokenizer, prompt):
5 | text_inputs = tokenizer(
6 | prompt,
7 | padding="max_length",
8 | max_length=77,
9 | truncation=True,
10 | return_tensors="pt",
11 | )
12 | text_input_ids = text_inputs.input_ids
13 | return text_input_ids
14 |
15 |
16 | def _encode_prompt_with_t5(
17 | text_encoder,
18 | tokenizer,
19 | max_sequence_length,
20 | prompt=None,
21 | num_images_per_prompt=1,
22 | device=None,
23 | ):
24 | prompt = [prompt] if isinstance(prompt, str) else prompt
25 | batch_size = len(prompt)
26 |
27 | text_inputs = tokenizer(
28 | prompt,
29 | padding="max_length",
30 | max_length=max_sequence_length,
31 | truncation=True,
32 | add_special_tokens=True,
33 | return_tensors="pt",
34 | )
35 | text_input_ids = text_inputs.input_ids
36 | prompt_embeds = text_encoder(text_input_ids.to(device))[0]
37 |
38 | dtype = text_encoder.dtype
39 | prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
40 |
41 | _, seq_len, _ = prompt_embeds.shape
42 |
43 | # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
44 | prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
45 | prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
46 |
47 | return prompt_embeds
48 |
49 |
50 | def _encode_prompt_with_clip(
51 | text_encoder,
52 | tokenizer,
53 | prompt: str,
54 | device=None,
55 | num_images_per_prompt: int = 1,
56 | ):
57 | prompt = [prompt] if isinstance(prompt, str) else prompt
58 | batch_size = len(prompt)
59 |
60 | text_inputs = tokenizer(
61 | prompt,
62 | padding="max_length",
63 | max_length=77,
64 | truncation=True,
65 | return_tensors="pt",
66 | )
67 |
68 | text_input_ids = text_inputs.input_ids
69 | prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
70 |
71 | pooled_prompt_embeds = prompt_embeds[0]
72 | prompt_embeds = prompt_embeds.hidden_states[-2]
73 | prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device)
74 |
75 | _, seq_len, _ = prompt_embeds.shape
76 |
77 | # duplicate text embeddings for each generation per prompt, using mps friendly method
78 | prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
79 | prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
80 |
81 | pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
82 | pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
83 |
84 | return prompt_embeds, pooled_prompt_embeds
85 |
86 |
87 | def encode_prompt(
88 | text_encoders,
89 | tokenizers,
90 | prompt: str,
91 | max_sequence_length,
92 | device=None,
93 | num_images_per_prompt: int = 1,
94 | ):
95 | prompt = [prompt] if isinstance(prompt, str) else prompt
96 |
97 | clip_tokenizers = tokenizers[:2]
98 | clip_text_encoders = text_encoders[:2]
99 |
100 | clip_prompt_embeds_list = []
101 | clip_pooled_prompt_embeds_list = []
102 | for tokenizer, text_encoder in zip(clip_tokenizers, clip_text_encoders):
103 | prompt_embeds, pooled_prompt_embeds = _encode_prompt_with_clip(
104 | text_encoder=text_encoder,
105 | tokenizer=tokenizer,
106 | prompt=prompt,
107 | device=device if device is not None else text_encoder.device,
108 | num_images_per_prompt=num_images_per_prompt,
109 | )
110 | clip_prompt_embeds_list.append(prompt_embeds)
111 | clip_pooled_prompt_embeds_list.append(pooled_prompt_embeds)
112 |
113 | clip_prompt_embeds = torch.cat(clip_prompt_embeds_list, dim=-1)
114 | pooled_prompt_embeds = torch.cat(clip_pooled_prompt_embeds_list, dim=-1)
115 |
116 | t5_prompt_embed = _encode_prompt_with_t5(
117 | text_encoders[-1],
118 | tokenizers[-1],
119 | max_sequence_length,
120 | prompt=prompt,
121 | num_images_per_prompt=num_images_per_prompt,
122 | device=device if device is not None else text_encoders[-1].device,
123 | )
124 |
125 | clip_prompt_embeds = torch.nn.functional.pad(
126 | clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1])
127 | )
128 | prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2)
129 |
130 | return prompt_embeds, pooled_prompt_embeds
131 |
132 |
133 | def compute_text_embeddings(prompt, text_encoders, tokenizers, max_sequence_length, device):
134 | with torch.no_grad():
135 | prompt_embeds, pooled_prompt_embeds = encode_prompt(text_encoders, tokenizers, prompt, max_sequence_length)
136 | prompt_embeds = prompt_embeds.to(device)
137 | pooled_prompt_embeds = pooled_prompt_embeds.to(device)
138 | return prompt_embeds, pooled_prompt_embeds
139 |
--------------------------------------------------------------------------------
/model/refiner/camera_util.py:
--------------------------------------------------------------------------------
1 | import math
2 | import os
3 | from io import BytesIO
4 |
5 | import imageio
6 | import matplotlib.pyplot as plt
7 | import numpy as np
8 | import torch
9 | import tqdm
10 | from einops import einsum, rearrange
11 | from plyfile import PlyData, PlyElement
12 | from scipy.spatial.transform import Rotation as R
13 | from torchvision.utils import save_image
14 |
15 |
16 | def inverse_sigmoid(x):
17 | return torch.log(x / (1 - x))
18 |
19 |
20 | def quaternion_to_matrix(quaternions):
21 | """
22 | Convert rotations given as quaternions to rotation matrices.
23 | Args:
24 | quaternions: quaternions with real part first,
25 | as tensor of shape (..., 4).
26 | Returns:
27 | Rotation matrices as tensor of shape (..., 3, 3).
28 | """
29 | r, i, j, k = torch.unbind(quaternions, -1)
30 | two_s = 2.0 / (quaternions * quaternions).sum(-1)
31 |
32 | o = torch.stack(
33 | (
34 | 1 - two_s * (j * j + k * k),
35 | two_s * (i * j - k * r),
36 | two_s * (i * k + j * r),
37 | two_s * (i * j + k * r),
38 | 1 - two_s * (i * i + k * k),
39 | two_s * (j * k - i * r),
40 | two_s * (i * k - j * r),
41 | two_s * (j * k + i * r),
42 | 1 - two_s * (i * i + j * j),
43 | ),
44 | -1,
45 | )
46 | return o.reshape(quaternions.shape[:-1] + (3, 3))
47 |
48 |
49 | def matrix_to_quaternion(M: torch.Tensor) -> torch.Tensor:
50 | """
51 | Matrix-to-quaternion conversion method. Equation taken from
52 | https://www.euclideanspace.com/maths/geometry/rotations/conversions/matrixToQuaternion/index.htm
53 | Args:
54 | M: rotation matrices, (... x 3 x 3)
55 | Returns:
56 | q: quaternion of shape (... x 4)
57 | """
58 | prefix_shape = M.shape[:-2]
59 | Ms = M.reshape(-1, 3, 3)
60 |
61 | trs = 1 + Ms[:, 0, 0] + Ms[:, 1, 1] + Ms[:, 2, 2]
62 |
63 | Qs = []
64 |
65 | for i in range(Ms.shape[0]):
66 | M = Ms[i]
67 | tr = trs[i]
68 | if tr > 0:
69 | r = torch.sqrt(tr) / 2.0
70 | x = (M[2, 1] - M[1, 2]) / (4 * r)
71 | y = (M[0, 2] - M[2, 0]) / (4 * r)
72 | z = (M[1, 0] - M[0, 1]) / (4 * r)
73 | elif (M[0, 0] > M[1, 1]) and (M[0, 0] > M[2, 2]):
74 | S = torch.sqrt(1.0 + M[0, 0] - M[1, 1] - M[2, 2]) * 2 # S=4*qx
75 | r = (M[2, 1] - M[1, 2]) / S
76 | x = 0.25 * S
77 | y = (M[0, 1] + M[1, 0]) / S
78 | z = (M[0, 2] + M[2, 0]) / S
79 | elif M[1, 1] > M[2, 2]:
80 | S = torch.sqrt(1.0 + M[1, 1] - M[0, 0] - M[2, 2]) * 2 # S=4*qy
81 | r = (M[0, 2] - M[2, 0]) / S
82 | x = (M[0, 1] + M[1, 0]) / S
83 | y = 0.25 * S
84 | z = (M[1, 2] + M[2, 1]) / S
85 | else:
86 | S = torch.sqrt(1.0 + M[2, 2] - M[0, 0] - M[1, 1]) * 2 # S=4*qz
87 | r = (M[1, 0] - M[0, 1]) / S
88 | x = (M[0, 2] + M[2, 0]) / S
89 | y = (M[1, 2] + M[2, 1]) / S
90 | z = 0.25 * S
91 | Q = torch.stack([r, x, y, z], dim=-1)
92 | Qs += [Q]
93 |
94 | return torch.stack(Qs, dim=0).reshape(*prefix_shape, 4)
95 |
96 |
97 | @torch.amp.autocast("cuda", enabled=False)
98 | def quaternion_slerp(q0, q1, fraction, spin: int = 0, shortestpath: bool = True):
99 | """Return spherical linear interpolation between two quaternions.
100 | Args:
101 | quat0: first quaternion
102 | quat1: second quaternion
103 | fraction: how much to interpolate between quat0 vs quat1 (if 0, closer to quat0; if 1, closer to quat1)
104 | spin: how much of an additional spin to place on the interpolation
105 | shortestpath: whether to return the short or long path to rotation
106 | """
107 | d = (q0 * q1).sum(-1)
108 | if shortestpath:
109 | # invert rotation
110 | d[d < 0.0] = -d[d < 0.0]
111 | q1[d < 0.0] = q1[d < 0.0]
112 |
113 | d = d.clamp(0, 1.0)
114 |
115 | angle = torch.acos(d) + spin * math.pi
116 | isin = 1.0 / (torch.sin(angle) + 1e-10)
117 | q0_ = q0 * torch.sin((1.0 - fraction) * angle) * isin
118 | q1_ = q1 * torch.sin(fraction * angle) * isin
119 |
120 | q = q0_ + q1_
121 | q[angle < 1e-5, :] = q0
122 |
123 | return q
124 |
125 |
126 | def sample_from_two_pose(pose_a, pose_b, fraction, noise_strengths=[0, 0]):
127 | """
128 | Args:
129 | pose_a: first pose
130 | pose_b: second pose
131 | fraction
132 | """
133 |
134 | quat_a = matrix_to_quaternion(pose_a[..., :3, :3])
135 | quat_b = matrix_to_quaternion(pose_b[..., :3, :3])
136 |
137 | quaternion = quaternion_slerp(quat_a, quat_b, fraction)
138 | quaternion = torch.nn.functional.normalize(quaternion + torch.randn_like(quaternion) * noise_strengths[0], dim=-1)
139 |
140 | R = quaternion_to_matrix(quaternion)
141 | T = (1 - fraction) * pose_a[..., :3, 3] + fraction * pose_b[..., :3, 3]
142 | T = T + torch.randn_like(T) * noise_strengths[1]
143 |
144 | new_pose = pose_a.clone()
145 | new_pose[..., :3, :3] = R
146 | new_pose[..., :3, 3] = T
147 | return new_pose
148 |
149 |
150 | def sample_from_dense_cameras(dense_cameras, t, noise_strengths=[0, 0, 0, 0]):
151 | _, N, A, B = dense_cameras.shape
152 | _, M = t.shape
153 |
154 | t = t.to(dense_cameras.device)
155 | left = torch.floor(t * (N - 1)).long().clamp(0, N - 2)
156 | right = left + 1
157 | fraction = t * (N - 1) - left
158 | a = torch.gather(dense_cameras, 1, left[..., None].repeat(1, 1, A, B))
159 | b = torch.gather(dense_cameras, 1, right[..., None].repeat(1, 1, A, B))
160 |
161 | new_pose = sample_from_two_pose(a[:, :, :3, 3:], b[:, :, :3, 3:], fraction, noise_strengths=noise_strengths[:2])
162 |
163 | new_ins = (1 - fraction) * a[:, :, :3, :3] + fraction * b[:, :, :3, :3]
164 |
165 | return torch.cat([new_ins, new_pose], dim=-1)
166 |
167 |
168 | def export_ply_for_gaussians(path, gaussians):
169 | xyz, features, opacity, scales, rotations = gaussians
170 |
171 | means3D = xyz.contiguous().float()
172 | opacity = opacity.contiguous().float()
173 | scales = scales.contiguous().float()
174 | rotations = rotations.contiguous().float()
175 | shs = features.contiguous().float() # [N, 1, 3]
176 |
177 | SH_C0 = 0.28209479177387814
178 | means3D, rotations, shs = adjust_gaussians(means3D, rotations, shs, SH_C0, inverse=False)
179 |
180 | opacity = inverse_sigmoid(opacity)
181 | scales = torch.log(scales + 1e-8)
182 |
183 | xyzs = means3D.detach().cpu().numpy()
184 | f_dc = shs.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()
185 | opacities = opacity.detach().cpu().numpy()
186 | scales = scales.detach().cpu().numpy()
187 | rotations = rotations.detach().cpu().numpy()
188 |
189 | l = ["x", "y", "z"] # noqa: E741
190 | # All channels except the 3 DC
191 | for i in range(f_dc.shape[1]):
192 | l.append("f_dc_{}".format(i))
193 | l.append("opacity")
194 | for i in range(scales.shape[1]):
195 | l.append("scale_{}".format(i))
196 | for i in range(rotations.shape[1]):
197 | l.append("rot_{}".format(i))
198 |
199 | dtype_full = [(attribute, "f4") for attribute in l]
200 |
201 | elements = np.empty(xyzs.shape[0], dtype=dtype_full)
202 | attributes = np.concatenate((xyzs, f_dc, opacities, scales, rotations), axis=1)
203 | elements[:] = list(map(tuple, attributes))
204 | el = PlyElement.describe(elements, "vertex")
205 |
206 | PlyData([el]).write(path + '.ply')
207 |
208 | plydata = PlyData([el])
209 |
210 | vert = plydata["vertex"]
211 | sorted_indices = np.argsort(
212 | -np.exp(vert["scale_0"] + vert["scale_1"] + vert["scale_2"]) / (1 + np.exp(-vert["opacity"]))
213 | )
214 | buffer = BytesIO()
215 | for idx in sorted_indices:
216 | v = plydata["vertex"][idx]
217 | position = np.array([v["x"], v["y"], v["z"]], dtype=np.float32)
218 | scales = np.exp(
219 | np.array(
220 | [v["scale_0"], v["scale_1"], v["scale_2"]],
221 | dtype=np.float32,
222 | )
223 | )
224 | rot = np.array(
225 | [v["rot_0"], v["rot_1"], v["rot_2"], v["rot_3"]],
226 | dtype=np.float32,
227 | )
228 | color = np.array(
229 | [
230 | 0.5 + SH_C0 * v["f_dc_0"],
231 | 0.5 + SH_C0 * v["f_dc_1"],
232 | 0.5 + SH_C0 * v["f_dc_2"],
233 | 1 / (1 + np.exp(-v["opacity"])),
234 | ]
235 | )
236 | buffer.write(position.tobytes())
237 | buffer.write(scales.tobytes())
238 | buffer.write((color * 255).clip(0, 255).astype(np.uint8).tobytes())
239 | buffer.write(((rot / np.linalg.norm(rot)) * 128 + 128).clip(0, 255).astype(np.uint8).tobytes())
240 |
241 | with open(path + '.splat', "wb") as f:
242 | f.write(buffer.getvalue())
243 |
244 |
245 | def load_ply_for_gaussians(path, device="cpu"):
246 | plydata = PlyData.read(path)
247 |
248 | xyz = np.stack(
249 | (
250 | np.asarray(plydata.elements[0]["x"]),
251 | np.asarray(plydata.elements[0]["y"]),
252 | np.asarray(plydata.elements[0]["z"]),
253 | ),
254 | axis=1,
255 | )
256 | opacities = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis]
257 |
258 | print("Number of points at loading : ", xyz.shape[0])
259 |
260 | features_dc = np.zeros((xyz.shape[0], 3, 1))
261 | features_dc[:, 0, 0] = np.asarray(plydata.elements[0]["f_dc_0"])
262 | features_dc[:, 1, 0] = np.asarray(plydata.elements[0]["f_dc_1"])
263 | features_dc[:, 2, 0] = np.asarray(plydata.elements[0]["f_dc_2"])
264 |
265 | scale_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("scale_")]
266 | scales = np.zeros((xyz.shape[0], len(scale_names)))
267 | for idx, attr_name in enumerate(scale_names):
268 | scales[:, idx] = np.asarray(plydata.elements[0][attr_name])
269 |
270 | rot_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("rot")]
271 | rots = np.zeros((xyz.shape[0], len(rot_names)))
272 | for idx, attr_name in enumerate(rot_names):
273 | rots[:, idx] = np.asarray(plydata.elements[0][attr_name])
274 |
275 | xyz = torch.tensor(xyz, dtype=torch.float, device=device)[None]
276 | features = torch.tensor(features_dc, dtype=torch.float, device=device).transpose(1, 2)[None]
277 | opacity = torch.tensor(opacities, dtype=torch.float, device=device)[None]
278 | scales = torch.tensor(scales, dtype=torch.float, device=device)[None]
279 | rotations = torch.tensor(rots, dtype=torch.float, device=device)[None]
280 |
281 | opacity = torch.sigmoid(opacity)
282 | scales = torch.exp(scales)
283 |
284 | SH_C0 = 0.28209479177387814
285 | xyz, rotations, features = adjust_gaussians(xyz, rotations, features, SH_C0, inverse=True)
286 |
287 | return xyz, features, opacity, scales, rotations
288 |
289 |
290 | def adjust_gaussians(means, rotations, shs, SH_C0, inverse):
291 | rot_adjust = torch.tensor(
292 | [
293 | [0, 0, 1],
294 | [-1, 0, 0],
295 | [0, -1, 0],
296 | ],
297 | dtype=torch.float32,
298 | device=means.device,
299 | )
300 |
301 | adjustment = torch.tensor(
302 | R.from_rotvec([0, 0, -45], True).as_matrix(),
303 | dtype=torch.float32,
304 | device=means.device,
305 | )
306 |
307 | rot_adjust = adjustment @ rot_adjust
308 |
309 | if inverse: # load: convert wxyz --> xyzw (rotation), convert shs to precomputed color
310 | rot_adjust = rot_adjust.inverse()
311 | means = einsum(rot_adjust, means, "i j, ... j -> ... i")
312 | rotations = R.from_quat(rotations[0].detach().cpu().numpy(), scalar_first=True).as_matrix()
313 | rotations = rot_adjust.detach().cpu().numpy() @ rotations
314 | rotations = R.from_matrix(rotations).as_quat()
315 | rotations = torch.from_numpy(rotations)[None].to(dtype=torch.float32, device=means.device)
316 | shs = 0.5 + shs * SH_C0
317 |
318 | else: # export: convert xyzw --> wxyz (rotation), convert precomputed color to shs
319 | means = einsum(rot_adjust, means, "i j, ... j -> ... i")
320 | rotations = R.from_quat(rotations.detach().cpu().numpy()).as_matrix()
321 | rotations = rot_adjust.detach().cpu().numpy() @ rotations
322 | rotations = R.from_matrix(rotations).as_quat()
323 | x, y, z, w = rearrange(rotations, "g xyzw -> xyzw g")
324 | rotations = torch.from_numpy(np.stack((w, x, y, z), axis=-1)).to(torch.float32)
325 | shs = (shs - 0.5) / SH_C0
326 |
327 | return means, rotations, shs
328 |
329 |
330 | @torch.no_grad()
331 | def export_video(render_fn, save_path, name, dense_cameras, fps=60, num_frames=720, size=512, device="cuda:0"):
332 | images = []
333 | depths = []
334 |
335 | for i in tqdm.trange(num_frames, desc="Rendering video..."):
336 | t = torch.full((1, 1), fill_value=i / num_frames, device=device)
337 |
338 | camera = sample_from_dense_cameras(dense_cameras, t)
339 |
340 | image, depth = render_fn(camera, size, size)
341 |
342 | images.append(process_image(image.reshape(3, size, size)))
343 | depths.append(process_image(depth.reshape(1, size, size)))
344 |
345 | imageio.mimwrite(os.path.join(save_path, f"{name}.mp4"), images, fps=fps, quality=8, macro_block_size=1)
346 |
347 |
348 | def process_image(image):
349 | return image.permute(1, 2, 0).detach().cpu().mul(1 / 2).add(1 / 2).clamp(0, 1).mul(255).numpy().astype(np.uint8)
350 |
351 |
352 | @torch.no_grad()
353 | def export_mv(render_fn, save_path, dense_cameras, size=256):
354 | num_views = dense_cameras.shape[1]
355 | imgs = []
356 | for i in tqdm.trange(num_views, desc="Rendering images..."):
357 | image = render_fn(dense_cameras[:, i].unsqueeze(1), size, size)[0]
358 | path = os.path.join(save_path, "mv_results")
359 | os.makedirs(path, exist_ok=True)
360 |
361 | path = os.path.join(path, f"refined_render_img_{i}.png")
362 | image = image.reshape(3, size, size).clamp(-1, 1).add(1).mul(1 / 2)
363 | imgs.append(image)
364 | save_image(image, path)
365 |
366 | cmap = plt.get_cmap("hsv")
367 | num_frames = 8
368 | num_rows = 2
369 | num_cols = 4
370 | figsize = (num_cols * 2, num_rows * 2)
371 | fig, axs = plt.subplots(num_rows, num_cols, figsize=figsize)
372 | axs = axs.flatten()
373 | for i in range(num_rows * num_cols):
374 | if i < num_frames:
375 | axs[i].imshow((imgs[i].cpu().numpy().transpose(1, 2, 0) * 255.0).astype(np.uint8))
376 | for s in ["bottom", "top", "left", "right"]:
377 | axs[i].spines[s].set_color(cmap(i / (num_frames)))
378 | axs[i].spines[s].set_linewidth(5)
379 | axs[i].set_xticks([])
380 | axs[i].set_yticks([])
381 | else:
382 | axs[i].axis("off")
383 | plt.tight_layout()
384 | plt.savefig(os.path.join(save_path, "refined_mv_images.pdf"), transparent=True)
385 |
--------------------------------------------------------------------------------
/model/refiner/gs_util.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from diff_gaussian_rasterization import (
4 | GaussianRasterizationSettings,
5 | GaussianRasterizer,
6 | )
7 | from einops import rearrange
8 |
9 | from model.gsdecoder.cuda_splatting import get_fov, get_projection_matrix
10 |
11 |
12 | def inverse_sigmoid(x):
13 | return torch.log(x / (1 - x))
14 |
15 |
16 | def inverse_softplus(x, beta=1):
17 | return (torch.exp(beta * x) - 1).log() / beta
18 |
19 |
20 | def build_rotation(r): # Note that we follow xyzw format.
21 | norm = torch.sqrt(r[:, 0] * r[:, 0] + r[:, 1] * r[:, 1] + r[:, 2] * r[:, 2] + r[:, 3] * r[:, 3])
22 |
23 | q = r / norm[:, None]
24 |
25 | R = torch.zeros((q.size(0), 3, 3), device=r.device)
26 |
27 | x = q[:, 0]
28 | y = q[:, 1]
29 | z = q[:, 2]
30 | r = q[:, 3]
31 |
32 | R[:, 0, 0] = 1 - 2 * (y * y + z * z)
33 | R[:, 0, 1] = 2 * (x * y - r * z)
34 | R[:, 0, 2] = 2 * (x * z + r * y)
35 | R[:, 1, 0] = 2 * (x * y + r * z)
36 | R[:, 1, 1] = 1 - 2 * (x * x + z * z)
37 | R[:, 1, 2] = 2 * (y * z - r * x)
38 | R[:, 2, 0] = 2 * (x * z - r * y)
39 | R[:, 2, 1] = 2 * (y * z + r * x)
40 | R[:, 2, 2] = 1 - 2 * (x * x + y * y)
41 | return R
42 |
43 |
44 | def rgb2shs(rgb):
45 | SH_C0 = 0.28209479177387814
46 | return (rgb - 0.5) / SH_C0
47 |
48 |
49 | def shstorgb(shs):
50 | SH_C0 = 0.28209479177387814
51 | return (shs * SH_C0 + 0.5).clamp(min=0)
52 |
53 |
54 | class GaussiansManeger:
55 | def __init__(self, xyz, features, opacity, scales, rotations, lrs):
56 | self._xyz = nn.Parameter(xyz.squeeze(0).contiguous().float().detach().clone().requires_grad_(True))
57 | self._features = nn.Parameter(
58 | rgb2shs(features).squeeze(0).contiguous().float().detach().clone().requires_grad_(True)
59 | )
60 | self._opacity = nn.Parameter(
61 | inverse_sigmoid(opacity.squeeze(0)).contiguous().float().detach().clone().requires_grad_(True)
62 | )
63 | self._scales = nn.Parameter(
64 | torch.log(scales.squeeze(0) + 1e-8).contiguous().float().detach().clone().requires_grad_(True)
65 | )
66 | self._rotations = nn.Parameter(rotations.squeeze(0).contiguous().float().detach().clone().requires_grad_(True))
67 |
68 | self.device = self._xyz.device
69 |
70 | self.optimizer = torch.optim.Adam(
71 | [
72 | {"name": "xyz", "params": [self._xyz], "lr": lrs["xyz"]},
73 | {"name": "features", "params": [self._features], "lr": lrs["features"]},
74 | {"name": "opacity", "params": [self._opacity], "lr": lrs["opacity"]},
75 | {"name": "scales", "params": [self._scales], "lr": lrs["scales"]},
76 | {"name": "rotations", "params": [self._rotations], "lr": lrs["rotations"]},
77 | ],
78 | betas=(0.9, 0.99),
79 | lr=0.0,
80 | eps=1e-15,
81 | )
82 |
83 | self.xyz_gradient_accum = torch.zeros((self._xyz.shape[0], 1), device=self.device)
84 | self.denom = torch.zeros((self._xyz.shape[0], 1), device=self.device)
85 | self.max_radii2D = torch.zeros((self._xyz.shape[0],), device=self.device)
86 | self.is_visible = torch.zeros((self._xyz.shape[0],), device=self.device)
87 |
88 | self.percent_dense = 0.003
89 |
90 | def __call__(self):
91 | xyz = self._xyz
92 | features = shstorgb(self._features)
93 | opacity = torch.sigmoid(self._opacity)
94 | scales = torch.exp(self._scales)
95 | rotations = torch.nn.functional.normalize(self._rotations, dim=-1)
96 | return (
97 | xyz.unsqueeze(0),
98 | features.unsqueeze(0),
99 | opacity.unsqueeze(0),
100 | scales.unsqueeze(0),
101 | rotations.unsqueeze(0),
102 | )
103 |
104 | @torch.no_grad()
105 | def densify_and_prune(self, max_grad=4, extent=2, opacity_threshold=0.001):
106 | grads = self.xyz_gradient_accum / self.denom
107 | grads[grads.isnan()] = 0.0
108 |
109 | self.densify_and_clone(grads, max_grad, scene_extent=extent)
110 | self.densify_and_split(grads, max_grad, scene_extent=extent)
111 |
112 | prune_mask = (torch.sigmoid(self._opacity) < opacity_threshold).squeeze()
113 | self.prune_points(prune_mask)
114 | torch.cuda.empty_cache()
115 |
116 | def prune_points(self, mask):
117 | valid_points_mask = ~mask
118 |
119 | optimizable_tensors = self.prune_optimizer(valid_points_mask)
120 | self._xyz = optimizable_tensors["xyz"]
121 | self._features = optimizable_tensors["features"]
122 | self._opacity = optimizable_tensors["opacity"]
123 | self._scales = optimizable_tensors["scales"]
124 | self._rotations = optimizable_tensors["rotations"]
125 |
126 | self.xyz_gradient_accum = self.xyz_gradient_accum[valid_points_mask]
127 | self.denom = self.denom[valid_points_mask]
128 | self.max_radii2D = self.max_radii2D[valid_points_mask]
129 | self.is_visible = self.is_visible[valid_points_mask]
130 |
131 | def prune_optimizer(self, mask):
132 | optimizable_tensors = {}
133 | for group in self.optimizer.param_groups:
134 | stored_state = self.optimizer.state.get(group["params"][0], None)
135 | if stored_state is not None:
136 | stored_state["exp_avg"] = stored_state["exp_avg"][mask]
137 | stored_state["exp_avg_sq"] = stored_state["exp_avg_sq"][mask]
138 |
139 | del self.optimizer.state[group["params"][0]]
140 | group["params"][0] = nn.Parameter((group["params"][0][mask].requires_grad_(True)))
141 | self.optimizer.state[group["params"][0]] = stored_state
142 |
143 | optimizable_tensors[group["name"]] = group["params"][0]
144 | else:
145 | group["params"][0] = nn.Parameter(group["params"][0][mask].requires_grad_(True))
146 | optimizable_tensors[group["name"]] = group["params"][0]
147 | return optimizable_tensors
148 |
149 | def add_points(self, params):
150 | num_points = params["xyz"].shape[0]
151 |
152 | optimizable_tensors = self.cat_tensors_to_optimizer(params)
153 | self._xyz = optimizable_tensors["xyz"]
154 | self._features = optimizable_tensors["features"]
155 | self._opacity = optimizable_tensors["opacity"]
156 | self._scales = optimizable_tensors["scales"]
157 | self._rotations = optimizable_tensors["rotations"]
158 |
159 | self.xyz_gradient_accum = torch.cat([self.xyz_gradient_accum, torch.zeros((num_points, 1), device=self.device)])
160 | self.denom = torch.cat([self.denom, torch.zeros((num_points, 1), device=self.device)])
161 | self.max_radii2D = torch.cat([self.max_radii2D, torch.zeros((num_points,), device=self.device)])
162 | self.is_visible = torch.cat([self.is_visible, torch.zeros((num_points,), device=self.device)])
163 |
164 | def cat_tensors_to_optimizer(self, tensors_dict):
165 | optimizable_tensors = {}
166 | for group in self.optimizer.param_groups:
167 | assert len(group["params"]) == 1
168 | extension_tensor = tensors_dict[group["name"]]
169 | stored_state = self.optimizer.state.get(group["params"][0], None)
170 | if stored_state is not None:
171 | stored_state["exp_avg"] = torch.cat(
172 | (stored_state["exp_avg"], torch.zeros_like(extension_tensor)), dim=0
173 | )
174 | stored_state["exp_avg_sq"] = torch.cat(
175 | (stored_state["exp_avg_sq"], torch.zeros_like(extension_tensor)), dim=0
176 | )
177 |
178 | del self.optimizer.state[group["params"][0]]
179 | group["params"][0] = nn.Parameter(
180 | torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True)
181 | )
182 | self.optimizer.state[group["params"][0]] = stored_state
183 |
184 | optimizable_tensors[group["name"]] = group["params"][0]
185 | else:
186 | group["params"][0] = nn.Parameter(
187 | torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True)
188 | )
189 | optimizable_tensors[group["name"]] = group["params"][0]
190 |
191 | return optimizable_tensors
192 |
193 | def densify_and_split(self, grads, grad_threshold=0.04, scene_extent=2, N=2):
194 | n_init_points = self._xyz.shape[0]
195 |
196 | # Extract points that satisfy the gradient condition
197 | padded_grad = torch.zeros((n_init_points), device=self.device)
198 | padded_grad[: grads.shape[0]] = grads.squeeze()
199 | selected_pts_mask = torch.where(padded_grad >= grad_threshold, True, False)
200 | selected_pts_mask = torch.logical_and(
201 | selected_pts_mask, torch.max(torch.exp(self._scales), dim=1).values > self.percent_dense * scene_extent
202 | )
203 |
204 | stds = torch.exp(self._scales[selected_pts_mask]).repeat(N, 1)
205 | means = torch.zeros((stds.size(0), 3), device=self.device)
206 | samples = torch.normal(mean=means, std=stds)
207 |
208 | rots = build_rotation(self._rotations[selected_pts_mask]).repeat(N, 1, 1)
209 |
210 | new_xyz = torch.bmm(rots, samples.unsqueeze(-1)).squeeze(-1) + self._xyz[selected_pts_mask].repeat(N, 1)
211 | new_scales = torch.log(torch.exp(self._scales[selected_pts_mask]) / (0.8 * N)).repeat(N, 1)
212 | new_rotations = self._rotations[selected_pts_mask].repeat(N, 1)
213 | new_features = self._features[selected_pts_mask].repeat(N, 1, 1)
214 | new_opacity = self._opacity[selected_pts_mask].repeat(N, 1)
215 |
216 | params = {
217 | "xyz": new_xyz,
218 | "features": new_features,
219 | "opacity": new_opacity,
220 | "scales": new_scales,
221 | "rotations": new_rotations,
222 | }
223 |
224 | self.add_points(params)
225 |
226 | prune_filter = torch.cat(
227 | (selected_pts_mask, torch.zeros(N * selected_pts_mask.sum(), device=self.device, dtype=bool))
228 | )
229 | self.prune_points(prune_filter)
230 |
231 | def densify_and_clone(self, grads, grad_threshold=0.02, scene_extent=2):
232 | # Extract points that satisfy the gradient condition
233 | selected_pts_mask = torch.where(torch.norm(grads, dim=-1) >= grad_threshold, True, False)
234 | selected_pts_mask = torch.logical_and(
235 | selected_pts_mask, torch.max(torch.exp(self._scales), dim=1).values <= self.percent_dense * scene_extent
236 | )
237 |
238 | new_xyz = self._xyz[selected_pts_mask]
239 | new_features = self._features[selected_pts_mask]
240 | new_opacity = self._opacity[selected_pts_mask]
241 | new_scales = self._scales[selected_pts_mask]
242 | new_rotations = self._rotations[selected_pts_mask]
243 |
244 | params = {
245 | "xyz": new_xyz,
246 | "features": new_features,
247 | "opacity": new_opacity,
248 | "scales": new_scales,
249 | "rotations": new_rotations,
250 | }
251 |
252 | self.add_points(params)
253 |
254 | def add_densification_stats(self, viewspace_point_tensor, update_filter):
255 | self.xyz_gradient_accum[update_filter] += torch.norm(
256 | viewspace_point_tensor.grad[update_filter, :2], dim=-1, keepdim=True
257 | )
258 | self.denom[update_filter] += 1
259 |
260 |
261 | class GaussianRenderer(nn.Module):
262 | def __init__(self, h, w):
263 | super().__init__()
264 |
265 | self.h = h
266 | self.w = w
267 | self.near = 0.001
268 | self.far = 100
269 |
270 | def get_viewpoint_cameras(self, cameras):
271 | device = cameras.device
272 | K, Rt = cameras.split([3, 4], dim=-1)
273 |
274 | normalized_K = K.clone()
275 | normalized_K[0] *= 1 / (normalized_K[0, 2] * 2) # as cx cy = h/2, w/2
276 | normalized_K[1] *= 1 / (normalized_K[1, 2] * 2) # as cx cy = h/2, w/2
277 |
278 | fov_x, fov_y = get_fov(normalized_K[None]).unbind(dim=-1)
279 | tan_fov_x = (0.5 * fov_x).tan()
280 | tan_fov_y = (0.5 * fov_y).tan()
281 |
282 | projection_matrix = get_projection_matrix(self.near, self.far, fov_x, fov_y)
283 | projection_matrix = rearrange(projection_matrix, "b i j -> b j i")
284 |
285 | bottom = torch.tensor([0, 0, 0, 1], dtype=Rt.dtype, device=device)
286 | homo_poses = torch.cat([Rt, bottom[None]], dim=0)
287 |
288 | view_matrix = rearrange(homo_poses.inverse(), "i j -> () j i")
289 | full_projection = view_matrix @ projection_matrix
290 |
291 | return tan_fov_x, tan_fov_y, view_matrix, full_projection, Rt[:3, 3]
292 |
293 | @torch.amp.autocast("cuda", enabled=False)
294 | def forward(self, cameras, gaussians, h, w, bg="random"):
295 | B, N = cameras.shape[:2]
296 | xyz, features, opacity, scales, rotations = gaussians
297 |
298 | self.radii = []
299 | self.viewspace_points = []
300 |
301 | images = []
302 | depths = []
303 | masks = []
304 |
305 | bg_color = torch.tensor([1, 1, 1], device=cameras.device).float()
306 |
307 | if bg == "random":
308 | bg_color = torch.rand_like(bg_color)
309 |
310 | for i in range(B):
311 | for j in range(N):
312 | tan_fov_x, tan_fov_y, view_matrix, full_projection, campos = self.get_viewpoint_cameras(cameras[i, j])
313 |
314 | mean_gradients = torch.zeros_like(xyz[i], requires_grad=True)
315 |
316 | try:
317 | mean_gradients.retain_grad()
318 | except Exception:
319 | pass
320 |
321 | settings = GaussianRasterizationSettings(
322 | image_height=h,
323 | image_width=w,
324 | tanfovx=tan_fov_x.item(),
325 | tanfovy=tan_fov_y.item(),
326 | bg=bg_color,
327 | scale_modifier=1.0,
328 | viewmatrix=view_matrix,
329 | projmatrix=full_projection,
330 | sh_degree=0,
331 | campos=campos,
332 | prefiltered=False,
333 | debug=False,
334 | )
335 |
336 | rasterizer = GaussianRasterizer(settings)
337 |
338 | image, radii, depth, mask = rasterizer(
339 | means3D=xyz[i],
340 | means2D=mean_gradients,
341 | shs=None,
342 | colors_precomp=features[i].squeeze(1),
343 | opacities=opacity[i],
344 | scales=scales[i],
345 | rotations=rotations[i][:, [3, 0, 1, 2]],
346 | cov3D_precomp=None,
347 | )
348 |
349 | rendered_image = image.clamp(0, 1)
350 | rendered_mask = mask.clamp(0, 1)
351 | rendered_depth = depth + self.far * (1 - rendered_mask)
352 |
353 | images.append(rendered_image)
354 | depths.append(rendered_depth)
355 | masks.append(rendered_mask)
356 |
357 | self.radii.append(radii)
358 | self.viewspace_points.append(mean_gradients)
359 |
360 | images = torch.stack(images, dim=0).unflatten(0, (B, N)) * 2 - 1
361 | depths = torch.stack(depths, dim=0).unflatten(0, (B, N))
362 | masks = torch.stack(masks, dim=0).unflatten(0, (B, N))
363 |
364 | return images, depths, masks
365 |
--------------------------------------------------------------------------------
/model/refiner/sds_pp_refiner.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import tqdm
5 | from diffusers import DDIMScheduler, StableDiffusionPipeline
6 |
7 | from model.refiner.camera_util import sample_from_dense_cameras
8 | from model.refiner.gs_util import GaussianRenderer, GaussiansManeger
9 |
10 |
11 | class GSRefinerSDSPlusPlus(nn.Module):
12 | def __init__(
13 | self,
14 | sd_model_key="stabilityai/stable-diffusion-2-1-base",
15 | num_views=1,
16 | total_iterations=500,
17 | guidance_scale=100,
18 | min_step_percent=0.02,
19 | max_step_percent=0.75,
20 | lr_scale=0.25,
21 | lrs={"xyz": 0.0001, "features": 0.01, "opacity": 0.05, "scales": 0.01, "rotations": 0.01},
22 | use_lods=True,
23 | lambda_latent_sds=1,
24 | lambda_image_sds=0.1,
25 | lambda_image_variation=0.001,
26 | opacity_threshold=0.001,
27 | img_size=512,
28 | num_densifications=4,
29 | text_templete="$text$",
30 | negative_text_templete="",
31 | ):
32 | super().__init__()
33 |
34 | pipe = StableDiffusionPipeline.from_pretrained(sd_model_key)
35 |
36 | self.tokenizer = pipe.tokenizer
37 | self.text_encoder = pipe.text_encoder.requires_grad_(False)
38 | self.vae = pipe.vae.requires_grad_(False)
39 | self.unet = pipe.unet.requires_grad_(False)
40 |
41 | self.scheduler = DDIMScheduler.from_pretrained(sd_model_key, subfolder="scheduler", local_files_only=True)
42 |
43 | del pipe
44 |
45 | self.num_views = num_views
46 | self.total_iterations = total_iterations
47 | self.guidance_scale = guidance_scale
48 | self.lrs = {key: value * lr_scale for key, value in lrs.items()}
49 |
50 | self.register_buffer("alphas_cumprod", self.scheduler.alphas_cumprod, persistent=False)
51 |
52 | self.device = "cpu"
53 |
54 | self.num_train_timesteps = self.scheduler.config.num_train_timesteps
55 |
56 | self.set_min_max_steps(min_step_percent, max_step_percent)
57 |
58 | self.renderer = GaussianRenderer(img_size, img_size)
59 |
60 | self.text_templete = text_templete
61 | self.negative_text_templete = negative_text_templete
62 |
63 | self.use_lods = use_lods
64 |
65 | self.lambda_latent_sds = lambda_latent_sds
66 | self.lambda_image_sds = lambda_image_sds
67 | self.lambda_image_variation = lambda_image_variation
68 |
69 | self.img_size = img_size
70 |
71 | self.opacity_threshold = opacity_threshold
72 |
73 | self.densification_interval = self.total_iterations // (num_densifications + 1)
74 |
75 | def set_min_max_steps(self, min_step_percent=0.02, max_step_percent=0.98):
76 | self.min_step = int(self.num_train_timesteps * min_step_percent)
77 | self.max_step = int(self.num_train_timesteps * max_step_percent)
78 |
79 | def to(self, device):
80 | self.device = device
81 | return super().to(device)
82 |
83 | @torch.no_grad()
84 | def encode_text(self, texts):
85 | inputs = self.tokenizer(
86 | texts,
87 | padding="max_length",
88 | truncation_strategy="longest_first",
89 | max_length=self.tokenizer.model_max_length,
90 | return_tensors="pt",
91 | )
92 | text_embeddings = self.text_encoder(inputs.input_ids.to(next(self.text_encoder.parameters()).device))[0]
93 | return text_embeddings
94 |
95 | # @torch.amp.autocast("cuda", enabled=False)
96 | def encode_image(self, images):
97 | posterior = self.vae.encode(images).latent_dist
98 | latents = posterior.sample() * self.vae.config.scaling_factor
99 | return latents
100 |
101 | # @torch.amp.autocast("cuda", enabled=False)
102 | def decode_latent(self, latents):
103 | latents = 1 / self.vae.config.scaling_factor * latents
104 | images = self.vae.decode(latents).sample
105 | return images
106 |
107 | def train_step(
108 | self,
109 | images,
110 | t,
111 | text_embeddings,
112 | uncond_text_embeddings,
113 | learnable_text_embeddings,
114 | ):
115 | latents = self.encode_image(images)
116 |
117 | with torch.no_grad():
118 | B = latents.shape[0]
119 | t = t.repeat(self.num_views)
120 |
121 | noise = torch.randn_like(latents)
122 | latents_noisy = self.scheduler.add_noise(latents, noise, t)
123 |
124 | if self.use_lods:
125 | with torch.enable_grad():
126 | noise_pred_learnable = self.unet(
127 | latents_noisy, t, encoder_hidden_states=learnable_text_embeddings
128 | ).sample
129 |
130 | loss_embedding = F.mse_loss(noise_pred_learnable, noise, reduction="mean")
131 | else:
132 | noise_pred_learnable = noise
133 | loss_embedding = 0
134 |
135 | with torch.no_grad():
136 | noise_pred = self.unet(
137 | torch.cat([latents_noisy, latents_noisy], 0),
138 | torch.cat([t, t], 0),
139 | encoder_hidden_states=torch.cat([text_embeddings, uncond_text_embeddings], 0),
140 | ).sample
141 |
142 | noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2, dim=0)
143 | noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
144 |
145 | w = (1 - self.alphas_cumprod[t]).view(-1, 1, 1, 1)
146 |
147 | alpha = self.alphas_cumprod[t].view(-1, 1, 1, 1) ** 0.5
148 | sigma = (1 - self.alphas_cumprod[t].view(-1, 1, 1, 1)) ** 0.5
149 |
150 | latents_pred = (latents_noisy - sigma * (noise_pred - noise_pred_learnable + noise)) / alpha
151 | images_pred = self.decode_latent(latents_pred).clamp(-1, 1)
152 |
153 | loss_latent_sds = (
154 | F.mse_loss(latents, latents_pred, reduction="none").sum([1, 2, 3]) * w * alpha / sigma
155 | ).sum() / B
156 | loss_image_sds = (
157 | F.mse_loss(images, images_pred, reduction="none").sum([1, 2, 3]) * w * alpha / sigma
158 | ).sum() / B
159 |
160 | return loss_latent_sds, loss_image_sds, loss_embedding
161 |
162 | @torch.amp.autocast("cuda", enabled=True)
163 | @torch.enable_grad()
164 | def refine_gaussians(self, gaussians, text, dense_cameras):
165 | gaussians_original = gaussians
166 | xyz, features, opacity, scales, rotations = gaussians
167 |
168 | mask = opacity[..., 0] >= self.opacity_threshold
169 | xyz_original = xyz[mask][None]
170 | features_original = features[mask][None]
171 | opacity_original = opacity[mask][None]
172 | scales_original = scales[mask][None]
173 | rotations_original = rotations[mask][None]
174 |
175 | text = self.text_templete.replace("$text$", text)
176 |
177 | text_embeddings = self.encode_text([text])
178 | uncond_text_embeddings = self.encode_text([self.negative_text_templete.replace("$text$", text)])
179 |
180 | class LearnableTextEmbeddings(nn.Module):
181 | def __init__(self, uncond_text_embeddings):
182 | super().__init__()
183 | self.embeddings = nn.Parameter(torch.zeros_like(uncond_text_embeddings.float().detach().clone()))
184 | self.to(self.embeddings.device)
185 |
186 | def forward(self, cameras):
187 | B = cameras.shape[1]
188 | return self.embeddings.repeat(B, 1, 1)
189 |
190 | _learnable_text_embeddings = LearnableTextEmbeddings(uncond_text_embeddings)
191 |
192 | text_embeddings = text_embeddings.repeat(self.num_views, 1, 1)
193 | uncond_text_embeddings = uncond_text_embeddings.repeat(self.num_views, 1, 1)
194 |
195 | new_gaussians = GaussiansManeger(
196 | xyz_original, features_original, opacity_original, scales_original, rotations_original, self.lrs
197 | )
198 |
199 | optimizer_embeddings = torch.optim.Adam(_learnable_text_embeddings.parameters(), lr=self.lrs["embeddings"])
200 | for i in tqdm.trange(self.total_iterations, desc="Refining..."):
201 | if i % self.densification_interval == 0 and i != 0:
202 | new_gaussians.densify_and_prune(opacity_threshold=self.opacity_threshold)
203 |
204 | with torch.amp.autocast("cuda", enabled=False):
205 | cameras = sample_from_dense_cameras(dense_cameras, torch.rand(1, self.num_views).to(self.device))
206 |
207 | learnable_text_embeddings = _learnable_text_embeddings(cameras)
208 |
209 | with torch.no_grad():
210 | images_original = self.renderer(cameras, gaussians_original, h=self.img_size, w=self.img_size)[0]
211 |
212 | gaussians = new_gaussians()
213 | images_pred= self.renderer(cameras, gaussians, h=self.img_size, w=self.img_size)[0]
214 |
215 | anneal_t = int((i / self.total_iterations) ** (1 / 2) * (self.min_step - self.max_step) + self.max_step)
216 | t = torch.full((1,), anneal_t, dtype=torch.long, device=self.device)
217 |
218 | loss_latent_sds, loss_img_sds, loss_embedding = self.train_step(
219 | images_pred.squeeze(0), t, text_embeddings, uncond_text_embeddings, learnable_text_embeddings
220 | )
221 |
222 | image_variation_loss = F.mse_loss(images_original, images_pred, reduction='sum') / self.num_views
223 |
224 | loss = (
225 | loss_latent_sds * self.lambda_latent_sds +
226 | loss_img_sds * self.lambda_image_sds +
227 | image_variation_loss * self.lambda_image_variation +
228 | loss_embedding
229 | )
230 |
231 | optimizer_embeddings.zero_grad()
232 | new_gaussians.optimizer.zero_grad()
233 |
234 | # self.lambda_scale_regularization
235 | scales = torch.exp(new_gaussians._scales)
236 | big_points_ws = scales.max(dim=1).values > 0.05
237 | loss += 10 * scales[big_points_ws].sum()
238 |
239 | loss.backward()
240 |
241 | new_gaussians.optimizer.step()
242 | optimizer_embeddings.step()
243 |
244 | for radii, viewspace_points in zip(self.renderer.radii, self.renderer.viewspace_points):
245 | visibility_filter = radii > 0
246 | new_gaussians.is_visible[visibility_filter] = 1
247 | new_gaussians.max_radii2D[visibility_filter] = torch.max(
248 | new_gaussians.max_radii2D[visibility_filter], radii[visibility_filter]
249 | )
250 | new_gaussians.add_densification_stats(viewspace_points, visibility_filter)
251 |
252 | gaussians = new_gaussians()
253 | is_visible = new_gaussians.is_visible.bool()
254 | gaussians = [p[:, is_visible].detach() for p in gaussians]
255 |
256 | del new_gaussians
257 | return gaussians
258 |
--------------------------------------------------------------------------------
/model/util.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 |
3 | import torch
4 | from diffusers import AutoencoderKL, SD3Transformer2DModel
5 | from transformers import AutoImageProcessor, AutoModelForDepthEstimation
6 |
7 |
8 | def create_vae(cfg):
9 | vae = AutoencoderKL.from_pretrained(
10 | cfg.mv_rf_model.hf_path,
11 | subfolder="vae",
12 | )
13 | return vae
14 |
15 |
16 | def create_sd3_transformer():
17 | sd3 = SD3Transformer2DModel.from_pretrained(
18 | "stabilityai/stable-diffusion-3-medium-diffusers",
19 | subfolder="transformer",
20 | low_cpu_mem_usage=True,
21 | )
22 | return sd3
23 |
24 |
25 | def create_depth(cfg):
26 | """
27 | How to use depth models
28 |
29 | # prepare image for the model
30 | inputs = depth_image_processor(images=image, return_tensors="pt")
31 |
32 | with torch.no_grad():
33 | outputs = model(**inputs)
34 | predicted_depth = outputs.predicted_depth
35 | """
36 | depth_image_processor = AutoImageProcessor.from_pretrained(cfg.depth_encoder.hf_path)
37 | depth_model = AutoModelForDepthEstimation.from_pretrained(cfg.depth_encoder.hf_path)
38 | return depth_image_processor, depth_model
39 |
40 |
41 | def convert_to_buffer(module: torch.nn.Module, persistent: bool = True):
42 | # Recurse over child modules.
43 | for name, child in list(module.named_children()):
44 | convert_to_buffer(child, persistent)
45 |
46 | # Also re-save buffers to change persistence.
47 | for name, parameter_or_buffer in (
48 | *module.named_parameters(recurse=False),
49 | *module.named_buffers(recurse=False),
50 | ):
51 | value = parameter_or_buffer.detach().clone()
52 | delattr(module, name)
53 | module.register_buffer(name, value, persistent=persistent)
54 |
55 |
56 | def update_ema(ema_model, model, decay=0.9999):
57 | """
58 | Step the EMA model towards the current model.
59 | """
60 |
61 | model_context_manager = model.summon_full_params(model)
62 | with model_context_manager:
63 | ema_params = OrderedDict(ema_model.named_parameters())
64 | model_params = OrderedDict(model.named_parameters())
65 | for name, param in model_params.items():
66 | ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay)
67 |
68 |
69 | def stack_depth_images(depth_in):
70 | """
71 | Crawled from https://github.com/prs-eth/Marigold/blob/main/src/trainer/marigold_trainer.py#L395
72 | """
73 | if 4 == len(depth_in.shape):
74 | stacked = depth_in.repeat(1, 3, 1, 1)
75 | elif 3 == len(depth_in.shape):
76 | stacked = depth_in.unsqueeze(1)
77 | stacked = depth_in.repeat(1, 3, 1, 1)
78 | return stacked
79 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | hydra-core==1.3.2
2 | transformers==4.40.2
3 | wandb==0.17.0
4 | timm==0.5.4
5 | diffusers==0.30.0
6 | sentencepiece==0.2.0
7 | matplotlib==3.8.4
8 | imageio-ffmpeg==0.5.1
9 | camtools
10 | accelerate
11 | einops
12 | lpips
13 | plyfile
14 | fvcore
15 | carvekit-colab
16 | scikit-image
17 | opencv-python
18 | git+https://github.com/byeongjun-park/diff-gaussian-rasterization
19 | git+https://github.com/facebookresearch/pytorch3d@V0.7.8
--------------------------------------------------------------------------------
/util/camera_visualization.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import os
8 |
9 | import matplotlib.pyplot as plt
10 | import torch
11 | from mpl_toolkits.mplot3d import Axes3D # noqa: F401 unused import
12 |
13 |
14 | def get_camera_wireframe(scale: float = 0.3):
15 | """
16 | Returns a wireframe of a 3D line-plot of a camera symbol.
17 | """
18 | a = 0.5 * torch.tensor([-1, 1, 2])
19 | b = 0.5 * torch.tensor([1, 1, 2])
20 | c = 0.5 * torch.tensor([-1, -1, 2])
21 | d = 0.5 * torch.tensor([1, -1, 2])
22 | C = torch.zeros(3)
23 | camera_points = [a, b, d, c, a, C, b, d, C, c, C]
24 | lines = torch.stack([x.float() for x in camera_points]) * scale
25 | return lines
26 |
27 |
28 | def plot_cameras(ax, cameras, color=None):
29 | """
30 | Plots a set of `cameras` objects into the maplotlib axis `ax` with
31 | color `color`.
32 | """
33 | cam_wires_canonical = get_camera_wireframe(scale=0.05).cuda()[None]
34 | cam_trans = cameras.get_world_to_view_transform().inverse()
35 | cam_wires_trans = cam_trans.transform_points(cam_wires_canonical)
36 | plot_handles = []
37 | cmap = plt.get_cmap("hsv")
38 |
39 | for i, wire in enumerate(cam_wires_trans):
40 | # the Z and Y axes are flipped intentionally here!
41 | color_ = cmap(i / len(cameras))[:-1] if color is None else color
42 | x_, z_, y_ = wire.detach().cpu().numpy().T.astype(float)
43 | (h,) = ax.plot(x_, y_, z_, color=color_, linewidth=0.5)
44 | plot_handles.append(h)
45 | return cam_wires_trans
46 |
47 |
48 | def plot_camera_scene(cameras, cameras_gt, path):
49 | """
50 | Plots a set of predicted cameras `cameras` and their corresponding
51 | ground truth locations `cameras_gt`. The plot is named with
52 | a string passed inside the `status` argument.
53 | """
54 | fig = plt.figure()
55 | ax = fig.add_subplot(projection="3d")
56 | ax.clear()
57 |
58 | points = plot_cameras(ax, cameras)
59 | points_gt = plot_cameras(ax, cameras_gt, color=(0.1, 0.1, 0.1))
60 | tot_pts = torch.cat([points[:, -1], points_gt[:, -1]], dim=0)
61 |
62 | max_scene = tot_pts.max(dim=0)[0].cpu()
63 | min_scene = tot_pts.min(dim=0)[0].cpu()
64 | ax.set_xticklabels([])
65 | ax.set_yticklabels([])
66 | ax.set_zticklabels([])
67 | ax.set_xlim3d([min_scene[0] - 0.1, max_scene[0] + 0.3])
68 | ax.set_ylim3d([min_scene[2] - 0.1, max_scene[2] + 0.1])
69 | ax.set_zlim3d([min_scene[1] - 0.1, max_scene[1] + 0.1])
70 |
71 | ax.invert_yaxis()
72 |
73 | plt.savefig(os.path.join(path, "pose.pdf"), bbox_inches="tight", pad_inches=0, transparent=True)
--------------------------------------------------------------------------------
/util/dist_util.py:
--------------------------------------------------------------------------------
1 | import os
2 | import socket
3 |
4 | import torch
5 | import torch.distributed as dist
6 |
7 | DEVICE = None
8 |
9 |
10 | def setup_dist(args):
11 | if dist.is_initialized():
12 | return
13 |
14 | if os.environ.get("MASTER_ADDR", None) is None:
15 | hostname = socket.gethostbyname(socket.getfqdn())
16 | os.environ["MASTER_ADDR"] = hostname
17 | os.environ["RANK"] = "0"
18 | os.environ["WORLD_SIZE"] = "1"
19 | port = _find_free_port()
20 | os.environ["MASTER_PORT"] = str(port)
21 |
22 | dist.init_process_group("nccl")
23 | assert args.global_batch_size % dist.get_world_size() == 0, "Batch size must be divisible by world size."
24 | rank = dist.get_rank()
25 | device = rank % torch.cuda.device_count() + args.gpu_offset
26 | seed = args.global_seed * dist.get_world_size() + rank
27 | torch.manual_seed(seed)
28 | torch.cuda.set_device(device)
29 | global DEVICE
30 | DEVICE = device
31 | print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}, device={DEVICE}")
32 |
33 |
34 | def _find_free_port():
35 | try:
36 | s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
37 | s.bind(("", 0))
38 | s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
39 | return s.getsockname()[1]
40 | finally:
41 | s.close()
42 |
43 |
44 | def cleanup():
45 | """
46 | End DDP training.
47 | """
48 | dist.destroy_process_group()
49 |
50 |
51 | def device():
52 | if not dist.is_initialized():
53 | raise NameError
54 | return DEVICE
55 |
--------------------------------------------------------------------------------