├── .gitignore ├── README.md ├── asset └── teasure_figure.png ├── config ├── base_config.yaml └── experiments │ └── generation.yaml ├── download.sh ├── inference └── generate.py ├── model ├── gsdecoder │ ├── camera_embedding.py │ ├── cuda_splatting.py │ ├── decoder_splatting.py │ ├── gs_decoder_architecture.py │ └── load_gsdecoder.py ├── multiview_rf │ ├── load_mv_sd3.py │ ├── mv_sd3_architecture.py │ └── text_embedding.py ├── refiner │ ├── camera_util.py │ ├── gs_util.py │ └── sds_pp_refiner.py └── util.py ├── requirements.txt └── util ├── camera_visualization.py └── dist_util.py /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode 2 | baseline_results 3 | checkpoints 4 | final 5 | final_latest_4 6 | output 7 | point_clouds 8 | results 9 | sample_ablation 10 | sample_scene 11 | samples 12 | teasure 13 | tmp 14 | wandb 15 | __pycache__ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [CVPR 2025] SplatFlow: Multi-View Rectified Flow Model for 3D Gaussian Splatting Synthesis 2 | 3 | 4 | 5 | 6 | 7 | 8 | This repository contains the official pytorch implementation of the paper: "SplatFlow: Multi-View Rectified Flow Model for 3D Gaussian Splatting Synthesis". 9 | ![teaser](asset/teasure_figure.png) 10 | 11 | 12 | ## Installation 13 | Our code is tested with Python 3.10, Pytorch 2.4.0, and CUDA 11.8. To install the required packages, run the following command: 14 | ```bash 15 | pip3 install torch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 --index-url https://download.pytorch.org/whl/cu118 16 | pip install -r requirements.txt 17 | ``` 18 | 19 | ## Model Checkpoints 20 | We have extended the model training period after submission to enhance its performance. Updated model checkpoints are now available, and benchmark results will be revised following the review process. You can download the model checkpoints with the following command: 21 | ``` 22 | bash download.sh 23 | ``` 24 | 25 | 26 | 27 | ## Inference 28 | ### 1. 3DGS Generation 29 | 30 | ```bash 31 | export PYTHONPATH=$(pwd) 32 | huggingface-cli login 33 | python inference/generate.py +experiments=generation inference.generate.prompt="Your prompt here" 34 | ``` 35 | 36 | ## Citation 37 | If you find this repository helpful for your project, please consider citing our work. :) 38 | ``` 39 | @article{go2024splatflow, 40 | title={SplatFlow: Multi-View Rectified Flow Model for 3D Gaussian Splatting Synthesis}, 41 | author={Go, Hyojun and Park, Byeongjun and Jang, Jiho and Kim, Jin-Young and Kwon, Soonwoo and Kim, Changick}, 42 | journal={arXiv preprint arXiv:2411.16443}, 43 | year={2024} 44 | } 45 | ``` 46 | 47 | ## Acknolwedgement 48 | We thank [director3d](https://github.com/imlixinyang/Director3D) 49 | 50 | 51 | ## TODO: 52 | - [ ] Update project page. 53 | - [x] Add model checkpoints 54 | - [x] Code verification. 55 | - [ ] Add more details on the README.md 56 | - [ ] Add the training script 57 | - [ ] Share text annotations for RealEstate10K, DL3DV-10K, MVImgNet, and ACID datasets. (Used in the next work, may be integrated with https://github.com/gohyojun15/VideoRFSplat) 58 | -------------------------------------------------------------------------------- /asset/teasure_figure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gohyojun15/SplatFlow/6964764ca8ac4ea2ca7c6e4e2f662457cec26040/asset/teasure_figure.png -------------------------------------------------------------------------------- /config/base_config.yaml: -------------------------------------------------------------------------------- 1 | 2 | general: 3 | iterations: 0 4 | num_workers: 0 5 | global_batch_size: 0 6 | global_seed: 42 7 | gpu_offset: 0 8 | mixed_precision: true 9 | sampled_view: 8 10 | debug: true 11 | 12 | optim: 13 | lr: 0 14 | weight_decay: 0 15 | warmup_steps: 0 16 | 17 | gsdecoder: 18 | gan_loss: 19 | enable: true 20 | disc_start: 0 21 | disc_weight: 0.1 22 | loading_ckpt_path: Null 23 | 24 | mv_rf_model: 25 | resume_from_ckpt: Null 26 | weighting_scheme: "logit_normal" 27 | logit_mean: 0 28 | logit_std: 1.0 29 | mode_scale: 1.29 30 | precondition_outputs: true 31 | hf_path: "stabilityai/stable-diffusion-3-medium-diffusers" 32 | 33 | depth_encoder: 34 | hf_path: "depth-anything/Depth-Anything-V2-Small-hf" 35 | 36 | defaults: 37 | - _self_ 38 | - override hydra/job_logging: disabled 39 | - override hydra/hydra_logging: disabled 40 | 41 | 42 | # Inference configuration. 43 | inference: 44 | mv_rf_ckpt: checkpoints/mv_rf_ema.pt 45 | gsdecoder_ckpt: checkpoints/gs_decoder.pt 46 | sample: 47 | num_steps: 200 48 | cfg: true 49 | cfg_scale: [7.0, 5.0, 1.0] 50 | stop_ray: 50 51 | sd3_guidance: true 52 | sd3_cfg: true 53 | sd3_scale: 3.0 54 | 55 | hydra: 56 | run: 57 | dir: . 58 | output_subdir: null -------------------------------------------------------------------------------- /config/experiments/generation.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | inference: 4 | # generation 5 | generate: 6 | save_path: "samples" 7 | prompt: "A blue car on a clean street with buildings in the background" 8 | refiner: 9 | args: 10 | sd_model_key: 'stabilityai/stable-diffusion-2-1-base' 11 | num_views: 1 12 | img_size: 512 13 | guidance_scale: 7.5 14 | min_step_percent: 0.02 15 | max_step_percent: 0.5 16 | num_densifications: 4 17 | lr_scale: 0.25 18 | lrs: {'xyz': 2e-4, 'features': 1e-2, 'opacity': 5e-2, 'scales': 1e-3, 'rotations': 1e-2, 'embeddings': 1e-2} 19 | use_lods: True 20 | lambda_latent_sds: 1 21 | lambda_image_sds: 0.1 22 | lambda_image_variation: 0.001 23 | opacity_threshold: 0.001 24 | text_templete: $text$ 25 | negative_text_templete: 'unclear. noisy. point cloud. low-res. low-quality. low-resolution. unrealistic.' 26 | total_iterations: 1000 -------------------------------------------------------------------------------- /download.sh: -------------------------------------------------------------------------------- 1 | 2 | # Via google drive 3 | # gdown 1Ch9YK0eA7-alMIKK8NxKaoyJ7rWZiqif 4 | # gdown 1BUXCmR7jDTiGfbV55NHf7dZp6GsV5LCf 5 | 6 | 7 | mkdir -p checkpoints 8 | cd checkpoints 9 | wget https://huggingface.co/HJGO/splatflow/resolve/main/gs_decoder.pt?download=true -O gs_decoder.pt 10 | wait 11 | 12 | wget https://huggingface.co/HJGO/splatflow/resolve/main/mv_rf_ema.pt?download=true -O mv_rf_ema.pt 13 | wait -------------------------------------------------------------------------------- /inference/generate.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | import hydra 5 | import matplotlib.pyplot as plt 6 | import torch 7 | import torchvision 8 | from diffusers import FlowMatchEulerDiscreteScheduler 9 | from einops import rearrange, repeat 10 | from pytorch3d.utils import cameras_from_opencv_projection 11 | from scipy.spatial.transform import Rotation as R 12 | from tqdm import tqdm 13 | 14 | from model.gsdecoder.camera_embedding import get_plucker_rays, optimize_plucker_ray 15 | from model.gsdecoder.load_gsdecoder import create_gsdecoder 16 | from model.multiview_rf.load_mv_sd3 import create_sd_multiview_rf_model 17 | from model.multiview_rf.text_embedding import compute_text_embeddings 18 | from model.refiner.camera_util import ( 19 | export_mv, 20 | export_ply_for_gaussians, 21 | export_video, 22 | load_ply_for_gaussians, 23 | ) 24 | from model.refiner.sds_pp_refiner import GSRefinerSDSPlusPlus 25 | from model.util import create_sd3_transformer, create_vae 26 | from util import camera_visualization, dist_util 27 | 28 | 29 | def visualize_camera_pose(cameras, path): 30 | # visualize cameras 31 | extrinsic = cameras[0][:, :, 3:] 32 | device = cameras.device 33 | num_view = len(extrinsic) 34 | homo_extrinsic = repeat(torch.eye(4).to(device, dtype=extrinsic.dtype), 'i j -> v i j', v=num_view).clone() 35 | homo_extrinsic[:, :3] = extrinsic 36 | w2c = homo_extrinsic.inverse() 37 | image_size = repeat(torch.tensor([256, 256], device=device), "i -> v i", v=num_view) 38 | cameras_ours = cameras_from_opencv_projection( 39 | R=w2c[:, :3, :3], 40 | tvec=w2c[:, :3, -1], 41 | camera_matrix=cameras[0, :, :, :3], 42 | image_size=image_size, 43 | ) 44 | fig = plt.figure() 45 | ax = fig.add_subplot(projection="3d") 46 | ax.clear() 47 | points = camera_visualization.plot_cameras(ax, cameras_ours) 48 | cc = points[:, -1] 49 | max_scene = cc.max(dim=0)[0].cpu() 50 | min_scene = cc.min(dim=0)[0].cpu() 51 | ax.set_xticklabels([]) 52 | ax.set_yticklabels([]) 53 | ax.set_zticklabels([]) 54 | ax.set_xlim3d([min_scene[0] - 0.1, max_scene[0] + 0.3]) 55 | ax.set_ylim3d([min_scene[2] - 0.1, max_scene[2] + 0.1]) 56 | ax.set_zlim3d([min_scene[1] - 0.1, max_scene[1] + 0.1]) 57 | ax.invert_yaxis() 58 | plt.savefig(os.path.join(path, "pose.pdf"), bbox_inches="tight", pad_inches=0, transparent=True) 59 | 60 | def save_results(path, output, poses, Ks): 61 | means, covariance, opacity, rgb, rotation, scale = output 62 | _, v, h, w, _ = means.shape 63 | 64 | means = rearrange(means, "() v h w xyz -> (v h w) xyz") 65 | opacity = rearrange(opacity, "() v h w o -> (v h w) o") 66 | rgb = rearrange(rgb, "() v h w rgb -> (v h w) rgb") 67 | rotation = rearrange(rotation, "() v h w q -> (v h w) q") 68 | scale = rearrange(scale, "() v h w s -> (v h w) s") 69 | source_rotations = repeat(poses[..., :3, :3], "() v i j -> (v h w) i j", h=h, w=w) 70 | 71 | cam_rotation_matrix = R.from_quat(rotation.detach().cpu().numpy()).as_matrix() 72 | world_rotation_matrix = source_rotations.detach().cpu().numpy() @ cam_rotation_matrix 73 | world_rotations = R.from_matrix(world_rotation_matrix).as_quat() 74 | world_rotations = torch.from_numpy(world_rotations).to(source_rotations.device) 75 | 76 | export_ply_for_gaussians(os.path.join(path, 'gaussian'), (means, rgb[:, None], opacity, scale, world_rotations)) 77 | cameras = torch.cat([Ks[0], poses[0, :, :3]], dim=-1) 78 | torch.save(cameras, path / "cameras.pt") 79 | 80 | 81 | def generate_sampling( 82 | text_prompt, 83 | text_encoders, 84 | tokenizers, 85 | model, 86 | gs_decoder, 87 | vae, 88 | noise_scheduler, 89 | stable_diffusion3_transformer, 90 | cfg, 91 | device, 92 | dtype, 93 | ): 94 | """ 95 | Generating a sample, supposing that batch size is 1 96 | """ 97 | batch_size = 1 98 | 99 | if isinstance(text_prompt, str): 100 | text_prompt = [text_prompt] 101 | negative_prompt = [""] 102 | 103 | noise_scheduler.set_timesteps(cfg.inference.sample.num_steps) 104 | timesteps = noise_scheduler.timesteps 105 | latents = torch.randn(batch_size, 8, 38, 32, 32, device=device, dtype=dtype) 106 | 107 | with torch.no_grad(): 108 | prompt_embeds, pooled_prompt_embeds = compute_text_embeddings( 109 | text_prompt, 110 | text_encoders, 111 | tokenizers, 112 | 77, 113 | device, 114 | ) 115 | if cfg.inference.sample.cfg: 116 | negative_prompt_embeds, pooled_negative_prompt_embeds = compute_text_embeddings( 117 | negative_prompt, text_encoders, tokenizers, 77, device=device 118 | ) 119 | prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) 120 | pooled_prompt_embeds = torch.cat([pooled_negative_prompt_embeds, pooled_prompt_embeds], dim=0) 121 | cfg_scales = [] 122 | cfg_scales.append([cfg.inference.sample.cfg_scale[0]] * 16) # Image 123 | cfg_scales.append([cfg.inference.sample.cfg_scale[1]] * 16) # Depth 124 | cfg_scales.append([cfg.inference.sample.cfg_scale[2]] * 6) # Pose 125 | cfg_scales = torch.tensor(sum(cfg_scales, []), dtype=latents.dtype, device=latents.device) 126 | cfg_scales = repeat(cfg_scales, "i -> () () i () ()") 127 | 128 | for i, t in tqdm(enumerate(timesteps)): 129 | latent_model_input = torch.cat([latents] * 2) if cfg.inference.sample.cfg else latents 130 | timestep = t.expand(latent_model_input.shape[0]).to(device) 131 | 132 | noise_pred = model( 133 | hidden_states=latent_model_input, 134 | timestep=timestep, 135 | encoder_hidden_states=prompt_embeds, 136 | pooled_projections=pooled_prompt_embeds, 137 | return_dict=False, 138 | )[0] 139 | 140 | if cfg.inference.sample.cfg: 141 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 142 | noise_pred = noise_pred_uncond + cfg_scales * (noise_pred_text - noise_pred_uncond) 143 | 144 | if noise_scheduler.step_index is None: 145 | noise_scheduler._init_step_index(t) 146 | 147 | if noise_scheduler.step_index <= cfg.inference.sample.stop_ray: 148 | with torch.enable_grad(): 149 | original_latent = latents - noise_pred * noise_scheduler.sigmas[noise_scheduler.step_index] 150 | poses, Ks, inv_poses = optimize_plucker_ray(original_latent[:, :, -6:]) 151 | clean_ray_latent = get_plucker_rays(poses[:, :, :3], Ks, is_diffusion=True) 152 | 153 | if cfg.inference.sample.sd3_guidance and (i % 3 == 0) and (i < cfg.inference.sample.stop_ray): 154 | image_2d_flatten_latent = rearrange(latent_model_input, "b v c h w -> (b v) c h w") 155 | image_timestep = timestep.repeat(latent_model_input.shape[1]) 156 | 157 | if cfg.inference.sample.sd3_cfg: 158 | positive_prompt_embed = prompt_embeds[1].repeat(latent_model_input.shape[1], 1, 1) 159 | negative_prompt_embed = prompt_embeds[0].repeat(latent_model_input.shape[1], 1, 1) 160 | image_prompt_embed = torch.cat([negative_prompt_embed, positive_prompt_embed], dim=0) 161 | 162 | positive_pooled_prompt_embed = pooled_prompt_embeds[1].repeat(latent_model_input.shape[1], 1) 163 | negative_pooled_prompt_embed = pooled_prompt_embeds[0].repeat(latent_model_input.shape[1], 1) 164 | image_pooled_prompt_embed = torch.cat( 165 | [negative_pooled_prompt_embed, positive_pooled_prompt_embed], dim=0 166 | ) 167 | else: 168 | image_prompt_embed = prompt_embeds.repeat(latent_model_input.shape[1], 1, 1) 169 | image_pooled_prompt_embed = pooled_prompt_embeds.repeat(latent_model_input.shape[1], 1) 170 | 171 | sd_noise_pred = stable_diffusion3_transformer( 172 | hidden_states=image_2d_flatten_latent[:, :16, :, :], 173 | timestep=image_timestep, 174 | encoder_hidden_states=image_prompt_embed, 175 | pooled_projections=image_pooled_prompt_embed, 176 | return_dict=False, 177 | )[0] 178 | 179 | if cfg.inference.sample.sd3_cfg: 180 | sd_noise_pred_uncond, sd_noise_pred_text = sd_noise_pred.chunk(2) 181 | sd_noise_pred = sd_noise_pred_uncond + cfg.inference.sample.sd3_scale * ( 182 | sd_noise_pred_text - sd_noise_pred_uncond 183 | ) 184 | sd_noise_pred = rearrange(sd_noise_pred, "(b v) c h w -> b v c h w", v=latent_model_input.shape[1]) 185 | noise_pred[:, :, :16, :, :] = sd_noise_pred 186 | 187 | latents = noise_scheduler.step(noise_pred, t, latents, return_dict=False)[0] 188 | # update ray latent 189 | sigmas = noise_scheduler.sigmas[noise_scheduler.step_index] 190 | noise = torch.randn_like(latents[:, :, 32:]) 191 | noisy_ray_latent = (1.0 - sigmas) * clean_ray_latent + sigmas * noise 192 | latents[:, :, -6:] = noisy_ray_latent 193 | 194 | with torch.inference_mode(): 195 | img_latent, depth_latent, _ = latents.split([16, 16, 6], dim=2) 196 | output = gs_decoder(img_latent, depth_latent, None, poses[:, :, :3], Ks, near=0.05, far=20) 197 | means, covariance, opacity, rgb, rotation, scale = output 198 | # save results 199 | path = Path(f"{cfg.inference.generate.save_path}/{text_prompt[0]}") 200 | os.makedirs(path, exist_ok=True) 201 | 202 | text_prompt_save = path / "text_prompt.txt" 203 | with open(text_prompt_save, "w", encoding="utf-8") as f: 204 | f.write(text_prompt[0]) 205 | 206 | # Save 3DGS and camera parameters 207 | save_results(path, output, poses, Ks) 208 | 209 | ### Save images 210 | rendered_images = gs_decoder.eval_rendering( 211 | means.float(), 212 | covariance.float(), 213 | opacity.float(), 214 | rgb.float(), 215 | poses[:, :, :3].float(), 216 | Ks.float(), 217 | near=0.05, 218 | far=20, 219 | ) 220 | 221 | rendered_images = rendered_images.clamp(min=0, max=1) 222 | transform = torchvision.transforms.ToPILImage() 223 | 224 | mv_results_path = path / "mv_results" 225 | os.makedirs(mv_results_path, exist_ok=True) 226 | 227 | for view, img in enumerate(rendered_images): 228 | transform(img).save(mv_results_path / f"render_img_{view}.png") 229 | 230 | 231 | @hydra.main(config_path="../config", config_name="base_config.yaml", version_base="1.1") 232 | def main(cfg): 233 | dist_util.setup_dist(cfg.general) 234 | device = dist_util.device() 235 | dtype = torch.float16 236 | print(f"Device: {device}") 237 | 238 | path = Path(f"{cfg.inference.generate.save_path}/{cfg.inference.generate.prompt}") 239 | 240 | if os.path.exists(path / "cameras.pt") and os.path.exists(path / "gaussian.ply"): 241 | print("Step 1 is already done") # Skip generating First step 242 | else: 243 | """ 244 | First Step: Generate a initial 3DGS 245 | """ 246 | # create vae part 247 | vae = create_vae(cfg) 248 | decoder = create_gsdecoder(cfg) 249 | vae, decoder = vae.to(device=device, dtype=dtype), decoder.to(device=device) 250 | decoder.load_state_dict(torch.load(cfg.inference.gsdecoder_ckpt, map_location="cpu", weights_only=False)) 251 | vae.eval(), decoder.eval() 252 | 253 | # MV RF 254 | model, tokenizer, text_encoders = create_sd_multiview_rf_model() 255 | model = model.to(device=device, dtype=dtype) 256 | model.load_state_dict(torch.load(cfg.inference.mv_rf_ckpt, map_location="cpu", weights_only=False)) 257 | model.eval() 258 | 259 | # Text encoders 260 | text_encoders_list = [] 261 | for i, text_encoder in enumerate(text_encoders): 262 | text_encoder.requires_grad_(False) 263 | text_encoder.to(device, dtype=dtype) 264 | text_encoders_list.append(text_encoder.to(device, dtype=dtype)) 265 | text_encoders = text_encoders_list 266 | 267 | noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( 268 | cfg.mv_rf_model.hf_path, 269 | subfolder="scheduler", 270 | force_download=False, 271 | shift=1.0, 272 | ) 273 | 274 | if cfg.inference.sample.sd3_guidance: 275 | stable_diffusion3_transformer = create_sd3_transformer() 276 | stable_diffusion3_transformer = stable_diffusion3_transformer.to(device, dtype=dtype) 277 | stable_diffusion3_transformer.eval() 278 | else: 279 | stable_diffusion3_transformer = None 280 | 281 | generate_sampling( 282 | cfg.inference.generate.prompt, 283 | text_encoders, 284 | tokenizer, 285 | model, 286 | decoder, 287 | vae, 288 | noise_scheduler, 289 | stable_diffusion3_transformer, 290 | cfg, 291 | device, 292 | dtype, 293 | ) 294 | 295 | """ 296 | Second Step: Refine the 3DGS 297 | """ 298 | refiner = GSRefinerSDSPlusPlus(**cfg.inference.generate.refiner.args) 299 | refiner.to(device) 300 | gaussians = load_ply_for_gaussians(path / "gaussian.ply", device=device) 301 | cameras = torch.load(path / "cameras.pt", weights_only=False).unsqueeze(0) 302 | refined_gaussians = refiner.refine_gaussians(gaussians, cfg.inference.generate.prompt, dense_cameras=cameras) 303 | 304 | visualize_camera_pose(cameras, path) 305 | export_ply_for_gaussians(os.path.join(path, 'refined_gaussian'), [p[0] for p in refined_gaussians]) 306 | 307 | def render_fn(cameras, h=512, w=512): 308 | return refiner.renderer(cameras, refined_gaussians, h=h, w=w, bg=None)[:2] 309 | 310 | export_mv( 311 | render_fn, 312 | path, 313 | cameras, 314 | ) 315 | 316 | export_video(render_fn, path, "refined_video", cameras, device=device) 317 | 318 | 319 | if __name__ == "__main__": 320 | main() 321 | -------------------------------------------------------------------------------- /model/gsdecoder/camera_embedding.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import torch 3 | from einops import rearrange, repeat 4 | from torch import nn 5 | 6 | # Reference: https://github.com/valeoai/LaRa/blob/main/semanticbev/models/components/LaRa_embeddings.py 7 | # We slightly modify the official code to fit in our setting. 8 | 9 | 10 | def meshgrid(spatial_shape, normalized=True, indexing="ij", device=None): 11 | """Create evenly spaced position coordinates for self.spatial_shape with values in [v_min, v_max]. 12 | :param v_min: minimum coordinate value per dimension. 13 | :param v_max: maximum coordinate value per dimension. 14 | :return: position coordinates tensor of shape (*shape, len(shape)). 15 | """ 16 | if normalized: 17 | axis_coords = [torch.linspace(-1.0, 1.0, steps=s, device=device) for s in spatial_shape] 18 | else: 19 | axis_coords = [torch.linspace(0, s - 1, steps=s, device=device) for s in spatial_shape] 20 | 21 | grid_coords = torch.meshgrid(*axis_coords, indexing=indexing) 22 | 23 | return torch.stack(grid_coords, dim=-1) 24 | 25 | 26 | def get_plucker_rays(extrinsics, intrinsics, h=32, w=32, stride=8, is_diffusion=False): 27 | b, v = extrinsics.shape[:2] 28 | 29 | # Adjust intrinsics scale due to downsizing by input_stride (we take feature maps as input not the raw images) 30 | 31 | updated_intrinsics = intrinsics.clone().unsqueeze(1) if len(intrinsics.shape) == 3 else intrinsics.clone() 32 | updated_intrinsics[..., 0, 0] *= 1 / stride 33 | updated_intrinsics[..., 0, 2] *= 1 / stride 34 | updated_intrinsics[..., 1, 1] *= 1 / stride 35 | updated_intrinsics[..., 1, 2] *= 1 / stride 36 | 37 | # create positionnal encodings 38 | pixel_coords = meshgrid((w, h), normalized=False, indexing="xy", device=extrinsics.device) 39 | ones = torch.ones((h, w, 1), device=extrinsics.device) 40 | 41 | pixel_coords = torch.cat([pixel_coords, ones], dim=-1) # [x, y, 1] vectors of pixel coordinates 42 | pixel_coords = rearrange(pixel_coords, "h w c -> c (h w)") 43 | pixel_coords = repeat(pixel_coords, "... -> b v ...", b=b, v=v) 44 | 45 | # Split extrinsics into rots and trans, rots == c2w, trans == camera center at world coordinate 46 | rots, trans = extrinsics.split([3, 1], dim=-1) 47 | 48 | # pixel_coords.shape = [B, N, 3, K] | N # of cams, K # of pixels 49 | directions = rots @ updated_intrinsics.inverse() @ pixel_coords 50 | directions = directions / directions.norm(dim=2, keepdim=True) 51 | directions = rearrange(directions, "b v c (h w) -> b v c h w", h=h, w=w) 52 | cam_origins = repeat(trans.squeeze(-1), "b v c -> b v c h w", h=h, w=w) 53 | moments = torch.cross(cam_origins, directions, dim=2) 54 | 55 | output = [] if is_diffusion else [cam_origins] 56 | output.append(directions) 57 | output.append(moments) 58 | 59 | return torch.cat(output, dim=2) 60 | 61 | 62 | def optimize_plucker_ray(ray_latent): 63 | b, v, _, h, w = ray_latent.shape 64 | ray_latent = ray_latent.float() # this function does not support bfloat16 type.. 65 | directions, moments = ray_latent.split(3, dim=2) 66 | 67 | # Reverse Process 68 | c = torch.linalg.norm(directions, dim=2, keepdim=True) 69 | origins = torch.cross(directions, moments / c, dim=2) 70 | 71 | new_trans = intersect_skew_lines_high_dim( 72 | rearrange(origins, "b n c h w -> b n (h w) c"), rearrange(directions, "b n c h w -> b n (h w) c") 73 | ) 74 | 75 | # Retrieve target rays 76 | I_intrinsic_ = torch.tensor([[1, 0, h // 2], [0, 1, w // 2], [0, 0, 1]], dtype=ray_latent.dtype, device=c.device) 77 | I_intrinsic = repeat(I_intrinsic_, "i j -> b v i j", b=b, v=v) 78 | I_rot = repeat(torch.eye(3, dtype=ray_latent.dtype, device=c.device), "i j -> b v i j", b=b, v=v) 79 | 80 | # create positionnal encodings 81 | pixel_coords = meshgrid((w, h), normalized=False, indexing="xy", device=ray_latent.device) 82 | ones = torch.ones((h, w, 1), device=ray_latent.device) 83 | 84 | pixel_coords = torch.cat([pixel_coords, ones], dim=-1) # [x, y, 1] vectors of pixel coordinates 85 | pixel_coords = rearrange(pixel_coords, "h w c -> c (h w)") 86 | pixel_coords = repeat(pixel_coords, "... -> b v ...", b=b, v=v) 87 | 88 | I_directions = I_rot @ I_intrinsic.inverse() @ pixel_coords 89 | I_directions = I_directions / I_directions.norm(dim=2, keepdim=True) 90 | I_directions = rearrange(I_directions, "b v c (h w) -> b v c h w", h=h, w=w) 91 | 92 | new_rots, new_intrinsics = [], [] 93 | for bb in range(b): 94 | Rs, Ks = [], [] 95 | for vv in range(v): 96 | R, f, pp = compute_optimal_rotation_intrinsics( 97 | I_directions[bb, vv], directions[bb, vv], reproj_threshold=0.2 98 | ) 99 | Rs.append(R) 100 | K = I_intrinsic_.clone() 101 | K[:2, :2] = torch.diag(1 / f) 102 | K[:, -1][:2] += pp 103 | Ks.append(K) 104 | 105 | new_rots.append(torch.stack(Rs)) 106 | new_intrinsics.append(torch.stack(Ks)) 107 | 108 | new_rots = torch.stack(new_rots) 109 | new_intrinsics = torch.stack(new_intrinsics) 110 | 111 | ff = nn.Parameter(new_intrinsics[..., [0, 1], [0, 1]].mean(1)[0], requires_grad=True) 112 | optimizer = torch.optim.Adam([ff], lr=0.01) 113 | X = torch.tensor( 114 | [[1, 0, h // 2], [0, 1, w // 2], [0, 0, 1]], dtype=ray_latent.dtype, device=c.device, requires_grad=True 115 | ) 116 | 117 | # Optimization loop 118 | num_iterations = 10 119 | for i in range(num_iterations + 1): 120 | optimizer.zero_grad() 121 | scale_matrix = torch.diag(torch.cat([ff, torch.ones_like(ff[:1])], dim=0)) 122 | 123 | triu_intrinsics = X @ scale_matrix 124 | re_X = repeat(triu_intrinsics, "i j -> b v i j", b=b, v=v) 125 | 126 | I_directions = I_rot @ re_X.inverse() @ pixel_coords 127 | I_directions = I_directions / I_directions.norm(dim=2, keepdim=True) 128 | I_directions = rearrange(I_directions, "b v c (h w) -> b v c h w", h=h, w=w) 129 | 130 | new_rots = [] 131 | for bb in range(b): 132 | Rs = [] 133 | for vv in range(v): 134 | R = compute_optimal_rotation_alignment(I_directions[bb, vv], directions[bb, vv]) 135 | Rs.append(R) 136 | 137 | new_rots.append(torch.stack(Rs)) 138 | 139 | new_rots = torch.stack(new_rots) 140 | 141 | if i == num_iterations: 142 | break 143 | 144 | I_directions = new_rots.clone() @ re_X.inverse() @ pixel_coords.clone() 145 | I_directions = I_directions / I_directions.norm(dim=2, keepdim=True) 146 | I_directions = rearrange(I_directions, "b v c (h w) -> b v c h w", h=h, w=w) 147 | loss = torch.norm(I_directions - directions, dim=2).mean() 148 | 149 | loss.backward() 150 | optimizer.step() 151 | 152 | new_rots = new_rots.detach().clone() 153 | new_intrinsics = re_X.detach().clone() 154 | 155 | stride = 8 156 | new_intrinsics[..., 0, 0] *= stride 157 | new_intrinsics[..., 0, 2] *= stride 158 | new_intrinsics[..., 1, 1] *= stride 159 | new_intrinsics[..., 1, 2] *= stride 160 | 161 | # normalize camera pose 162 | new_poses = torch.cat([new_rots, new_trans.unsqueeze(-1)], dim=-1) 163 | bottom = torch.tensor([0, 0, 0, 1], dtype=new_poses.dtype, device=new_poses.device) 164 | homo_poses = torch.cat([new_poses, repeat(bottom, "i -> b v () i", b=b, v=v)], dim=-2) 165 | inv_poses = torch.cat([homo_poses.inverse()[:, :, :3], repeat(bottom, "i -> b v () i", b=b, v=v)], dim=-2) 166 | return homo_poses, new_intrinsics, inv_poses 167 | 168 | 169 | # Refer to RayDiffusion 170 | def intersect_skew_lines_high_dim(p, r): 171 | # p : num views x 3 x num points 172 | # Implements https://en.wikipedia.org/wiki/Skew_lines In more than two dimensions 173 | dim = p.shape[-1] 174 | 175 | eye = torch.eye(dim, device=p.device, dtype=p.dtype)[None, None, None] 176 | I_min_cov = eye - (r[..., None] * r[..., None, :]) 177 | sum_proj = I_min_cov.matmul(p[..., None]).sum(dim=-3) 178 | 179 | p_intersect = torch.linalg.lstsq(I_min_cov.sum(dim=-3), sum_proj).solution[..., 0] 180 | 181 | return p_intersect 182 | 183 | 184 | def compute_optimal_rotation_intrinsics(rays_origin, rays_target, z_threshold=1e-8, reproj_threshold=0.2): 185 | """ 186 | Note: for some reason, f seems to be 1/f. 187 | 188 | Args: 189 | rays_origin (torch.Tensor): (3, H, W) 190 | rays_target (torch.Tensor): (3, H, W) 191 | z_threshold (float): Threshold for z value to be considered valid. 192 | 193 | Returns: 194 | R (torch.tensor): (3, 3) 195 | focal_length (torch.tensor): (2,) 196 | principal_point (torch.tensor): (2,) 197 | """ 198 | device = rays_origin.device 199 | _, h, w = rays_origin.shape 200 | 201 | rays_origin = rearrange(rays_origin, "c h w -> (h w) c") 202 | rays_target = rearrange(rays_target, "c h w -> (h w) c") 203 | 204 | z_mask = torch.logical_and(torch.abs(rays_target) > z_threshold, torch.abs(rays_origin) > z_threshold)[:, 2] 205 | rays_target = rays_target[z_mask] 206 | rays_origin = rays_origin[z_mask] 207 | rays_origin = rays_origin[:, :2] / rays_origin[:, -1:] 208 | rays_target = rays_target[:, :2] / rays_target[:, -1:] 209 | 210 | A, _ = cv2.findHomography( 211 | rays_origin.cpu().numpy(), 212 | rays_target.cpu().numpy(), 213 | cv2.RANSAC, 214 | reproj_threshold, 215 | ) 216 | A = torch.from_numpy(A).float().to(device) 217 | 218 | if torch.linalg.det(A) < 0: 219 | A = -A 220 | 221 | R, L = ql_decomposition(A) 222 | L = L / L[2][2] 223 | 224 | f = torch.stack((L[0][0], L[1][1])) 225 | pp = torch.stack((L[2][0], L[2][1])) 226 | return R, f, pp 227 | 228 | 229 | def ql_decomposition(A): 230 | P = torch.tensor([[0, 0, 1], [0, 1, 0], [1, 0, 0]], device=A.device).float() 231 | A_tilde = torch.matmul(A, P) 232 | Q_tilde, R_tilde = torch.linalg.qr(A_tilde) 233 | Q = torch.matmul(Q_tilde, P) 234 | L = torch.matmul(torch.matmul(P, R_tilde), P) 235 | d = torch.diag(L) 236 | Q[:, 0] *= torch.sign(d[0]) 237 | Q[:, 1] *= torch.sign(d[1]) 238 | Q[:, 2] *= torch.sign(d[2]) 239 | L[0] *= torch.sign(d[0]) 240 | L[1] *= torch.sign(d[1]) 241 | L[2] *= torch.sign(d[2]) 242 | return Q, L 243 | 244 | 245 | def compute_optimal_rotation_alignment(A, B): 246 | """ 247 | Compute optimal R that minimizes: || A - B @ R ||_F 248 | 249 | Args: 250 | A (torch.Tensor): (3, H, W) 251 | B (torch.Tensor): (3, H, W) 252 | 253 | Returns: 254 | R (torch.tensor): (3, 3) 255 | """ 256 | A = rearrange(A, "c h w -> (h w) c") 257 | B = rearrange(B, "c h w -> (h w) c") 258 | 259 | # normally with R @ B, this would be A @ B.T 260 | H = B.T @ A 261 | U, _, Vh = torch.linalg.svd(H, full_matrices=True) 262 | s = torch.linalg.det(U @ Vh) 263 | S_prime = torch.diag(torch.tensor([1, 1, torch.sign(s)], device=A.device)) 264 | return U @ S_prime @ Vh 265 | -------------------------------------------------------------------------------- /model/gsdecoder/cuda_splatting.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from diff_gaussian_rasterization import ( 3 | GaussianRasterizationSettings, 4 | GaussianRasterizer, 5 | ) 6 | from einops import einsum, rearrange, repeat 7 | 8 | 9 | def get_fov(intrinsics): 10 | intrinsics_inv = intrinsics.inverse() 11 | 12 | def process_vector(vector): 13 | vector = torch.tensor(vector, dtype=torch.float32, device=intrinsics.device) 14 | vector = einsum(intrinsics_inv, vector, "b i j, j -> b i") 15 | return vector / vector.norm(dim=-1, keepdim=True) 16 | 17 | left = process_vector([0, 0.5, 1]) 18 | right = process_vector([1, 0.5, 1]) 19 | top = process_vector([0.5, 0, 1]) 20 | bottom = process_vector([0.5, 1, 1]) 21 | fov_x = (left * right).sum(dim=-1).acos() 22 | fov_y = (top * bottom).sum(dim=-1).acos() 23 | return torch.stack((fov_x, fov_y), dim=-1) 24 | 25 | 26 | def get_projection_matrix(near, far, fov_x, fov_y): 27 | """Maps points in the viewing frustum to (-1, 1) on the X/Y axes and (0, 1) on the Z 28 | axis. Differs from the OpenGL version in that Z doesn't have range (-1, 1) after 29 | transformation and that Z is flipped. 30 | """ 31 | tan_fov_x = (0.5 * fov_x).tan() 32 | tan_fov_y = (0.5 * fov_y).tan() 33 | 34 | top = tan_fov_y * near 35 | bottom = -top 36 | right = tan_fov_x * near 37 | left = -right 38 | 39 | (b,) = fov_x.shape 40 | result = torch.zeros((b, 4, 4), dtype=torch.float32, device=fov_x.device) 41 | result[:, 0, 0] = 2 * near / (right - left) 42 | result[:, 1, 1] = 2 * near / (top - bottom) 43 | result[:, 0, 2] = (right + left) / (right - left) 44 | result[:, 1, 2] = (top + bottom) / (top - bottom) 45 | result[:, 3, 2] = 1 46 | result[:, 2, 2] = far / (far - near) 47 | result[:, 2, 3] = -(far * near) / (far - near) 48 | return result 49 | 50 | 51 | def render_cuda( 52 | extrinsics, 53 | intrinsics, 54 | near, 55 | far, 56 | image_shape, 57 | num_views, 58 | background_color, 59 | gaussian_means, 60 | gaussian_covariances, 61 | gaussians_rgb, 62 | gaussian_opacities, 63 | ): 64 | b, _, _ = extrinsics.shape 65 | h, w = image_shape 66 | 67 | update_intrinsics = intrinsics.clone() 68 | update_intrinsics[:, 0] *= 1 / w 69 | update_intrinsics[:, 1] *= 1 / h 70 | 71 | fov_x, fov_y = get_fov(update_intrinsics).unbind(dim=-1) 72 | tan_fov_x = (0.5 * fov_x).tan() 73 | tan_fov_y = (0.5 * fov_y).tan() 74 | 75 | projection_matrix = get_projection_matrix(near, far, fov_x, fov_y) 76 | projection_matrix = rearrange(projection_matrix, "b i j -> b j i") 77 | view_matrix = rearrange(extrinsics.inverse(), "b i j -> b j i") 78 | full_projection = view_matrix @ projection_matrix 79 | 80 | gaussian_means = repeat(gaussian_means, "b n c -> (b v) n c", v=num_views) 81 | gaussian_covariances = repeat(gaussian_covariances, "b n i j -> (b v) n i j", v=num_views) 82 | gaussians_rgb = repeat(gaussians_rgb, "b n c -> (b v) n c", v=num_views) 83 | gaussian_opacities = repeat(gaussian_opacities, "b n c -> (b v) n c", v=num_views) 84 | 85 | all_images = [] 86 | all_depths = [] 87 | for i in range(b): 88 | # Set up a tensor for the gradients of the screen-space means. 89 | mean_gradients = torch.zeros_like(gaussian_means[i], requires_grad=True) 90 | try: 91 | mean_gradients.retain_grad() 92 | except Exception: 93 | pass 94 | 95 | settings = GaussianRasterizationSettings( 96 | image_height=h, 97 | image_width=w, 98 | tanfovx=tan_fov_x[i].item(), 99 | tanfovy=tan_fov_y[i].item(), 100 | bg=background_color[i], 101 | scale_modifier=1.0, 102 | viewmatrix=view_matrix[i], 103 | projmatrix=full_projection[i], 104 | sh_degree=0, 105 | campos=extrinsics[i, :3, 3], 106 | prefiltered=False, # This matches the original usage. 107 | debug=False, 108 | ) 109 | rasterizer = GaussianRasterizer(settings) 110 | 111 | row, col = torch.triu_indices(3, 3) 112 | 113 | image, _, depth, _ = rasterizer( 114 | means3D=gaussian_means[i], 115 | means2D=mean_gradients, 116 | shs=None, 117 | colors_precomp=gaussians_rgb[i], 118 | opacities=gaussian_opacities[i], 119 | cov3D_precomp=gaussian_covariances[i, :, row, col], 120 | ) 121 | all_images.append(image) 122 | all_depths.append(depth) 123 | 124 | return torch.stack(all_images), torch.stack(all_depths) 125 | -------------------------------------------------------------------------------- /model/gsdecoder/decoder_splatting.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from einops import einsum, rearrange, repeat 3 | from torch import nn 4 | 5 | from .camera_embedding import meshgrid 6 | from .cuda_splatting import render_cuda 7 | 8 | # Reference: https://github.com/dcharatan/pixelsplat/blob/main/src/model/encoder/common/gaussian_adapter.py 9 | # We slightly modify the official code to fit in our setting. 10 | 11 | 12 | # https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/transforms/rotation_conversions.py 13 | def quaternion_to_matrix(quaternions, eps: float = 1e-8): 14 | # Order changed to match scipy format! 15 | i, j, k, r = torch.unbind(quaternions, dim=-1) 16 | two_s = 2 / ((quaternions * quaternions).sum(dim=-1) + eps) 17 | 18 | o = torch.stack( 19 | ( 20 | 1 - two_s * (j * j + k * k), 21 | two_s * (i * j - k * r), 22 | two_s * (i * k + j * r), 23 | two_s * (i * j + k * r), 24 | 1 - two_s * (i * i + k * k), 25 | two_s * (j * k - i * r), 26 | two_s * (i * k - j * r), 27 | two_s * (j * k + i * r), 28 | 1 - two_s * (i * i + j * j), 29 | ), 30 | -1, 31 | ) 32 | return rearrange(o, "... (i j) -> ... i j", i=3, j=3) 33 | 34 | 35 | def build_covariance(scale, rotation_xyzw): 36 | scale = scale.diag_embed() 37 | rotation = quaternion_to_matrix(rotation_xyzw) 38 | return rotation @ scale @ rearrange(scale, "... i j -> ... j i") @ rearrange(rotation, "... i j -> ... j i") 39 | 40 | 41 | class DecoderSplatting(nn.Module): 42 | def __init__(self): 43 | super().__init__() 44 | self.register_buffer("background_color", torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32), persistent=False) 45 | self.act_scale = nn.Softplus() 46 | self.act_rgb = nn.Softplus() 47 | 48 | def get_scale_multiplier(self, intrinsics): 49 | pixel_size = torch.ones((2,), dtype=torch.float32, device=intrinsics.device) 50 | xy_multipliers = einsum( 51 | intrinsics[..., :2, :2].inverse(), 52 | pixel_size, 53 | "... i j, j -> ... i", 54 | ) 55 | return xy_multipliers.sum(dim=-1) 56 | 57 | def gaussian_adapter(self, raw_gaussians, extrinsics, intrinsics, near=1, far=100): 58 | b, v, c, h, w = raw_gaussians.shape 59 | if len(intrinsics.shape) == 3: 60 | intrinsics = repeat(intrinsics, "b i j -> b v i j", v=v) 61 | extrinsics = repeat(extrinsics, "b v i j -> b v () () i j") 62 | intrinsics = repeat(intrinsics, "b v i j -> b v () () i j") 63 | raw_gaussians = rearrange(raw_gaussians, "b v c h w -> b v h w c") 64 | 65 | rgb, disp, opacity, scales, rotations, xy_offset = raw_gaussians.split((3, 1, 1, 3, 4, 2), dim=-1) 66 | 67 | # calculate xy_offset and origin/direction for each view. 68 | pixel_coords = meshgrid((w, h), normalized=False, indexing="xy", device=raw_gaussians.device) 69 | pixel_coords = repeat(pixel_coords, "h w c -> b v h w c", b=b, v=v) 70 | 71 | coordinates = pixel_coords + (xy_offset.sigmoid() - 0.5) 72 | coordinates = torch.cat([coordinates, torch.ones_like(coordinates[..., :1])], dim=-1) 73 | 74 | directions = einsum(intrinsics.inverse(), coordinates, "... i j, ... j -> ... i") 75 | directions = directions / directions.norm(dim=-1, keepdim=True) 76 | # directions = directions / directions[..., -1:] 77 | directions = torch.cat([directions, torch.zeros_like(directions[..., :1])], dim=-1) 78 | directions = einsum(extrinsics, directions, "... i j, ... j -> ... i") 79 | origins = extrinsics[..., -1].broadcast_to(directions.shape) 80 | 81 | # calculate depth from disparity 82 | depths = 1.0 / (disp.sigmoid() * (1.0 / near - 1.0 / far) + 1.0 / far) 83 | 84 | # calculate all parameters of gaussian splats 85 | means = origins + directions * depths 86 | 87 | multiplier = self.get_scale_multiplier(intrinsics) 88 | scales = self.act_scale(scales) * multiplier[..., None] 89 | 90 | if len(torch.where(scales > 0.05)[0]) > 0: 91 | big_gaussian_reg_loss = torch.mean(scales[torch.where(scales > 0.05)]) 92 | else: 93 | big_gaussian_reg_loss = 0 94 | 95 | if len(torch.where(scales < 1e-6)[0]) > 0: 96 | small_gaussian_reg_loss = torch.mean(-torch.log(scales[torch.where(scales < 1e-6)]) * 0.1) 97 | else: 98 | small_gaussian_reg_loss = 0 99 | 100 | rotations = rotations / (rotations.norm(dim=-1, keepdim=True) + 1e-8) 101 | covariances = build_covariance(scales, rotations) 102 | c2w_rotations = extrinsics[..., :3, :3] 103 | covariances = c2w_rotations @ covariances @ c2w_rotations.transpose(-1, -2) 104 | 105 | opacity = opacity.sigmoid() 106 | rgb = self.act_rgb(rgb) 107 | if self.training: 108 | return means, covariances, opacity, rgb, big_gaussian_reg_loss, small_gaussian_reg_loss 109 | else: 110 | return means, covariances, opacity, rgb, rotations, scales 111 | 112 | def forward(self, raw_gaussians, target_extrinsics, source_extrinsics, intrinsics, near=1, far=100): 113 | b, v = target_extrinsics.shape[:2] 114 | h, w = raw_gaussians.shape[-2:] 115 | 116 | means, covariance, opacity, rgb, b_loss, s_loss = self.gaussian_adapter( 117 | raw_gaussians, source_extrinsics, intrinsics, near, far 118 | ) 119 | lower_ext = torch.tensor([0.0, 0.0, 0.0, 1.0], device=target_extrinsics.device, dtype=torch.float32) 120 | lower_ext = repeat(lower_ext, "i -> b v () i", b=b, v=v) 121 | homo_extrinsics = torch.cat([target_extrinsics, lower_ext], dim=2) 122 | output_image, output_depth = render_cuda( 123 | rearrange(homo_extrinsics, "b v i j -> (b v) i j"), 124 | repeat(intrinsics, "b i j -> (b v) i j", v=v), 125 | near, 126 | far, 127 | (h, w), 128 | v, 129 | repeat(self.background_color, "i -> bv i", bv=b * v), 130 | rearrange(means, "b v h w xyz -> b (v h w) xyz"), 131 | rearrange(covariance, "b v h w i j -> b (v h w) i j"), 132 | rearrange(rgb, "b v h w c -> b (v h w) c"), 133 | rearrange(opacity, "b v h w o -> b (v h w) o"), 134 | ) 135 | return output_image, output_depth, b_loss, s_loss 136 | 137 | def eval_forward(self, means, covariance, opacity, rgb, target_extrinsics, target_intrinsics, near, far): 138 | # TODO: merge eval and training forward logics 139 | _, _, h, w, _ = means.shape 140 | b, target_view_num, _, _ = target_extrinsics.shape 141 | 142 | if len(target_intrinsics.shape) == 3: 143 | target_intrinsics = repeat(target_intrinsics, "b i j -> (b v) i j", v=target_view_num) 144 | else: 145 | target_intrinsics = repeat(target_intrinsics, "b v i j -> (b v) i j") 146 | 147 | lower_ext = torch.tensor([0.0, 0.0, 0.0, 1.0], device=target_extrinsics.device, dtype=torch.float32) 148 | lower_ext = repeat(lower_ext, "i -> b v () i", b=b, v=target_view_num) 149 | homo_extrinsics = torch.cat([target_extrinsics, lower_ext], dim=2) 150 | output_image, output_depth = render_cuda( 151 | rearrange(homo_extrinsics, "b v i j -> (b v) i j"), 152 | target_intrinsics, 153 | near, 154 | far, 155 | (h, w), 156 | target_view_num, 157 | repeat(self.background_color, "i -> bv i", bv=b * target_view_num), 158 | rearrange(means, "b v h w xyz -> b (v h w) xyz"), 159 | rearrange(covariance, "b v h w i j -> b (v h w) i j"), 160 | rearrange(rgb, "b v h w rgb -> b (v h w) rgb"), 161 | rearrange(opacity, "b v h w o -> b (v h w) o"), 162 | ) 163 | return output_image 164 | -------------------------------------------------------------------------------- /model/gsdecoder/gs_decoder_architecture.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from diffusers.models.attention_processor import Attention 7 | from diffusers.models.unets.unet_2d_blocks import UNetMidBlock2D, get_up_block 8 | from einops import rearrange, repeat 9 | 10 | from .camera_embedding import get_plucker_rays 11 | from .decoder_splatting import DecoderSplatting 12 | 13 | 14 | class AttnProcessor2_0_modified: 15 | r""" 16 | We copied the AttnProcessor2_0 class from the diffusers library 17 | """ 18 | 19 | def __init__(self): 20 | if not hasattr(F, "scaled_dot_product_attention"): 21 | raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") 22 | 23 | def __call__( 24 | self, 25 | attn: Attention, 26 | hidden_states: torch.Tensor, 27 | encoder_hidden_states: Optional[torch.Tensor] = None, 28 | attention_mask: Optional[torch.Tensor] = None, 29 | temb: Optional[torch.Tensor] = None, 30 | num_views: Optional[torch.Tensor] = None, 31 | *args, 32 | **kwargs, 33 | ) -> torch.Tensor: 34 | residual = hidden_states 35 | if attn.spatial_norm is not None: 36 | hidden_states = attn.spatial_norm(hidden_states, temb) 37 | 38 | input_ndim = hidden_states.ndim 39 | 40 | if input_ndim == 4: 41 | batch_size, channel, height, width = hidden_states.shape 42 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) 43 | 44 | batch_size, sequence_length, _ = ( 45 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape 46 | ) 47 | 48 | if attention_mask is not None: 49 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) 50 | # scaled_dot_product_attention expects attention_mask shape to be 51 | # (batch, heads, source_length, target_length) 52 | attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) 53 | 54 | if attn.group_norm is not None: 55 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) 56 | 57 | hidden_states = rearrange(hidden_states, "(b v) t c -> b (v t) c", v=num_views) 58 | batch_size = hidden_states.shape[0] 59 | query = attn.to_q(hidden_states) 60 | 61 | if encoder_hidden_states is None: 62 | encoder_hidden_states = hidden_states 63 | elif attn.norm_cross: 64 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) 65 | 66 | key = attn.to_k(encoder_hidden_states) 67 | value = attn.to_v(encoder_hidden_states) 68 | 69 | inner_dim = key.shape[-1] 70 | head_dim = inner_dim // attn.heads 71 | 72 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 73 | 74 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 75 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 76 | 77 | # the output of sdp = (batch, num_heads, seq_len, head_dim) 78 | # TODO: add support for attn.scale when we move to Torch 2.1 79 | hidden_states = F.scaled_dot_product_attention( 80 | query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False 81 | ) 82 | 83 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) 84 | hidden_states = hidden_states.to(query.dtype) 85 | 86 | # linear proj 87 | hidden_states = attn.to_out[0](hidden_states) 88 | # dropout 89 | hidden_states = attn.to_out[1](hidden_states) 90 | 91 | hidden_states = rearrange(hidden_states, "b (v t) c -> (b v) t c", v=num_views) 92 | batch_size = hidden_states.shape[0] 93 | 94 | if input_ndim == 4: 95 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) 96 | 97 | if attn.residual_connection: 98 | hidden_states = hidden_states + residual 99 | 100 | hidden_states = hidden_states / attn.rescale_output_factor 101 | 102 | return hidden_states 103 | 104 | 105 | class GSDecoder(nn.Module): 106 | def __init__( 107 | self, 108 | in_channels: int = 41, 109 | out_channels: int = 14, 110 | up_block_types: Tuple[str, ...] = ( 111 | "UpDecoderBlock2D", 112 | "UpDecoderBlock2D", 113 | "UpDecoderBlock2D", 114 | "UpDecoderBlock2D", 115 | ), 116 | block_out_channels: Tuple[int, ...] = (128, 256, 512, 512), 117 | layers_per_block: int = 2, 118 | norm_num_groups: int = 32, 119 | cfg=None, 120 | ): 121 | super(GSDecoder, self).__init__() 122 | self.cfg = cfg 123 | self.conv_in = nn.Conv2d( 124 | in_channels, 125 | block_out_channels[-1], 126 | kernel_size=3, 127 | stride=1, 128 | padding=1, 129 | ) 130 | 131 | self.gn = nn.GroupNorm(norm_num_groups, block_out_channels[-1], affine=True) 132 | self.attn_block_pre = nn.ModuleList( 133 | [ 134 | Attention( 135 | block_out_channels[-1], 136 | heads=1, 137 | dim_head=block_out_channels[-1], 138 | rescale_output_factor=1, 139 | eps=1e-6, 140 | norm_num_groups=32, 141 | spatial_norm_dim=None, 142 | residual_connection=True, 143 | processor=AttnProcessor2_0_modified(), 144 | bias=True, 145 | upcast_softmax=True, 146 | _from_deprecated_attn_block=True, 147 | dropout=0.0, 148 | ) 149 | for _ in range(3) 150 | ] 151 | ) 152 | 153 | self.mid_block = UNetMidBlock2D( 154 | in_channels=block_out_channels[-1], 155 | resnet_eps=1e-6, 156 | resnet_act_fn="silu", 157 | output_scale_factor=1, 158 | resnet_time_scale_shift="default", 159 | attention_head_dim=block_out_channels[-1], 160 | resnet_groups=norm_num_groups, 161 | temb_channels=None, 162 | add_attention=False, 163 | ) 164 | 165 | self.attn_block_post = nn.ModuleList( 166 | [ 167 | Attention( 168 | block_out_channels[-1], 169 | heads=1, 170 | dim_head=block_out_channels[-1], 171 | rescale_output_factor=1, 172 | eps=1e-6, 173 | norm_num_groups=32, 174 | spatial_norm_dim=None, 175 | residual_connection=True, 176 | processor=AttnProcessor2_0_modified(), 177 | bias=True, 178 | upcast_softmax=True, 179 | _from_deprecated_attn_block=True, 180 | dropout=0.0, 181 | ) 182 | for _ in range(3) 183 | ] 184 | ) 185 | 186 | # up 187 | self.up_blocks = nn.ModuleList([]) 188 | reversed_block_out_channels = list(reversed(block_out_channels)) 189 | output_channel = reversed_block_out_channels[0] 190 | for i, up_block_type in enumerate(up_block_types): 191 | prev_output_channel = output_channel 192 | output_channel = reversed_block_out_channels[i] 193 | 194 | is_final_block = i == len(block_out_channels) - 1 195 | 196 | up_block = get_up_block( 197 | up_block_type, 198 | num_layers=layers_per_block + 1, 199 | in_channels=prev_output_channel, 200 | out_channels=output_channel, 201 | prev_output_channel=None, 202 | add_upsample=not is_final_block, 203 | resnet_eps=1e-6, 204 | resnet_act_fn="silu", 205 | resnet_groups=norm_num_groups, 206 | attention_head_dim=output_channel, 207 | temb_channels=None, 208 | resnet_time_scale_shift="group", 209 | ) 210 | self.up_blocks.append(up_block) 211 | 212 | self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6) 213 | self.conv_act = nn.SiLU() 214 | self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1) 215 | 216 | self.splatter = DecoderSplatting() 217 | 218 | def forward( 219 | self, latent_image, latent_depth, target_extrinsics, source_extrinsics, intrinsics, near, far, plucker_ray=None 220 | ): 221 | # calculate plucker rays 222 | lh, lw = latent_image.shape[-2:] 223 | if plucker_ray is None: 224 | ray_embeddings = get_plucker_rays(source_extrinsics, intrinsics, h=lh, w=lw, stride=8) 225 | else: 226 | ray_embeddings = plucker_ray 227 | 228 | if self.training: 229 | ray_embeddings += torch.randn_like(ray_embeddings) * 0.0001 # simulating some noises 230 | 231 | if latent_depth is not None: 232 | x = torch.cat([latent_image, latent_depth, ray_embeddings], dim=2) # channels: 16 + 16 + 9 233 | else: 234 | x = torch.cat([latent_image, ray_embeddings], dim=2) 235 | 236 | b, v, _, h, w = x.size() 237 | x = rearrange(x, "b v c h w -> (b v) c h w") # TODO: patchify? 238 | x = self.conv_in(x) 239 | x = self.gn(x) 240 | 241 | for attn_block in self.attn_block_pre: 242 | x = attn_block(x, num_views=v) 243 | 244 | x = self.mid_block(x) 245 | 246 | for attn_block in self.attn_block_post: 247 | x = attn_block(x, num_views=v) 248 | 249 | # up 250 | for up_block in self.up_blocks: 251 | x = up_block(x) 252 | 253 | x = self.conv_norm_out(x) 254 | x = self.conv_act(x) 255 | x = self.conv_out(x) 256 | 257 | x = rearrange(x, "(b v) c h w -> b v c h w", b=b, v=v) # (b v c_out h w) 258 | 259 | with torch.amp.autocast("cuda", enabled=False): 260 | if self.training: 261 | x_img, x_depth, b_loss, s_loss = self.splatter( 262 | x.float(), target_extrinsics, source_extrinsics, intrinsics, near, far 263 | ) 264 | return x_img, repeat(x_depth, "b c h w -> b (c r) h w", r=3), b_loss, s_loss 265 | else: 266 | x = self.splatter.gaussian_adapter( 267 | x.float(), source_extrinsics, intrinsics, near, far 268 | ) # return encoded gaussian primitives 269 | 270 | return x 271 | 272 | def eval_rendering(self, means, covariance, opacity, rgb, target_extrinsics, target_intrinsics, near, far): 273 | # TODO: I should merge this eval rendering loop for this 274 | return self.splatter.eval_forward( 275 | means, covariance, opacity, rgb, target_extrinsics, target_intrinsics, near, far 276 | ) 277 | -------------------------------------------------------------------------------- /model/gsdecoder/load_gsdecoder.py: -------------------------------------------------------------------------------- 1 | from model.gsdecoder.gs_decoder_architecture import GSDecoder 2 | 3 | 4 | def create_gsdecoder(cfg): 5 | decoder = GSDecoder(cfg=cfg) 6 | return decoder 7 | -------------------------------------------------------------------------------- /model/multiview_rf/load_mv_sd3.py: -------------------------------------------------------------------------------- 1 | from argparse import Namespace 2 | 3 | from transformers import CLIPTokenizer, PretrainedConfig, T5TokenizerFast 4 | 5 | from model.multiview_rf.mv_sd3_architecture import MultiViewSD3Transformer 6 | 7 | 8 | def import_model_class_from_model_name_or_path( 9 | pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder" 10 | ): 11 | text_encoder_config = PretrainedConfig.from_pretrained( 12 | pretrained_model_name_or_path, 13 | subfolder=subfolder, 14 | revision=revision, 15 | ) 16 | model_class = text_encoder_config.architectures[0] 17 | if model_class == "CLIPTextModelWithProjection": 18 | from transformers import CLIPTextModelWithProjection 19 | 20 | return CLIPTextModelWithProjection 21 | elif model_class == "T5EncoderModel": 22 | from transformers import T5EncoderModel 23 | 24 | return T5EncoderModel 25 | else: 26 | raise ValueError(f"{model_class} is not supported.") 27 | 28 | 29 | # Load a sd multi-view rf model 30 | def create_sd_multiview_rf_model(num_input_output_channel=38): 31 | args = Namespace() 32 | args.pretrained_model_name_or_path = "stabilityai/stable-diffusion-3-medium-diffusers" 33 | args.revision = None 34 | args.variant = None 35 | 36 | # rf model 37 | rf_transformer = MultiViewSD3Transformer.from_pretrained( 38 | "stabilityai/stable-diffusion-3-medium-diffusers", 39 | subfolder="transformer", 40 | ignore_mismatched_sizes=True, 41 | strict=False, 42 | low_cpu_mem_usage=False, 43 | ) 44 | rf_transformer.adjust_output_input_channel_size(num_input_output_channel) 45 | 46 | # tokenizers 47 | tokenizer_one = CLIPTokenizer.from_pretrained( 48 | args.pretrained_model_name_or_path, 49 | subfolder="tokenizer", 50 | revision=args.revision, 51 | low_cpu_mem_usage=True, 52 | ) 53 | tokenizer_two = CLIPTokenizer.from_pretrained( 54 | args.pretrained_model_name_or_path, 55 | subfolder="tokenizer_2", 56 | revision=args.revision, 57 | low_cpu_mem_usage=True, 58 | ) 59 | tokenizer_three = T5TokenizerFast.from_pretrained( 60 | args.pretrained_model_name_or_path, 61 | subfolder="tokenizer_3", 62 | revision=args.revision, 63 | low_cpu_mem_usage=True, 64 | ) 65 | 66 | # Text encoders 67 | text_encoder_cls_one = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision) 68 | text_encoder_cls_two = import_model_class_from_model_name_or_path( 69 | args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2" 70 | ) 71 | text_encoder_cls_three = import_model_class_from_model_name_or_path( 72 | args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_3" 73 | ) 74 | 75 | def load_text_encoders(class_one, class_two, class_three): 76 | text_encoder_one = class_one.from_pretrained( 77 | args.pretrained_model_name_or_path, 78 | subfolder="text_encoder", 79 | revision=args.revision, 80 | variant=args.variant, 81 | low_cpu_mem_usage=True, 82 | torch_dtype="auto", 83 | ) 84 | text_encoder_two = class_two.from_pretrained( 85 | args.pretrained_model_name_or_path, 86 | subfolder="text_encoder_2", 87 | revision=args.revision, 88 | variant=args.variant, 89 | low_cpu_mem_usage=True, 90 | torch_dtype="auto", 91 | ) 92 | text_encoder_three = class_three.from_pretrained( 93 | args.pretrained_model_name_or_path, 94 | subfolder="text_encoder_3", 95 | revision=args.revision, 96 | variant=args.variant, 97 | low_cpu_mem_usage=True, 98 | torch_dtype="auto", 99 | ) 100 | return text_encoder_one, text_encoder_two, text_encoder_three 101 | 102 | text_encoder_one, text_encoder_two, text_encoder_three = load_text_encoders( 103 | text_encoder_cls_one, text_encoder_cls_two, text_encoder_cls_three 104 | ) 105 | tokenizers = [tokenizer_one, tokenizer_two, tokenizer_three] 106 | text_encoders = [text_encoder_one, text_encoder_two, text_encoder_three] 107 | 108 | return rf_transformer, tokenizers, text_encoders 109 | -------------------------------------------------------------------------------- /model/multiview_rf/mv_sd3_architecture.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Optional, Union 2 | 3 | import torch 4 | import torch.nn as nn 5 | from diffusers.configuration_utils import ConfigMixin, register_to_config 6 | from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin 7 | from diffusers.models.attention import JointTransformerBlock, _chunked_feed_forward 8 | from diffusers.models.attention_processor import Attention, AttentionProcessor 9 | from diffusers.models.embeddings import ( 10 | CombinedTimestepTextProjEmbeddings, 11 | PatchEmbed, 12 | get_3d_sincos_pos_embed, 13 | ) 14 | from diffusers.models.modeling_outputs import Transformer2DModelOutput 15 | from diffusers.models.modeling_utils import ModelMixin 16 | from diffusers.models.normalization import AdaLayerNormContinuous 17 | from diffusers.utils import ( 18 | USE_PEFT_BACKEND, 19 | is_torch_version, 20 | logging, 21 | scale_lora_layers, 22 | unscale_lora_layers, 23 | ) 24 | from einops import rearrange 25 | 26 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 27 | 28 | 29 | class JointTransformerBlockMultiView(JointTransformerBlock): 30 | def __init__(self, num_views, *args, **kwargs): 31 | super().__init__(*args, **kwargs) 32 | self.num_views = num_views 33 | 34 | def freeze(self): 35 | for param in self.parameters(): 36 | param.requires_grad = False 37 | 38 | def unfreeze_view(self): 39 | for param in self.view_layernorm.parameters(): 40 | param.requires_grad = True 41 | for param in self.view_adaln.parameters(): 42 | param.requires_grad = True 43 | for param in self.view_attn.parameters(): 44 | param.requires_grad = True 45 | for param in self.view_fc.parameters(): 46 | param.requires_grad = True 47 | 48 | def forward( 49 | self, 50 | hidden_states: torch.FloatTensor, 51 | encoder_hidden_states: torch.FloatTensor, 52 | temb: torch.FloatTensor, 53 | ): 54 | norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) 55 | 56 | # Normal block 57 | if self.context_pre_only: 58 | norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb) 59 | else: 60 | norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( 61 | encoder_hidden_states, emb=temb 62 | ) 63 | 64 | K, N, M = hidden_states.shape 65 | B = K // self.num_views 66 | norm_hidden_states = rearrange(norm_hidden_states, "(b v) n m -> b (v n) m", b=B, v=self.num_views, n=N, m=M) 67 | norm_encoder_hidden_states = rearrange( 68 | norm_encoder_hidden_states, "(b v) n m -> b (v n) m", b=B, v=self.num_views 69 | ) 70 | # Attention. 71 | attn_output, context_attn_output = self.attn( 72 | hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states 73 | ) 74 | attn_output = rearrange(attn_output, "b (v n) m -> (b v) n m", b=B, v=self.num_views) 75 | context_attn_output = rearrange(context_attn_output, "b (v n) m -> (b v) n m", b=B, v=self.num_views) 76 | 77 | # Process attention outputs for the `hidden_states`. 78 | attn_output = gate_msa.unsqueeze(1) * attn_output 79 | hidden_states = hidden_states + attn_output 80 | norm_hidden_states = self.norm2(hidden_states) 81 | norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] 82 | if self._chunk_size is not None: 83 | # "feed_forward_chunk_size" can be used to save memory 84 | ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) 85 | else: 86 | ff_output = self.ff(norm_hidden_states) 87 | ff_output = gate_mlp.unsqueeze(1) * ff_output 88 | hidden_states = hidden_states + ff_output 89 | 90 | # Process attention outputs for the `encoder_hidden_states`. 91 | if self.context_pre_only: 92 | encoder_hidden_states = None 93 | else: 94 | context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output 95 | encoder_hidden_states = encoder_hidden_states + context_attn_output 96 | 97 | norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) 98 | norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] 99 | if self._chunk_size is not None: 100 | # "feed_forward_chunk_size" can be used to save memory 101 | context_ff_output = _chunked_feed_forward( 102 | self.ff_context, norm_encoder_hidden_states, self._chunk_dim, self._chunk_size 103 | ) 104 | else: 105 | context_ff_output = self.ff_context(norm_encoder_hidden_states) 106 | encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output 107 | 108 | return encoder_hidden_states, hidden_states 109 | 110 | 111 | class PatchEmbedMultiView(PatchEmbed): 112 | def __init__( 113 | self, 114 | height: int = 224, 115 | width: int = 224, 116 | patch_size=16, 117 | in_channels=3, 118 | embed_dim=768, 119 | layer_norm=False, 120 | flatten=True, 121 | bias=True, 122 | interpolation_scale=1, 123 | pos_embed_type="sincos", 124 | pos_embed_max_size=None, # For SD3 cropping 125 | num_views=8, 126 | ): 127 | super().__init__( 128 | height=height, 129 | width=width, 130 | patch_size=patch_size, 131 | in_channels=in_channels, 132 | embed_dim=embed_dim, 133 | layer_norm=layer_norm, 134 | flatten=flatten, 135 | bias=bias, 136 | interpolation_scale=interpolation_scale, 137 | pos_embed_type=pos_embed_type, 138 | pos_embed_max_size=pos_embed_max_size, 139 | ) 140 | self.num_views = num_views 141 | pos_embed = get_3d_sincos_pos_embed( 142 | embed_dim=embed_dim, 143 | spatial_size=(16, 16), 144 | temporal_size=num_views, 145 | spatial_interpolation_scale=1.0, 146 | temporal_interpolation_scale=1.0, 147 | ) 148 | self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0)) 149 | 150 | def forward(self, latent): 151 | latent = self.proj(latent) 152 | if self.flatten: 153 | latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC 154 | if self.layer_norm: 155 | latent = self.norm(latent) 156 | if self.pos_embed is None: 157 | return latent.to(latent.dtype) 158 | pos_embed = self.pos_embed 159 | b = latent.shape[0] // self.num_views 160 | view = self.num_views 161 | latent = rearrange(latent, "(b v) n m -> b v n m", b=b, v=view) 162 | latent = latent + pos_embed.to(latent.dtype) 163 | latent = rearrange(latent, "b v n m -> (b v) n m", b=b, v=view) 164 | return latent 165 | 166 | 167 | class MultiViewSD3Transformer(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): 168 | """ 169 | The `MultiViewSD3Transformer` model is a multi-view extension of the `StableDiffusion3` model. The model is 170 | """ 171 | 172 | _supports_gradient_checkpointing = True 173 | 174 | @register_to_config 175 | def __init__( 176 | self, 177 | sample_size: int = 128, 178 | patch_size: int = 2, 179 | in_channels: int = 16, 180 | num_layers: int = 18, 181 | attention_head_dim: int = 64, 182 | num_attention_heads: int = 18, 183 | joint_attention_dim: int = 4096, 184 | caption_projection_dim: int = 1152, 185 | pooled_projection_dim: int = 2048, 186 | out_channels: int = 16, 187 | pos_embed_max_size: int = 96, 188 | num_views: int = 8, 189 | ): 190 | super().__init__() 191 | default_out_channels = in_channels 192 | self.out_channels = out_channels if out_channels is not None else default_out_channels 193 | self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim 194 | 195 | self.pos_embed = PatchEmbedMultiView( 196 | height=self.config.sample_size, 197 | width=self.config.sample_size, 198 | patch_size=self.config.patch_size, 199 | in_channels=16, 200 | embed_dim=self.inner_dim, 201 | pos_embed_max_size=pos_embed_max_size, # hard-code for now. 202 | num_views=num_views, 203 | ) 204 | self.time_text_embed = CombinedTimestepTextProjEmbeddings( 205 | embedding_dim=self.inner_dim, pooled_projection_dim=self.config.pooled_projection_dim 206 | ) 207 | 208 | self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.config.caption_projection_dim) 209 | 210 | # `attention_head_dim` is doubled to account for the mixing. 211 | # It needs to crafted when we get the actual checkpoints. 212 | self.transformer_blocks = nn.ModuleList( 213 | [ 214 | JointTransformerBlockMultiView( 215 | dim=self.inner_dim, 216 | num_attention_heads=self.config.num_attention_heads, 217 | attention_head_dim=self.inner_dim, 218 | context_pre_only=i == num_layers - 1, 219 | num_views=num_views, 220 | ) 221 | for i in range(self.config.num_layers) 222 | ] 223 | ) 224 | self.patch_size = patch_size 225 | self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6) 226 | self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) 227 | self.gradient_checkpointing = False 228 | 229 | def freeze(self): 230 | for param in self.parameters(): 231 | param.requires_grad = False 232 | 233 | def unfreeze_view(self): 234 | for blk in self.transformer_blocks: 235 | for param in blk.parameters(): 236 | param.requires_grad = True 237 | 238 | def adjust_output_input_channel_size(self, new_in_channels: int): 239 | self.config.in_channels = new_in_channels 240 | self.out_channels = new_in_channels 241 | old_conv_layer = self.pos_embed.proj 242 | 243 | # Calculate scaling factor 244 | scaling_factor_conv = (old_conv_layer.in_channels / new_in_channels) ** 0.5 245 | 246 | # Create a new convolutional layer with the desired number of input channels 247 | new_conv_layer = nn.Conv2d( 248 | in_channels=new_in_channels, 249 | out_channels=old_conv_layer.out_channels, 250 | kernel_size=old_conv_layer.kernel_size, 251 | stride=old_conv_layer.stride, 252 | padding=old_conv_layer.padding, 253 | dilation=old_conv_layer.dilation, 254 | groups=old_conv_layer.groups, 255 | bias=old_conv_layer.bias is not None, 256 | ) 257 | 258 | with torch.no_grad(): 259 | channels_to_copy = min(old_conv_layer.in_channels, new_in_channels) 260 | 261 | # Copy existing weights 262 | new_conv_layer.weight.data[:, :channels_to_copy, :, :] = old_conv_layer.weight.data[ 263 | :, :channels_to_copy, :, : 264 | ] 265 | 266 | # Copy existing weights to new input channels via modulo indexing 267 | for i in range(channels_to_copy, new_in_channels): 268 | idx = i % old_conv_layer.in_channels 269 | new_conv_layer.weight.data[:, i : i + 1, :, :] = old_conv_layer.weight.data[:, idx : idx + 1, :, :] 270 | 271 | # Scale the weights 272 | new_conv_layer.weight.mul_(scaling_factor_conv) 273 | 274 | # Copy bias if it exists 275 | if old_conv_layer.bias is not None: 276 | new_conv_layer.bias.data = old_conv_layer.bias.data 277 | 278 | # Replace the old convolutional layer with the new one 279 | self.pos_embed.proj = new_conv_layer 280 | 281 | # Output layer modification 282 | old_linear_layer = self.proj_out # Get the original Linear layer 283 | 284 | # Calculate the new output features for the Linear layer 285 | new_out_features = self.patch_size * self.patch_size * new_in_channels 286 | 287 | # Calculate scaling factor for the linear layer 288 | scaling_factor_linear = (old_linear_layer.out_features / new_out_features) ** 0.5 289 | 290 | # Create a new Linear layer with the desired output channels 291 | new_linear_layer = nn.Linear( 292 | old_linear_layer.in_features, # Keep the input features the same 293 | new_out_features, # New number of output features 294 | bias=old_linear_layer.bias is not None, 295 | ) 296 | 297 | with torch.no_grad(): 298 | features_to_copy = min(old_linear_layer.out_features, new_out_features) 299 | 300 | # Copy existing weights 301 | new_linear_layer.weight.data[:features_to_copy, :] = old_linear_layer.weight.data[:features_to_copy, :] 302 | 303 | # Copy existing weights to new outputs via modulo indexing 304 | for i in range(features_to_copy, new_out_features): 305 | idx = i % old_linear_layer.out_features 306 | new_linear_layer.weight.data[i, :] = old_linear_layer.weight.data[idx, :] 307 | 308 | # Scale the weights 309 | new_linear_layer.weight.mul_(scaling_factor_linear) 310 | 311 | # Copy existing biases 312 | if old_linear_layer.bias is not None: 313 | new_linear_layer.bias.data[:features_to_copy] = old_linear_layer.bias.data[:features_to_copy] 314 | 315 | # Copy biases via modulo indexing for new outputs 316 | for i in range(features_to_copy, new_out_features): 317 | idx = i % old_linear_layer.out_features 318 | new_linear_layer.bias.data[i] = old_linear_layer.bias.data[idx] 319 | 320 | # Replace the old Linear layer with the new one 321 | self.proj_out = new_linear_layer 322 | 323 | # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking 324 | def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: 325 | """ 326 | Sets the attention processor to use [feed forward 327 | chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers). 328 | 329 | Parameters: 330 | chunk_size (`int`, *optional*): 331 | The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually 332 | over each tensor of dim=`dim`. 333 | dim (`int`, *optional*, defaults to `0`): 334 | The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch) 335 | or dim=1 (sequence length). 336 | """ 337 | if dim not in [0, 1]: 338 | raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}") 339 | 340 | # By default chunk size is 1 341 | chunk_size = chunk_size or 1 342 | 343 | def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): 344 | if hasattr(module, "set_chunk_feed_forward"): 345 | module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) 346 | 347 | for child in module.children(): 348 | fn_recursive_feed_forward(child, chunk_size, dim) 349 | 350 | for module in self.children(): 351 | fn_recursive_feed_forward(module, chunk_size, dim) 352 | 353 | @property 354 | # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors 355 | def attn_processors(self) -> Dict[str, AttentionProcessor]: 356 | r""" 357 | Returns: 358 | `dict` of attention processors: A dictionary containing all attention processors used in the model with 359 | indexed by its weight name. 360 | """ 361 | # set recursively 362 | processors = {} 363 | 364 | def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): 365 | if hasattr(module, "get_processor"): 366 | processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True) 367 | 368 | for sub_name, child in module.named_children(): 369 | fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) 370 | 371 | return processors 372 | 373 | for name, module in self.named_children(): 374 | fn_recursive_add_processors(name, module, processors) 375 | 376 | return processors 377 | 378 | # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor 379 | def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): 380 | r""" 381 | Sets the attention processor to use to compute attention. 382 | 383 | Parameters: 384 | processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): 385 | The instantiated processor class or a dictionary of processor classes that will be set as the processor 386 | for **all** `Attention` layers. 387 | 388 | If `processor` is a dict, the key needs to define the path to the corresponding cross attention 389 | processor. This is strongly recommended when setting trainable attention processors. 390 | 391 | """ 392 | count = len(self.attn_processors.keys()) 393 | 394 | if isinstance(processor, dict) and len(processor) != count: 395 | raise ValueError( 396 | f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" 397 | f" number of attention layers: {count}. Please make sure to pass {count} processor classes." 398 | ) 399 | 400 | def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): 401 | if hasattr(module, "set_processor"): 402 | if not isinstance(processor, dict): 403 | module.set_processor(processor) 404 | else: 405 | module.set_processor(processor.pop(f"{name}.processor")) 406 | 407 | for sub_name, child in module.named_children(): 408 | fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) 409 | 410 | for name, module in self.named_children(): 411 | fn_recursive_attn_processor(name, module, processor) 412 | 413 | # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections 414 | def fuse_qkv_projections(self): 415 | """ 416 | Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) 417 | are fused. For cross-attention modules, key and value projection matrices are fused. 418 | 419 | 420 | 421 | This API is 🧪 experimental. 422 | 423 | 424 | """ 425 | self.original_attn_processors = None 426 | 427 | for _, attn_processor in self.attn_processors.items(): 428 | if "Added" in str(attn_processor.__class__.__name__): 429 | raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") 430 | 431 | self.original_attn_processors = self.attn_processors 432 | 433 | for module in self.modules(): 434 | if isinstance(module, Attention): 435 | module.fuse_projections(fuse=True) 436 | 437 | # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections 438 | def unfuse_qkv_projections(self): 439 | """Disables the fused QKV projection if enabled. 440 | 441 | 442 | 443 | This API is 🧪 experimental. 444 | 445 | 446 | 447 | """ 448 | if self.original_attn_processors is not None: 449 | self.set_attn_processor(self.original_attn_processors) 450 | 451 | def _set_gradient_checkpointing(self, module, value=False): 452 | if hasattr(module, "gradient_checkpointing"): 453 | module.gradient_checkpointing = value 454 | 455 | def forward( 456 | self, 457 | hidden_states: torch.FloatTensor, 458 | encoder_hidden_states: torch.FloatTensor = None, 459 | pooled_projections: torch.FloatTensor = None, 460 | timestep: torch.LongTensor = None, 461 | block_controlnet_hidden_states: List = None, 462 | joint_attention_kwargs: Optional[Dict[str, Any]] = None, 463 | return_dict: bool = True, 464 | ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: 465 | """ 466 | The [`SD3Transformer2DModel`] forward method. 467 | 468 | Args: 469 | hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`): 470 | Input `hidden_states`. 471 | encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`): 472 | Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. 473 | pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected 474 | from the embeddings of input conditions. 475 | timestep ( `torch.LongTensor`): 476 | Used to indicate denoising step. 477 | block_controlnet_hidden_states: (`list` of `torch.Tensor`): 478 | A list of tensors that if specified are added to the residuals of transformer blocks. 479 | joint_attention_kwargs (`dict`, *optional*): 480 | A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under 481 | `self.processor` in 482 | [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). 483 | return_dict (`bool`, *optional*, defaults to `True`): 484 | Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain 485 | tuple. 486 | 487 | Returns: 488 | If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a 489 | `tuple` where the first element is the sample tensor. 490 | """ 491 | if joint_attention_kwargs is not None: 492 | joint_attention_kwargs = joint_attention_kwargs.copy() 493 | lora_scale = joint_attention_kwargs.pop("scale", 1.0) 494 | else: 495 | lora_scale = 1.0 496 | 497 | if USE_PEFT_BACKEND: 498 | # weight the lora layers by setting `lora_scale` for each PEFT layer 499 | scale_lora_layers(self, lora_scale) 500 | else: 501 | if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None: 502 | logger.warning( 503 | "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." 504 | ) 505 | 506 | b, view, channel, height, width = hidden_states.shape 507 | # 2Dfy 508 | hidden_states = hidden_states.reshape(b * view, channel, height, width) 509 | 510 | hidden_states = self.pos_embed(hidden_states) 511 | 512 | temb = self.time_text_embed(timestep, pooled_projections) 513 | encoder_hidden_states = self.context_embedder(encoder_hidden_states) 514 | 515 | # Multi-view Patch 516 | temb = temb.unsqueeze(1).repeat(1, view, 1) 517 | encoder_hidden_states = encoder_hidden_states.unsqueeze(1).repeat(1, view, 1, 1) 518 | 519 | # 2dfy 520 | temb = temb.reshape(b * view, temb.shape[-1]) 521 | encoder_hidden_states = encoder_hidden_states.reshape( 522 | b * view, encoder_hidden_states.shape[-2], encoder_hidden_states.shape[-1] 523 | ) 524 | 525 | for index_block, block in enumerate(self.transformer_blocks): 526 | if self.training and self.gradient_checkpointing: 527 | 528 | def create_custom_forward(module, return_dict=None): 529 | def custom_forward(*inputs): 530 | if return_dict is not None: 531 | return module(*inputs, return_dict=return_dict) 532 | else: 533 | return module(*inputs) 534 | 535 | return custom_forward 536 | 537 | ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} 538 | encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( 539 | create_custom_forward(block), 540 | hidden_states, 541 | encoder_hidden_states, 542 | temb, 543 | **ckpt_kwargs, 544 | ) 545 | 546 | else: 547 | encoder_hidden_states, hidden_states = block( 548 | hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb 549 | ) 550 | 551 | # controlnet residual 552 | if block_controlnet_hidden_states is not None and block.context_pre_only is False: 553 | interval_control = len(self.transformer_blocks) // len(block_controlnet_hidden_states) 554 | hidden_states = hidden_states + block_controlnet_hidden_states[index_block // interval_control] 555 | 556 | hidden_states = self.norm_out(hidden_states, temb) 557 | hidden_states = self.proj_out(hidden_states) 558 | 559 | # unpatchify 560 | patch_size = self.config.patch_size 561 | height = height // patch_size 562 | width = width // patch_size 563 | 564 | hidden_states = hidden_states.reshape( 565 | shape=(hidden_states.shape[0], height, width, patch_size, patch_size, self.out_channels) 566 | ) 567 | hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) 568 | output = hidden_states.reshape( 569 | shape=(hidden_states.shape[0], self.out_channels, height * patch_size, width * patch_size) 570 | ) 571 | 572 | output = output.reshape(b, view, self.out_channels, height * patch_size, width * patch_size) 573 | if USE_PEFT_BACKEND: 574 | # remove `lora_scale` from each PEFT layer 575 | unscale_lora_layers(self, lora_scale) 576 | 577 | if not return_dict: 578 | return (output,) 579 | 580 | return Transformer2DModelOutput(sample=output) 581 | -------------------------------------------------------------------------------- /model/multiview_rf/text_embedding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def tokenize_prompt(tokenizer, prompt): 5 | text_inputs = tokenizer( 6 | prompt, 7 | padding="max_length", 8 | max_length=77, 9 | truncation=True, 10 | return_tensors="pt", 11 | ) 12 | text_input_ids = text_inputs.input_ids 13 | return text_input_ids 14 | 15 | 16 | def _encode_prompt_with_t5( 17 | text_encoder, 18 | tokenizer, 19 | max_sequence_length, 20 | prompt=None, 21 | num_images_per_prompt=1, 22 | device=None, 23 | ): 24 | prompt = [prompt] if isinstance(prompt, str) else prompt 25 | batch_size = len(prompt) 26 | 27 | text_inputs = tokenizer( 28 | prompt, 29 | padding="max_length", 30 | max_length=max_sequence_length, 31 | truncation=True, 32 | add_special_tokens=True, 33 | return_tensors="pt", 34 | ) 35 | text_input_ids = text_inputs.input_ids 36 | prompt_embeds = text_encoder(text_input_ids.to(device))[0] 37 | 38 | dtype = text_encoder.dtype 39 | prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) 40 | 41 | _, seq_len, _ = prompt_embeds.shape 42 | 43 | # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method 44 | prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) 45 | prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) 46 | 47 | return prompt_embeds 48 | 49 | 50 | def _encode_prompt_with_clip( 51 | text_encoder, 52 | tokenizer, 53 | prompt: str, 54 | device=None, 55 | num_images_per_prompt: int = 1, 56 | ): 57 | prompt = [prompt] if isinstance(prompt, str) else prompt 58 | batch_size = len(prompt) 59 | 60 | text_inputs = tokenizer( 61 | prompt, 62 | padding="max_length", 63 | max_length=77, 64 | truncation=True, 65 | return_tensors="pt", 66 | ) 67 | 68 | text_input_ids = text_inputs.input_ids 69 | prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) 70 | 71 | pooled_prompt_embeds = prompt_embeds[0] 72 | prompt_embeds = prompt_embeds.hidden_states[-2] 73 | prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device) 74 | 75 | _, seq_len, _ = prompt_embeds.shape 76 | 77 | # duplicate text embeddings for each generation per prompt, using mps friendly method 78 | prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) 79 | prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) 80 | 81 | pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1) 82 | pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1) 83 | 84 | return prompt_embeds, pooled_prompt_embeds 85 | 86 | 87 | def encode_prompt( 88 | text_encoders, 89 | tokenizers, 90 | prompt: str, 91 | max_sequence_length, 92 | device=None, 93 | num_images_per_prompt: int = 1, 94 | ): 95 | prompt = [prompt] if isinstance(prompt, str) else prompt 96 | 97 | clip_tokenizers = tokenizers[:2] 98 | clip_text_encoders = text_encoders[:2] 99 | 100 | clip_prompt_embeds_list = [] 101 | clip_pooled_prompt_embeds_list = [] 102 | for tokenizer, text_encoder in zip(clip_tokenizers, clip_text_encoders): 103 | prompt_embeds, pooled_prompt_embeds = _encode_prompt_with_clip( 104 | text_encoder=text_encoder, 105 | tokenizer=tokenizer, 106 | prompt=prompt, 107 | device=device if device is not None else text_encoder.device, 108 | num_images_per_prompt=num_images_per_prompt, 109 | ) 110 | clip_prompt_embeds_list.append(prompt_embeds) 111 | clip_pooled_prompt_embeds_list.append(pooled_prompt_embeds) 112 | 113 | clip_prompt_embeds = torch.cat(clip_prompt_embeds_list, dim=-1) 114 | pooled_prompt_embeds = torch.cat(clip_pooled_prompt_embeds_list, dim=-1) 115 | 116 | t5_prompt_embed = _encode_prompt_with_t5( 117 | text_encoders[-1], 118 | tokenizers[-1], 119 | max_sequence_length, 120 | prompt=prompt, 121 | num_images_per_prompt=num_images_per_prompt, 122 | device=device if device is not None else text_encoders[-1].device, 123 | ) 124 | 125 | clip_prompt_embeds = torch.nn.functional.pad( 126 | clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1]) 127 | ) 128 | prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2) 129 | 130 | return prompt_embeds, pooled_prompt_embeds 131 | 132 | 133 | def compute_text_embeddings(prompt, text_encoders, tokenizers, max_sequence_length, device): 134 | with torch.no_grad(): 135 | prompt_embeds, pooled_prompt_embeds = encode_prompt(text_encoders, tokenizers, prompt, max_sequence_length) 136 | prompt_embeds = prompt_embeds.to(device) 137 | pooled_prompt_embeds = pooled_prompt_embeds.to(device) 138 | return prompt_embeds, pooled_prompt_embeds 139 | -------------------------------------------------------------------------------- /model/refiner/camera_util.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | from io import BytesIO 4 | 5 | import imageio 6 | import matplotlib.pyplot as plt 7 | import numpy as np 8 | import torch 9 | import tqdm 10 | from einops import einsum, rearrange 11 | from plyfile import PlyData, PlyElement 12 | from scipy.spatial.transform import Rotation as R 13 | from torchvision.utils import save_image 14 | 15 | 16 | def inverse_sigmoid(x): 17 | return torch.log(x / (1 - x)) 18 | 19 | 20 | def quaternion_to_matrix(quaternions): 21 | """ 22 | Convert rotations given as quaternions to rotation matrices. 23 | Args: 24 | quaternions: quaternions with real part first, 25 | as tensor of shape (..., 4). 26 | Returns: 27 | Rotation matrices as tensor of shape (..., 3, 3). 28 | """ 29 | r, i, j, k = torch.unbind(quaternions, -1) 30 | two_s = 2.0 / (quaternions * quaternions).sum(-1) 31 | 32 | o = torch.stack( 33 | ( 34 | 1 - two_s * (j * j + k * k), 35 | two_s * (i * j - k * r), 36 | two_s * (i * k + j * r), 37 | two_s * (i * j + k * r), 38 | 1 - two_s * (i * i + k * k), 39 | two_s * (j * k - i * r), 40 | two_s * (i * k - j * r), 41 | two_s * (j * k + i * r), 42 | 1 - two_s * (i * i + j * j), 43 | ), 44 | -1, 45 | ) 46 | return o.reshape(quaternions.shape[:-1] + (3, 3)) 47 | 48 | 49 | def matrix_to_quaternion(M: torch.Tensor) -> torch.Tensor: 50 | """ 51 | Matrix-to-quaternion conversion method. Equation taken from 52 | https://www.euclideanspace.com/maths/geometry/rotations/conversions/matrixToQuaternion/index.htm 53 | Args: 54 | M: rotation matrices, (... x 3 x 3) 55 | Returns: 56 | q: quaternion of shape (... x 4) 57 | """ 58 | prefix_shape = M.shape[:-2] 59 | Ms = M.reshape(-1, 3, 3) 60 | 61 | trs = 1 + Ms[:, 0, 0] + Ms[:, 1, 1] + Ms[:, 2, 2] 62 | 63 | Qs = [] 64 | 65 | for i in range(Ms.shape[0]): 66 | M = Ms[i] 67 | tr = trs[i] 68 | if tr > 0: 69 | r = torch.sqrt(tr) / 2.0 70 | x = (M[2, 1] - M[1, 2]) / (4 * r) 71 | y = (M[0, 2] - M[2, 0]) / (4 * r) 72 | z = (M[1, 0] - M[0, 1]) / (4 * r) 73 | elif (M[0, 0] > M[1, 1]) and (M[0, 0] > M[2, 2]): 74 | S = torch.sqrt(1.0 + M[0, 0] - M[1, 1] - M[2, 2]) * 2 # S=4*qx 75 | r = (M[2, 1] - M[1, 2]) / S 76 | x = 0.25 * S 77 | y = (M[0, 1] + M[1, 0]) / S 78 | z = (M[0, 2] + M[2, 0]) / S 79 | elif M[1, 1] > M[2, 2]: 80 | S = torch.sqrt(1.0 + M[1, 1] - M[0, 0] - M[2, 2]) * 2 # S=4*qy 81 | r = (M[0, 2] - M[2, 0]) / S 82 | x = (M[0, 1] + M[1, 0]) / S 83 | y = 0.25 * S 84 | z = (M[1, 2] + M[2, 1]) / S 85 | else: 86 | S = torch.sqrt(1.0 + M[2, 2] - M[0, 0] - M[1, 1]) * 2 # S=4*qz 87 | r = (M[1, 0] - M[0, 1]) / S 88 | x = (M[0, 2] + M[2, 0]) / S 89 | y = (M[1, 2] + M[2, 1]) / S 90 | z = 0.25 * S 91 | Q = torch.stack([r, x, y, z], dim=-1) 92 | Qs += [Q] 93 | 94 | return torch.stack(Qs, dim=0).reshape(*prefix_shape, 4) 95 | 96 | 97 | @torch.amp.autocast("cuda", enabled=False) 98 | def quaternion_slerp(q0, q1, fraction, spin: int = 0, shortestpath: bool = True): 99 | """Return spherical linear interpolation between two quaternions. 100 | Args: 101 | quat0: first quaternion 102 | quat1: second quaternion 103 | fraction: how much to interpolate between quat0 vs quat1 (if 0, closer to quat0; if 1, closer to quat1) 104 | spin: how much of an additional spin to place on the interpolation 105 | shortestpath: whether to return the short or long path to rotation 106 | """ 107 | d = (q0 * q1).sum(-1) 108 | if shortestpath: 109 | # invert rotation 110 | d[d < 0.0] = -d[d < 0.0] 111 | q1[d < 0.0] = q1[d < 0.0] 112 | 113 | d = d.clamp(0, 1.0) 114 | 115 | angle = torch.acos(d) + spin * math.pi 116 | isin = 1.0 / (torch.sin(angle) + 1e-10) 117 | q0_ = q0 * torch.sin((1.0 - fraction) * angle) * isin 118 | q1_ = q1 * torch.sin(fraction * angle) * isin 119 | 120 | q = q0_ + q1_ 121 | q[angle < 1e-5, :] = q0 122 | 123 | return q 124 | 125 | 126 | def sample_from_two_pose(pose_a, pose_b, fraction, noise_strengths=[0, 0]): 127 | """ 128 | Args: 129 | pose_a: first pose 130 | pose_b: second pose 131 | fraction 132 | """ 133 | 134 | quat_a = matrix_to_quaternion(pose_a[..., :3, :3]) 135 | quat_b = matrix_to_quaternion(pose_b[..., :3, :3]) 136 | 137 | quaternion = quaternion_slerp(quat_a, quat_b, fraction) 138 | quaternion = torch.nn.functional.normalize(quaternion + torch.randn_like(quaternion) * noise_strengths[0], dim=-1) 139 | 140 | R = quaternion_to_matrix(quaternion) 141 | T = (1 - fraction) * pose_a[..., :3, 3] + fraction * pose_b[..., :3, 3] 142 | T = T + torch.randn_like(T) * noise_strengths[1] 143 | 144 | new_pose = pose_a.clone() 145 | new_pose[..., :3, :3] = R 146 | new_pose[..., :3, 3] = T 147 | return new_pose 148 | 149 | 150 | def sample_from_dense_cameras(dense_cameras, t, noise_strengths=[0, 0, 0, 0]): 151 | _, N, A, B = dense_cameras.shape 152 | _, M = t.shape 153 | 154 | t = t.to(dense_cameras.device) 155 | left = torch.floor(t * (N - 1)).long().clamp(0, N - 2) 156 | right = left + 1 157 | fraction = t * (N - 1) - left 158 | a = torch.gather(dense_cameras, 1, left[..., None].repeat(1, 1, A, B)) 159 | b = torch.gather(dense_cameras, 1, right[..., None].repeat(1, 1, A, B)) 160 | 161 | new_pose = sample_from_two_pose(a[:, :, :3, 3:], b[:, :, :3, 3:], fraction, noise_strengths=noise_strengths[:2]) 162 | 163 | new_ins = (1 - fraction) * a[:, :, :3, :3] + fraction * b[:, :, :3, :3] 164 | 165 | return torch.cat([new_ins, new_pose], dim=-1) 166 | 167 | 168 | def export_ply_for_gaussians(path, gaussians): 169 | xyz, features, opacity, scales, rotations = gaussians 170 | 171 | means3D = xyz.contiguous().float() 172 | opacity = opacity.contiguous().float() 173 | scales = scales.contiguous().float() 174 | rotations = rotations.contiguous().float() 175 | shs = features.contiguous().float() # [N, 1, 3] 176 | 177 | SH_C0 = 0.28209479177387814 178 | means3D, rotations, shs = adjust_gaussians(means3D, rotations, shs, SH_C0, inverse=False) 179 | 180 | opacity = inverse_sigmoid(opacity) 181 | scales = torch.log(scales + 1e-8) 182 | 183 | xyzs = means3D.detach().cpu().numpy() 184 | f_dc = shs.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy() 185 | opacities = opacity.detach().cpu().numpy() 186 | scales = scales.detach().cpu().numpy() 187 | rotations = rotations.detach().cpu().numpy() 188 | 189 | l = ["x", "y", "z"] # noqa: E741 190 | # All channels except the 3 DC 191 | for i in range(f_dc.shape[1]): 192 | l.append("f_dc_{}".format(i)) 193 | l.append("opacity") 194 | for i in range(scales.shape[1]): 195 | l.append("scale_{}".format(i)) 196 | for i in range(rotations.shape[1]): 197 | l.append("rot_{}".format(i)) 198 | 199 | dtype_full = [(attribute, "f4") for attribute in l] 200 | 201 | elements = np.empty(xyzs.shape[0], dtype=dtype_full) 202 | attributes = np.concatenate((xyzs, f_dc, opacities, scales, rotations), axis=1) 203 | elements[:] = list(map(tuple, attributes)) 204 | el = PlyElement.describe(elements, "vertex") 205 | 206 | PlyData([el]).write(path + '.ply') 207 | 208 | plydata = PlyData([el]) 209 | 210 | vert = plydata["vertex"] 211 | sorted_indices = np.argsort( 212 | -np.exp(vert["scale_0"] + vert["scale_1"] + vert["scale_2"]) / (1 + np.exp(-vert["opacity"])) 213 | ) 214 | buffer = BytesIO() 215 | for idx in sorted_indices: 216 | v = plydata["vertex"][idx] 217 | position = np.array([v["x"], v["y"], v["z"]], dtype=np.float32) 218 | scales = np.exp( 219 | np.array( 220 | [v["scale_0"], v["scale_1"], v["scale_2"]], 221 | dtype=np.float32, 222 | ) 223 | ) 224 | rot = np.array( 225 | [v["rot_0"], v["rot_1"], v["rot_2"], v["rot_3"]], 226 | dtype=np.float32, 227 | ) 228 | color = np.array( 229 | [ 230 | 0.5 + SH_C0 * v["f_dc_0"], 231 | 0.5 + SH_C0 * v["f_dc_1"], 232 | 0.5 + SH_C0 * v["f_dc_2"], 233 | 1 / (1 + np.exp(-v["opacity"])), 234 | ] 235 | ) 236 | buffer.write(position.tobytes()) 237 | buffer.write(scales.tobytes()) 238 | buffer.write((color * 255).clip(0, 255).astype(np.uint8).tobytes()) 239 | buffer.write(((rot / np.linalg.norm(rot)) * 128 + 128).clip(0, 255).astype(np.uint8).tobytes()) 240 | 241 | with open(path + '.splat', "wb") as f: 242 | f.write(buffer.getvalue()) 243 | 244 | 245 | def load_ply_for_gaussians(path, device="cpu"): 246 | plydata = PlyData.read(path) 247 | 248 | xyz = np.stack( 249 | ( 250 | np.asarray(plydata.elements[0]["x"]), 251 | np.asarray(plydata.elements[0]["y"]), 252 | np.asarray(plydata.elements[0]["z"]), 253 | ), 254 | axis=1, 255 | ) 256 | opacities = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis] 257 | 258 | print("Number of points at loading : ", xyz.shape[0]) 259 | 260 | features_dc = np.zeros((xyz.shape[0], 3, 1)) 261 | features_dc[:, 0, 0] = np.asarray(plydata.elements[0]["f_dc_0"]) 262 | features_dc[:, 1, 0] = np.asarray(plydata.elements[0]["f_dc_1"]) 263 | features_dc[:, 2, 0] = np.asarray(plydata.elements[0]["f_dc_2"]) 264 | 265 | scale_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("scale_")] 266 | scales = np.zeros((xyz.shape[0], len(scale_names))) 267 | for idx, attr_name in enumerate(scale_names): 268 | scales[:, idx] = np.asarray(plydata.elements[0][attr_name]) 269 | 270 | rot_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("rot")] 271 | rots = np.zeros((xyz.shape[0], len(rot_names))) 272 | for idx, attr_name in enumerate(rot_names): 273 | rots[:, idx] = np.asarray(plydata.elements[0][attr_name]) 274 | 275 | xyz = torch.tensor(xyz, dtype=torch.float, device=device)[None] 276 | features = torch.tensor(features_dc, dtype=torch.float, device=device).transpose(1, 2)[None] 277 | opacity = torch.tensor(opacities, dtype=torch.float, device=device)[None] 278 | scales = torch.tensor(scales, dtype=torch.float, device=device)[None] 279 | rotations = torch.tensor(rots, dtype=torch.float, device=device)[None] 280 | 281 | opacity = torch.sigmoid(opacity) 282 | scales = torch.exp(scales) 283 | 284 | SH_C0 = 0.28209479177387814 285 | xyz, rotations, features = adjust_gaussians(xyz, rotations, features, SH_C0, inverse=True) 286 | 287 | return xyz, features, opacity, scales, rotations 288 | 289 | 290 | def adjust_gaussians(means, rotations, shs, SH_C0, inverse): 291 | rot_adjust = torch.tensor( 292 | [ 293 | [0, 0, 1], 294 | [-1, 0, 0], 295 | [0, -1, 0], 296 | ], 297 | dtype=torch.float32, 298 | device=means.device, 299 | ) 300 | 301 | adjustment = torch.tensor( 302 | R.from_rotvec([0, 0, -45], True).as_matrix(), 303 | dtype=torch.float32, 304 | device=means.device, 305 | ) 306 | 307 | rot_adjust = adjustment @ rot_adjust 308 | 309 | if inverse: # load: convert wxyz --> xyzw (rotation), convert shs to precomputed color 310 | rot_adjust = rot_adjust.inverse() 311 | means = einsum(rot_adjust, means, "i j, ... j -> ... i") 312 | rotations = R.from_quat(rotations[0].detach().cpu().numpy(), scalar_first=True).as_matrix() 313 | rotations = rot_adjust.detach().cpu().numpy() @ rotations 314 | rotations = R.from_matrix(rotations).as_quat() 315 | rotations = torch.from_numpy(rotations)[None].to(dtype=torch.float32, device=means.device) 316 | shs = 0.5 + shs * SH_C0 317 | 318 | else: # export: convert xyzw --> wxyz (rotation), convert precomputed color to shs 319 | means = einsum(rot_adjust, means, "i j, ... j -> ... i") 320 | rotations = R.from_quat(rotations.detach().cpu().numpy()).as_matrix() 321 | rotations = rot_adjust.detach().cpu().numpy() @ rotations 322 | rotations = R.from_matrix(rotations).as_quat() 323 | x, y, z, w = rearrange(rotations, "g xyzw -> xyzw g") 324 | rotations = torch.from_numpy(np.stack((w, x, y, z), axis=-1)).to(torch.float32) 325 | shs = (shs - 0.5) / SH_C0 326 | 327 | return means, rotations, shs 328 | 329 | 330 | @torch.no_grad() 331 | def export_video(render_fn, save_path, name, dense_cameras, fps=60, num_frames=720, size=512, device="cuda:0"): 332 | images = [] 333 | depths = [] 334 | 335 | for i in tqdm.trange(num_frames, desc="Rendering video..."): 336 | t = torch.full((1, 1), fill_value=i / num_frames, device=device) 337 | 338 | camera = sample_from_dense_cameras(dense_cameras, t) 339 | 340 | image, depth = render_fn(camera, size, size) 341 | 342 | images.append(process_image(image.reshape(3, size, size))) 343 | depths.append(process_image(depth.reshape(1, size, size))) 344 | 345 | imageio.mimwrite(os.path.join(save_path, f"{name}.mp4"), images, fps=fps, quality=8, macro_block_size=1) 346 | 347 | 348 | def process_image(image): 349 | return image.permute(1, 2, 0).detach().cpu().mul(1 / 2).add(1 / 2).clamp(0, 1).mul(255).numpy().astype(np.uint8) 350 | 351 | 352 | @torch.no_grad() 353 | def export_mv(render_fn, save_path, dense_cameras, size=256): 354 | num_views = dense_cameras.shape[1] 355 | imgs = [] 356 | for i in tqdm.trange(num_views, desc="Rendering images..."): 357 | image = render_fn(dense_cameras[:, i].unsqueeze(1), size, size)[0] 358 | path = os.path.join(save_path, "mv_results") 359 | os.makedirs(path, exist_ok=True) 360 | 361 | path = os.path.join(path, f"refined_render_img_{i}.png") 362 | image = image.reshape(3, size, size).clamp(-1, 1).add(1).mul(1 / 2) 363 | imgs.append(image) 364 | save_image(image, path) 365 | 366 | cmap = plt.get_cmap("hsv") 367 | num_frames = 8 368 | num_rows = 2 369 | num_cols = 4 370 | figsize = (num_cols * 2, num_rows * 2) 371 | fig, axs = plt.subplots(num_rows, num_cols, figsize=figsize) 372 | axs = axs.flatten() 373 | for i in range(num_rows * num_cols): 374 | if i < num_frames: 375 | axs[i].imshow((imgs[i].cpu().numpy().transpose(1, 2, 0) * 255.0).astype(np.uint8)) 376 | for s in ["bottom", "top", "left", "right"]: 377 | axs[i].spines[s].set_color(cmap(i / (num_frames))) 378 | axs[i].spines[s].set_linewidth(5) 379 | axs[i].set_xticks([]) 380 | axs[i].set_yticks([]) 381 | else: 382 | axs[i].axis("off") 383 | plt.tight_layout() 384 | plt.savefig(os.path.join(save_path, "refined_mv_images.pdf"), transparent=True) 385 | -------------------------------------------------------------------------------- /model/refiner/gs_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from diff_gaussian_rasterization import ( 4 | GaussianRasterizationSettings, 5 | GaussianRasterizer, 6 | ) 7 | from einops import rearrange 8 | 9 | from model.gsdecoder.cuda_splatting import get_fov, get_projection_matrix 10 | 11 | 12 | def inverse_sigmoid(x): 13 | return torch.log(x / (1 - x)) 14 | 15 | 16 | def inverse_softplus(x, beta=1): 17 | return (torch.exp(beta * x) - 1).log() / beta 18 | 19 | 20 | def build_rotation(r): # Note that we follow xyzw format. 21 | norm = torch.sqrt(r[:, 0] * r[:, 0] + r[:, 1] * r[:, 1] + r[:, 2] * r[:, 2] + r[:, 3] * r[:, 3]) 22 | 23 | q = r / norm[:, None] 24 | 25 | R = torch.zeros((q.size(0), 3, 3), device=r.device) 26 | 27 | x = q[:, 0] 28 | y = q[:, 1] 29 | z = q[:, 2] 30 | r = q[:, 3] 31 | 32 | R[:, 0, 0] = 1 - 2 * (y * y + z * z) 33 | R[:, 0, 1] = 2 * (x * y - r * z) 34 | R[:, 0, 2] = 2 * (x * z + r * y) 35 | R[:, 1, 0] = 2 * (x * y + r * z) 36 | R[:, 1, 1] = 1 - 2 * (x * x + z * z) 37 | R[:, 1, 2] = 2 * (y * z - r * x) 38 | R[:, 2, 0] = 2 * (x * z - r * y) 39 | R[:, 2, 1] = 2 * (y * z + r * x) 40 | R[:, 2, 2] = 1 - 2 * (x * x + y * y) 41 | return R 42 | 43 | 44 | def rgb2shs(rgb): 45 | SH_C0 = 0.28209479177387814 46 | return (rgb - 0.5) / SH_C0 47 | 48 | 49 | def shstorgb(shs): 50 | SH_C0 = 0.28209479177387814 51 | return (shs * SH_C0 + 0.5).clamp(min=0) 52 | 53 | 54 | class GaussiansManeger: 55 | def __init__(self, xyz, features, opacity, scales, rotations, lrs): 56 | self._xyz = nn.Parameter(xyz.squeeze(0).contiguous().float().detach().clone().requires_grad_(True)) 57 | self._features = nn.Parameter( 58 | rgb2shs(features).squeeze(0).contiguous().float().detach().clone().requires_grad_(True) 59 | ) 60 | self._opacity = nn.Parameter( 61 | inverse_sigmoid(opacity.squeeze(0)).contiguous().float().detach().clone().requires_grad_(True) 62 | ) 63 | self._scales = nn.Parameter( 64 | torch.log(scales.squeeze(0) + 1e-8).contiguous().float().detach().clone().requires_grad_(True) 65 | ) 66 | self._rotations = nn.Parameter(rotations.squeeze(0).contiguous().float().detach().clone().requires_grad_(True)) 67 | 68 | self.device = self._xyz.device 69 | 70 | self.optimizer = torch.optim.Adam( 71 | [ 72 | {"name": "xyz", "params": [self._xyz], "lr": lrs["xyz"]}, 73 | {"name": "features", "params": [self._features], "lr": lrs["features"]}, 74 | {"name": "opacity", "params": [self._opacity], "lr": lrs["opacity"]}, 75 | {"name": "scales", "params": [self._scales], "lr": lrs["scales"]}, 76 | {"name": "rotations", "params": [self._rotations], "lr": lrs["rotations"]}, 77 | ], 78 | betas=(0.9, 0.99), 79 | lr=0.0, 80 | eps=1e-15, 81 | ) 82 | 83 | self.xyz_gradient_accum = torch.zeros((self._xyz.shape[0], 1), device=self.device) 84 | self.denom = torch.zeros((self._xyz.shape[0], 1), device=self.device) 85 | self.max_radii2D = torch.zeros((self._xyz.shape[0],), device=self.device) 86 | self.is_visible = torch.zeros((self._xyz.shape[0],), device=self.device) 87 | 88 | self.percent_dense = 0.003 89 | 90 | def __call__(self): 91 | xyz = self._xyz 92 | features = shstorgb(self._features) 93 | opacity = torch.sigmoid(self._opacity) 94 | scales = torch.exp(self._scales) 95 | rotations = torch.nn.functional.normalize(self._rotations, dim=-1) 96 | return ( 97 | xyz.unsqueeze(0), 98 | features.unsqueeze(0), 99 | opacity.unsqueeze(0), 100 | scales.unsqueeze(0), 101 | rotations.unsqueeze(0), 102 | ) 103 | 104 | @torch.no_grad() 105 | def densify_and_prune(self, max_grad=4, extent=2, opacity_threshold=0.001): 106 | grads = self.xyz_gradient_accum / self.denom 107 | grads[grads.isnan()] = 0.0 108 | 109 | self.densify_and_clone(grads, max_grad, scene_extent=extent) 110 | self.densify_and_split(grads, max_grad, scene_extent=extent) 111 | 112 | prune_mask = (torch.sigmoid(self._opacity) < opacity_threshold).squeeze() 113 | self.prune_points(prune_mask) 114 | torch.cuda.empty_cache() 115 | 116 | def prune_points(self, mask): 117 | valid_points_mask = ~mask 118 | 119 | optimizable_tensors = self.prune_optimizer(valid_points_mask) 120 | self._xyz = optimizable_tensors["xyz"] 121 | self._features = optimizable_tensors["features"] 122 | self._opacity = optimizable_tensors["opacity"] 123 | self._scales = optimizable_tensors["scales"] 124 | self._rotations = optimizable_tensors["rotations"] 125 | 126 | self.xyz_gradient_accum = self.xyz_gradient_accum[valid_points_mask] 127 | self.denom = self.denom[valid_points_mask] 128 | self.max_radii2D = self.max_radii2D[valid_points_mask] 129 | self.is_visible = self.is_visible[valid_points_mask] 130 | 131 | def prune_optimizer(self, mask): 132 | optimizable_tensors = {} 133 | for group in self.optimizer.param_groups: 134 | stored_state = self.optimizer.state.get(group["params"][0], None) 135 | if stored_state is not None: 136 | stored_state["exp_avg"] = stored_state["exp_avg"][mask] 137 | stored_state["exp_avg_sq"] = stored_state["exp_avg_sq"][mask] 138 | 139 | del self.optimizer.state[group["params"][0]] 140 | group["params"][0] = nn.Parameter((group["params"][0][mask].requires_grad_(True))) 141 | self.optimizer.state[group["params"][0]] = stored_state 142 | 143 | optimizable_tensors[group["name"]] = group["params"][0] 144 | else: 145 | group["params"][0] = nn.Parameter(group["params"][0][mask].requires_grad_(True)) 146 | optimizable_tensors[group["name"]] = group["params"][0] 147 | return optimizable_tensors 148 | 149 | def add_points(self, params): 150 | num_points = params["xyz"].shape[0] 151 | 152 | optimizable_tensors = self.cat_tensors_to_optimizer(params) 153 | self._xyz = optimizable_tensors["xyz"] 154 | self._features = optimizable_tensors["features"] 155 | self._opacity = optimizable_tensors["opacity"] 156 | self._scales = optimizable_tensors["scales"] 157 | self._rotations = optimizable_tensors["rotations"] 158 | 159 | self.xyz_gradient_accum = torch.cat([self.xyz_gradient_accum, torch.zeros((num_points, 1), device=self.device)]) 160 | self.denom = torch.cat([self.denom, torch.zeros((num_points, 1), device=self.device)]) 161 | self.max_radii2D = torch.cat([self.max_radii2D, torch.zeros((num_points,), device=self.device)]) 162 | self.is_visible = torch.cat([self.is_visible, torch.zeros((num_points,), device=self.device)]) 163 | 164 | def cat_tensors_to_optimizer(self, tensors_dict): 165 | optimizable_tensors = {} 166 | for group in self.optimizer.param_groups: 167 | assert len(group["params"]) == 1 168 | extension_tensor = tensors_dict[group["name"]] 169 | stored_state = self.optimizer.state.get(group["params"][0], None) 170 | if stored_state is not None: 171 | stored_state["exp_avg"] = torch.cat( 172 | (stored_state["exp_avg"], torch.zeros_like(extension_tensor)), dim=0 173 | ) 174 | stored_state["exp_avg_sq"] = torch.cat( 175 | (stored_state["exp_avg_sq"], torch.zeros_like(extension_tensor)), dim=0 176 | ) 177 | 178 | del self.optimizer.state[group["params"][0]] 179 | group["params"][0] = nn.Parameter( 180 | torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True) 181 | ) 182 | self.optimizer.state[group["params"][0]] = stored_state 183 | 184 | optimizable_tensors[group["name"]] = group["params"][0] 185 | else: 186 | group["params"][0] = nn.Parameter( 187 | torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True) 188 | ) 189 | optimizable_tensors[group["name"]] = group["params"][0] 190 | 191 | return optimizable_tensors 192 | 193 | def densify_and_split(self, grads, grad_threshold=0.04, scene_extent=2, N=2): 194 | n_init_points = self._xyz.shape[0] 195 | 196 | # Extract points that satisfy the gradient condition 197 | padded_grad = torch.zeros((n_init_points), device=self.device) 198 | padded_grad[: grads.shape[0]] = grads.squeeze() 199 | selected_pts_mask = torch.where(padded_grad >= grad_threshold, True, False) 200 | selected_pts_mask = torch.logical_and( 201 | selected_pts_mask, torch.max(torch.exp(self._scales), dim=1).values > self.percent_dense * scene_extent 202 | ) 203 | 204 | stds = torch.exp(self._scales[selected_pts_mask]).repeat(N, 1) 205 | means = torch.zeros((stds.size(0), 3), device=self.device) 206 | samples = torch.normal(mean=means, std=stds) 207 | 208 | rots = build_rotation(self._rotations[selected_pts_mask]).repeat(N, 1, 1) 209 | 210 | new_xyz = torch.bmm(rots, samples.unsqueeze(-1)).squeeze(-1) + self._xyz[selected_pts_mask].repeat(N, 1) 211 | new_scales = torch.log(torch.exp(self._scales[selected_pts_mask]) / (0.8 * N)).repeat(N, 1) 212 | new_rotations = self._rotations[selected_pts_mask].repeat(N, 1) 213 | new_features = self._features[selected_pts_mask].repeat(N, 1, 1) 214 | new_opacity = self._opacity[selected_pts_mask].repeat(N, 1) 215 | 216 | params = { 217 | "xyz": new_xyz, 218 | "features": new_features, 219 | "opacity": new_opacity, 220 | "scales": new_scales, 221 | "rotations": new_rotations, 222 | } 223 | 224 | self.add_points(params) 225 | 226 | prune_filter = torch.cat( 227 | (selected_pts_mask, torch.zeros(N * selected_pts_mask.sum(), device=self.device, dtype=bool)) 228 | ) 229 | self.prune_points(prune_filter) 230 | 231 | def densify_and_clone(self, grads, grad_threshold=0.02, scene_extent=2): 232 | # Extract points that satisfy the gradient condition 233 | selected_pts_mask = torch.where(torch.norm(grads, dim=-1) >= grad_threshold, True, False) 234 | selected_pts_mask = torch.logical_and( 235 | selected_pts_mask, torch.max(torch.exp(self._scales), dim=1).values <= self.percent_dense * scene_extent 236 | ) 237 | 238 | new_xyz = self._xyz[selected_pts_mask] 239 | new_features = self._features[selected_pts_mask] 240 | new_opacity = self._opacity[selected_pts_mask] 241 | new_scales = self._scales[selected_pts_mask] 242 | new_rotations = self._rotations[selected_pts_mask] 243 | 244 | params = { 245 | "xyz": new_xyz, 246 | "features": new_features, 247 | "opacity": new_opacity, 248 | "scales": new_scales, 249 | "rotations": new_rotations, 250 | } 251 | 252 | self.add_points(params) 253 | 254 | def add_densification_stats(self, viewspace_point_tensor, update_filter): 255 | self.xyz_gradient_accum[update_filter] += torch.norm( 256 | viewspace_point_tensor.grad[update_filter, :2], dim=-1, keepdim=True 257 | ) 258 | self.denom[update_filter] += 1 259 | 260 | 261 | class GaussianRenderer(nn.Module): 262 | def __init__(self, h, w): 263 | super().__init__() 264 | 265 | self.h = h 266 | self.w = w 267 | self.near = 0.001 268 | self.far = 100 269 | 270 | def get_viewpoint_cameras(self, cameras): 271 | device = cameras.device 272 | K, Rt = cameras.split([3, 4], dim=-1) 273 | 274 | normalized_K = K.clone() 275 | normalized_K[0] *= 1 / (normalized_K[0, 2] * 2) # as cx cy = h/2, w/2 276 | normalized_K[1] *= 1 / (normalized_K[1, 2] * 2) # as cx cy = h/2, w/2 277 | 278 | fov_x, fov_y = get_fov(normalized_K[None]).unbind(dim=-1) 279 | tan_fov_x = (0.5 * fov_x).tan() 280 | tan_fov_y = (0.5 * fov_y).tan() 281 | 282 | projection_matrix = get_projection_matrix(self.near, self.far, fov_x, fov_y) 283 | projection_matrix = rearrange(projection_matrix, "b i j -> b j i") 284 | 285 | bottom = torch.tensor([0, 0, 0, 1], dtype=Rt.dtype, device=device) 286 | homo_poses = torch.cat([Rt, bottom[None]], dim=0) 287 | 288 | view_matrix = rearrange(homo_poses.inverse(), "i j -> () j i") 289 | full_projection = view_matrix @ projection_matrix 290 | 291 | return tan_fov_x, tan_fov_y, view_matrix, full_projection, Rt[:3, 3] 292 | 293 | @torch.amp.autocast("cuda", enabled=False) 294 | def forward(self, cameras, gaussians, h, w, bg="random"): 295 | B, N = cameras.shape[:2] 296 | xyz, features, opacity, scales, rotations = gaussians 297 | 298 | self.radii = [] 299 | self.viewspace_points = [] 300 | 301 | images = [] 302 | depths = [] 303 | masks = [] 304 | 305 | bg_color = torch.tensor([1, 1, 1], device=cameras.device).float() 306 | 307 | if bg == "random": 308 | bg_color = torch.rand_like(bg_color) 309 | 310 | for i in range(B): 311 | for j in range(N): 312 | tan_fov_x, tan_fov_y, view_matrix, full_projection, campos = self.get_viewpoint_cameras(cameras[i, j]) 313 | 314 | mean_gradients = torch.zeros_like(xyz[i], requires_grad=True) 315 | 316 | try: 317 | mean_gradients.retain_grad() 318 | except Exception: 319 | pass 320 | 321 | settings = GaussianRasterizationSettings( 322 | image_height=h, 323 | image_width=w, 324 | tanfovx=tan_fov_x.item(), 325 | tanfovy=tan_fov_y.item(), 326 | bg=bg_color, 327 | scale_modifier=1.0, 328 | viewmatrix=view_matrix, 329 | projmatrix=full_projection, 330 | sh_degree=0, 331 | campos=campos, 332 | prefiltered=False, 333 | debug=False, 334 | ) 335 | 336 | rasterizer = GaussianRasterizer(settings) 337 | 338 | image, radii, depth, mask = rasterizer( 339 | means3D=xyz[i], 340 | means2D=mean_gradients, 341 | shs=None, 342 | colors_precomp=features[i].squeeze(1), 343 | opacities=opacity[i], 344 | scales=scales[i], 345 | rotations=rotations[i][:, [3, 0, 1, 2]], 346 | cov3D_precomp=None, 347 | ) 348 | 349 | rendered_image = image.clamp(0, 1) 350 | rendered_mask = mask.clamp(0, 1) 351 | rendered_depth = depth + self.far * (1 - rendered_mask) 352 | 353 | images.append(rendered_image) 354 | depths.append(rendered_depth) 355 | masks.append(rendered_mask) 356 | 357 | self.radii.append(radii) 358 | self.viewspace_points.append(mean_gradients) 359 | 360 | images = torch.stack(images, dim=0).unflatten(0, (B, N)) * 2 - 1 361 | depths = torch.stack(depths, dim=0).unflatten(0, (B, N)) 362 | masks = torch.stack(masks, dim=0).unflatten(0, (B, N)) 363 | 364 | return images, depths, masks 365 | -------------------------------------------------------------------------------- /model/refiner/sds_pp_refiner.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import tqdm 5 | from diffusers import DDIMScheduler, StableDiffusionPipeline 6 | 7 | from model.refiner.camera_util import sample_from_dense_cameras 8 | from model.refiner.gs_util import GaussianRenderer, GaussiansManeger 9 | 10 | 11 | class GSRefinerSDSPlusPlus(nn.Module): 12 | def __init__( 13 | self, 14 | sd_model_key="stabilityai/stable-diffusion-2-1-base", 15 | num_views=1, 16 | total_iterations=500, 17 | guidance_scale=100, 18 | min_step_percent=0.02, 19 | max_step_percent=0.75, 20 | lr_scale=0.25, 21 | lrs={"xyz": 0.0001, "features": 0.01, "opacity": 0.05, "scales": 0.01, "rotations": 0.01}, 22 | use_lods=True, 23 | lambda_latent_sds=1, 24 | lambda_image_sds=0.1, 25 | lambda_image_variation=0.001, 26 | opacity_threshold=0.001, 27 | img_size=512, 28 | num_densifications=4, 29 | text_templete="$text$", 30 | negative_text_templete="", 31 | ): 32 | super().__init__() 33 | 34 | pipe = StableDiffusionPipeline.from_pretrained(sd_model_key) 35 | 36 | self.tokenizer = pipe.tokenizer 37 | self.text_encoder = pipe.text_encoder.requires_grad_(False) 38 | self.vae = pipe.vae.requires_grad_(False) 39 | self.unet = pipe.unet.requires_grad_(False) 40 | 41 | self.scheduler = DDIMScheduler.from_pretrained(sd_model_key, subfolder="scheduler", local_files_only=True) 42 | 43 | del pipe 44 | 45 | self.num_views = num_views 46 | self.total_iterations = total_iterations 47 | self.guidance_scale = guidance_scale 48 | self.lrs = {key: value * lr_scale for key, value in lrs.items()} 49 | 50 | self.register_buffer("alphas_cumprod", self.scheduler.alphas_cumprod, persistent=False) 51 | 52 | self.device = "cpu" 53 | 54 | self.num_train_timesteps = self.scheduler.config.num_train_timesteps 55 | 56 | self.set_min_max_steps(min_step_percent, max_step_percent) 57 | 58 | self.renderer = GaussianRenderer(img_size, img_size) 59 | 60 | self.text_templete = text_templete 61 | self.negative_text_templete = negative_text_templete 62 | 63 | self.use_lods = use_lods 64 | 65 | self.lambda_latent_sds = lambda_latent_sds 66 | self.lambda_image_sds = lambda_image_sds 67 | self.lambda_image_variation = lambda_image_variation 68 | 69 | self.img_size = img_size 70 | 71 | self.opacity_threshold = opacity_threshold 72 | 73 | self.densification_interval = self.total_iterations // (num_densifications + 1) 74 | 75 | def set_min_max_steps(self, min_step_percent=0.02, max_step_percent=0.98): 76 | self.min_step = int(self.num_train_timesteps * min_step_percent) 77 | self.max_step = int(self.num_train_timesteps * max_step_percent) 78 | 79 | def to(self, device): 80 | self.device = device 81 | return super().to(device) 82 | 83 | @torch.no_grad() 84 | def encode_text(self, texts): 85 | inputs = self.tokenizer( 86 | texts, 87 | padding="max_length", 88 | truncation_strategy="longest_first", 89 | max_length=self.tokenizer.model_max_length, 90 | return_tensors="pt", 91 | ) 92 | text_embeddings = self.text_encoder(inputs.input_ids.to(next(self.text_encoder.parameters()).device))[0] 93 | return text_embeddings 94 | 95 | # @torch.amp.autocast("cuda", enabled=False) 96 | def encode_image(self, images): 97 | posterior = self.vae.encode(images).latent_dist 98 | latents = posterior.sample() * self.vae.config.scaling_factor 99 | return latents 100 | 101 | # @torch.amp.autocast("cuda", enabled=False) 102 | def decode_latent(self, latents): 103 | latents = 1 / self.vae.config.scaling_factor * latents 104 | images = self.vae.decode(latents).sample 105 | return images 106 | 107 | def train_step( 108 | self, 109 | images, 110 | t, 111 | text_embeddings, 112 | uncond_text_embeddings, 113 | learnable_text_embeddings, 114 | ): 115 | latents = self.encode_image(images) 116 | 117 | with torch.no_grad(): 118 | B = latents.shape[0] 119 | t = t.repeat(self.num_views) 120 | 121 | noise = torch.randn_like(latents) 122 | latents_noisy = self.scheduler.add_noise(latents, noise, t) 123 | 124 | if self.use_lods: 125 | with torch.enable_grad(): 126 | noise_pred_learnable = self.unet( 127 | latents_noisy, t, encoder_hidden_states=learnable_text_embeddings 128 | ).sample 129 | 130 | loss_embedding = F.mse_loss(noise_pred_learnable, noise, reduction="mean") 131 | else: 132 | noise_pred_learnable = noise 133 | loss_embedding = 0 134 | 135 | with torch.no_grad(): 136 | noise_pred = self.unet( 137 | torch.cat([latents_noisy, latents_noisy], 0), 138 | torch.cat([t, t], 0), 139 | encoder_hidden_states=torch.cat([text_embeddings, uncond_text_embeddings], 0), 140 | ).sample 141 | 142 | noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2, dim=0) 143 | noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond) 144 | 145 | w = (1 - self.alphas_cumprod[t]).view(-1, 1, 1, 1) 146 | 147 | alpha = self.alphas_cumprod[t].view(-1, 1, 1, 1) ** 0.5 148 | sigma = (1 - self.alphas_cumprod[t].view(-1, 1, 1, 1)) ** 0.5 149 | 150 | latents_pred = (latents_noisy - sigma * (noise_pred - noise_pred_learnable + noise)) / alpha 151 | images_pred = self.decode_latent(latents_pred).clamp(-1, 1) 152 | 153 | loss_latent_sds = ( 154 | F.mse_loss(latents, latents_pred, reduction="none").sum([1, 2, 3]) * w * alpha / sigma 155 | ).sum() / B 156 | loss_image_sds = ( 157 | F.mse_loss(images, images_pred, reduction="none").sum([1, 2, 3]) * w * alpha / sigma 158 | ).sum() / B 159 | 160 | return loss_latent_sds, loss_image_sds, loss_embedding 161 | 162 | @torch.amp.autocast("cuda", enabled=True) 163 | @torch.enable_grad() 164 | def refine_gaussians(self, gaussians, text, dense_cameras): 165 | gaussians_original = gaussians 166 | xyz, features, opacity, scales, rotations = gaussians 167 | 168 | mask = opacity[..., 0] >= self.opacity_threshold 169 | xyz_original = xyz[mask][None] 170 | features_original = features[mask][None] 171 | opacity_original = opacity[mask][None] 172 | scales_original = scales[mask][None] 173 | rotations_original = rotations[mask][None] 174 | 175 | text = self.text_templete.replace("$text$", text) 176 | 177 | text_embeddings = self.encode_text([text]) 178 | uncond_text_embeddings = self.encode_text([self.negative_text_templete.replace("$text$", text)]) 179 | 180 | class LearnableTextEmbeddings(nn.Module): 181 | def __init__(self, uncond_text_embeddings): 182 | super().__init__() 183 | self.embeddings = nn.Parameter(torch.zeros_like(uncond_text_embeddings.float().detach().clone())) 184 | self.to(self.embeddings.device) 185 | 186 | def forward(self, cameras): 187 | B = cameras.shape[1] 188 | return self.embeddings.repeat(B, 1, 1) 189 | 190 | _learnable_text_embeddings = LearnableTextEmbeddings(uncond_text_embeddings) 191 | 192 | text_embeddings = text_embeddings.repeat(self.num_views, 1, 1) 193 | uncond_text_embeddings = uncond_text_embeddings.repeat(self.num_views, 1, 1) 194 | 195 | new_gaussians = GaussiansManeger( 196 | xyz_original, features_original, opacity_original, scales_original, rotations_original, self.lrs 197 | ) 198 | 199 | optimizer_embeddings = torch.optim.Adam(_learnable_text_embeddings.parameters(), lr=self.lrs["embeddings"]) 200 | for i in tqdm.trange(self.total_iterations, desc="Refining..."): 201 | if i % self.densification_interval == 0 and i != 0: 202 | new_gaussians.densify_and_prune(opacity_threshold=self.opacity_threshold) 203 | 204 | with torch.amp.autocast("cuda", enabled=False): 205 | cameras = sample_from_dense_cameras(dense_cameras, torch.rand(1, self.num_views).to(self.device)) 206 | 207 | learnable_text_embeddings = _learnable_text_embeddings(cameras) 208 | 209 | with torch.no_grad(): 210 | images_original = self.renderer(cameras, gaussians_original, h=self.img_size, w=self.img_size)[0] 211 | 212 | gaussians = new_gaussians() 213 | images_pred= self.renderer(cameras, gaussians, h=self.img_size, w=self.img_size)[0] 214 | 215 | anneal_t = int((i / self.total_iterations) ** (1 / 2) * (self.min_step - self.max_step) + self.max_step) 216 | t = torch.full((1,), anneal_t, dtype=torch.long, device=self.device) 217 | 218 | loss_latent_sds, loss_img_sds, loss_embedding = self.train_step( 219 | images_pred.squeeze(0), t, text_embeddings, uncond_text_embeddings, learnable_text_embeddings 220 | ) 221 | 222 | image_variation_loss = F.mse_loss(images_original, images_pred, reduction='sum') / self.num_views 223 | 224 | loss = ( 225 | loss_latent_sds * self.lambda_latent_sds + 226 | loss_img_sds * self.lambda_image_sds + 227 | image_variation_loss * self.lambda_image_variation + 228 | loss_embedding 229 | ) 230 | 231 | optimizer_embeddings.zero_grad() 232 | new_gaussians.optimizer.zero_grad() 233 | 234 | # self.lambda_scale_regularization 235 | scales = torch.exp(new_gaussians._scales) 236 | big_points_ws = scales.max(dim=1).values > 0.05 237 | loss += 10 * scales[big_points_ws].sum() 238 | 239 | loss.backward() 240 | 241 | new_gaussians.optimizer.step() 242 | optimizer_embeddings.step() 243 | 244 | for radii, viewspace_points in zip(self.renderer.radii, self.renderer.viewspace_points): 245 | visibility_filter = radii > 0 246 | new_gaussians.is_visible[visibility_filter] = 1 247 | new_gaussians.max_radii2D[visibility_filter] = torch.max( 248 | new_gaussians.max_radii2D[visibility_filter], radii[visibility_filter] 249 | ) 250 | new_gaussians.add_densification_stats(viewspace_points, visibility_filter) 251 | 252 | gaussians = new_gaussians() 253 | is_visible = new_gaussians.is_visible.bool() 254 | gaussians = [p[:, is_visible].detach() for p in gaussians] 255 | 256 | del new_gaussians 257 | return gaussians 258 | -------------------------------------------------------------------------------- /model/util.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | from diffusers import AutoencoderKL, SD3Transformer2DModel 5 | from transformers import AutoImageProcessor, AutoModelForDepthEstimation 6 | 7 | 8 | def create_vae(cfg): 9 | vae = AutoencoderKL.from_pretrained( 10 | cfg.mv_rf_model.hf_path, 11 | subfolder="vae", 12 | ) 13 | return vae 14 | 15 | 16 | def create_sd3_transformer(): 17 | sd3 = SD3Transformer2DModel.from_pretrained( 18 | "stabilityai/stable-diffusion-3-medium-diffusers", 19 | subfolder="transformer", 20 | low_cpu_mem_usage=True, 21 | ) 22 | return sd3 23 | 24 | 25 | def create_depth(cfg): 26 | """ 27 | How to use depth models 28 | 29 | # prepare image for the model 30 | inputs = depth_image_processor(images=image, return_tensors="pt") 31 | 32 | with torch.no_grad(): 33 | outputs = model(**inputs) 34 | predicted_depth = outputs.predicted_depth 35 | """ 36 | depth_image_processor = AutoImageProcessor.from_pretrained(cfg.depth_encoder.hf_path) 37 | depth_model = AutoModelForDepthEstimation.from_pretrained(cfg.depth_encoder.hf_path) 38 | return depth_image_processor, depth_model 39 | 40 | 41 | def convert_to_buffer(module: torch.nn.Module, persistent: bool = True): 42 | # Recurse over child modules. 43 | for name, child in list(module.named_children()): 44 | convert_to_buffer(child, persistent) 45 | 46 | # Also re-save buffers to change persistence. 47 | for name, parameter_or_buffer in ( 48 | *module.named_parameters(recurse=False), 49 | *module.named_buffers(recurse=False), 50 | ): 51 | value = parameter_or_buffer.detach().clone() 52 | delattr(module, name) 53 | module.register_buffer(name, value, persistent=persistent) 54 | 55 | 56 | def update_ema(ema_model, model, decay=0.9999): 57 | """ 58 | Step the EMA model towards the current model. 59 | """ 60 | 61 | model_context_manager = model.summon_full_params(model) 62 | with model_context_manager: 63 | ema_params = OrderedDict(ema_model.named_parameters()) 64 | model_params = OrderedDict(model.named_parameters()) 65 | for name, param in model_params.items(): 66 | ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay) 67 | 68 | 69 | def stack_depth_images(depth_in): 70 | """ 71 | Crawled from https://github.com/prs-eth/Marigold/blob/main/src/trainer/marigold_trainer.py#L395 72 | """ 73 | if 4 == len(depth_in.shape): 74 | stacked = depth_in.repeat(1, 3, 1, 1) 75 | elif 3 == len(depth_in.shape): 76 | stacked = depth_in.unsqueeze(1) 77 | stacked = depth_in.repeat(1, 3, 1, 1) 78 | return stacked 79 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | hydra-core==1.3.2 2 | transformers==4.40.2 3 | wandb==0.17.0 4 | timm==0.5.4 5 | diffusers==0.30.0 6 | sentencepiece==0.2.0 7 | matplotlib==3.8.4 8 | imageio-ffmpeg==0.5.1 9 | camtools 10 | accelerate 11 | einops 12 | lpips 13 | plyfile 14 | fvcore 15 | carvekit-colab 16 | scikit-image 17 | opencv-python 18 | git+https://github.com/byeongjun-park/diff-gaussian-rasterization 19 | git+https://github.com/facebookresearch/pytorch3d@V0.7.8 -------------------------------------------------------------------------------- /util/camera_visualization.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | 9 | import matplotlib.pyplot as plt 10 | import torch 11 | from mpl_toolkits.mplot3d import Axes3D # noqa: F401 unused import 12 | 13 | 14 | def get_camera_wireframe(scale: float = 0.3): 15 | """ 16 | Returns a wireframe of a 3D line-plot of a camera symbol. 17 | """ 18 | a = 0.5 * torch.tensor([-1, 1, 2]) 19 | b = 0.5 * torch.tensor([1, 1, 2]) 20 | c = 0.5 * torch.tensor([-1, -1, 2]) 21 | d = 0.5 * torch.tensor([1, -1, 2]) 22 | C = torch.zeros(3) 23 | camera_points = [a, b, d, c, a, C, b, d, C, c, C] 24 | lines = torch.stack([x.float() for x in camera_points]) * scale 25 | return lines 26 | 27 | 28 | def plot_cameras(ax, cameras, color=None): 29 | """ 30 | Plots a set of `cameras` objects into the maplotlib axis `ax` with 31 | color `color`. 32 | """ 33 | cam_wires_canonical = get_camera_wireframe(scale=0.05).cuda()[None] 34 | cam_trans = cameras.get_world_to_view_transform().inverse() 35 | cam_wires_trans = cam_trans.transform_points(cam_wires_canonical) 36 | plot_handles = [] 37 | cmap = plt.get_cmap("hsv") 38 | 39 | for i, wire in enumerate(cam_wires_trans): 40 | # the Z and Y axes are flipped intentionally here! 41 | color_ = cmap(i / len(cameras))[:-1] if color is None else color 42 | x_, z_, y_ = wire.detach().cpu().numpy().T.astype(float) 43 | (h,) = ax.plot(x_, y_, z_, color=color_, linewidth=0.5) 44 | plot_handles.append(h) 45 | return cam_wires_trans 46 | 47 | 48 | def plot_camera_scene(cameras, cameras_gt, path): 49 | """ 50 | Plots a set of predicted cameras `cameras` and their corresponding 51 | ground truth locations `cameras_gt`. The plot is named with 52 | a string passed inside the `status` argument. 53 | """ 54 | fig = plt.figure() 55 | ax = fig.add_subplot(projection="3d") 56 | ax.clear() 57 | 58 | points = plot_cameras(ax, cameras) 59 | points_gt = plot_cameras(ax, cameras_gt, color=(0.1, 0.1, 0.1)) 60 | tot_pts = torch.cat([points[:, -1], points_gt[:, -1]], dim=0) 61 | 62 | max_scene = tot_pts.max(dim=0)[0].cpu() 63 | min_scene = tot_pts.min(dim=0)[0].cpu() 64 | ax.set_xticklabels([]) 65 | ax.set_yticklabels([]) 66 | ax.set_zticklabels([]) 67 | ax.set_xlim3d([min_scene[0] - 0.1, max_scene[0] + 0.3]) 68 | ax.set_ylim3d([min_scene[2] - 0.1, max_scene[2] + 0.1]) 69 | ax.set_zlim3d([min_scene[1] - 0.1, max_scene[1] + 0.1]) 70 | 71 | ax.invert_yaxis() 72 | 73 | plt.savefig(os.path.join(path, "pose.pdf"), bbox_inches="tight", pad_inches=0, transparent=True) -------------------------------------------------------------------------------- /util/dist_util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import socket 3 | 4 | import torch 5 | import torch.distributed as dist 6 | 7 | DEVICE = None 8 | 9 | 10 | def setup_dist(args): 11 | if dist.is_initialized(): 12 | return 13 | 14 | if os.environ.get("MASTER_ADDR", None) is None: 15 | hostname = socket.gethostbyname(socket.getfqdn()) 16 | os.environ["MASTER_ADDR"] = hostname 17 | os.environ["RANK"] = "0" 18 | os.environ["WORLD_SIZE"] = "1" 19 | port = _find_free_port() 20 | os.environ["MASTER_PORT"] = str(port) 21 | 22 | dist.init_process_group("nccl") 23 | assert args.global_batch_size % dist.get_world_size() == 0, "Batch size must be divisible by world size." 24 | rank = dist.get_rank() 25 | device = rank % torch.cuda.device_count() + args.gpu_offset 26 | seed = args.global_seed * dist.get_world_size() + rank 27 | torch.manual_seed(seed) 28 | torch.cuda.set_device(device) 29 | global DEVICE 30 | DEVICE = device 31 | print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}, device={DEVICE}") 32 | 33 | 34 | def _find_free_port(): 35 | try: 36 | s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 37 | s.bind(("", 0)) 38 | s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 39 | return s.getsockname()[1] 40 | finally: 41 | s.close() 42 | 43 | 44 | def cleanup(): 45 | """ 46 | End DDP training. 47 | """ 48 | dist.destroy_process_group() 49 | 50 | 51 | def device(): 52 | if not dist.is_initialized(): 53 | raise NameError 54 | return DEVICE 55 | --------------------------------------------------------------------------------