├── viewdiff ├── __init__.py ├── data │ ├── __init__.py │ ├── co3d │ │ ├── __init__.py │ │ ├── save_recentered_sequences.py │ │ ├── generate_co3d_dreambooth_data.py │ │ ├── generate_blip2_captions.py │ │ └── util.py │ └── create_video_from_image_folder.py ├── metrics │ ├── __init__.py │ └── image_metrics.py ├── model │ ├── __init__.py │ ├── projection │ │ ├── util.py │ │ ├── fastplane │ │ │ └── fastplane_module.py │ │ ├── layer.py │ │ └── voxel_proj.py │ ├── custom_transformer_2d.py │ └── custom_attention_processor.py ├── scripts │ ├── __init__.py │ ├── misc │ │ ├── __init__.py │ │ ├── create_masked_images.py │ │ ├── calculate_mean_image_stats.py │ │ ├── process_nerfstudio_to_sdfstudio.py │ │ └── export_nerf_transforms.py │ ├── test │ │ ├── test_spherical_360_256x256.sh │ │ ├── eval_single_image_input.sh │ │ ├── test_sliding_window_smooth_alternating_theta_30_360_256x256.sh │ │ └── test_sliding_window_smooth_alternating_theta_60_360_256x256.sh │ ├── train_small.sh │ └── train.sh ├── convert_checkpoint_to_model.py └── io_util.py ├── docs └── teaser.jpg ├── .gitignore ├── requirements.txt ├── CONTRIBUTING.md ├── CODE_OF_CONDUCT.md ├── README.MD └── LICENSE /viewdiff/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /viewdiff/data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /viewdiff/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /viewdiff/model/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /viewdiff/scripts/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /viewdiff/data/co3d/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /viewdiff/scripts/misc/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/teaser.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/ViewDiff/HEAD/docs/teaser.jpg -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | output/ 2 | outputs/ 3 | runs/ 4 | checkpoints/ 5 | wandb/ 6 | 7 | .idea 8 | .vscode 9 | __pycache__/ 10 | 11 | *.pt 12 | *egg-info/ 13 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.23.0 2 | bitsandbytes==0.41.2.post2 3 | cachetools==5.3.0 4 | carvekit_colab==4.1.0 5 | dacite==1.8.1 6 | diffusers==0.21.2 7 | ffmpeg_python==0.2.0 8 | imageio>=2.27 9 | lpips==0.1.4 10 | numpy==1.24.2 11 | omegaconf==2.3.0 12 | opencv_python>=4.8.1.78 13 | packaging==23.2 14 | Pillow==10.3.0 15 | PyYAML==6.0.1 16 | scikit-image==0.20.0 17 | scipy==1.10.0 18 | torch==2.3.2 19 | torchvision==0.15.2 20 | tqdm==4.66.3 21 | transformers==4.48.0 22 | tyro==0.5.7 23 | xformers 24 | tensorboard 25 | -------------------------------------------------------------------------------- /viewdiff/data/co3d/save_recentered_sequences.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | import tyro 4 | from .co3d_dataset import CO3DConfig, CO3DDataset 5 | 6 | 7 | def save_recentered( 8 | dataset_config: CO3DConfig, 9 | recompute: bool = False, 10 | ): 11 | # make sure the important fields are set correctly 12 | dataset_config.dataset_args.load_point_clouds = True 13 | dataset_config.batch.load_recentered = False 14 | dataset_config.batch.need_mask_augmentations = False 15 | 16 | # Get the dataset: parse CO3Dv2 17 | dataset = CO3DDataset(dataset_config) 18 | 19 | # save recentered data 20 | dataset.recenter_sequences(recompute=recompute) 21 | 22 | 23 | if __name__ == "__main__": 24 | tyro.cli(save_recentered) 25 | -------------------------------------------------------------------------------- /viewdiff/scripts/test/test_spherical_360_256x256.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | 4 | export CO3DV2_DATASET_ROOT=$1 5 | 6 | python -m viewdiff.test \ 7 | --run-config.pretrained_model_name_or_path "$2" \ 8 | --run-config.output_dir $3 \ 9 | --run-config.n_input_images "$4" \ 10 | --run-config.num_inference_steps 10 \ 11 | --run_config.scheduler_type "unipc" \ 12 | --dataset-config.co3d-root $CO3DV2_DATASET_ROOT \ 13 | --dataset-config.category "$5" \ 14 | --dataset-config.dataset_args.pick_sequence "$6" \ 15 | --dataset-config.batch.other_selection "sequence" \ 16 | --dataset-config.batch.sequence_offset 1 \ 17 | --dataset-config.batch.load_recentered \ 18 | --dataset-config.batch.use_blip_prompt \ 19 | --dataset-config.batch.replace_pose_with_spherical_start_phi 0 \ 20 | --dataset-config.batch.replace_pose_with_spherical_end_phi 360 \ 21 | --dataset-config.batch.crop "foreground" \ 22 | --dataset-config.batch.image_width 256 \ 23 | --dataset-config.batch.image_height 256 \ 24 | -------------------------------------------------------------------------------- /viewdiff/scripts/test/eval_single_image_input.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | 4 | export CO3DV2_DATASET_ROOT=$1 5 | 6 | python -m viewdiff.test \ 7 | --run-config.pretrained_model_name_or_path "$2" \ 8 | --run-config.output_dir $3 \ 9 | --run-config.n_input_images 5 \ 10 | --run-config.sliding_window.input_condition_mode "dataset" \ 11 | --run_config.sliding_window.input_condition_n_images 1 \ 12 | --run-config.num_inference_steps 10 \ 13 | --run_config.scheduler_type "unipc" \ 14 | --run_config.max_steps 200 \ 15 | --run_config.guidance_scale 1.0 \ 16 | --run_config.n_repeat_generation 20 \ 17 | --dataset-config.co3d-root $CO3DV2_DATASET_ROOT \ 18 | --dataset-config.category "$4" \ 19 | --dataset-config.batch.other_selection "random" \ 20 | --dataset-config.batch.load_recentered \ 21 | --dataset-config.batch.crop "foreground" \ 22 | --dataset-config.batch.prompt "a $4" \ 23 | --dataset-config.batch.image_width 256 \ 24 | --dataset-config.batch.image_height 256 \ 25 | --dataset-config.dataset_args.load_masks 26 | -------------------------------------------------------------------------------- /viewdiff/data/create_video_from_image_folder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | import argparse 4 | import ffmpeg 5 | 6 | 7 | def main(args): 8 | ffmpeg_bin = "/usr/bin/ffmpeg" 9 | ( 10 | ffmpeg.input( 11 | f"{args.image_folder}/{args.file_name_pattern_glob}", pattern_type="glob", framerate=args.framerate 12 | ) 13 | .output(args.output_path, **{"codec:v": "libx264", "crf": 19}, loglevel="quiet") 14 | .run(cmd=ffmpeg_bin, overwrite_output=True) 15 | ) 16 | 17 | 18 | if __name__ == "__main__": 19 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 20 | 21 | # GENERAL CONFIG 22 | group = parser.add_argument_group("general") 23 | group.add_argument("--image_folder", required=True) 24 | group.add_argument("--file_name_pattern_glob", required=False, default="*.png") 25 | group.add_argument("--framerate", required=False, type=int, default=15) 26 | group.add_argument("--output_path", required=False, type=str, default="video.mp4") 27 | 28 | args = parser.parse_args() 29 | 30 | main(args) 31 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to ViewDiff 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `main`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints. 13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Contributor License Agreement ("CLA") 16 | In order to accept your pull request, we need you to submit a CLA. You only need 17 | to do this once to work on any of Facebook's open source projects. 18 | 19 | Complete your CLA here: 20 | 21 | ## Issues 22 | We use GitHub issues to track public bugs. Please ensure your description is 23 | clear and has sufficient instructions to be able to reproduce the issue. 24 | 25 | Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe 26 | disclosure of security bugs. In those cases, please go through the process 27 | outlined on that page and do not file a public issue. 28 | 29 | ## License 30 | By contributing to ViewDiff, you agree that your contributions will be licensed 31 | under the LICENSE file in the root directory of this source tree. -------------------------------------------------------------------------------- /viewdiff/scripts/test/test_sliding_window_smooth_alternating_theta_30_360_256x256.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | 4 | export CO3DV2_DATASET_ROOT=$1 5 | 6 | python -m viewdiff.test \ 7 | --run-config.pretrained_model_name_or_path $2 \ 8 | --run-config.output_dir $3 \ 9 | --run-config.n_input_images "10" \ 10 | --run-config.create_nerf_exports \ 11 | --run-config.save.pred_gif \ 12 | --run-config.sliding_window.is_active \ 13 | --run-config.sliding_window.create_smooth_video \ 14 | --run-config.sliding_window.repeat_first_n_steps 1 \ 15 | --run-config.sliding_window.n_full_batches_to_save 1 \ 16 | --run-config.sliding_window.perc_add_images_to_save 0.5 \ 17 | --run-config.sliding_window.max_degrees 60 \ 18 | --run-config.sliding_window.degree_increment 50 \ 19 | --run-config.sliding_window.first_theta 30.0 \ 20 | --run-config.sliding_window.min_theta 30.0 \ 21 | --run-config.sliding_window.max_theta 30.0 \ 22 | --run-config.sliding_window.first_radius 4.0 \ 23 | --run-config.sliding_window.min_radius 4.0 \ 24 | --run-config.sliding_window.max_radius 4.0 \ 25 | --run-config.num_inference_steps 10 \ 26 | --run_config.scheduler_type "unipc" \ 27 | --run_config.max_steps 6 \ 28 | --dataset-config.co3d-root $CO3DV2_DATASET_ROOT \ 29 | --dataset-config.category $4 \ 30 | --dataset-config.batch.other_selection "sequence" \ 31 | --dataset-config.batch.sequence_offset 1 \ 32 | --dataset-config.batch.load_recentered \ 33 | --dataset-config.batch.use_blip_prompt \ 34 | --dataset-config.batch.crop "foreground" \ 35 | --dataset-config.batch.image_width 256 \ 36 | --dataset-config.batch.image_height 256 \ 37 | --dataset-config.seed 500 38 | -------------------------------------------------------------------------------- /viewdiff/scripts/test/test_sliding_window_smooth_alternating_theta_60_360_256x256.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | 4 | export CO3DV2_DATASET_ROOT=$1 5 | 6 | python -m viewdiff.test \ 7 | --run-config.pretrained_model_name_or_path $2 \ 8 | --run-config.output_dir $3 \ 9 | --run-config.n_input_images "10" \ 10 | --run-config.create_nerf_exports \ 11 | --run-config.save.pred_gif \ 12 | --run-config.sliding_window.is_active \ 13 | --run-config.sliding_window.create_smooth_video \ 14 | --run-config.sliding_window.repeat_first_n_steps 1 \ 15 | --run-config.sliding_window.n_full_batches_to_save 1 \ 16 | --run-config.sliding_window.perc_add_images_to_save 0.5 \ 17 | --run-config.sliding_window.max_degrees 60 \ 18 | --run-config.sliding_window.degree_increment 50 \ 19 | --run-config.sliding_window.first_theta 60.0 \ 20 | --run-config.sliding_window.min_theta 60.0 \ 21 | --run-config.sliding_window.max_theta 60.0 \ 22 | --run-config.sliding_window.first_radius 4.0 \ 23 | --run-config.sliding_window.min_radius 4.0 \ 24 | --run-config.sliding_window.max_radius 4.0 \ 25 | --run-config.num_inference_steps 10 \ 26 | --run_config.scheduler_type "unipc" \ 27 | --run_config.max_steps 6 \ 28 | --dataset-config.co3d-root $CO3DV2_DATASET_ROOT \ 29 | --dataset-config.category $4 \ 30 | --dataset-config.batch.other_selection "sequence" \ 31 | --dataset-config.batch.sequence_offset 1 \ 32 | --dataset-config.batch.load_recentered \ 33 | --dataset-config.batch.use_blip_prompt \ 34 | --dataset-config.batch.crop "foreground" \ 35 | --dataset-config.batch.image_width 256 \ 36 | --dataset-config.batch.image_height 256 \ 37 | --dataset-config.seed 500 38 | -------------------------------------------------------------------------------- /viewdiff/metrics/image_metrics.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | import torch 4 | import lpips 5 | import skimage 6 | 7 | 8 | def load_lpips_vgg_model(lpips_vgg_model_path: str): 9 | return lpips.LPIPS(net='vgg', model_path=lpips_vgg_model_path) 10 | 11 | 12 | def calc_psnr_ssim_lpips(src: torch.Tensor, target: torch.Tensor, lpips_vgg_model: lpips.LPIPS): 13 | """ 14 | 15 | :param src: (n_batches, n_images_per_batch, C, H, W) 16 | :param target: (n_batches, n_images_per_batch, C, H, W) 17 | :param lpips_vgg_model: 18 | :return: 19 | """ 20 | l2_criterion = torch.nn.MSELoss(reduction='none') 21 | 22 | psnrs = [] 23 | lpipses = [] 24 | ssims = [] 25 | for batch_idx in range(src.shape[0]): 26 | # ================ PSNR measurement ================ 27 | # don't average across frames 28 | l2_loss = l2_criterion(src[batch_idx], target[batch_idx]).mean(dim=[1, 2, 3]) 29 | psnr = -10 * torch.log10(l2_loss) 30 | psnrs.extend(list(x.item() for x in psnr)) 31 | 32 | # ================ LPIPS measurement ================ 33 | lpips = lpips_vgg_model(src[batch_idx], target[batch_idx], normalize=True) 34 | lpipses.extend(list(x.item() for x in lpips)) 35 | 36 | # ================ SSIM measurement ============= 37 | for view_idx in range(src.shape[1]): 38 | ssim = skimage.metrics.structural_similarity( 39 | src[batch_idx, view_idx].cpu().numpy(), 40 | target[batch_idx, view_idx].cpu().numpy(), 41 | data_range=1, 42 | channel_axis=0 43 | ) 44 | ssims.append(float(ssim)) 45 | 46 | return psnrs, lpipses, ssims 47 | -------------------------------------------------------------------------------- /viewdiff/model/projection/util.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | import torch 4 | 5 | 6 | def get_pixel_grids(height, width, reverse=False): 7 | if reverse: 8 | # specify as +X left, +Y up (e.g. Pytorch3D convention, see here: https://github.com/facebookresearch/pytorch3d/blob/main/docs/notes/cameras.md) 9 | x_linspace = torch.linspace(width - 1, 0, width).view(1, width).expand(height, width) 10 | y_linspace = torch.linspace(height - 1, 0, height).view(height, 1).expand(height, width) 11 | else: 12 | # specify as +X right, +Y down 13 | x_linspace = torch.linspace(0, width - 1, width).view(1, width).expand(height, width) 14 | y_linspace = torch.linspace(0, height - 1, height).view(height, 1).expand(height, width) 15 | x_coordinates = x_linspace.contiguous().view(-1).contiguous() 16 | y_coordinates = y_linspace.contiguous().view(-1).contiguous() 17 | 18 | ones = torch.ones(height * width) 19 | indices_grid = torch.stack([x_coordinates, y_coordinates, ones], dim=0) 20 | return indices_grid 21 | 22 | 23 | def project_batch(points: torch.Tensor, K: torch.Tensor, world2cam: torch.Tensor, eps=1e-6) -> torch.Tensor: 24 | """ 25 | 26 | Args: 27 | points (torch.Tensor): (batch_size, 3, P) 28 | world2cam (torch.Tensor): (batch_size, 4, 4) 29 | K (torch.Tensor): (batch_size, 3, 3) 30 | 31 | Returns: 32 | torch.Tensor: (xy in pixels, depth) 33 | """ 34 | cam_points = world2cam[..., :3, :3].bmm(points) + world2cam[..., :3, 3:4] 35 | xy_proj = K.bmm(cam_points) 36 | 37 | zs = xy_proj[..., 2:3, :] 38 | mask = (zs.abs() < eps).detach() 39 | zs[mask] = eps 40 | sampler = torch.cat((xy_proj[..., 0:2, :] / zs, xy_proj[..., 2:3, :]), dim=1) 41 | 42 | # Remove invalid zs that cause nans 43 | sampler[mask.repeat(1, 3, 1)] = -10 44 | 45 | return sampler 46 | 47 | 48 | def screen_to_ndc(x, h, w): 49 | # convert as specified here: https://github.com/facebookresearch/pytorch3d/blob/main/docs/notes/cameras.md 50 | sampler = torch.clone(x) 51 | if h > w: 52 | # W from [-1, 1], H from [-s, s] where s=H/W 53 | sampler[..., 0:1] = (sampler[..., 0:1] + 0.5) / w * 2.0 - 1.0 54 | sampler[..., 1:2] = ((sampler[..., 1:2] + 0.5) / h * 2.0 - 1.0) * h / w 55 | else: 56 | # H from [-1, 1], W from [-s, s] where s=W/H 57 | sampler[..., 0:1] = ((sampler[..., 0:1] + 0.5) / w * 2.0 - 1.0) * w / h 58 | sampler[..., 1:2] = (sampler[..., 1:2] + 0.5) / h * 2.0 - 1.0 59 | return sampler 60 | -------------------------------------------------------------------------------- /viewdiff/scripts/train_small.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | 4 | export CO3DV2_DATASET_ROOT=$1 5 | 6 | accelerate launch --mixed_precision="no" -m viewdiff.train \ 7 | --finetune-config.io.pretrained_model_name_or_path $2 \ 8 | --finetune-config.io.output_dir $3 \ 9 | --finetune-config.io.experiment_name "train_teddybear" \ 10 | --finetune-config.training.mixed_precision "no" \ 11 | --finetune-config.training.dataloader_num_workers "0" \ 12 | --finetune-config.training.num_train_epochs "1000" \ 13 | --finetune-config.training.train_batch_size "1" \ 14 | --finetune-config.training.dreambooth_prior_preservation_loss_weight "0.1" \ 15 | --finetune_config.training.noise_prediction_type "epsilon" \ 16 | --finetune_config.training.prob_images_not_noisy "0.25" \ 17 | --finetune_config.training.max_num_images_not_noisy "2" \ 18 | --finetune_config.training.validation_epochs "1" \ 19 | --finetune_config.training.dreambooth_prior_preservation_every_nth "1" \ 20 | --finetune-config.optimizer.learning_rate "5e-5" \ 21 | --finetune-config.optimizer.vol_rend_learning_rate "1e-3" \ 22 | --finetune-config.optimizer.vol_rend_adam_weight_decay "0.0" \ 23 | --finetune-config.optimizer.gradient_accumulation_steps "1" \ 24 | --finetune-config.optimizer.max_grad_norm "5e-3" \ 25 | --finetune-config.cross_frame_attention.to_k_other_frames "2" \ 26 | --finetune-config.cross_frame_attention.random_others \ 27 | --finetune-config.cross_frame_attention.with_self_attention \ 28 | --finetune-config.cross_frame_attention.use_temb_cond \ 29 | --finetune-config.cross_frame_attention.mode "pretrained" \ 30 | --finetune-config.cross_frame_attention.n_cfa_down_blocks "0" \ 31 | --finetune-config.cross_frame_attention.n_cfa_up_blocks "0" \ 32 | --finetune-config.cross_frame_attention.unproj_reproj_mode "with_cfa" \ 33 | --finetune-config.cross_frame_attention.num_3d_layers "1" \ 34 | --finetune-config.cross_frame_attention.dim_3d_latent "16" \ 35 | --finetune-config.cross_frame_attention.dim_3d_grid "64" \ 36 | --finetune-config.cross_frame_attention.n_novel_images "1" \ 37 | --finetune-config.cross_frame_attention.vol_rend_proj_in_mode "multiple" \ 38 | --finetune-config.cross_frame_attention.vol_rend_proj_out_mode "multiple" \ 39 | --finetune-config.cross_frame_attention.vol_rend_aggregator_mode "ibrnet" \ 40 | --finetune-config.cross_frame_attention.last_layer_mode "no_residual_connection" \ 41 | --finetune_config.cross_frame_attention.vol_rend_model_background \ 42 | --finetune_config.cross_frame_attention.vol_rend_background_grid_percentage "0.5" \ 43 | --finetune-config.model.pose_cond_mode "sa-ca" \ 44 | --finetune-config.model.pose_cond_coord_space "absolute" \ 45 | --finetune-config.model.pose_cond_lora_rank "64" \ 46 | --finetune-config.model.n_input_images "3" \ 47 | --dataset-config.co3d-root $CO3DV2_DATASET_ROOT \ 48 | --dataset-config.category $4 \ 49 | --dataset-config.max_sequences 50 \ 50 | --dataset-config.batch.load_recentered \ 51 | --dataset-config.batch.use_blip_prompt \ 52 | --dataset-config.batch.crop "random" \ 53 | --dataset-config.batch.image_width "256" \ 54 | --dataset-config.batch.image_height "256" \ 55 | --dataset-config.batch.other_selection "mix" \ 56 | --validation-dataset-config.co3d-root $CO3DV2_DATASET_ROOT \ 57 | --validation-dataset-config.category "teddybear" \ 58 | --validation-dataset-config.max_sequences "1" \ 59 | --validation-dataset-config.batch.load_recentered \ 60 | --validation-dataset-config.batch.use_blip_prompt \ 61 | --validation-dataset-config.batch.crop "random" \ 62 | --validation-dataset-config.batch.image_width "256" \ 63 | --validation-dataset-config.batch.image_height "256" \ 64 | --validation-dataset-config.dataset_args.n_frames_per_sequence "3" 65 | -------------------------------------------------------------------------------- /viewdiff/scripts/train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | 4 | export CO3DV2_DATASET_ROOT=$1 5 | 6 | accelerate launch --mixed_precision="no" --multi_gpu -m viewdiff.train \ 7 | --finetune-config.io.pretrained_model_name_or_path $2 \ 8 | --finetune-config.io.output_dir $3 \ 9 | --finetune-config.io.experiment_name "train_teddybear" \ 10 | --finetune-config.training.mixed_precision "no" \ 11 | --finetune-config.training.dataloader_num_workers "0" \ 12 | --finetune-config.training.num_train_epochs "1000" \ 13 | --finetune-config.training.train_batch_size "4" \ 14 | --finetune-config.training.dreambooth_prior_preservation_loss_weight "0.1" \ 15 | --finetune_config.training.noise_prediction_type "epsilon" \ 16 | --finetune_config.training.prob_images_not_noisy "0.25" \ 17 | --finetune_config.training.max_num_images_not_noisy "2" \ 18 | --finetune_config.training.validation_epochs "1" \ 19 | --finetune_config.training.dreambooth_prior_preservation_every_nth "1" \ 20 | --finetune-config.optimizer.learning_rate "5e-5" \ 21 | --finetune-config.optimizer.vol_rend_learning_rate "1e-3" \ 22 | --finetune-config.optimizer.vol_rend_adam_weight_decay "0.0" \ 23 | --finetune-config.optimizer.gradient_accumulation_steps "8" \ 24 | --finetune-config.optimizer.max_grad_norm "5e-3" \ 25 | --finetune-config.cross_frame_attention.to_k_other_frames "4" \ 26 | --finetune-config.cross_frame_attention.random_others \ 27 | --finetune-config.cross_frame_attention.with_self_attention \ 28 | --finetune-config.cross_frame_attention.use_temb_cond \ 29 | --finetune-config.cross_frame_attention.mode "pretrained" \ 30 | --finetune-config.cross_frame_attention.n_cfa_down_blocks "1" \ 31 | --finetune-config.cross_frame_attention.n_cfa_up_blocks "1" \ 32 | --finetune-config.cross_frame_attention.unproj_reproj_mode "with_cfa" \ 33 | --finetune-config.cross_frame_attention.num_3d_layers "5" \ 34 | --finetune-config.cross_frame_attention.dim_3d_latent "16" \ 35 | --finetune-config.cross_frame_attention.dim_3d_grid "128" \ 36 | --finetune-config.cross_frame_attention.n_novel_images "1" \ 37 | --finetune-config.cross_frame_attention.vol_rend_proj_in_mode "multiple" \ 38 | --finetune-config.cross_frame_attention.vol_rend_proj_out_mode "multiple" \ 39 | --finetune-config.cross_frame_attention.vol_rend_aggregator_mode "ibrnet" \ 40 | --finetune-config.cross_frame_attention.last_layer_mode "no_residual_connection" \ 41 | --finetune_config.cross_frame_attention.vol_rend_model_background \ 42 | --finetune_config.cross_frame_attention.vol_rend_background_grid_percentage "0.5" \ 43 | --finetune-config.model.pose_cond_mode "sa-ca" \ 44 | --finetune-config.model.pose_cond_coord_space "absolute" \ 45 | --finetune-config.model.pose_cond_lora_rank "64" \ 46 | --finetune-config.model.n_input_images "5" \ 47 | --dataset-config.co3d-root $CO3DV2_DATASET_ROOT \ 48 | --dataset-config.category $4 \ 49 | --dataset-config.max_sequences "500" \ 50 | --dataset-config.batch.load_recentered \ 51 | --dataset-config.batch.use_blip_prompt \ 52 | --dataset-config.batch.crop "random" \ 53 | --dataset-config.batch.image_width "256" \ 54 | --dataset-config.batch.image_height "256" \ 55 | --dataset-config.batch.other_selection "mix" \ 56 | --validation-dataset-config.co3d-root $CO3DV2_DATASET_ROOT \ 57 | --validation-dataset-config.category "teddybear" \ 58 | --validation-dataset-config.max_sequences "1" \ 59 | --validation-dataset-config.batch.load_recentered \ 60 | --validation-dataset-config.batch.use_blip_prompt \ 61 | --validation-dataset-config.batch.crop "random" \ 62 | --validation-dataset-config.batch.image_width "256" \ 63 | --validation-dataset-config.batch.image_height "256" \ 64 | --validation-dataset-config.dataset_args.n_frames_per_sequence "5" 65 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | This Code of Conduct also applies outside the project spaces when there is a 56 | reasonable belief that an individual's behavior may have a negative impact on 57 | the project or its community. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported by contacting the project team at . All 63 | complaints will be reviewed and investigated and will result in a response that 64 | is deemed necessary and appropriate to the circumstances. The project team is 65 | obligated to maintain confidentiality with regard to the reporter of an incident. 66 | Further details of specific enforcement policies may be posted separately. 67 | 68 | Project maintainers who do not follow or enforce the Code of Conduct in good 69 | faith may face temporary or permanent repercussions as determined by other 70 | members of the project's leadership. 71 | 72 | ## Attribution 73 | 74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 76 | 77 | [homepage]: https://www.contributor-covenant.org 78 | 79 | For answers to common questions about this code of conduct, see 80 | https://www.contributor-covenant.org/faq 81 | -------------------------------------------------------------------------------- /viewdiff/data/co3d/generate_co3d_dreambooth_data.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | from typing import Tuple 4 | import os 5 | import random 6 | from tqdm.auto import tqdm 7 | import json 8 | import torch 9 | import tyro 10 | from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler 11 | 12 | 13 | def load_sd(pretrained_model_name_or_path: str = "stabilityai/stable-diffusion-2-1-base", device: str = "cuda") -> StableDiffusionPipeline: 14 | pipe = StableDiffusionPipeline.from_pretrained(pretrained_model_name_or_path, torch_dtype=torch.float16).to(device) 15 | pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) 16 | pipe.set_progress_bar_config(disable=True) 17 | return pipe 18 | 19 | 20 | def load_prompts(prompt_file: str): 21 | blip_prompts_dict = {} 22 | with open(prompt_file, "r") as ff: 23 | blip_prompts = json.load(ff) 24 | for category, sequence_dict in blip_prompts.items(): 25 | if category not in blip_prompts_dict: 26 | blip_prompts_dict[category] = {} 27 | for sequence, prompts in sequence_dict.items(): 28 | if sequence not in blip_prompts_dict[category]: 29 | blip_prompts_dict[category][sequence] = [] 30 | blip_prompts_dict[category][sequence].extend(prompts) 31 | 32 | return blip_prompts_dict 33 | 34 | 35 | def main( 36 | prompt_file: str, 37 | output_path: str, 38 | pretrained_model_name_or_path: str = "stabilityai/stable-diffusion-2-1-base", 39 | device: str = "cuda", 40 | max_sequences_per_category: int = 300, 41 | max_prompts_per_sequence: int = 2, 42 | num_images_per_prompt: int = 2, 43 | selected_categories: Tuple[str, ...] = () 44 | ): 45 | # load 46 | pipe = load_sd(pretrained_model_name_or_path, device) 47 | prompts = load_prompts(prompt_file) 48 | 49 | # setup output 50 | os.makedirs(output_path, exist_ok=True) 51 | 52 | # iterate all categories 53 | for category, sequence_dict in prompts.items(): 54 | # check if category is selected 55 | if len(selected_categories) > 0 and category not in selected_categories: 56 | continue 57 | 58 | # setup output dir 59 | category_out = os.path.join(output_path, category) 60 | os.makedirs(category_out, exist_ok=True) 61 | image_to_prompt = {} 62 | 63 | # subsample sequences 64 | sequences = sequence_dict.keys() 65 | n_sequences = len(sequences) 66 | sequences = random.sample( 67 | sequences, 68 | k=min(n_sequences, max_sequences_per_category), 69 | ) 70 | 71 | # generate for all remaining sequences 72 | for s in tqdm(sequences, desc=f"Generate for {category}"): 73 | # subsample prompts 74 | prompts = sequence_dict[s] 75 | n_prompts = len(prompts) 76 | prompts = random.sample( 77 | prompts, 78 | k=min(n_prompts, max_prompts_per_sequence), 79 | ) 80 | 81 | # generate + save next images 82 | image_counter = 0 83 | for p in prompts: 84 | images = pipe(p, num_images_per_prompt=num_images_per_prompt).images 85 | for img in images: 86 | image_name = f"{s}_{image_counter}" 87 | out_file = os.path.join(category_out, f"{image_name}.jpg") 88 | image_to_prompt[image_name] = p 89 | with open(out_file, "wb") as f: 90 | img.save(f) 91 | image_counter += 1 92 | 93 | # save image_to_prompt file 94 | image_to_prompt_file = os.path.join(category_out, "image_to_prompt.json") 95 | with open(image_to_prompt_file, "w") as f: 96 | json.dump(image_to_prompt, f, indent=4) 97 | 98 | 99 | if __name__ == '__main__': 100 | tyro.cli(main) 101 | -------------------------------------------------------------------------------- /viewdiff/data/co3d/generate_blip2_captions.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | import torch 4 | from torch.utils.data import DataLoader 5 | from datetime import datetime 6 | import os 7 | import json 8 | import tyro 9 | from tqdm.auto import tqdm 10 | from transformers import Blip2Processor, Blip2ForConditionalGeneration 11 | 12 | from ...io_util import torch_to_pil 13 | from .co3d_dataset import CO3DConfig, CO3DDataset 14 | 15 | 16 | def load_blip2_model(pretrained_model_name_or_path: str = "Salesforce/blip2-opt-2.7b", device: str = "cuda"): 17 | processor = Blip2Processor.from_pretrained(pretrained_model_name_or_path) 18 | model = Blip2ForConditionalGeneration.from_pretrained(pretrained_model_name_or_path, torch_dtype=torch.float16) 19 | model = model.to(device) 20 | 21 | return model, processor 22 | 23 | 24 | def save_captions_file(captions, output_file: str, intermediate: bool = False): 25 | date_time = datetime.now().strftime("%d.%m.%Y_%H:%M:%S.%f") 26 | output_file_parts = output_file.split(".") 27 | output_file_without_suffix = ".".join(output_file_parts[:-1]) 28 | output_file_without_suffix += f"_{date_time}" 29 | if intermediate: 30 | output_file_without_suffix += f"_intermediate" 31 | output_file_with_time = f"{output_file_without_suffix}.{output_file_parts[-1]}" 32 | with open(output_file_with_time, "w") as f: 33 | json.dump(captions, f, indent=4) 34 | 35 | 36 | @torch.no_grad() 37 | def generate_blip2_captions( 38 | dataset_config: CO3DConfig, 39 | pretrained_model_name_or_path: str = "Salesforce/blip2-opt-2.7b", 40 | device: str = "cuda", 41 | batch_size: int = 4, 42 | output_file: str = "co3d_blip2_captions.json", 43 | ): 44 | # load blip2 model 45 | model, processor = load_blip2_model(pretrained_model_name_or_path, device) 46 | 47 | # make sure the important fields are set correctly 48 | dataset_config.dataset_args.load_point_clouds = False 49 | dataset_config.batch.load_recentered = False 50 | dataset_config.batch.need_mask_augmentations = False 51 | dataset_config.batch.n_parallel_images = 1 52 | dataset_config.dataset_args.n_frames_per_sequence = 5 53 | 54 | # can make it square here already s.t. collate works - processor makes it square anyways 55 | dataset_config.batch.crop = "resize" 56 | dataset_config.batch.image_height = 512 57 | dataset_config.batch.image_width = 512 58 | 59 | # Get the dataset: parse CO3Dv2 60 | dataset = CO3DDataset(dataset_config) 61 | dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0) 62 | 63 | # loop over data 64 | captions = {} 65 | for idx, batch in enumerate(tqdm(dataloader, desc="Generate Captions")): 66 | # get sequence, category from batch 67 | sequences = [os.path.basename(x) for x in batch["root"]] 68 | categories = [os.path.basename(os.path.dirname(x)) for x in batch["root"]] 69 | 70 | # get image from batch 71 | images = batch["images"] # (batch_size, K=1, C, H, W) 72 | images = images.squeeze() # (batch_size, C, H, W) 73 | images = [torch_to_pil(x) for x in images] # processor expects PIL images 74 | 75 | # run captioning 76 | inputs = processor(images=images, return_tensors="pt").to(device, torch.float16) 77 | generated_ids = model.generate(**inputs) 78 | generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True) 79 | generated_text = [s.strip() for s in generated_text] 80 | 81 | # save captions 82 | for c, s, p in zip(categories, sequences, generated_text): 83 | if c not in captions: 84 | captions[c] = {} 85 | if s not in captions[c]: 86 | captions[c][s] = [] 87 | captions[c][s].append(p) 88 | 89 | # save intermediate outputs in case this crashes at some point 90 | if idx % 5000 == 0: 91 | save_captions_file(captions, output_file, intermediate=True) 92 | 93 | # save final file 94 | save_captions_file(captions, output_file) 95 | 96 | 97 | if __name__ == "__main__": 98 | tyro.cli(generate_blip2_captions) 99 | -------------------------------------------------------------------------------- /viewdiff/scripts/misc/create_masked_images.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | from typing import List, Dict, Any 4 | import cv2 5 | import tyro 6 | import os 7 | from tqdm.auto import tqdm 8 | import numpy as np 9 | import io 10 | import torch 11 | from PIL import Image 12 | 13 | class BackgroundRemoval: 14 | def __init__(self, device='cuda'): 15 | from carvekit.api.high import HiInterface 16 | self.interface = HiInterface( 17 | object_type="object", # Can be "object" or "hairs-like". 18 | batch_size_seg=5, 19 | batch_size_matting=1, 20 | device=device, 21 | seg_mask_size=640, # Use 640 for Tracer B7 and 320 for U2Net 22 | matting_mask_size=2048, 23 | trimap_prob_threshold=231, 24 | trimap_dilation=30, 25 | trimap_erosion_iters=5, 26 | fp16=True, 27 | ) 28 | 29 | @torch.no_grad() 30 | def __call__(self, image): 31 | # image: [H, W, 3] array in [0, 255]. 32 | image = Image.fromarray(image) 33 | image = self.interface([image])[0] 34 | image = np.array(image) 35 | bg_mask = image[..., 3:4] == 0 36 | return image, bg_mask 37 | 38 | 39 | def load_carvekit_bkgd_removal_model(checkpoint_dir: str, device: str = "cuda"): 40 | if checkpoint_dir is not None: 41 | import carvekit.ml.files as cmf 42 | from pathlib import Path 43 | cmf.checkpoints_dir = Path(checkpoint_dir) 44 | return BackgroundRemoval(device=device) 45 | 46 | 47 | def remove_background(image, mask_predictor): 48 | # predict mask 49 | rgba, mask = mask_predictor(image) # [H, W, 4] 50 | 51 | # remove salt&pepper noise from mask (a common artifact of this method) 52 | kernel = np.ones((5, 5), np.uint8) 53 | mask = cv2.dilate(mask.astype(np.uint8), kernel, iterations=1) 54 | mask = cv2.erode(mask, kernel, iterations=1) 55 | mask = mask.astype(bool)[..., None] 56 | 57 | # white background 58 | rgb = rgba[..., :3] * (1 - mask) + 255 * mask 59 | rgb = rgb.astype(np.uint8) 60 | 61 | return rgb, mask 62 | 63 | 64 | def segment_and_save_carvekit(image_path, mask_path, masked_image_path, mask_predictor): 65 | # read image 66 | with open(image_path, "rb") as f: 67 | array = np.asarray(bytearray(f.read()), dtype=np.uint8) 68 | image = cv2.imdecode(array, cv2.IMREAD_UNCHANGED) 69 | 70 | # convert image 71 | if image.shape[-1] == 4: 72 | image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGB) 73 | else: 74 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 75 | 76 | rgb, mask = remove_background(image, mask_predictor) 77 | 78 | # convert mask to uint8 79 | mask = (1 - mask[..., 0].astype(np.uint8)) * 255 80 | 81 | # save rgb image 82 | with open(masked_image_path, "wb") as f: 83 | is_success, buffer = cv2.imencode(".png", cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR)) 84 | if not is_success: 85 | print("Could not write", mask_path) 86 | else: 87 | io_buf = io.BytesIO(buffer) 88 | f.write(io_buf.getbuffer()) 89 | 90 | # save mask 91 | with open(mask_path, "wb") as f: 92 | is_success, buffer = cv2.imencode(".png", mask) 93 | if not is_success: 94 | print("Could not write", mask_path) 95 | else: 96 | io_buf = io.BytesIO(buffer) 97 | f.write(io_buf.getbuffer()) 98 | 99 | 100 | def main(run_folder: str, carvekit_checkpoint_dir: str, runs_offset: int = 0, max_n_runs: int = -1): 101 | # get all runs 102 | runs = os.listdir(run_folder) 103 | runs = [os.path.join(run_folder, f, "images") for f in runs] 104 | runs = [f for f in runs if os.path.isdir(f)] 105 | 106 | # filter runs 107 | runs = runs[runs_offset:runs_offset+max_n_runs] 108 | print("mask these runs", runs) 109 | 110 | # load carvekit model 111 | carvekit_model = load_carvekit_bkgd_removal_model(carvekit_checkpoint_dir) 112 | 113 | for input_image_folder in tqdm(runs, desc="Create masked images"): 114 | # setup output dir for masks 115 | output_mask_folder = os.path.join(input_image_folder, "masks") 116 | os.makedirs(output_mask_folder, exist_ok=True) 117 | 118 | # setup output dir for images 119 | output_image_folder = os.path.join(input_image_folder, "masked_images") 120 | os.makedirs(output_image_folder, exist_ok=True) 121 | 122 | # segment each image 123 | images = [f for f in os.listdir(input_image_folder) if ".png" in f and "pred_file" in f] 124 | for image in tqdm(images, desc="mask image", leave=True): 125 | segment_and_save_carvekit( 126 | image_path=os.path.join(input_image_folder, image), 127 | mask_path=os.path.join(output_mask_folder, image), 128 | masked_image_path=os.path.join(output_image_folder, image), 129 | mask_predictor=carvekit_model 130 | ) 131 | 132 | 133 | if __name__ == '__main__': 134 | tyro.cli(main) 135 | -------------------------------------------------------------------------------- /viewdiff/convert_checkpoint_to_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | import os 4 | import json 5 | import tyro 6 | import shutil 7 | 8 | from accelerate import Accelerator 9 | from accelerate.utils import ProjectConfiguration 10 | 11 | from .model.custom_stable_diffusion_pipeline import CustomStableDiffusionPipeline 12 | 13 | from .train_util import FinetuneConfig, unet_attn_processors_state_dict, load_models 14 | from diffusers.loaders import LoraLoaderMixin 15 | 16 | from dacite import from_dict, Config 17 | 18 | from .train import update_model 19 | 20 | 21 | def convert_checkpoint_to_model( 22 | checkpoint_path: str, keep_config_output_dir: bool = False, pretrained_model_name_or_path: str = "stabilityai/stable-diffusion-2-1-base" 23 | ): 24 | # load config from checkpoint_path 25 | if checkpoint_path[-1] == "/": 26 | # cannot have trailing slash in checkpoint path for dirname 27 | checkpoint_path = checkpoint_path[:-1] 28 | root_dir = os.path.dirname(checkpoint_path) 29 | ckpt_name = os.path.basename(checkpoint_path) 30 | if checkpoint_path[-1] != "/": 31 | # need trailing slash in checkpoint path 32 | checkpoint_path += "/" 33 | config_path = os.path.join(root_dir, "config.json") 34 | 35 | if not os.path.isfile(str(config_path)): 36 | raise ValueError("cannot find config.json in ", config_path) 37 | 38 | with open(config_path, "r") as f: 39 | config_data = json.load(f) 40 | finetune_config = from_dict(FinetuneConfig, data=config_data, config=Config(cast=[tuple, int])) 41 | 42 | if not keep_config_output_dir: 43 | finetune_config.io.output_dir = os.path.join(root_dir, f"saved_model_from_{ckpt_name}") 44 | 45 | if pretrained_model_name_or_path is not None: 46 | finetune_config.io.pretrained_model_name_or_path = pretrained_model_name_or_path 47 | 48 | # setup run 49 | accelerator_project_config = ProjectConfiguration( 50 | project_dir=finetune_config.io.output_dir, 51 | ) 52 | accelerator = Accelerator( 53 | gradient_accumulation_steps=finetune_config.optimizer.gradient_accumulation_steps, 54 | mixed_precision=finetune_config.training.mixed_precision, 55 | project_config=accelerator_project_config, 56 | ) 57 | 58 | # Load models. 59 | _, _, text_encoder, vae, orig_unet = load_models( 60 | finetune_config.io.pretrained_model_name_or_path, revision=finetune_config.io.revision 61 | ) 62 | unet, unet_lora_parameters = update_model(finetune_config, orig_unet) 63 | model_cls = type(unet) 64 | orig_model_cls = type(orig_unet) 65 | 66 | def load_model_hook(models, input_dir): 67 | for _ in range(len(models)): 68 | # pop models so that they are not loaded again 69 | model = models.pop() 70 | 71 | # load diffusers style into model 72 | if isinstance(model, model_cls): 73 | in_dir = os.path.join(input_dir, "unet") 74 | load_model = model_cls.from_pretrained(in_dir) 75 | model.register_to_config(**load_model.config) 76 | model.load_state_dict(load_model.state_dict(), strict=unet_lora_parameters is None) 77 | del load_model 78 | 79 | if unet_lora_parameters is not None: 80 | try: 81 | lora_state_dict, network_alpha = LoraLoaderMixin.lora_state_dict(in_dir, weight_name="pytorch_lora_weights.safetensors") 82 | except: 83 | lora_state_dict, network_alpha = LoraLoaderMixin.lora_state_dict(in_dir, weight_name="pytorch_lora_weights.bin") 84 | lora_state_dict = {k.replace("unet.", ""): v for k, v in lora_state_dict.items()} 85 | model.load_state_dict(lora_state_dict, strict=False) 86 | print("Loaded LoRA weights into model") 87 | elif isinstance(model, orig_model_cls): 88 | in_dir = os.path.join(input_dir, "orig_unet") 89 | load_model = orig_model_cls.from_pretrained(in_dir) 90 | model.register_to_config(**load_model.config) 91 | model.load_state_dict(load_model.state_dict()) 92 | del load_model 93 | print("Loaded orig_unet model") 94 | else: 95 | raise ValueError(f"unexpected load model: {model.__class__}") 96 | 97 | accelerator.register_load_state_pre_hook(load_model_hook) 98 | unet = accelerator.prepare(unet) 99 | 100 | # load in the weights and states from a previous save 101 | accelerator.print(f"Load checkpoint {checkpoint_path}") 102 | accelerator.load_state(checkpoint_path) 103 | 104 | accelerator.wait_for_everyone() 105 | if accelerator.is_main_process: 106 | accelerator.print(f"Save model at", finetune_config.io.output_dir) 107 | # Create the pipeline using the trained modules and save it. 108 | unet = accelerator.unwrap_model(unet) 109 | 110 | pipeline = CustomStableDiffusionPipeline.from_pretrained( 111 | finetune_config.io.pretrained_model_name_or_path, 112 | text_encoder=text_encoder, 113 | vae=vae, 114 | unet=unet, 115 | revision=finetune_config.io.revision, 116 | ) 117 | pipeline.save_pretrained(finetune_config.io.output_dir) 118 | 119 | # save lora layers 120 | unet_lora_layers = unet_attn_processors_state_dict(unet) 121 | LoraLoaderMixin.save_lora_weights(save_directory=os.path.join(finetune_config.io.output_dir, "unet"), unet_lora_layers=unet_lora_layers) 122 | 123 | # copy config to new dir 124 | shutil.copy(config_path, os.path.join(finetune_config.io.output_dir, "config.json")) 125 | 126 | accelerator.print("Finished.") 127 | 128 | 129 | if __name__ == "__main__": 130 | tyro.cli(convert_checkpoint_to_model) 131 | -------------------------------------------------------------------------------- /README.MD: -------------------------------------------------------------------------------- 1 | # ViewDiff: 3D-Consistent Image Generation with Text-to-Image Models 2 | ViewDiff generates high-quality, multi-view consistent images of a real-world 3D object in authentic surroundings. 3 | 4 | This is the official repository that contains source code for the CVPR 2024 paper [ViewDiff](https://lukashoel.github.io/ViewDiff/). 5 | 6 | [[arXiv](https://arxiv.org/abs/2403.01807)] [[Project Page](https://lukashoel.github.io/ViewDiff/)] [[Video](https://youtu.be/SdjoCqHzMMk)] 7 | 8 | ![Teaser](docs/teaser.jpg "ViewDiff") 9 | 10 | If you find ViewDiff useful for your work please cite: 11 | ``` 12 | @inproceedings{hoellein2024viewdiff, 13 | title={ViewDiff: 3D-Consistent Image Generation with Text-To-Image Models}, 14 | author={H{\"o}llein, Lukas and Bo\v{z}i\v{c}, Alja\v{z} and M{\"u}ller, Norman and Novotny, David and Tseng, Hung-Yu and Richardt, Christian and Zollh{\"o}fer, Michael and Nie{\ss}ner, Matthias}, 15 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 16 | year={2024} 17 | } 18 | ``` 19 | 20 | ## Installation 21 | 22 | Create a conda environment with all required dependencies: 23 | 24 | ``` 25 | conda create -n viewdiff python=3.10 26 | conda activate viewdiff 27 | pip install -r requirements.txt 28 | ``` 29 | 30 | Then install Pytorch3D by following the official instructions. For example, to install Pytorch3D on Linux (tested with Pytorch3D 0.7.4): 31 | 32 | ``` 33 | conda install -c fvcore -c iopath -c conda-forge fvcore iopath 34 | pip install "git+https://github.com/facebookresearch/pytorch3d.git@stable" 35 | ``` 36 | 37 | Then manually update triton to the required version: 38 | 39 | ``` 40 | pip install --upgrade triton==2.1.0 41 | ``` 42 | 43 | ## Data Preparation 44 | 45 | - Download CO3D categories that you would like to train on. Follow the official instructions here: https://github.com/facebookresearch/co3d. You should end up with a directory structure like this: 46 | 47 | ``` 48 | 49 | /teddybear 50 | /hydrant 51 | /donut 52 | /apple 53 | ... 54 | ``` 55 | 56 | - Generate BLIP2 text captions from the images for each category: 57 | 58 | ``` 59 | export CO3DV2_DATASET_ROOT= 60 | python -m viewdiff.data.co3d.generate_blip2_captions --dataset-config.co3d_root --output_file /co3d_blip2_captions.json 61 | ``` 62 | 63 | - Generate the prior preservation (aka Dreambooth) dataset for each category: 64 | 65 | ``` 66 | export CO3DV2_DATASET_ROOT= 67 | python -m viewdiff.data.co3d.generate_co3d_dreambooth_data --prompt_file /co3d_blip2_captions.json --output_path /dreambooth_prior_preservation_dataset 68 | ``` 69 | 70 | - Recenter the poses of each object, such that the object lies within the unit cube: 71 | 72 | ``` 73 | export CO3DV2_DATASET_ROOT= 74 | python -m viewdiff.data.co3d.save_recentered_sequences --dataset-config.co3d_root 75 | ``` 76 | 77 | ## Training 78 | 79 | Execute the following script (requires 2x A100 80GB GPUs): 80 | 81 | ``` 82 | ./viewdiff/scripts/train.sh "stabilityai/stable-diffusion-2-1-base" outputs/train 83 | ``` 84 | 85 | If you only have a smaller GPU available and want to sanity check that everything is working, you can execute this script (e.g. on a RTX 3090 GPU): 86 | 87 | ``` 88 | ./viewdiff/scripts/train_small.sh "stabilityai/stable-diffusion-2-1-base" outputs/train 89 | ``` 90 | 91 | In our experiments, we train the model for 60K iterations. 92 | 93 | ## Evaluation 94 | 95 | First, export a trained model to a runnable checkpoint: 96 | 97 | ``` 98 | python -m viewdiff.convert_checkpoint_to_model --checkpoint-path 99 | ``` 100 | 101 | ### Create 360 Degree Images In A Single Batch 102 | 103 | Execute the following script: 104 | 105 | ``` 106 | ./viewdiff/scripts/test/test_spherical_360_256x256.sh outputs/single-batch-uncond-generation 107 | ``` 108 | 109 | This creates ```num_images``` images of a single object in a single forward pass of the model (first row in the teaser image). 110 | In total ```num_steps``` objects will be created. 111 | 112 | ### Create Smooth Rendering Around An Object (Elevation=60 Degrees) 113 | 114 | Execute the following script: 115 | 116 | ``` 117 | ./viewdiff/scripts/test/test_sliding_window_smooth_alternating_theta_60_360_256x256.sh outputs/smooth-autoregressive-theta-60 118 | ``` 119 | 120 | This creates a video rendering of an object in a spherical trajectory at 60 degrees elevation. 121 | 122 | ### Create Smooth Rendering Around An Object (Elevation=30 Degrees) 123 | 124 | Execute the following script: 125 | 126 | ``` 127 | ./viewdiff/scripts/test/test_sliding_window_smooth_alternating_theta_30_360_256x256.sh outputs/smooth-autoregressive-theta-30 128 | ``` 129 | 130 | This creates a video rendering of an object in a spherical trajectory at 30 degrees elevation. 131 | 132 | ### Condition Generation On Single Input Image 133 | 134 | Execute the following script: 135 | 136 | ``` 137 | ./viewdiff/scripts/test/eval_single_image_input.sh outputs/single-image-eval 138 | ``` 139 | 140 | This renders novel views for an object of the test set given a single image input. 141 | We also save the quantitative metrics PSNR, SSIM, LPIPS in the output directory. 142 | 143 | ### Optimize a NeRF 144 | 145 | We provide an easy way to train a NeRF from our generated images. 146 | When creating a smooth rendering, we save a ```transforms.json``` file in the standard NeRF convention, that can be used to optimize a NeRF for the generated object. 147 | It can be used with standard NeRF frameworks like [Instant-NGP](https://github.com/NVlabs/instant-ngp) or [NeRFStudio](https://github.com/nerfstudio-project/nerfstudio). 148 | 149 | ## LICENSE 150 | 151 | The majority of this repository is licensed under CC-BY-NC, however portions of the project are available under separate license terms: 152 | - diffusers is licensed under the Apache 2.0 license. We use the repository to extend the default U-Net architecture, by adapting the model definition found in the original library. 153 | 154 | - Fastplane (aka [Lightplane](https://lightplane.github.io/)) module for memory-efficient volumetric rendering is licensed under the BSD License. Please consider citing the original paper if you find the module useful: 155 | ``` 156 | @article{cao2024lightplane, 157 | author = {Ang Cao and Justin Johnson and Andrea Vedaldi and David Novotny}, 158 | title = {Lightplane: Highly-Scalable Components for Neural 3D Fields}, 159 | journal = {arXiv}, 160 | year = {2024}, 161 | } 162 | ``` 163 | -------------------------------------------------------------------------------- /viewdiff/scripts/misc/calculate_mean_image_stats.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | import numpy as np 4 | import torch 5 | import os 6 | import tyro 7 | import json 8 | 9 | 10 | @torch.no_grad() 11 | def main( 12 | input_path: str 13 | ): 14 | # get all metrics from several batches 15 | metrics_files = [os.path.join(input_path, x) for x in os.listdir(input_path) if "metrics" in x and "combined" not in x] 16 | 17 | # create per-frame average and maximum metrics, as well as the metrics from the "average-image" (see viewset-diffusion for more details...) 18 | avg_metrics_dict = {} 19 | best_metrics_dict = {} 20 | avg_image_metrics_dict = {} 21 | invalid_indices_dict = {} 22 | all_offset = 0 23 | for mf in metrics_files: 24 | # open metric file & read header infos 25 | with open(mf, "r") as f: 26 | mf = json.load(f) 27 | batch_size = mf["batch-size"] 28 | n_repeat_generation = mf["n_repeat_generation"] 29 | if "n_images_per_batch" in mf and "n_known_images_per_batch" in mf: 30 | n_images_per_batch = mf["n_images_per_batch"] - mf["n_known_images_per_batch"] 31 | else: 32 | n_images_per_batch = 1 33 | 34 | # these metrics should be in the file, loop over them... 35 | for metric_type in ["psnr", "lpips", "ssim"]: 36 | # create empty dicts for each metric in the final dicts 37 | if metric_type not in avg_metrics_dict: 38 | avg_metrics_dict[metric_type] = {} 39 | if metric_type not in best_metrics_dict: 40 | best_metrics_dict[metric_type] = {} 41 | if metric_type not in avg_image_metrics_dict: 42 | avg_image_metrics_dict[metric_type] = {} 43 | 44 | # go through all per-frame versions 45 | all_scores = mf[metric_type]["all"] 46 | assert len(all_scores) == n_repeat_generation 47 | for score_dict in all_scores: 48 | for key, val_list in score_dict.items(): 49 | assert len(val_list) == batch_size * n_images_per_batch 50 | 51 | # mark inf values as invalid 52 | for val_idx in range(len(val_list)): 53 | if val_list[val_idx] == float("inf"): 54 | if metric_type not in invalid_indices_dict: 55 | invalid_indices_dict[metric_type] = {} 56 | if key not in invalid_indices_dict[metric_type]: 57 | invalid_indices_dict[metric_type][key] = [] 58 | invalid_indices_dict[metric_type][key].append(val_idx + all_offset) 59 | 60 | # select best result per-frame from "all" 61 | if key not in best_metrics_dict[metric_type]: 62 | assert all_offset == 0 63 | best_metrics_dict[metric_type][key] = [] 64 | if len(best_metrics_dict[metric_type][key]) < (all_offset + batch_size * n_images_per_batch): 65 | best_metrics_dict[metric_type][key].extend([x for x in val_list]) 66 | else: 67 | for val_idx in range(len(val_list)): 68 | if metric_type == "psnr" or metric_type == "ssim": 69 | if val_list[val_idx] > best_metrics_dict[metric_type][key][all_offset + val_idx]: 70 | best_metrics_dict[metric_type][key][all_offset + val_idx] = val_list[val_idx] 71 | else: 72 | if val_list[val_idx] < best_metrics_dict[metric_type][key][all_offset + val_idx]: 73 | best_metrics_dict[metric_type][key][all_offset + val_idx] = val_list[val_idx] 74 | 75 | # sum results per-frame from "all" 76 | if key not in avg_metrics_dict[metric_type]: 77 | assert all_offset == 0 78 | avg_metrics_dict[metric_type][key] = [] 79 | if len(avg_metrics_dict[metric_type][key]) < (all_offset + batch_size * n_images_per_batch): 80 | avg_metrics_dict[metric_type][key].extend([x for x in val_list]) 81 | else: 82 | for val_idx in range(len(val_list)): 83 | avg_metrics_dict[metric_type][key][all_offset + val_idx] += val_list[val_idx] 84 | 85 | # calc per-frame average 86 | for key in avg_metrics_dict[metric_type].keys(): 87 | for val_idx in range(batch_size * n_images_per_batch): 88 | avg_metrics_dict[metric_type][key][all_offset + val_idx] /= n_repeat_generation 89 | 90 | # concat together the scores from "from-avg-image" 91 | avg_scores = mf[metric_type]["from-avg-image"] 92 | for key, val_list in avg_scores.items(): 93 | if key not in avg_image_metrics_dict[metric_type]: 94 | avg_image_metrics_dict[metric_type][key] = [] 95 | val_list = [x for x in val_list if x != float("inf")] # filter inf values from val_list 96 | avg_image_metrics_dict[metric_type][key].extend([x for x in val_list]) 97 | 98 | # next values from next metrics dict should be written at this offset 99 | all_offset += batch_size * n_images_per_batch 100 | 101 | # calc average across frames for all dicts 102 | for dict_key, metrics_dict in zip(["avg", "best", "avg-image"], [avg_metrics_dict, best_metrics_dict, avg_image_metrics_dict]): 103 | total_dict = {} 104 | for metric_type, metric_dict in metrics_dict.items(): 105 | out_metric_type = f"0-total-{metric_type}" 106 | total_dict[out_metric_type] = {} 107 | for key, val_list in metric_dict.items(): 108 | # filter inf values 109 | if dict_key == "avg" or dict_key == "best": 110 | if metric_type in invalid_indices_dict and key in invalid_indices_dict[metric_type]: 111 | print(f"ignore indices for {dict_key}-{metric_type}-{key}: {invalid_indices_dict[metric_type][key]}. values: {[val_list[x] for x in invalid_indices_dict[metric_type][key]]}") 112 | val_list = [x for i, x in enumerate(val_list) if i not in invalid_indices_dict[metric_type][key]] 113 | 114 | # calc mean/std 115 | val_list = np.array(val_list) 116 | mean_val = val_list.mean() 117 | std_val = val_list.std() 118 | total_dict[out_metric_type][key] = { 119 | "mean": mean_val, 120 | "std": std_val 121 | } 122 | metrics_dict = {**total_dict, **metrics_dict} 123 | 124 | # save file 125 | output_file_path = os.path.join(input_path, f"combined_{dict_key}_metrics.json") 126 | with open(output_file_path, "w") as f: 127 | json.dump(metrics_dict, f, indent=4, sort_keys=True) 128 | 129 | 130 | if __name__ == '__main__': 131 | tyro.cli(main) 132 | -------------------------------------------------------------------------------- /viewdiff/data/co3d/util.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | from typing import Tuple 4 | from omegaconf import DictConfig 5 | 6 | import os 7 | import torch 8 | 9 | from pytorch3d.implicitron.dataset.json_index_dataset_map_provider_v2 import JsonIndexDatasetMapProviderV2 10 | from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset 11 | from pytorch3d.implicitron.tools.config import expand_args_fields 12 | 13 | 14 | def json_index_dataset_load_category(dataset_root: str, category: str, dataset_args: DictConfig): 15 | # adapted from https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/implicitron/dataset/json_index_dataset_map_provider_v2.py 16 | frame_file = os.path.join( 17 | dataset_root, 18 | category, 19 | "frame_annotations.jgz", 20 | ) 21 | sequence_file = os.path.join( 22 | dataset_root, 23 | category, 24 | "sequence_annotations.jgz", 25 | ) 26 | 27 | if not os.path.isfile(frame_file): 28 | # The frame_file does not exist. 29 | # Most probably the user has not specified the root folder. 30 | raise ValueError( 31 | f"Looking for frame annotations in {frame_file}." + " Please specify a correct dataset_root folder." 32 | ) 33 | 34 | # setup the common dataset arguments 35 | common_dataset_kwargs = { 36 | **dataset_args, 37 | "dataset_root": dataset_root, 38 | "frame_annotations_file": frame_file, 39 | "sequence_annotations_file": sequence_file, 40 | "subsets": None, 41 | "subset_lists_file": "", 42 | } 43 | 44 | # get the used dataset type 45 | expand_args_fields(JsonIndexDataset) 46 | dataset = JsonIndexDataset(**common_dataset_kwargs) 47 | 48 | return dataset 49 | 50 | 51 | def get_dataset(co3d_root, category, subset, split, **dataset_kw_args): 52 | print(f"start parse dataset for category={category}, subset={subset}") 53 | dataset_args = DictConfig(dataset_kw_args) 54 | 55 | if subset is None: 56 | # directly load JsonIndexDataset, do not use subset 57 | dataset = json_index_dataset_load_category(dataset_root=co3d_root, category=category, dataset_args=dataset_args) 58 | else: 59 | # use subset with JsonIndexDatasetMapProviderV2 60 | expand_args_fields(JsonIndexDatasetMapProviderV2) 61 | dataset_map = JsonIndexDatasetMapProviderV2( 62 | category=category, # load this category 63 | subset_name=subset, # load all sequences/frames that are specified in this subset 64 | test_on_train=False, # want to load the actual test data 65 | only_test_set=False, # want to have train/val/test splits accessible 66 | load_eval_batches=False, # for generation do not need eval batches, rather go through "test" split of each sequence 67 | dataset_JsonIndexDataset_args=dataset_args, 68 | ).get_dataset_map() 69 | dataset = dataset_map[split] 70 | print("finish parse dataset") 71 | 72 | return dataset 73 | 74 | 75 | def has_pointcloud(co3d_root: str, category: str, sequence_name: str) -> bool: 76 | """checks if the specified sequence has at a pointcloud object in the dataset. 77 | 78 | Args: 79 | co3d_root (str): root dir of dataset 80 | category (str): category of sequence 81 | sequence_name (str): sequence to check 82 | 83 | Returns: 84 | bool: True if the pointcloud exists, else False. 85 | """ 86 | return os.path.exists(os.path.join(co3d_root, category, sequence_name, "pointcloud.ply")) 87 | 88 | 89 | def get_crop_around_mask(mask, th, tw): 90 | # sanity checks 91 | h, w = mask.shape 92 | assert ( 93 | h >= th and w >= tw 94 | ), f"crop height/width must not be larger than original height/width: orig=({h}, {w}), crop=({th}, {tw})" 95 | if th > h or tw > w: 96 | raise ValueError(f"Required crop size {(th, tw)} is larger than input image size {(h, w)}") 97 | 98 | # get mask center coordinate 99 | coords = torch.nonzero(mask).float() 100 | mean_coord = torch.mean(coords, dim=0).int() 101 | 102 | # get top/left corner of crop rectangle 103 | top = max(0, mean_coord[0] - th // 2) 104 | left = max(0, mean_coord[1] - tw // 2) 105 | 106 | # check bounds 107 | top -= max(0, top + th - h) 108 | left -= max(0, left + tw - w) 109 | 110 | return top, left, th, tw 111 | 112 | 113 | def adjust_crop_size(orig_hw, crop_hw): 114 | # extract values 115 | h, w = orig_hw 116 | th, tw = crop_hw 117 | 118 | # adjust crop_size such that the larger crop_size has the same size as orig_size 119 | scale = min(h / th, w / tw) 120 | new_h = int(th * scale) 121 | new_w = int(tw * scale) 122 | 123 | return new_h, new_w 124 | 125 | 126 | def scale_intrinsics(K: torch.Tensor, orig_hw: Tuple[int, int], resized_hw: Tuple[int, int]) -> torch.Tensor: 127 | scaling_factor_h = resized_hw[0] / orig_hw[0] 128 | scaling_factor_w = resized_hw[1] / orig_hw[1] 129 | 130 | K = K.clone() 131 | K[..., 0, 0] = K[..., 0, 0] * scaling_factor_w 132 | K[..., 1, 1] = K[..., 1, 1] * scaling_factor_h 133 | 134 | # K[..., 0, 2] = K[..., 0, 2] * scaling_factor_w 135 | # K[..., 1, 2] = K[..., 1, 2] * scaling_factor_h 136 | 137 | # We assume opencv-convention ((0, 0) refers to the center of the top-left pixel and (-0.5, -0.5) is the top-left corner of the image-plane): 138 | # We need to scale the principal offset with the 0.5 add/sub like here. 139 | # see this for explanation: https://dsp.stackexchange.com/questions/6055/how-does-resizing-an-image-affect-the-intrinsic-camera-matrix 140 | K[..., 0, 2] = (K[..., 0, 2] + 0.5) * scaling_factor_w - 0.5 141 | K[..., 1, 2] = (K[..., 1, 2] + 0.5) * scaling_factor_h - 0.5 142 | 143 | return K 144 | 145 | 146 | def scale_depth( 147 | depth: torch.Tensor, 148 | orig_K: torch.Tensor, 149 | orig_hw: Tuple[int, int], 150 | resized_K: torch.Tensor, 151 | resized_hw: Tuple[int, int], 152 | ) -> torch.Tensor: 153 | # downsampling depth w/ nearest interpolation results in wrong coordinates 154 | # this projects pixel centers from the resized to the original image-plane and grid_samples the depth value at those positions (using nearest interpolation). 155 | 156 | # construct coordinate grid 157 | # convention: (0, 0) is the center of the pixel 158 | x = torch.arange(resized_hw[1], device=depth.device, dtype=torch.float32) 159 | y = torch.arange(resized_hw[0], device=depth.device, dtype=torch.float32) 160 | 161 | grid_x, grid_y = torch.meshgrid(x, y, indexing="xy") 162 | grid_xy = torch.cat( 163 | [ 164 | grid_x.view(resized_hw[0], resized_hw[1], 1), 165 | grid_y.view(resized_hw[0], resized_hw[1], 1), 166 | ], 167 | dim=2, 168 | ) 169 | grid_xy = grid_xy[None].repeat(depth.shape[0], 1, 1, 1) 170 | 171 | # apply inverse scaled intrinsics 172 | grid_xy[..., 0] = (1.0 / resized_K[..., 0, 0][..., None, None]) * ( 173 | grid_xy[..., 0] - resized_K[..., 0, 2][..., None, None] 174 | ) 175 | grid_xy[..., 1] = (1.0 / resized_K[..., 1, 1][..., None, None]) * ( 176 | grid_xy[..., 1] - resized_K[..., 1, 2][..., None, None] 177 | ) 178 | 179 | # apply original intrinsics 180 | grid_xy[..., 0] = orig_K[..., 0, 0][..., None, None] * grid_xy[..., 0] + orig_K[..., 0, 2][..., None, None] 181 | grid_xy[..., 1] = orig_K[..., 1, 1][..., None, None] * grid_xy[..., 1] + orig_K[..., 1, 2][..., None, None] 182 | 183 | # go from [-0.5, orig_hw - 0.5] to [-1, 1] 184 | grid_xy[..., 0] = (grid_xy[..., 0] + 0.5) / orig_hw[1] * 2 - 1.0 185 | grid_xy[..., 1] = (grid_xy[..., 1] + 0.5) / orig_hw[0] * 2 - 1.0 186 | 187 | # grid_sample with coordinates in orig screen-space 188 | scaled_depth = torch.nn.functional.grid_sample( 189 | depth.unsqueeze(1), 190 | grid_xy, 191 | mode="nearest", 192 | padding_mode="zeros", 193 | align_corners=False, 194 | ).squeeze(1) 195 | 196 | return scaled_depth 197 | 198 | 199 | def scale_camera_center(world2cam: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: 200 | x = world2cam.clone() 201 | x[..., :3, 3:4] *= scale[..., None, None] 202 | 203 | return x 204 | 205 | 206 | def scale_bbox(bbox: torch.Tensor) -> torch.Tensor: 207 | # bbox assumed to be (N, 2, 3) 208 | largest_side_length = torch.max(bbox[:, 1] - bbox[:, 0], dim=1).values # e.g. 1.0 209 | scale = 2 / largest_side_length # want largest side to be 2.0 210 | bbox = bbox * scale[..., None, None] # scale bbox uniformly 211 | return bbox, scale -------------------------------------------------------------------------------- /viewdiff/model/projection/fastplane/fastplane_module.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 (c) Meta Platforms, Inc. and affiliates. 2 | # Inspired by Lightplane: https://github.com/facebookresearch/lightplane/tree/main 3 | # 4 | # BSD License 5 | # 6 | # For Lightplane software 7 | # 8 | # Copyright (c) Meta Platforms, Inc. and affiliates. 9 | # 10 | # Redistribution and use in source and binary forms, with or without modification, 11 | # are permitted provided that the following conditions are met: 12 | # 13 | # * Redistributions of source code must retain the above copyright notice, this 14 | # list of conditions and the following disclaimer. 15 | # 16 | # * Redistributions in binary form must reproduce the above copyright notice, 17 | # this list of conditions and the following disclaimer in the documentation 18 | # and/or other materials provided with the distribution. 19 | # 20 | # * Neither the name Meta nor the names of its contributors may be used to 21 | # endorse or promote products derived from this software without specific 22 | # prior written permission. 23 | # 24 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 25 | # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 26 | # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 27 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR 28 | # ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 29 | # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 30 | # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 31 | # ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 32 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 33 | # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 34 | 35 | import torch 36 | import math 37 | from typing import Optional, Tuple, Union 38 | from .fastplane_sig_function import ( 39 | fastplane, 40 | FastplaneShapeRepresentation, 41 | N_LAYERS, 42 | FastplaneActivationFun, 43 | ) 44 | 45 | 46 | MIN_KERNEL_RENDER_DIM = 16 47 | 48 | 49 | class FastplaneModule(torch.nn.Module): 50 | def __init__( 51 | self, 52 | mlp_n_hidden: int, 53 | render_dim: int, 54 | num_samples: int, 55 | num_samples_inf: int = 0, 56 | opacity_init_bias: float = -5.0, 57 | gain: float = 1.0, 58 | BLOCK_SIZE: int = 16, 59 | transmittance_thr: float = 0.0, 60 | mask_out_of_bounds_samples: bool = False, 61 | inject_noise_sigma: float = 0.0, 62 | inject_noise_seed: Optional[int] = None, 63 | contract_coords: bool = False, 64 | contract_perc_foreground: float = 0.5, 65 | disparity_at_inf: float = 1e-5, 66 | shape_representation: FastplaneShapeRepresentation = FastplaneShapeRepresentation.TRIPLANE, 67 | activation_fun: FastplaneActivationFun = FastplaneActivationFun.SOFTPLUS, 68 | bg_color: Union[Tuple[float, ...], float] = 0.0, 69 | ): 70 | super().__init__() 71 | 72 | self.num_samples = num_samples 73 | self.num_samples_inf = num_samples_inf 74 | self.opacity_init_bias = opacity_init_bias 75 | self.gain = gain 76 | self.BLOCK_SIZE = BLOCK_SIZE 77 | self.transmittance_thr = transmittance_thr 78 | self.mask_out_of_bounds_samples = mask_out_of_bounds_samples 79 | self.inject_noise_sigma = inject_noise_sigma 80 | self.inject_noise_seed = inject_noise_seed 81 | self.contract_coords = contract_coords 82 | self.contract_perc_foreground = contract_perc_foreground 83 | self.disparity_at_inf = disparity_at_inf 84 | self.shape_representation = shape_representation 85 | self.activation_fun = activation_fun 86 | self.render_dim = render_dim 87 | 88 | kernel_render_dim = max(MIN_KERNEL_RENDER_DIM, render_dim) 89 | self.mlp_weights = torch.nn.Parameter(torch.zeros(N_LAYERS, mlp_n_hidden, mlp_n_hidden)) 90 | for i in range(N_LAYERS): 91 | torch.nn.init.xavier_uniform_(self.mlp_weights.data[i]) 92 | self.mlp_biases = torch.nn.Parameter(torch.zeros(N_LAYERS, mlp_n_hidden)) 93 | self.weight_opacity = torch.nn.Parameter( 94 | torch.rand(mlp_n_hidden) * (2 / math.sqrt(mlp_n_hidden)) - 1 / math.sqrt(mlp_n_hidden) 95 | ) # xavier init 96 | self.bias_opacity = torch.nn.Parameter(torch.zeros(1) + opacity_init_bias) 97 | self.weight_color = torch.nn.Parameter(torch.zeros(mlp_n_hidden, max(render_dim, kernel_render_dim))) 98 | torch.nn.init.xavier_uniform_(self.weight_color.data) 99 | self.bias_color = torch.nn.Parameter(torch.zeros(kernel_render_dim)) 100 | self.register_buffer("bg_color", self._process_bg_color(bg_color)) 101 | 102 | def _process_bg_color(self, bg_color: Union[Tuple[float, ...], float]) -> torch.Tensor: 103 | if isinstance(bg_color, float): 104 | bg_color = torch.tensor([bg_color] * self.render_dim, dtype=torch.float) 105 | elif not isinstance(bg_color, torch.Tensor): 106 | bg_color = torch.tensor(bg_color, dtype=torch.float) 107 | assert len(bg_color) == self.render_dim 108 | return bg_color 109 | 110 | def forward( 111 | self, 112 | rays: torch.Tensor, 113 | centers: torch.Tensor, 114 | rays_encoding: torch.Tensor, 115 | near: torch.Tensor, 116 | far: torch.Tensor, 117 | v: Optional[torch.Tensor] = None, # voxel grid input 118 | v_color: Optional[torch.Tensor] = None, 119 | xy: Optional[torch.Tensor] = None, # triplane input 120 | yz: Optional[torch.Tensor] = None, 121 | zx: Optional[torch.Tensor] = None, 122 | xy_color: Optional[torch.Tensor] = None, 123 | yz_color: Optional[torch.Tensor] = None, 124 | zx_color: Optional[torch.Tensor] = None, 125 | inject_noise_sigma: Optional[float] = None, 126 | bg_color: Union[Tuple[float, ...], float, None] = None, 127 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 128 | inject_noise_sigma = self.inject_noise_sigma if inject_noise_sigma is None else inject_noise_sigma 129 | 130 | device = rays.device 131 | 132 | # check inputs 133 | if xy is not None: 134 | # triplane input 135 | assert v is None 136 | assert v_color is None 137 | xy_or_v = xy 138 | xy_color_or_v_color = xy_color 139 | else: 140 | # voxel grid input 141 | assert v is not None 142 | for x_ in [xy, yz, zx, xy_color, yz_color, zx_color]: 143 | assert x_ is None 144 | xy_or_v = v 145 | xy_color_or_v_color = v_color 146 | 147 | bg_color = (self.bg_color if bg_color is None else self._process_bg_color(bg_color)).to(device) 148 | 149 | ray_length_render, negative_log_transmittance, feature_render = fastplane( 150 | xy=xy_or_v, 151 | yz=yz, 152 | zx=zx, 153 | xy_color=xy_color_or_v_color, 154 | yz_color=yz_color, 155 | zx_color=zx_color, 156 | weights=self.mlp_weights, 157 | biases=self.mlp_biases, 158 | weight_opacity=self.weight_opacity, 159 | bias_opacity=self.bias_opacity, 160 | weight_color=self.weight_color, 161 | bias_color=self.bias_color, 162 | rays=rays, 163 | centers=centers, 164 | rays_encoding=rays_encoding, 165 | near=near, 166 | far=far, 167 | num_samples=self.num_samples, 168 | num_samples_inf=self.num_samples_inf, 169 | gain=self.gain, 170 | BLOCK_SIZE=self.BLOCK_SIZE, 171 | transmittance_thr=self.transmittance_thr, 172 | mask_out_of_bounds_samples=self.mask_out_of_bounds_samples, 173 | inject_noise_sigma=inject_noise_sigma, 174 | inject_noise_seed=self.inject_noise_seed, 175 | contract_coords=self.contract_coords, 176 | contract_perc_foreground=self.contract_perc_foreground, 177 | disparity_at_inf=self.disparity_at_inf, 178 | shape_representation=self.shape_representation, 179 | activation_fun=self.activation_fun, 180 | ) 181 | 182 | mask = 1 - torch.exp(-negative_log_transmittance) 183 | 184 | # apply the bg color 185 | feature_render = feature_render[..., : self.render_dim] + (1 - mask[..., None]) * bg_color 186 | 187 | return feature_render, mask, ray_length_render 188 | -------------------------------------------------------------------------------- /viewdiff/scripts/misc/process_nerfstudio_to_sdfstudio.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | import argparse 4 | import json 5 | import os 6 | import cv2 7 | import numpy as np 8 | import PIL 9 | from pathlib import Path 10 | from PIL import Image 11 | from torchvision import transforms 12 | 13 | 14 | def main(args): 15 | """ 16 | Given data that follows the nerfstudio format such as the output from colmap or polycam, 17 | convert to a format that sdfstudio will ingest 18 | """ 19 | output_dir = args.output_dir 20 | input_dir = args.input_dir 21 | os.makedirs(output_dir, exist_ok=True) 22 | 23 | with open(os.path.join(input_dir, "transforms.json"), "r") as f: 24 | cam_params = json.load(f) 25 | 26 | # === load camera intrinsics and poses === 27 | cam_intrinsics = [] 28 | frames = cam_params["frames"] 29 | poses = [] 30 | image_paths = [] 31 | mask_paths = [] 32 | # only load images with corresponding pose info 33 | # currently in random order??, probably need to sort 34 | for frame in frames: 35 | # load intrinsics 36 | cam_intrinsics.append(np.array([ 37 | [frame["fl_x"], 0, frame["cx"]], 38 | [0, frame["fl_y"], frame["cy"]], 39 | [0, 0, 1]])) 40 | 41 | # load poses 42 | # OpenGL/Blender convention, needs to change to COLMAP/OpenCV convention 43 | # https://docs.nerf.studio/en/latest/quickstart/data_conventions.html 44 | # IGNORED for now 45 | c2w = np.array(frame["transform_matrix"]).reshape(4, 4) 46 | c2w[0:3, 1:3] *= -1 47 | poses.append(c2w) 48 | 49 | # load images 50 | file_path = Path(frame["file_path"]) 51 | img_path = os.path.join(input_dir, "images", file_path.name) 52 | assert os.path.exists(img_path) 53 | image_paths.append(img_path) 54 | 55 | # load masks 56 | mask_path = os.path.join(input_dir, "masks", f"{file_path.stem}.png") 57 | assert os.path.exists(mask_path) 58 | mask_paths.append(mask_path) 59 | 60 | # Check correctness 61 | assert len(poses) == len(image_paths) 62 | assert len(mask_paths) == len(image_paths) 63 | assert len(poses) == len(cam_intrinsics) 64 | 65 | # Filter invalid poses 66 | poses = np.array(poses) 67 | valid_poses = np.isfinite(poses).all(axis=2).all(axis=1) 68 | min_vertices = poses[:, :3, 3][valid_poses].min(axis=0) 69 | max_vertices = poses[:, :3, 3][valid_poses].max(axis=0) 70 | 71 | # === Normalize the scene === 72 | if args.scene_type in ["indoor", "object"]: 73 | # Enlarge bbox by 1.05 for object scene and by 5.0 for indoor scene 74 | # TODO: Adaptively estimate `scene_scale_mult` based on depth-map or point-cloud prior 75 | if not args.scene_scale_mult: 76 | args.scene_scale_mult = 1.05 if args.scene_type == "object" else 5.0 77 | scene_scale = 2.0 / (np.max(max_vertices - min_vertices) * args.scene_scale_mult) 78 | scene_center = (min_vertices + max_vertices) / 2.0 79 | # normalize pose to unit cube 80 | poses[:, :3, 3] -= scene_center 81 | poses[:, :3, 3] *= scene_scale 82 | # calculate scale matrix 83 | scale_mat = np.eye(4).astype(np.float32) 84 | scale_mat[:3, 3] -= scene_center 85 | scale_mat[:3] *= scene_scale 86 | scale_mat = np.linalg.inv(scale_mat) 87 | else: 88 | scene_scale = 1.0 89 | scale_mat = np.eye(4).astype(np.float32) 90 | 91 | # === Construct the scene box === 92 | if args.scene_type == "indoor": 93 | scene_box = { 94 | "aabb": [[-1, -1, -1], [1, 1, 1]], 95 | "near": 0.05, 96 | "far": 2.5, 97 | "radius": 1.0, 98 | "collider_type": "box", 99 | } 100 | elif args.scene_type == "object": 101 | scene_box = { 102 | "aabb": [[-1, -1, -1], [1, 1, 1]], 103 | "near": 0.6, # 0.05 104 | "far": 2.0, 105 | "radius": 1.0, 106 | "collider_type": "near_far", 107 | } 108 | elif args.scene_type == "unbound": 109 | # TODO: case-by-case near far based on depth prior 110 | # such as colmap sparse points or sensor depths 111 | scene_box = { 112 | "aabb": [min_vertices.tolist(), max_vertices.tolist()], 113 | "near": 0.05, 114 | "far": 2.5 * np.max(max_vertices - min_vertices), 115 | "radius": np.min(max_vertices - min_vertices) / 2.0, 116 | "collider_type": "box", 117 | } 118 | else: 119 | raise NotImplementedError("unknown scene_type", args.scene_type) 120 | 121 | # === Resize the images and intrinsics === 122 | # Only resize the images when we want to use mono prior 123 | with open(image_paths[0], "rb") as f: 124 | array = np.asarray(bytearray(f.read()), dtype=np.uint8) 125 | sample_img = cv2.imdecode(array, cv2.IMREAD_UNCHANGED) 126 | h, w, _ = sample_img.shape 127 | if args.mono_prior: 128 | # get smallest side to generate square crop 129 | target_crop = min(h, w) 130 | tar_h = tar_w = 384 * args.crop_mult 131 | rgb_trans = transforms.Compose( 132 | [ 133 | transforms.CenterCrop(target_crop), 134 | transforms.Resize((tar_h, tar_w), interpolation=PIL.Image.BILINEAR) 135 | ] 136 | ) 137 | mask_trans = transforms.Compose( 138 | [ 139 | transforms.CenterCrop(target_crop), 140 | transforms.Resize((tar_h, tar_w), interpolation=PIL.Image.NEAREST) 141 | ] 142 | ) 143 | 144 | # Update camera intrinsics 145 | offset_x = (w - target_crop) * 0.5 146 | offset_y = (h - target_crop) * 0.5 147 | resize_factor = tar_h / target_crop 148 | for intrinsics in cam_intrinsics: 149 | # center crop by min_dim 150 | intrinsics[0, 2] -= offset_x 151 | intrinsics[1, 2] -= offset_y 152 | # resize from min_dim x min_dim -> to 384 x 384 153 | intrinsics[:2, :] *= resize_factor 154 | 155 | # Do nothing if we don't want to use mono prior 156 | else: 157 | tar_h, tar_w = h, w 158 | rgb_trans = transforms.Compose([]) 159 | mask_trans = transforms.Compose([]) 160 | 161 | # === Construct the frames in the meta_data.json === 162 | frames = [] 163 | out_index = 0 164 | for idx, (valid, pose, image_path) in enumerate(zip(valid_poses, poses, image_paths)): 165 | if not valid: 166 | continue 167 | 168 | # save rgb image 169 | out_img_name = f"{out_index:06d}_rgb.png" 170 | out_img_path = os.path.join(output_dir, out_img_name) 171 | with open(image_path, "rb") as f: 172 | img = Image.open(f) 173 | img_tensor = rgb_trans(img) 174 | with open(out_img_path, "wb") as f2: 175 | img_tensor.save(f2) 176 | 177 | frame = { 178 | "rgb_path": out_img_name, 179 | "camtoworld": pose.tolist(), 180 | "intrinsics": cam_intrinsics[idx].tolist() 181 | } 182 | 183 | # load mask 184 | mask_path = mask_paths[idx] 185 | out_mask_name = f"{out_index:06d}_foreground_mask.png" 186 | out_mask_path = os.path.join(output_dir, out_mask_name) 187 | with open(mask_path, "rb") as f: 188 | mask_PIL = Image.open(f) 189 | new_mask = mask_trans(mask_PIL) 190 | with open(out_mask_path, "wb") as f2: 191 | new_mask.save(f2) 192 | frame["foreground_mask"] = out_mask_name 193 | 194 | if args.mono_prior: 195 | frame["mono_depth_path"] = out_img_name.replace("_rgb.png", "_depth.npy") 196 | frame["mono_normal_path"] = out_img_name.replace("_rgb.png", "_normal.npy") 197 | 198 | frames.append(frame) 199 | out_index += 1 200 | 201 | # === Construct and export the metadata === 202 | meta_data = { 203 | "camera_model": "OPENCV", 204 | "height": tar_h, 205 | "width": tar_w, 206 | "has_mono_prior": args.mono_prior, 207 | "has_sensor_depth": False, 208 | "has_foreground_mask": True, 209 | "pairs": None, 210 | "worldtogt": scale_mat.tolist(), 211 | "scene_box": scene_box, 212 | "frames": frames, 213 | } 214 | with open(os.path.join(output_dir, "meta_data.json"), "w") as f: 215 | json.dump(meta_data, f, indent=4) 216 | 217 | # === Generate mono priors using omnidata === 218 | if args.mono_prior: 219 | assert os.path.exists(args.pretrained_models), "Pretrained model path not found" 220 | assert os.path.exists(args.omnidata_path), "omnidata l path not found" 221 | # generate mono depth and normal 222 | print("Generating mono depth...") 223 | os.system( 224 | f"python extract_monocular_cues.py \ 225 | --omnidata_path {args.omnidata_path} \ 226 | --pretrained_model {args.pretrained_models} \ 227 | --img_path {output_dir} --output_path {output_dir} \ 228 | --task depth" 229 | ) 230 | print("Generating mono normal...") 231 | os.system( 232 | f"python extract_monocular_cues.py \ 233 | --omnidata_path {args.omnidata_path} \ 234 | --pretrained_model {args.pretrained_models} \ 235 | --img_path {output_dir} --output_path {output_dir} \ 236 | --task normal" 237 | ) 238 | 239 | 240 | if __name__ == "__main__": 241 | parser = argparse.ArgumentParser(description="preprocess nerfstudio dataset to sdfstudio dataset, " 242 | "currently support colmap and polycam") 243 | 244 | parser.add_argument("--data", dest="input_dir", required=True, help="path to nerfstudio data directory") 245 | parser.add_argument("--output-dir", dest="output_dir", required=True, help="path to output data directory") 246 | parser.add_argument("--scene-type", dest="scene_type", required=True, choices=["indoor", "object", "unbound"], 247 | help="The scene will be normalized into a unit sphere when selecting indoor or object.") 248 | parser.add_argument("--scene-scale-mult", dest="scene_scale_mult", type=float, default=None, 249 | help="The bounding box of the scene is firstly calculated by the camera positions, " 250 | "then multiply with scene_scale_mult") 251 | parser.add_argument("--mono-prior", dest="mono_prior", action="store_true", 252 | help="Whether to generate mono-prior depths and normals. " 253 | "If enabled, the images will be cropped to 384*384") 254 | parser.add_argument("--crop-mult", dest="crop_mult", type=int, default=1, 255 | help="image size will be resized to crop_mult*384, only take effect when enabling mono-prior") 256 | parser.add_argument("--omnidata-path", dest="omnidata_path", 257 | default="/omnidata/omnidata_tools/torch", 258 | help="path to omnidata model") 259 | parser.add_argument("--pretrained-models", dest="pretrained_models", 260 | default="/omnidata_tools/torch/pretrained_models/", 261 | help="path to pretrained models") 262 | 263 | args = parser.parse_args() 264 | 265 | main(args) 266 | -------------------------------------------------------------------------------- /viewdiff/scripts/misc/export_nerf_transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | from typing import Literal 4 | import os 5 | import json 6 | import copy 7 | import shutil 8 | from argparse import Namespace 9 | from tqdm.auto import tqdm 10 | from PIL import Image 11 | import numpy as np 12 | import torch 13 | import tyro 14 | from ...data.create_video_from_image_folder import main as create_video_from_image_folder 15 | from .create_masked_images import load_carvekit_bkgd_removal_model, segment_and_save_carvekit 16 | from .process_nerfstudio_to_sdfstudio import main as process_nerfstudio_to_sdfstudio 17 | 18 | import imageio 19 | import shutil 20 | 21 | 22 | def get_transforms_header(): 23 | return { 24 | "camera_model": "OPENCV", 25 | "aabb_scale": 1.0, 26 | "frames": [] 27 | } 28 | 29 | 30 | def save_smooth_video( 31 | image_folder: str, 32 | n_images_per_batch: int = 10, 33 | framerate: int = 15, 34 | skip_first_n_steps: int = 0, 35 | sort_type: Literal["alternating", "interleaving"] = "interleaving", 36 | ): 37 | images = os.listdir(image_folder) 38 | images = [x for x in images if "pred_file_" in x and "cond_" not in x] 39 | 40 | def sort_images(image_list: str): 41 | # get step,frame for each image 42 | def get_attr(x): 43 | parts = x.split("_") 44 | step = int(parts[3]) 45 | frame = int(parts[-1].split(".")[0]) 46 | return step, frame 47 | 48 | attrs = [get_attr(x) for x in image_list] 49 | 50 | # split into steps 51 | images_by_step = {} 52 | for (step, frame), img in zip(attrs, image_list): 53 | if step not in images_by_step: 54 | images_by_step[step] = [] 55 | images_by_step[step].append((frame, img)) 56 | 57 | # sort each step img_list 58 | for step, img_list in images_by_step.items(): 59 | images_by_step[step] = sorted(img_list, key=lambda x: x[0]) 60 | 61 | # combine 62 | final_list = [] 63 | n_steps = len(images_by_step.keys()) 64 | if sort_type == "interleaving": 65 | # sorting: from each step the _0000.png then _0001.png, ... 66 | for frame_idx in range(n_images_per_batch): 67 | for step_idx in range(skip_first_n_steps, n_steps): 68 | final_list.append(images_by_step[step_idx][frame_idx][1]) 69 | elif sort_type == "alternating": 70 | # sorting: first all odd steps in ascending order: _0000.png, _0001.png, ... 71 | # second all even steps in descending order: _0009.png, _0008.png, ... 72 | odd_step_list = [] 73 | even_step_list = [] 74 | for step_idx in range(skip_first_n_steps, n_steps): 75 | is_even = (step_idx % 2) == 0 76 | l = even_step_list if is_even else odd_step_list 77 | for frame_idx in range(n_images_per_batch): 78 | l.append(images_by_step[step_idx][frame_idx][1]) 79 | final_list = [*odd_step_list, *even_step_list[::-1]] 80 | 81 | return final_list 82 | 83 | images = sort_images(images) 84 | images = [os.path.join(image_folder, x) for x in images] 85 | 86 | # copy images to tmp folder 87 | temp_dir = os.path.join(image_folder, f"tmp") 88 | file_paths = [] 89 | if not os.path.exists(temp_dir): 90 | os.makedirs(temp_dir) 91 | for i, x in enumerate(images): 92 | file_name = f"{i:04d}_{os.path.basename(x)}_{i:04d}.png" 93 | file_path = os.path.join(str(temp_dir), file_name) 94 | file_paths.append(file_path) 95 | shutil.copy(x, file_path) 96 | 97 | # create video from tmp folder with images now in the correct order in that folder 98 | video_out_name = "smooth_render.mp4" 99 | temp_video_out_path = os.path.join(temp_dir, video_out_name) 100 | 101 | video_args = Namespace( 102 | **{ 103 | "image_folder": temp_dir, 104 | "file_name_pattern_glob": "*.png", 105 | "output_path": temp_video_out_path, 106 | "framerate": framerate, 107 | } 108 | ) 109 | create_video_from_image_folder(video_args) 110 | 111 | # create gif from tmp folder with images now in the correct order in that folder 112 | gif_out_name = "smooth_render.gif" 113 | temp_gif_out_path = os.path.join(temp_dir, gif_out_name) 114 | with imageio.get_writer(temp_gif_out_path, mode='I', duration=1.0 / framerate) as writer: 115 | for filename in file_paths: 116 | image = imageio.imread(filename) 117 | writer.append_data(image) 118 | 119 | # move video/gif out of tmp folder 120 | video_out_path = os.path.join(image_folder, video_out_name) 121 | gif_out_path = os.path.join(image_folder, gif_out_name) 122 | shutil.copy(temp_video_out_path, video_out_path) 123 | shutil.copy(temp_gif_out_path, gif_out_path) 124 | 125 | # remove tmp folder 126 | shutil.rmtree(temp_dir) 127 | 128 | 129 | def main( 130 | input_path: str, 131 | output_path: str = None, 132 | step: int = -1, 133 | combine_all: bool = False, 134 | skip_first_n_steps: int = 0, 135 | create_smooth_video: bool = False, 136 | smooth_video_sort_type: Literal["alternating", "interleaving"] = "interleaving", 137 | smooth_video_framerate: int = 15, 138 | n_images_per_batch: int = 10, 139 | carvekit_checkpoint_dir: str = None, 140 | ): 141 | # get all cams and images files 142 | files = os.listdir(input_path) 143 | 144 | # determine which steps should be processed 145 | if step > -1: 146 | steps = [step] 147 | else: 148 | steps = [int(f.split(".")[0].split("_")[1]) for f in files if "cams" in f] 149 | steps = sorted(steps) 150 | steps = steps[skip_first_n_steps:] 151 | 152 | # where to save 153 | if output_path is None: 154 | output_path = os.path.join(input_path, "exported_nerf_convention") 155 | 156 | # get mask segmentation model 157 | carvekit_model = load_carvekit_bkgd_removal_model(carvekit_checkpoint_dir) 158 | 159 | # separately create transforms for all steps 160 | step_output_paths = [] 161 | frame_dicts = [] 162 | masked_images_frame_dicts = [] 163 | for step in tqdm(steps, desc="Export files to NeRF convention"): 164 | # get corresponding files 165 | step_str = f"{step:04d}" 166 | cam_file = [os.path.join(input_path, f) for f in files if f"cams_{step_str}" in f] 167 | assert len(cam_file) == 1, f"found more than one possible cam_file: {cam_file}" 168 | cam_file = cam_file[0] 169 | image_files = [os.path.join(input_path, f) for f in files if f"pred_file_step_{step_str}" in f] 170 | 171 | # prepare transforms 172 | transforms_dict = get_transforms_header() 173 | transforms_dict_masked_images = get_transforms_header() 174 | 175 | step_output_path = os.path.join(output_path, step_str) 176 | step_output_paths.append(step_output_path) 177 | 178 | images_output_folder = "images" 179 | images_output_path = os.path.join(step_output_path, images_output_folder) 180 | os.makedirs(images_output_path, exist_ok=True) 181 | 182 | masked_images_output_folder = "masked_images" 183 | masked_images_output_path = os.path.join(step_output_path, masked_images_output_folder) 184 | os.makedirs(masked_images_output_path, exist_ok=True) 185 | 186 | masks_output_folder = "masks" 187 | masks_output_path = os.path.join(step_output_path, masks_output_folder) 188 | os.makedirs(masks_output_path, exist_ok=True) 189 | 190 | # load cams 191 | with open(cam_file, "rb") as f: 192 | cam = torch.load(f) 193 | poses = cam["poses"] 194 | intrs = cam["intrs"] 195 | 196 | # load each (cam, intr, image) tuple and add it to transforms 197 | for key in poses.keys(): 198 | frame_dict = {} 199 | 200 | # add image file and h/w 201 | image_file_path = [f for f in image_files if key in f] 202 | if len(image_file_path) != 1: 203 | continue 204 | image_file_path = image_file_path[0] 205 | image_file_name = os.path.basename(image_file_path) 206 | image_file_output_path = os.path.join(images_output_path, image_file_name) 207 | shutil.copy(image_file_path, image_file_output_path) 208 | frame_dict["file_path"] = os.path.join(images_output_folder, image_file_name) 209 | with open(image_file_path, "rb") as f: 210 | image = np.array(Image.open(f)) 211 | frame_dict["h"] = image.shape[0] 212 | frame_dict["w"] = image.shape[1] 213 | 214 | # add intr 215 | intr = intrs[key].numpy() 216 | frame_dict["fl_x"] = float(intr[0, 0]) 217 | frame_dict["fl_y"] = float(intr[1, 1]) 218 | frame_dict["cx"] = float(intr[0, 2]) 219 | frame_dict["cy"] = float(intr[1, 2]) 220 | 221 | # convert pose from OPENCV world2cam to OPEN_GL cam2world (it's what nerfstudio expects: https://docs.nerf.studio/quickstart/data_conventions.html) 222 | pose = poses[key].numpy() 223 | 224 | # first, convert from world2cam to cam2world 225 | R = pose[:3, :3] 226 | T = pose[:3, 3:4] 227 | Rinv = R.T 228 | Tinv = -Rinv @ T 229 | pose_cam2world = np.concatenate([Rinv, Tinv], axis=1) 230 | pose_cam2world = np.concatenate([pose_cam2world, pose[3:4]], axis=0) # add hom 231 | 232 | # second, invert y/z coordinate 233 | pose_cam2world[:3, 1:3] *= -1 234 | 235 | frame_dict["transform_matrix"] = pose_cam2world.tolist() 236 | 237 | # save mask files 238 | masked_image_path = os.path.join(masked_images_output_path, image_file_name) 239 | mask_path = os.path.join(masks_output_path, image_file_name) 240 | if not os.path.exists(mask_path) or not os.path.exists(masked_image_path): 241 | segment_and_save_carvekit(image_file_output_path, mask_path, masked_image_path, mask_predictor=carvekit_model) 242 | frame_dict["mask_path"] = os.path.join(masks_output_folder, image_file_name) 243 | 244 | # add frame to unmasked transforms 245 | transforms_dict["frames"].append(frame_dict) 246 | frame_dicts.append((frame_dict, step_output_path)) 247 | 248 | # replace image with masked image 249 | masked_image_frame_dict = copy.deepcopy(frame_dict) 250 | masked_image_frame_dict["file_path"] = os.path.join(masked_images_output_folder, image_file_name) 251 | 252 | # add frame to masked transforms 253 | transforms_dict_masked_images["frames"].append(masked_image_frame_dict) 254 | masked_images_frame_dicts.append((masked_image_frame_dict, step_output_path)) 255 | 256 | # save transforms to file 257 | with open(os.path.join(step_output_path, "transforms.json"), "w") as f: 258 | json.dump(transforms_dict, f, indent=4) 259 | 260 | # save transforms to file 261 | with open(os.path.join(step_output_path, "transforms_masked_images.json"), "w") as f: 262 | json.dump(transforms_dict_masked_images, f, indent=4) 263 | 264 | # create a combined transforms from all steps 265 | if combine_all: 266 | step_output_path = os.path.join(output_path, "combined_all") 267 | step_output_paths.append(step_output_path) 268 | 269 | # prepare transforms 270 | transforms_dict = get_transforms_header() 271 | transforms_dict_masked_images = get_transforms_header() 272 | 273 | images_output_path = os.path.join(step_output_path, "images") 274 | os.makedirs(images_output_path, exist_ok=True) 275 | 276 | masked_images_output_path = os.path.join(step_output_path, "masked_images") 277 | os.makedirs(masked_images_output_path, exist_ok=True) 278 | 279 | masks_output_path = os.path.join(step_output_path, "masks") 280 | os.makedirs(masks_output_path, exist_ok=True) 281 | 282 | # copy together all files from all steps 283 | for (frame_dict, frame_dict_output_path), (masked_image_frame_dict, _) in zip(frame_dicts, masked_images_frame_dicts): 284 | image_file_name = os.path.basename(frame_dict["file_path"]) 285 | if "cond" in image_file_name: 286 | continue 287 | 288 | # copy rgb image 289 | shutil.copy( 290 | os.path.join(frame_dict_output_path, frame_dict["file_path"]), 291 | os.path.join(images_output_path, image_file_name) 292 | ) 293 | 294 | # copy mask image 295 | shutil.copy( 296 | os.path.join(frame_dict_output_path, frame_dict["mask_path"]), 297 | os.path.join(masks_output_path, image_file_name) 298 | ) 299 | 300 | # copy masked rgb image 301 | shutil.copy( 302 | os.path.join(frame_dict_output_path, masked_image_frame_dict["file_path"]), 303 | os.path.join(masked_images_output_path, image_file_name) 304 | ) 305 | 306 | # add frame 307 | transforms_dict["frames"].append(frame_dict) 308 | transforms_dict_masked_images["frames"].append(masked_image_frame_dict) 309 | 310 | # save transforms to file 311 | with open(os.path.join(step_output_path, "transforms.json"), "w") as f: 312 | json.dump(transforms_dict, f, indent=4) 313 | 314 | with open(os.path.join(step_output_path, "transforms_masked_images.json"), "w") as f: 315 | json.dump(transforms_dict_masked_images, f, indent=4) 316 | 317 | # create smooth videos 318 | if create_smooth_video: 319 | save_smooth_video( 320 | image_folder=masked_images_output_path, 321 | n_images_per_batch=n_images_per_batch, 322 | framerate=smooth_video_framerate, 323 | skip_first_n_steps=skip_first_n_steps, 324 | sort_type=smooth_video_sort_type 325 | ) 326 | save_smooth_video( 327 | image_folder=images_output_path, 328 | n_images_per_batch=n_images_per_batch, 329 | framerate=smooth_video_framerate, 330 | skip_first_n_steps=skip_first_n_steps, 331 | sort_type=smooth_video_sort_type 332 | ) 333 | 334 | # convert to sdfstudio format (without mono-cues) 335 | for step_output_path in tqdm(step_output_paths, desc="Convert to sdfstudio format"): 336 | args = { 337 | "input_dir": step_output_path, 338 | "output_dir": os.path.join(step_output_path, "sdfstudio-format"), 339 | "scene_type": "object", 340 | "scene_scale_mult": None, 341 | "mono_prior": False, 342 | "crop_mult": 1, 343 | "omnidata_path": None, 344 | "pretrained_models": None, 345 | } 346 | args = Namespace(**args) 347 | process_nerfstudio_to_sdfstudio(args) 348 | 349 | print("Done! The exported nerf data has been saved in:", output_path) 350 | 351 | 352 | if __name__ == '__main__': 353 | tyro.cli(main) 354 | -------------------------------------------------------------------------------- /viewdiff/model/custom_transformer_2d.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This file is partially based on the diffusers library, which licensed the code under the following license: 3 | 4 | # Copyright 2023 The HuggingFace Team. All rights reserved. 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | from typing import Any, Dict, Optional, Literal 18 | 19 | import torch 20 | from torch import nn 21 | 22 | from diffusers.configuration_utils import ConfigMixin, register_to_config 23 | from diffusers.utils import deprecate 24 | from diffusers.models.modeling_utils import ModelMixin 25 | from diffusers.models.transformer_2d import Transformer2DModelOutput 26 | 27 | from .custom_attention import BasicTransformerWithCrossFrameAttentionBlock 28 | 29 | from .projection.layer import UnprojReprojLayer 30 | 31 | 32 | class Transformer2DSelfAttnCrossAttnCrossFrameAttnModel(ModelMixin, ConfigMixin): 33 | """ 34 | Transformer model with self attention, cross attention, and cross-frame attention for image-like data. Takes continuous (actual 35 | embeddings) inputs. 36 | The final layer is initialized to be zero as suggested in ControlNet paper (https://arxiv.org/pdf/2302.05543.pdf). 37 | 38 | First, project the input (aka embedding) and reshape to b, t, d. Then apply standard 39 | transformer action. Finally, reshape to image. 40 | 41 | Parameters: 42 | num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. 43 | attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. 44 | in_channels (`int`, *optional*): 45 | Pass if the input is continuous. The number of channels in the input and output. 46 | num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. 47 | dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. 48 | cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use. 49 | sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images. 50 | Note that this is fixed at training time as it is used for learning a number of position embeddings. See 51 | `ImagePositionalEmbeddings`. 52 | num_vector_embeds (`int`, *optional*): 53 | Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels. 54 | Includes the class for the masked latent pixel. 55 | activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. 56 | num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`. 57 | The number of diffusion steps used during training. Note that this is fixed at training time as it is used 58 | to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for 59 | up to but not more than steps than `num_embeds_ada_norm`. 60 | attention_bias (`bool`, *optional*): 61 | Configure if the TransformerBlocks' attention should contain a bias parameter. 62 | """ 63 | 64 | @register_to_config 65 | def __init__( 66 | self, 67 | num_attention_heads: int = 16, 68 | attention_head_dim: int = 88, 69 | in_channels: Optional[int] = None, 70 | out_channels: Optional[int] = None, 71 | num_layers: int = 1, 72 | cross_attention_dim: Optional[int] = None, 73 | dropout: float = 0.0, 74 | norm_num_groups: int = 32, 75 | attention_bias: bool = False, 76 | activation_fn: str = "geglu", 77 | num_embeds_ada_norm: Optional[int] = None, 78 | use_linear_projection: bool = False, 79 | only_cross_attention: bool = False, 80 | upcast_attention: bool = False, 81 | norm_type: str = "layer_norm", 82 | norm_elementwise_affine: bool = True, 83 | # new arguments 84 | n_input_images: int = 5, 85 | to_k_other_frames: int = 4, 86 | random_others: bool = False, 87 | last_layer_mode: Literal["none", "zero-conv", "alpha"] = "none", 88 | use_lora_in_cfa: bool = False, 89 | use_temb_in_lora: bool = False, 90 | temb_size: int = 1280, 91 | temb_out_size: int = 10, 92 | pose_cond_dim=10, 93 | rank=4, 94 | network_alpha=None, 95 | use_cfa: bool = True, 96 | use_unproj_reproj: bool = False, 97 | num_3d_layers: int = 1, 98 | dim_3d_latent: int = 32, 99 | dim_3d_grid: int = 64, 100 | n_novel_images: int = 1, 101 | vol_rend_proj_in_mode: Literal["single", "multiple", "unet"] = "unet", 102 | vol_rend_proj_out_mode: Literal["single", "multiple"] = "multiple", 103 | vol_rend_aggregator_mode: Literal["mean", "ibrnet"] = "ibrnet", 104 | vol_rend_model_background: bool = False, 105 | vol_rend_background_grid_percentage: float = 0.5, 106 | vol_rend_disparity_at_inf: float = 1e-3, 107 | ): 108 | super().__init__() 109 | self.use_linear_projection = use_linear_projection 110 | self.num_attention_heads = num_attention_heads 111 | self.attention_head_dim = attention_head_dim 112 | inner_dim = num_attention_heads * attention_head_dim 113 | 114 | if norm_type == "layer_norm" and num_embeds_ada_norm is not None: 115 | deprecation_message = ( 116 | f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or" 117 | " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config." 118 | " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect" 119 | " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it" 120 | " would be very nice if you could open a Pull request for the `transformer/config.json` file" 121 | ) 122 | deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False) 123 | norm_type = "ada_norm" 124 | 125 | # 2. Define input layers 126 | self.in_channels = in_channels 127 | 128 | self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) 129 | if use_linear_projection: 130 | self.proj_in = nn.Linear(in_channels, inner_dim) 131 | else: 132 | self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) 133 | 134 | # 3. Define transformers blocks 135 | self.transformer_blocks = nn.ModuleList( 136 | [ 137 | BasicTransformerWithCrossFrameAttentionBlock( 138 | inner_dim, 139 | num_attention_heads, 140 | attention_head_dim, 141 | dropout=dropout, 142 | activation_fn=activation_fn, 143 | num_embeds_ada_norm=num_embeds_ada_norm, 144 | cross_attention_dim=cross_attention_dim, 145 | attention_bias=attention_bias, 146 | upcast_attention=upcast_attention, 147 | norm_type=norm_type, 148 | norm_elementwise_affine=norm_elementwise_affine, 149 | only_cross_attention=only_cross_attention, 150 | n_input_images=n_input_images, 151 | to_k_other_frames=to_k_other_frames, 152 | random_others=random_others, 153 | last_layer_mode=last_layer_mode, 154 | use_lora_in_cfa=use_lora_in_cfa, 155 | use_temb_in_lora=use_temb_in_lora, 156 | temb_size=temb_size, 157 | temb_out_size=temb_out_size, 158 | pose_cond_dim=pose_cond_dim, 159 | rank=rank, 160 | network_alpha=network_alpha, 161 | use_cfa=use_cfa, 162 | use_unproj_reproj=use_unproj_reproj, 163 | num_3d_layers=num_3d_layers, 164 | dim_3d_latent=dim_3d_latent, 165 | dim_3d_grid=dim_3d_grid, 166 | n_novel_images=n_novel_images, 167 | vol_rend_proj_in_mode=vol_rend_proj_in_mode, 168 | vol_rend_proj_out_mode=vol_rend_proj_out_mode, 169 | vol_rend_aggregator_mode=vol_rend_aggregator_mode, 170 | vol_rend_model_background=vol_rend_model_background, 171 | vol_rend_background_grid_percentage=vol_rend_background_grid_percentage, 172 | vol_rend_disparity_at_inf=vol_rend_disparity_at_inf, 173 | ) 174 | for d in range(num_layers) 175 | ] 176 | ) 177 | 178 | # 4. Define output layers 179 | self.out_channels = in_channels if out_channels is None else out_channels 180 | # TODO: should use out_channels for continuous projections 181 | if use_linear_projection: 182 | self.proj_out = nn.Linear(inner_dim, in_channels) 183 | else: 184 | self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) 185 | 186 | def get_cross_frame_parameters(self, vol_rend_mode: Literal["with", "without", "only"] = "with"): 187 | params = [] 188 | for b in self.transformer_blocks: 189 | if isinstance(b, BasicTransformerWithCrossFrameAttentionBlock): 190 | params.extend(b.get_cross_frame_parameters(vol_rend_mode=vol_rend_mode)) 191 | return params 192 | 193 | def get_last_layer_params(self): 194 | params = [] 195 | for b in self.transformer_blocks: 196 | if isinstance(b, BasicTransformerWithCrossFrameAttentionBlock): 197 | params.extend(b.get_last_layer_params()) 198 | return params 199 | 200 | def get_other_parameters(self): 201 | params = [ 202 | *list(self.norm.parameters()), 203 | *list(self.proj_in.parameters()), 204 | *list(self.proj_out.parameters()), 205 | ] 206 | 207 | for b in self.transformer_blocks: 208 | if isinstance(b, BasicTransformerWithCrossFrameAttentionBlock): 209 | params.extend(b.get_other_parameters()) 210 | else: 211 | params.extend(list(b.parameters())) 212 | return params 213 | 214 | def forward( 215 | self, 216 | hidden_states: torch.Tensor, 217 | encoder_hidden_states: Optional[torch.Tensor] = None, 218 | timestep: Optional[torch.LongTensor] = None, 219 | class_labels: Optional[torch.LongTensor] = None, 220 | cross_attention_kwargs: Dict[str, Any] = None, 221 | attention_mask: Optional[torch.Tensor] = None, 222 | encoder_attention_mask: Optional[torch.Tensor] = None, 223 | unproj_reproj_kwargs: Dict[str, Any] = None, 224 | return_dict: bool = True, 225 | ): 226 | """ 227 | Args: 228 | hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`. 229 | When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input 230 | hidden_states 231 | encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*): 232 | Conditional embeddings for cross attention layer. If not given, cross-attention defaults to 233 | self-attention. 234 | timestep ( `torch.LongTensor`, *optional*): 235 | Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step. 236 | class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): 237 | Optional class labels to be applied as an embedding in AdaLayerZeroNorm. Used to indicate class labels 238 | conditioning. 239 | encoder_attention_mask ( `torch.Tensor`, *optional* ). 240 | Cross-attention mask, applied to encoder_hidden_states. Two formats supported: 241 | Mask `(batch, sequence_length)` True = keep, False = discard. Bias `(batch, 1, sequence_length)` 0 242 | = keep, -10000 = discard. 243 | If ndim == 2: will be interpreted as a mask, then converted into a bias consistent with the format 244 | above. This bias will be added to the cross-attention scores. 245 | return_dict (`bool`, *optional*, defaults to `True`): 246 | Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. 247 | 248 | Returns: 249 | [`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`: 250 | [`~models.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a `tuple`. When 251 | returning a tuple, the first element is the sample tensor. 252 | """ 253 | # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. 254 | # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. 255 | # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. 256 | # expects mask of shape: 257 | # [batch, key_tokens] 258 | # adds singleton query_tokens dimension: 259 | # [batch, 1, key_tokens] 260 | # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: 261 | # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) 262 | # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) 263 | if attention_mask is not None and attention_mask.ndim == 2: 264 | # assume that mask is expressed as: 265 | # (1 = keep, 0 = discard) 266 | # convert mask into a bias that can be added to attention scores: 267 | # (keep = +0, discard = -10000.0) 268 | attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 269 | attention_mask = attention_mask.unsqueeze(1) 270 | 271 | # convert encoder_attention_mask to a bias the same way we do for attention_mask 272 | if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: 273 | encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 274 | encoder_attention_mask = encoder_attention_mask.unsqueeze(1) 275 | 276 | # 1. Input 277 | batch, _, height, width = hidden_states.shape 278 | residual = hidden_states 279 | 280 | hidden_states = self.norm(hidden_states) 281 | if not self.use_linear_projection: 282 | hidden_states = self.proj_in(hidden_states) 283 | inner_dim = hidden_states.shape[1] 284 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) 285 | else: 286 | inner_dim = hidden_states.shape[1] 287 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) 288 | hidden_states = self.proj_in(hidden_states) 289 | 290 | # 2. Blocks 291 | for block in self.transformer_blocks: 292 | hidden_states = block( 293 | hidden_states, 294 | attention_mask=attention_mask, 295 | encoder_hidden_states=encoder_hidden_states, 296 | encoder_attention_mask=encoder_attention_mask, 297 | timestep=timestep, 298 | cross_attention_kwargs=cross_attention_kwargs, 299 | class_labels=class_labels, 300 | unproj_reproj_kwargs=unproj_reproj_kwargs, 301 | ) 302 | 303 | # 3. Output 304 | if not self.use_linear_projection: 305 | hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() 306 | hidden_states = self.proj_out(hidden_states) 307 | else: 308 | hidden_states = self.proj_out(hidden_states) 309 | hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() 310 | 311 | output = hidden_states + residual 312 | 313 | if not return_dict: 314 | return (output,) 315 | 316 | return Transformer2DModelOutput(sample=output) 317 | -------------------------------------------------------------------------------- /viewdiff/model/projection/layer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | from typing import Tuple, Literal 4 | import torch 5 | from torch import nn 6 | 7 | from ...data.co3d.util import scale_intrinsics 8 | 9 | 10 | from diffusers.models.unet_2d import UNet2DModel 11 | 12 | from .util import get_pixel_grids 13 | from .voxel_proj import ( 14 | build_cost_volume, 15 | mean_aggregate_cost_volumes, 16 | get_rays_in_unit_cube, 17 | IBRNet_Aggregator, 18 | ) 19 | from ..custom_unet_3d import ResnetBlock3D, UNet3DModel 20 | from ..custom_attention_processor import collapse_batch, expand_batch 21 | 22 | from .fastplane.fastplane_module import FastplaneModule, FastplaneShapeRepresentation 23 | 24 | 25 | class ConvBlock(nn.Module): 26 | def __init__(self, channels_in: int, channels_out: int, kernel_size: int = 3, padding: int = 1, norm: bool = False): 27 | super().__init__() 28 | 29 | self.conv = nn.Conv2d(channels_in, channels_out, kernel_size=kernel_size, padding=padding) 30 | self.relu = nn.ReLU() 31 | self.norm = norm 32 | if norm: 33 | # self.norm = nn.InstanceNorm2d(channels_out, affine=True) 34 | self.norm = nn.LayerNorm([channels_out, 64, 64]) 35 | 36 | def forward(self, x): 37 | x = self.conv(x) 38 | if self.norm: 39 | x = self.norm(x) 40 | x = self.relu(x) 41 | return x 42 | 43 | 44 | class UnprojReprojLayer(nn.Module): 45 | def __init__( 46 | self, 47 | latent_channels: int = 320, 48 | num_3d_layers: int = 1, 49 | dim_3d_latent: int = 32, 50 | dim_3d_grid: int = 64, 51 | vol_rend_num_samples_per_ray: int = 128, 52 | vol_rend_near: float = 0.5, 53 | vol_rend_far: float = 4.5, 54 | vol_rend_model_background: bool = False, 55 | vol_rend_background_grid_percentage: float = 0.5, 56 | vol_rend_disparity_at_inf: float = 1e-3, 57 | n_novel_images: int = 0, 58 | proj_in_mode: Literal["single", "multiple", "unet"] = "single", 59 | proj_out_mode: Literal["single", "multiple"] = "multiple", 60 | aggregator_mode: Literal["mean", "ibrnet"] = "ibrnet", 61 | use_temb: bool = False, 62 | temb_dim: int = 1280, 63 | ): 64 | super().__init__() 65 | self.proj_in_mode = proj_in_mode 66 | self.proj_out_mode = proj_out_mode 67 | self.n_novel_images = n_novel_images 68 | self.dim_3d_grid = dim_3d_grid 69 | self.use_3d_net = num_3d_layers > 0 70 | self.use_3d_unet = num_3d_layers >= 3 71 | self.vol_rend_model_background = vol_rend_model_background 72 | self.vol_rend_background_grid_percentage = vol_rend_background_grid_percentage 73 | 74 | if self.use_3d_unet: 75 | assert ( 76 | num_3d_layers - 1 77 | ) % 2 == 0, "if num_3d_layers >=3 we want to construct a UNet. So specify an odd number of num_3d_layers, e.g. 3,5,7,9" 78 | n_blocks = (num_3d_layers - 1) // 2 79 | block_out_channels = [dim_3d_latent * i for i in range(1, n_blocks + 1)] 80 | self.blocks_3d = UNet3DModel( 81 | in_channels=dim_3d_latent, 82 | out_channels=dim_3d_latent, 83 | down_block_types=["DownBlock3D"] * n_blocks, 84 | up_block_types=["UpBlock3D"] * n_blocks, 85 | block_out_channels=block_out_channels, 86 | layers_per_block=1, 87 | norm_num_groups=min(32, dim_3d_latent), 88 | ) 89 | else: 90 | self.blocks_3d = nn.ModuleList( 91 | [ResnetBlock3D(in_channels=dim_3d_latent, groups=min(32, dim_3d_latent))] * num_3d_layers 92 | ) 93 | # renderer 94 | # turn incoming SD-latent-features into 3D features (e.g. normalization, feature-reduction) 95 | if proj_in_mode == "single": 96 | self.proj_in_2d = nn.Conv2d(latent_channels, dim_3d_latent, kernel_size=1, padding=0) 97 | elif proj_in_mode == "multiple": 98 | self.proj_in_2d = nn.Sequential( 99 | ConvBlock(latent_channels, latent_channels, kernel_size=3, padding=1), 100 | ConvBlock(latent_channels, dim_3d_latent, kernel_size=3, padding=1), 101 | nn.Conv2d(dim_3d_latent, dim_3d_latent, kernel_size=1, padding=0), 102 | ) 103 | elif proj_in_mode == "unet": 104 | n_blocks = 3 105 | block_out_channels = [dim_3d_latent * i for i in range(1, n_blocks + 1)] 106 | self.proj_in_2d = UNet2DModel( 107 | in_channels=latent_channels, 108 | out_channels=dim_3d_latent, 109 | down_block_types=["DownBlock2D"] * n_blocks, 110 | up_block_types=["UpBlock2D"] * n_blocks, 111 | block_out_channels=block_out_channels, 112 | layers_per_block=1, 113 | norm_num_groups=min(32, dim_3d_latent), 114 | add_attention=False, 115 | ) 116 | else: 117 | raise NotImplementedError("proj_in_mode", proj_in_mode) 118 | 119 | # turn projected 3D features into SD-latent-features (e.g. de-normalization, feature-increase) 120 | # it is necessary because vol-renderer outputs in [0, 1] because it needs sigmoid in the end to converge in general 121 | # however, latent features can be in arbitrary floating point feature value 122 | # this scale fct should learn to convert back to the arbitrary range 123 | # it should be a conv_1x1 to not turn this into a neural renderer that could destroy the 3D consistency (instead: only scale) 124 | # we have a linear and a non-linear scale fct and allow to choose them according to the --proj_out_mode flag 125 | # the background will only use a nonlinear_scale_fct if the background should actually be modeled (e.g. if --vol_rend_model_background is set) 126 | 127 | def linear_scale_fct(): 128 | return nn.Conv2d(dim_3d_latent, latent_channels, kernel_size=1, padding=0) 129 | 130 | def nonlinear_scale_fct(): 131 | return nn.Sequential( 132 | ConvBlock(dim_3d_latent, dim_3d_latent, kernel_size=1, padding=0), 133 | ConvBlock(dim_3d_latent, latent_channels, kernel_size=1, padding=0), 134 | nn.Conv2d(latent_channels, latent_channels, kernel_size=1, padding=0), 135 | ) 136 | 137 | if proj_out_mode == "single": 138 | self.proj_out_2d_fg = linear_scale_fct() 139 | self.proj_out_2d_bg = linear_scale_fct() 140 | elif proj_out_mode == "multiple": 141 | self.proj_out_2d_fg = nonlinear_scale_fct() 142 | if vol_rend_model_background: 143 | self.proj_out_2d_bg = nonlinear_scale_fct() 144 | else: 145 | self.proj_out_2d_bg = linear_scale_fct() 146 | else: 147 | raise NotImplementedError("proj_out_mode", proj_out_mode) 148 | 149 | # reduce multiple per-frame dense voxel-grids into a single dense voxel-grid 150 | self.aggregator_mode = aggregator_mode 151 | if aggregator_mode == "ibrnet": 152 | self.ibrnet_aggregator = IBRNet_Aggregator(feature_dim=dim_3d_latent, kernel_size=1, padding=0, use_temb=use_temb, temb_dim=temb_dim) 153 | 154 | # ray-direction encoder for volume_renderer 155 | self.linear_ray = nn.Linear(3, dim_3d_latent) 156 | nn.init.xavier_uniform_(self.linear_ray.weight.data) 157 | self.linear_ray.bias.data *= 0.0 158 | 159 | # volume renderer of the dense voxel-grid 160 | self.vol_rend_near = vol_rend_near 161 | self.vol_rend_far = vol_rend_far 162 | if vol_rend_model_background: 163 | self.volume_renderer = FastplaneModule( 164 | mlp_n_hidden=dim_3d_latent, 165 | render_dim=dim_3d_latent, 166 | num_samples=vol_rend_num_samples_per_ray, 167 | bg_color=1.0, 168 | shape_representation=FastplaneShapeRepresentation.VOXEL_GRID, 169 | mask_out_of_bounds_samples=True, 170 | num_samples_inf=vol_rend_num_samples_per_ray, 171 | contract_coords=True, 172 | contract_perc_foreground=1.0 - vol_rend_background_grid_percentage, 173 | disparity_at_inf=vol_rend_disparity_at_inf, 174 | inject_noise_sigma=0.0, 175 | ) 176 | else: 177 | self.volume_renderer = FastplaneModule( 178 | mlp_n_hidden=dim_3d_latent, 179 | render_dim=dim_3d_latent, 180 | num_samples=vol_rend_num_samples_per_ray, 181 | bg_color=1.0, 182 | shape_representation=FastplaneShapeRepresentation.VOXEL_GRID, 183 | mask_out_of_bounds_samples=True, 184 | ) 185 | 186 | def get_volume_renderer_params(self): 187 | params = [] 188 | 189 | params.extend(list(self.volume_renderer.parameters())) 190 | 191 | return params 192 | 193 | def get_other_params(self): 194 | params = [] 195 | for n, p in self.named_parameters(): 196 | if "volume_renderer" not in n: 197 | params.append(p) 198 | 199 | return params 200 | 201 | def forward( 202 | self, 203 | latents: torch.Tensor, 204 | pose: torch.Tensor, 205 | K: torch.Tensor, 206 | orig_hw: Tuple[int, int], 207 | timestep: torch.Tensor = None, 208 | temb: torch.Tensor = None, 209 | bbox: torch.Tensor = None, 210 | deactivate_view_dependent_rendering: bool = False 211 | ): 212 | """ 213 | Args: 214 | latents (torch.Tensor): (batch_size, num_images, C, h', w') 215 | pose (torch.Tensor): (batch_size, num_images, 4, 4) 216 | K (torch.Tensor): (batch_size, num_images, 3, 3) 217 | orig_hw (Tuple[int, int]): same across all batches 218 | temb (torch.Tensor, optional): (batch_size, num_images, temb_dim). Defaults to None. 219 | bbox (torch.Tensor, optional): (batch_size, 2, 3). Defaults to None. 220 | 221 | Returns: 222 | _type_: _description_ 223 | """ 224 | # downscale to latent dimension 225 | batch_size = latents.shape[0] 226 | num_images = latents.shape[1] 227 | latent_hw = latents.shape[3:] 228 | K_scaled = scale_intrinsics(K, orig_hw, latent_hw) 229 | K = K_scaled 230 | 231 | # lazy init canonical rays 232 | if not hasattr(self, "canonical_rays") or self.canonical_rays.shape[1] != (latent_hw[0] * latent_hw[1]): 233 | self.canonical_rays = get_pixel_grids(latent_hw[0], latent_hw[1]).to(latents.device).float() 234 | 235 | # get ray information in the correct world-space 236 | rays, centers, near_t, far_t, scale = get_rays_in_unit_cube( 237 | bbox, 238 | pose, 239 | K, 240 | self.canonical_rays, 241 | default_near=self.vol_rend_near, 242 | default_far=self.vol_rend_far, 243 | use_ray_aabb=not self.vol_rend_model_background, 244 | ) 245 | 246 | # reduce feature dim 247 | n_known_images = latents.shape[1] - self.n_novel_images 248 | features = collapse_batch(latents[:, :n_known_images]) 249 | 250 | if self.proj_in_mode == "unet": 251 | assert timestep is not None 252 | features = self.proj_in_2d(features, collapse_batch(timestep[:, :n_known_images])).sample 253 | else: 254 | features = self.proj_in_2d(features) 255 | 256 | features = expand_batch(features, n_known_images) 257 | features = features.to( 258 | rays.dtype 259 | ) # rays have fp32, want features to have it too, e.g. build_cost_volume always in highest precision 260 | 261 | # concat features with rays to get per-voxel results for both 262 | if self.aggregator_mode == "ibrnet": 263 | rays_for_cost_volume = ( 264 | rays[:, :n_known_images] 265 | .permute(0, 1, 3, 2) 266 | .reshape(batch_size, n_known_images, 3, latent_hw[0], latent_hw[1]) 267 | ) 268 | features = torch.cat([features, rays_for_cost_volume], dim=2) 269 | 270 | # unproj to dense voxel grid (per-frame) 271 | features, weights, points, voxel_depth = build_cost_volume( 272 | features, 273 | pose[:, :n_known_images], 274 | K[:, :n_known_images], 275 | bbox, 276 | grid_dim=self.dim_3d_grid, 277 | contract_background=self.vol_rend_model_background, 278 | contract_background_percentage=self.vol_rend_background_grid_percentage, 279 | ) 280 | 281 | # aggregate per-frame grids into single 3D grid 282 | features = features.to(latents.dtype) # 3D network should be in half-precision if specified 283 | agg_temb = temb[:, :n_known_images] if temb is not None else None 284 | 285 | if self.aggregator_mode == "ibrnet": 286 | features, voxel_dir = torch.split(features, [features.shape[2] - 3, 3], dim=2) 287 | features = self.ibrnet_aggregator(features, weights, voxel_depth, voxel_dir, agg_temb) 288 | elif self.aggregator_mode == "mean": 289 | features, _ = mean_aggregate_cost_volumes(features, weights) 290 | 291 | # apply 3D layers on grid 292 | if self.use_3d_unet: 293 | assert timestep is not None 294 | # give last timestep as input --> assumption is that non-noisy images are never in last position (e.g. sliding window inputs always are the first images) 295 | features = self.blocks_3d(features, timestep[:, -1]).sample 296 | else: 297 | for block in self.blocks_3d: 298 | # give last timestep as input --> assumption is that non-noisy images are never in last position (e.g. sliding window inputs always are the first images) 299 | features = block(features, temb[:, -1] if temb is not None else None) 300 | 301 | features = features.permute(0, 2, 3, 4, 1) 302 | 303 | # collapse (num_images, h*w) to render all rays jointly 304 | rays = rays.reshape(batch_size, -1, rays.shape[3]) 305 | centers = centers.reshape(batch_size, -1, centers.shape[3]) 306 | near_t = near_t.reshape(batch_size, -1) 307 | far_t = far_t.reshape(batch_size, -1) 308 | 309 | # volume-render (grid, rays) 310 | # only supports fp32 (there are some triton kernel impls that cast to float32, so make it the dtype from the very beginning) 311 | with torch.autocast("cuda", enabled=False): 312 | if deactivate_view_dependent_rendering: 313 | dummy_pose = torch.tensor([ 314 | [0, 1, 0, 0], 315 | [-0.7071, 0, 0.7071, 0], 316 | [0.7071, 0, 0.7071, 3.0], 317 | [0, 0, 0, 1.0], 318 | ], device=pose.device, dtype=pose.dtype) 319 | dummy_pose = dummy_pose[None, None].repeat(pose.shape[0], pose.shape[1], 1, 1) 320 | dummy_rays, _, _, _, _ = get_rays_in_unit_cube( 321 | bbox, 322 | dummy_pose, 323 | K, 324 | self.canonical_rays, 325 | default_near=self.vol_rend_near, 326 | default_far=self.vol_rend_far, 327 | use_ray_aabb=not self.vol_rend_model_background, 328 | ) 329 | dummy_rays = dummy_rays.reshape(batch_size, -1, dummy_rays.shape[3]) 330 | rays_encoding = self.linear_ray(dummy_rays).float() 331 | else: 332 | rays_encoding = self.linear_ray(rays).float() 333 | 334 | projected_latents, projected_mask, projected_depth = self.volume_renderer( 335 | v=features.float(), 336 | rays_encoding=rays_encoding, 337 | rays=rays, 338 | centers=centers, 339 | near=near_t, 340 | far=far_t, 341 | ) 342 | 343 | projected_latents = projected_latents.to(latents.dtype) 344 | projected_mask = projected_mask.to(latents.dtype) 345 | projected_depth = projected_depth.to(latents.dtype) 346 | 347 | # reshape back to (num_images, h*w) 348 | hw = latent_hw[0] * latent_hw[1] 349 | projected_latents = projected_latents.reshape(batch_size, num_images, hw, projected_latents.shape[2]) 350 | projected_depth = projected_depth.reshape(batch_size, num_images, hw) 351 | projected_mask = projected_mask.reshape(batch_size, num_images, hw) 352 | 353 | # reshape to (batch_size, num_images, C, h, w) 354 | projected_latents = projected_latents.permute(0, 1, 3, 2) 355 | projected_latents = projected_latents.reshape( 356 | batch_size, 357 | num_images, 358 | projected_latents.shape[2], 359 | latent_hw[0], 360 | latent_hw[1], 361 | ) 362 | projected_depth = projected_depth.reshape(batch_size, num_images, latent_hw[0], latent_hw[1]) 363 | projected_depth = projected_depth.unsqueeze(2) 364 | projected_mask = projected_mask.reshape(batch_size, num_images, latent_hw[0], latent_hw[1]) 365 | projected_mask = projected_mask.unsqueeze(2) 366 | 367 | # proj-out (back to larger channels, back from 0..1 to correct output range) 368 | # have separate de-normalization layers for fg and bg 369 | projected_latents = collapse_batch(projected_latents) 370 | p_fg = self.proj_out_2d_fg(projected_latents) 371 | p_bg = self.proj_out_2d_bg(projected_latents) 372 | m = collapse_batch(projected_mask).repeat(1, p_fg.shape[1], 1, 1) 373 | projected_latents = m * p_fg + (1 - m) * p_bg 374 | projected_latents = expand_batch(projected_latents, num_images) 375 | 376 | return projected_latents, projected_mask, projected_depth 377 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Attribution-NonCommercial-ShareAlike 4.0 International 2 | 3 | ======================================================================= 4 | 5 | Creative Commons Corporation ("Creative Commons") is not a law firm and 6 | does not provide legal services or legal advice. Distribution of 7 | Creative Commons public licenses does not create a lawyer-client or 8 | other relationship. Creative Commons makes its licenses and related 9 | information available on an "as-is" basis. Creative Commons gives no 10 | warranties regarding its licenses, any material licensed under their 11 | terms and conditions, or any related information. Creative Commons 12 | disclaims all liability for damages resulting from their use to the 13 | fullest extent possible. 14 | 15 | Using Creative Commons Public Licenses 16 | 17 | Creative Commons public licenses provide a standard set of terms and 18 | conditions that creators and other rights holders may use to share 19 | original works of authorship and other material subject to copyright 20 | and certain other rights specified in the public license below. The 21 | following considerations are for informational purposes only, are not 22 | exhaustive, and do not form part of our licenses. 23 | 24 | Considerations for licensors: Our public licenses are 25 | intended for use by those authorized to give the public 26 | permission to use material in ways otherwise restricted by 27 | copyright and certain other rights. Our licenses are 28 | irrevocable. Licensors should read and understand the terms 29 | and conditions of the license they choose before applying it. 30 | Licensors should also secure all rights necessary before 31 | applying our licenses so that the public can reuse the 32 | material as expected. Licensors should clearly mark any 33 | material not subject to the license. This includes other CC- 34 | licensed material, or material used under an exception or 35 | limitation to copyright. More considerations for licensors: 36 | wiki.creativecommons.org/Considerations_for_licensors 37 | 38 | Considerations for the public: By using one of our public 39 | licenses, a licensor grants the public permission to use the 40 | licensed material under specified terms and conditions. If 41 | the licensor's permission is not necessary for any reason--for 42 | example, because of any applicable exception or limitation to 43 | copyright--then that use is not regulated by the license. Our 44 | licenses grant only permissions under copyright and certain 45 | other rights that a licensor has authority to grant. Use of 46 | the licensed material may still be restricted for other 47 | reasons, including because others have copyright or other 48 | rights in the material. A licensor may make special requests, 49 | such as asking that all changes be marked or described. 50 | Although not required by our licenses, you are encouraged to 51 | respect those requests where reasonable. More considerations 52 | for the public: 53 | wiki.creativecommons.org/Considerations_for_licensees 54 | 55 | ======================================================================= 56 | 57 | Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International 58 | Public License 59 | 60 | By exercising the Licensed Rights (defined below), You accept and agree 61 | to be bound by the terms and conditions of this Creative Commons 62 | Attribution-NonCommercial-ShareAlike 4.0 International Public License 63 | ("Public License"). To the extent this Public License may be 64 | interpreted as a contract, You are granted the Licensed Rights in 65 | consideration of Your acceptance of these terms and conditions, and the 66 | Licensor grants You such rights in consideration of benefits the 67 | Licensor receives from making the Licensed Material available under 68 | these terms and conditions. 69 | 70 | 71 | Section 1 -- Definitions. 72 | 73 | a. Adapted Material means material subject to Copyright and Similar 74 | Rights that is derived from or based upon the Licensed Material 75 | and in which the Licensed Material is translated, altered, 76 | arranged, transformed, or otherwise modified in a manner requiring 77 | permission under the Copyright and Similar Rights held by the 78 | Licensor. For purposes of this Public License, where the Licensed 79 | Material is a musical work, performance, or sound recording, 80 | Adapted Material is always produced where the Licensed Material is 81 | synched in timed relation with a moving image. 82 | 83 | b. Adapter's License means the license You apply to Your Copyright 84 | and Similar Rights in Your contributions to Adapted Material in 85 | accordance with the terms and conditions of this Public License. 86 | 87 | c. BY-NC-SA Compatible License means a license listed at 88 | creativecommons.org/compatiblelicenses, approved by Creative 89 | Commons as essentially the equivalent of this Public License. 90 | 91 | d. Copyright and Similar Rights means copyright and/or similar rights 92 | closely related to copyright including, without limitation, 93 | performance, broadcast, sound recording, and Sui Generis Database 94 | Rights, without regard to how the rights are labeled or 95 | categorized. For purposes of this Public License, the rights 96 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 97 | Rights. 98 | 99 | e. Effective Technological Measures means those measures that, in the 100 | absence of proper authority, may not be circumvented under laws 101 | fulfilling obligations under Article 11 of the WIPO Copyright 102 | Treaty adopted on December 20, 1996, and/or similar international 103 | agreements. 104 | 105 | f. Exceptions and Limitations means fair use, fair dealing, and/or 106 | any other exception or limitation to Copyright and Similar Rights 107 | that applies to Your use of the Licensed Material. 108 | 109 | g. License Elements means the license attributes listed in the name 110 | of a Creative Commons Public License. The License Elements of this 111 | Public License are Attribution, NonCommercial, and ShareAlike. 112 | 113 | h. Licensed Material means the artistic or literary work, database, 114 | or other material to which the Licensor applied this Public 115 | License. 116 | 117 | i. Licensed Rights means the rights granted to You subject to the 118 | terms and conditions of this Public License, which are limited to 119 | all Copyright and Similar Rights that apply to Your use of the 120 | Licensed Material and that the Licensor has authority to license. 121 | 122 | j. Licensor means the individual(s) or entity(ies) granting rights 123 | under this Public License. 124 | 125 | k. NonCommercial means not primarily intended for or directed towards 126 | commercial advantage or monetary compensation. For purposes of 127 | this Public License, the exchange of the Licensed Material for 128 | other material subject to Copyright and Similar Rights by digital 129 | file-sharing or similar means is NonCommercial provided there is 130 | no payment of monetary compensation in connection with the 131 | exchange. 132 | 133 | l. Share means to provide material to the public by any means or 134 | process that requires permission under the Licensed Rights, such 135 | as reproduction, public display, public performance, distribution, 136 | dissemination, communication, or importation, and to make material 137 | available to the public including in ways that members of the 138 | public may access the material from a place and at a time 139 | individually chosen by them. 140 | 141 | m. Sui Generis Database Rights means rights other than copyright 142 | resulting from Directive 96/9/EC of the European Parliament and of 143 | the Council of 11 March 1996 on the legal protection of databases, 144 | as amended and/or succeeded, as well as other essentially 145 | equivalent rights anywhere in the world. 146 | 147 | n. You means the individual or entity exercising the Licensed Rights 148 | under this Public License. Your has a corresponding meaning. 149 | 150 | 151 | Section 2 -- Scope. 152 | 153 | a. License grant. 154 | 155 | 1. Subject to the terms and conditions of this Public License, 156 | the Licensor hereby grants You a worldwide, royalty-free, 157 | non-sublicensable, non-exclusive, irrevocable license to 158 | exercise the Licensed Rights in the Licensed Material to: 159 | 160 | a. reproduce and Share the Licensed Material, in whole or 161 | in part, for NonCommercial purposes only; and 162 | 163 | b. produce, reproduce, and Share Adapted Material for 164 | NonCommercial purposes only. 165 | 166 | 2. Exceptions and Limitations. For the avoidance of doubt, where 167 | Exceptions and Limitations apply to Your use, this Public 168 | License does not apply, and You do not need to comply with 169 | its terms and conditions. 170 | 171 | 3. Term. The term of this Public License is specified in Section 172 | 6(a). 173 | 174 | 4. Media and formats; technical modifications allowed. The 175 | Licensor authorizes You to exercise the Licensed Rights in 176 | all media and formats whether now known or hereafter created, 177 | and to make technical modifications necessary to do so. The 178 | Licensor waives and/or agrees not to assert any right or 179 | authority to forbid You from making technical modifications 180 | necessary to exercise the Licensed Rights, including 181 | technical modifications necessary to circumvent Effective 182 | Technological Measures. For purposes of this Public License, 183 | simply making modifications authorized by this Section 2(a) 184 | (4) never produces Adapted Material. 185 | 186 | 5. Downstream recipients. 187 | 188 | a. Offer from the Licensor -- Licensed Material. Every 189 | recipient of the Licensed Material automatically 190 | receives an offer from the Licensor to exercise the 191 | Licensed Rights under the terms and conditions of this 192 | Public License. 193 | 194 | b. Additional offer from the Licensor -- Adapted Material. 195 | Every recipient of Adapted Material from You 196 | automatically receives an offer from the Licensor to 197 | exercise the Licensed Rights in the Adapted Material 198 | under the conditions of the Adapter's License You apply. 199 | 200 | c. No downstream restrictions. You may not offer or impose 201 | any additional or different terms or conditions on, or 202 | apply any Effective Technological Measures to, the 203 | Licensed Material if doing so restricts exercise of the 204 | Licensed Rights by any recipient of the Licensed 205 | Material. 206 | 207 | 6. No endorsement. Nothing in this Public License constitutes or 208 | may be construed as permission to assert or imply that You 209 | are, or that Your use of the Licensed Material is, connected 210 | with, or sponsored, endorsed, or granted official status by, 211 | the Licensor or others designated to receive attribution as 212 | provided in Section 3(a)(1)(A)(i). 213 | 214 | b. Other rights. 215 | 216 | 1. Moral rights, such as the right of integrity, are not 217 | licensed under this Public License, nor are publicity, 218 | privacy, and/or other similar personality rights; however, to 219 | the extent possible, the Licensor waives and/or agrees not to 220 | assert any such rights held by the Licensor to the limited 221 | extent necessary to allow You to exercise the Licensed 222 | Rights, but not otherwise. 223 | 224 | 2. Patent and trademark rights are not licensed under this 225 | Public License. 226 | 227 | 3. To the extent possible, the Licensor waives any right to 228 | collect royalties from You for the exercise of the Licensed 229 | Rights, whether directly or through a collecting society 230 | under any voluntary or waivable statutory or compulsory 231 | licensing scheme. In all other cases the Licensor expressly 232 | reserves any right to collect such royalties, including when 233 | the Licensed Material is used other than for NonCommercial 234 | purposes. 235 | 236 | 237 | Section 3 -- License Conditions. 238 | 239 | Your exercise of the Licensed Rights is expressly made subject to the 240 | following conditions. 241 | 242 | a. Attribution. 243 | 244 | 1. If You Share the Licensed Material (including in modified 245 | form), You must: 246 | 247 | a. retain the following if it is supplied by the Licensor 248 | with the Licensed Material: 249 | 250 | i. identification of the creator(s) of the Licensed 251 | Material and any others designated to receive 252 | attribution, in any reasonable manner requested by 253 | the Licensor (including by pseudonym if 254 | designated); 255 | 256 | ii. a copyright notice; 257 | 258 | iii. a notice that refers to this Public License; 259 | 260 | iv. a notice that refers to the disclaimer of 261 | warranties; 262 | 263 | v. a URI or hyperlink to the Licensed Material to the 264 | extent reasonably practicable; 265 | 266 | b. indicate if You modified the Licensed Material and 267 | retain an indication of any previous modifications; and 268 | 269 | c. indicate the Licensed Material is licensed under this 270 | Public License, and include the text of, or the URI or 271 | hyperlink to, this Public License. 272 | 273 | 2. You may satisfy the conditions in Section 3(a)(1) in any 274 | reasonable manner based on the medium, means, and context in 275 | which You Share the Licensed Material. For example, it may be 276 | reasonable to satisfy the conditions by providing a URI or 277 | hyperlink to a resource that includes the required 278 | information. 279 | 3. If requested by the Licensor, You must remove any of the 280 | information required by Section 3(a)(1)(A) to the extent 281 | reasonably practicable. 282 | 283 | b. ShareAlike. 284 | 285 | In addition to the conditions in Section 3(a), if You Share 286 | Adapted Material You produce, the following conditions also apply. 287 | 288 | 1. The Adapter's License You apply must be a Creative Commons 289 | license with the same License Elements, this version or 290 | later, or a BY-NC-SA Compatible License. 291 | 292 | 2. You must include the text of, or the URI or hyperlink to, the 293 | Adapter's License You apply. You may satisfy this condition 294 | in any reasonable manner based on the medium, means, and 295 | context in which You Share Adapted Material. 296 | 297 | 3. You may not offer or impose any additional or different terms 298 | or conditions on, or apply any Effective Technological 299 | Measures to, Adapted Material that restrict exercise of the 300 | rights granted under the Adapter's License You apply. 301 | 302 | 303 | Section 4 -- Sui Generis Database Rights. 304 | 305 | Where the Licensed Rights include Sui Generis Database Rights that 306 | apply to Your use of the Licensed Material: 307 | 308 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 309 | to extract, reuse, reproduce, and Share all or a substantial 310 | portion of the contents of the database for NonCommercial purposes 311 | only; 312 | 313 | b. if You include all or a substantial portion of the database 314 | contents in a database in which You have Sui Generis Database 315 | Rights, then the database in which You have Sui Generis Database 316 | Rights (but not its individual contents) is Adapted Material, 317 | including for purposes of Section 3(b); and 318 | 319 | c. You must comply with the conditions in Section 3(a) if You Share 320 | all or a substantial portion of the contents of the database. 321 | 322 | For the avoidance of doubt, this Section 4 supplements and does not 323 | replace Your obligations under this Public License where the Licensed 324 | Rights include other Copyright and Similar Rights. 325 | 326 | 327 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 328 | 329 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 330 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 331 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 332 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 333 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 334 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 335 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 336 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 337 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 338 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 339 | 340 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 341 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 342 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 343 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 344 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 345 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 346 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 347 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 348 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 349 | 350 | c. The disclaimer of warranties and limitation of liability provided 351 | above shall be interpreted in a manner that, to the extent 352 | possible, most closely approximates an absolute disclaimer and 353 | waiver of all liability. 354 | 355 | 356 | Section 6 -- Term and Termination. 357 | 358 | a. This Public License applies for the term of the Copyright and 359 | Similar Rights licensed here. However, if You fail to comply with 360 | this Public License, then Your rights under this Public License 361 | terminate automatically. 362 | 363 | b. Where Your right to use the Licensed Material has terminated under 364 | Section 6(a), it reinstates: 365 | 366 | 1. automatically as of the date the violation is cured, provided 367 | it is cured within 30 days of Your discovery of the 368 | violation; or 369 | 370 | 2. upon express reinstatement by the Licensor. 371 | 372 | For the avoidance of doubt, this Section 6(b) does not affect any 373 | right the Licensor may have to seek remedies for Your violations 374 | of this Public License. 375 | 376 | c. For the avoidance of doubt, the Licensor may also offer the 377 | Licensed Material under separate terms or conditions or stop 378 | distributing the Licensed Material at any time; however, doing so 379 | will not terminate this Public License. 380 | 381 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 382 | License. 383 | 384 | 385 | Section 7 -- Other Terms and Conditions. 386 | 387 | a. The Licensor shall not be bound by any additional or different 388 | terms or conditions communicated by You unless expressly agreed. 389 | 390 | b. Any arrangements, understandings, or agreements regarding the 391 | Licensed Material not stated herein are separate from and 392 | independent of the terms and conditions of this Public License. 393 | 394 | 395 | Section 8 -- Interpretation. 396 | 397 | a. For the avoidance of doubt, this Public License does not, and 398 | shall not be interpreted to, reduce, limit, restrict, or impose 399 | conditions on any use of the Licensed Material that could lawfully 400 | be made without permission under this Public License. 401 | 402 | b. To the extent possible, if any provision of this Public License is 403 | deemed unenforceable, it shall be automatically reformed to the 404 | minimum extent necessary to make it enforceable. If the provision 405 | cannot be reformed, it shall be severed from this Public License 406 | without affecting the enforceability of the remaining terms and 407 | conditions. 408 | 409 | c. No term or condition of this Public License will be waived and no 410 | failure to comply consented to unless expressly agreed to by the 411 | Licensor. 412 | 413 | d. Nothing in this Public License constitutes or may be interpreted 414 | as a limitation upon, or waiver of, any privileges and immunities 415 | that apply to the Licensor or You, including from the legal 416 | processes of any jurisdiction or authority. 417 | -------------------------------------------------------------------------------- /viewdiff/io_util.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | import os 4 | from argparse import Namespace 5 | import json 6 | import imageio 7 | import numpy as np 8 | from datetime import datetime 9 | from typing import List, Dict, Tuple, Literal, Optional 10 | from dataclasses import dataclass 11 | from collections.abc import MutableMapping, Iterable 12 | 13 | from PIL import Image 14 | import torch 15 | from torch.nn.functional import interpolate 16 | 17 | from torchvision.utils import make_grid 18 | 19 | from .data.co3d.co3d_dataset import CO3DConfig 20 | 21 | from .data.create_video_from_image_folder import main as create_video_from_image_folder 22 | 23 | 24 | @dataclass 25 | class SaveConfig: 26 | """Which file types should be saved in #save_inference_outputs().""" 27 | 28 | image_grids: bool = False 29 | pred_files: bool = True 30 | pred_video: bool = True 31 | pred_gif: bool = False 32 | denoise_files: bool = False 33 | denoise_video: bool = False 34 | cams: bool = True 35 | prompts: bool = True 36 | rendered_depth: bool = False 37 | cond_files: bool = False 38 | image_metrics: bool = True 39 | 40 | 41 | @dataclass 42 | class IOConfig: 43 | """Arguments for IO.""" 44 | 45 | save: SaveConfig = SaveConfig() 46 | 47 | pretrained_model_name_or_path: str = "stabilityai/stable-diffusion-2-1-base", 48 | """Path to pretrained model or model identifier from huggingface.co/models""" 49 | 50 | revision: Optional[str] = None 51 | """Revision of pretrained model identifier from huggingface.co/models.""" 52 | 53 | output_dir: str = "output" 54 | """The output directory where the model predictions and checkpoints will be written.""" 55 | 56 | experiment_name: Optional[str] = None 57 | """If this is set, will use this instead of the datetime string as identifier for the experiment.""" 58 | 59 | logging_dir: str = "logs" 60 | """[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to 61 | *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.""" 62 | 63 | log_images_every_nth: int = 500 64 | """log images every nth step""" 65 | 66 | report_to: Literal["tensorboard", "custom_tensorboard"] = "custom_tensorboard" 67 | """The integration to report the results and logs to. Supported platforms are `"tensorboard"`""" 68 | 69 | checkpointing_steps: int = 500 70 | """Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming 71 | training using `--resume_from_checkpoint`.""" 72 | 73 | checkpoints_total_limit: int = 2 74 | """Max number of checkpoints to store.""" 75 | 76 | resume_from_checkpoint: Optional[str] = None 77 | """Whether training should be resumed from a previous checkpoint. Use a path saved by 78 | ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.""" 79 | 80 | automatic_checkpoint_resume: bool = False 81 | 82 | 83 | def setup_output_directories( 84 | io_config: IOConfig, 85 | model_config: dataclass, 86 | dataset_config: CO3DConfig, 87 | is_train: bool = False, 88 | ): 89 | """Creates output directories for an experiment as specified in the provided configs. 90 | 91 | Args: 92 | io_config (dataclass): _description_ 93 | model_config (dataclass): _description_ 94 | dataset_config (CO3Dv2MaskAugmentationConfig): _description_ 95 | is_train (bool, optional): _description_. Defaults to False. 96 | """ 97 | 98 | if isinstance(dataset_config, CO3DConfig): 99 | # append category to output_dir 100 | category_name = "all" if dataset_config.category is None else dataset_config.category 101 | io_config.output_dir = os.path.join(io_config.output_dir, category_name) 102 | 103 | # append n_sequences to output_dir 104 | if dataset_config.max_sequences > -1: 105 | io_config.output_dir = os.path.join(io_config.output_dir, f"{dataset_config.max_sequences}_sequences") 106 | 107 | # append picked_sequences to output_dir 108 | io_config.output_dir = os.path.join(io_config.output_dir, ",".join(dataset_config.dataset_args.pick_sequence)) 109 | 110 | # append subset to output_dir 111 | io_config.output_dir = os.path.join( 112 | io_config.output_dir, "subset_all" if dataset_config.subset is None else f"subset_{dataset_config.subset}" 113 | ) 114 | else: 115 | raise NotImplementedError("unsupported dataset config", type(dataset_config)) 116 | 117 | # append input/output information to output_dir 118 | io_config.output_dir = os.path.join( 119 | io_config.output_dir, f"input_{model_config.n_input_images}" 120 | ) 121 | 122 | # append train/test information to output_dir 123 | io_config.output_dir = os.path.join(io_config.output_dir, "train" if is_train else "test") 124 | 125 | # append experiment name or current datetime to output_dir 126 | has_exp_name = hasattr(io_config, "experiment_name") and io_config.experiment_name is not None 127 | if has_exp_name: 128 | io_config.output_dir = os.path.join(io_config.output_dir, io_config.experiment_name) 129 | if not has_exp_name or not is_train: 130 | date_time = datetime.now().strftime("%d.%m.%Y_%H:%M:%S.%f") 131 | io_config.output_dir = os.path.join(io_config.output_dir, date_time) 132 | 133 | 134 | def make_output_directories(io_config: IOConfig): 135 | if not os.path.exists(io_config.output_dir): 136 | os.makedirs(io_config.output_dir) 137 | io_config.image_out_dir = os.path.join(io_config.output_dir, "images") 138 | io_config.rendered_depth_out_dir = os.path.join(io_config.output_dir, "rendered_depth") 139 | io_config.stats_out_dir = os.path.join(io_config.output_dir, "stats") 140 | if not os.path.exists(io_config.image_out_dir): 141 | os.makedirs(io_config.image_out_dir) 142 | if not os.path.exists(io_config.rendered_depth_out_dir): 143 | os.makedirs(io_config.rendered_depth_out_dir) 144 | if not os.path.exists(io_config.stats_out_dir): 145 | os.makedirs(io_config.stats_out_dir) 146 | 147 | 148 | def make_image_grid(*images) -> torch.Tensor: 149 | """Returns a grid of image tuples for the given batch that can be used for logging. 150 | Each row represents one batch sample with K pairs. 151 | Each tuple is the horizontal concatenation of the list of images. 152 | 153 | Args: 154 | images: list of images where each value is a tensor of shape (N, K, C, H, W). Expects all tensors in the same value range, s.t. normalization maps them to 0..1 155 | 156 | Returns: 157 | torch.Tensor: the grid image represented as tensor in range [0..1] 158 | """ 159 | N, K = images[0].shape[:2] 160 | combined_image_list = [] 161 | for n in range(N): 162 | for k in range(K): 163 | combined_image = torch.cat([img[n, k, :3].detach().cpu().float() for img in images], dim=-1) 164 | combined_image_list.append(combined_image) 165 | combined_image = make_grid(combined_image_list, nrow=K, normalize=True) 166 | 167 | return combined_image 168 | 169 | 170 | def make_pred_grid(pred_images: List[torch.Tensor], *images) -> Tuple[torch.Tensor, torch.Tensor]: 171 | """Returns a grid of image tuples for the given batch and predictions that can be used for logging. 172 | Each row represents one batch sample with K triplets. 173 | Each tuple is the horizontal concatenation of the list of images and the prediction. 174 | 175 | Args: 176 | images: list of images where each value is a tensor of shape (N, K, C, H, W). Expects all tensors in the same value range, s.t. normalization maps them to 0..1 177 | pred_images (List[torch.Tensor]): tensor of shape (N, K, C, H', W') in range [0..1]. 178 | 179 | Returns: 180 | torch.Tensor: the grid image represented as tensor in range [0..1] 181 | """ 182 | 183 | N, K = images[0].shape[:2] 184 | 185 | # normalize and resize input images 186 | converted_images = [] 187 | for img in images: 188 | # reshape images if necessary 189 | if img.shape[-2:] != pred_images.shape[-2:]: 190 | img = img.view(N * K, *img.shape[2:]) 191 | img = interpolate(img, pred_images.shape[-2:], mode="bilinear") 192 | img = img.view(N, K, *img.shape[1:]) 193 | 194 | # normalize to 0..1 195 | img = norm_0_1(img) 196 | 197 | converted_images.append(img) 198 | 199 | return make_image_grid(*converted_images, pred_images) 200 | 201 | 202 | def norm_0_1(x: torch.Tensor) -> torch.Tensor: 203 | min_val = x.min() 204 | max_val = x.max() 205 | return (x - min_val) / (max_val - min_val) 206 | 207 | 208 | def torch_to_numpy(img: torch.Tensor) -> np.ndarray: 209 | """Converts a tensor to a np.ndarray. 210 | 211 | Args: 212 | img (torch.Tensor): tensor in arbitrary range of type float32 213 | 214 | Returns: 215 | np.ndarray: np image of type uint8 in range [0..255] 216 | """ 217 | # normalize to [0, 1] 218 | img = norm_0_1(img) 219 | 220 | # to [0, 255] 221 | img = (img * 255.0).to(torch.uint8) 222 | 223 | # (C, H, W) torch.Tensor to (H, W, C) np.array 224 | img = img.permute(1, 2, 0).cpu().numpy() 225 | 226 | return img 227 | 228 | 229 | def torch_to_pil(img: torch.Tensor) -> Image: 230 | """Converts a tensor to a PIL Image. 231 | 232 | Args: 233 | img (torch.Tensor): tensor in arbitrary range 234 | 235 | Returns: 236 | Image: PIL Image in range [0..255] 237 | """ 238 | img = torch_to_numpy(img) 239 | 240 | # np.array to PIL.Image 241 | img = Image.fromarray(img) 242 | 243 | return img 244 | 245 | 246 | def convert_to_tensorboard_dict(x: Dict) -> Dict: 247 | """Converts a dictionary to a format supported for logging in tensorboard. 248 | That is: flattens the dictionary, replaces None with "", and replaces lists with str(list). 249 | 250 | Args: 251 | x (Dict): the dict to convert 252 | 253 | Returns: 254 | Dict: the converted dict 255 | """ 256 | 257 | def flatten(dictionary, parent_key="", separator="_"): 258 | items = [] 259 | for key, value in dictionary.items(): 260 | new_key = parent_key + separator + key if parent_key else key 261 | if isinstance(value, MutableMapping): 262 | items.extend(flatten(value, new_key, separator=separator).items()) 263 | else: 264 | items.append((new_key, value)) 265 | return dict(items) 266 | 267 | # hierarchical dicts cannot be logged in tb, replace with flattened dict 268 | x = flatten(x) 269 | 270 | # None cannot be logged in tb, replace with "" 271 | x = {k: v if v is not None else "" for k, v in x.items()} 272 | 273 | # lists cannot be logged in tb, replace with str repr 274 | x = {k: ", ".join(map(str, v)) if isinstance(v, Iterable) else v for k, v in x.items()} 275 | 276 | return x 277 | 278 | 279 | def save_inference_outputs( 280 | batch: Dict[str, torch.Tensor], 281 | output, 282 | io_config: IOConfig, 283 | writer, 284 | step: int = 0, 285 | prefix: str = None, 286 | ): 287 | """Saves one batch/prediction to output folders and tensorboard. 288 | 289 | Args: 290 | batch (Dict[str, torch.Tensor]): the input batch 291 | output: the predicted output 292 | io_config (IOConfig): specifying output folder location 293 | writer (SummaryWriter): the tensorboard logger 294 | step (int, optional): the index of the batch. Defaults to 0. 295 | prefix (str, optional): prefix to use for all file/tensorboard outputs. Defaults to None. 296 | """ 297 | if prefix is not None: 298 | tb_prefix = f"{prefix}/" 299 | prefix = f"{prefix}_" 300 | else: 301 | tb_prefix = "" 302 | prefix = "" 303 | 304 | # parse input to log in the batch 305 | batch_input_list = [] 306 | keys = ["images"] 307 | for frame_idx in keys: 308 | if frame_idx in batch: 309 | batch_input_list.append(batch[frame_idx]) 310 | 311 | # save image grid 312 | if io_config.save.image_grids: 313 | pred_image_grid = make_pred_grid(output.images, *batch_input_list) 314 | writer.add_image(f"{tb_prefix}Images", pred_image_grid, global_step=step) 315 | pred_image_grid_pil = torch_to_pil(pred_image_grid) 316 | with open(os.path.join(io_config.image_out_dir, f"{prefix}image_grid_{step:04d}.png"), "wb") as f: 317 | pred_image_grid_pil.save(f) 318 | 319 | # add image diffusion process 320 | root_output_path = io_config.image_out_dir 321 | file_patterns = [] 322 | if io_config.save.denoise_files and hasattr(output, "image_list"): 323 | # save image denoise predictions as separate files 324 | file_patterns.append("denoise_files_") 325 | for time_idx, img in enumerate(output.image_list): 326 | img = make_image_grid(img) 327 | writer.add_image(f"{tb_prefix}/Denoise/{time_idx}", img, global_step=step) 328 | img = torch_to_pil(img) 329 | with open(os.path.join(root_output_path, f"{prefix}denoise_files_{step:04d}_{time_idx:04d}.png"), "wb") as f: 330 | img.save(f) 331 | 332 | if io_config.save.denoise_video: 333 | # save image denoise predictions as video 334 | file_patterns.append("denoise_video_") 335 | output_path = os.path.join(root_output_path, f"{prefix}denoise_video_{step:04d}.mp4") 336 | video_args = Namespace( 337 | **{ 338 | "image_folder": root_output_path, 339 | "file_name_pattern_glob": f"{prefix}denoise_files_{step:04d}_*.png", 340 | "output_path": output_path, 341 | "framerate": 10, 342 | } 343 | ) 344 | create_video_from_image_folder(video_args) 345 | 346 | # create list of pil images for attention logging and filename logging (shared) 347 | # save poses/intrs to dict 348 | N, K = output.images.shape[:2] 349 | pil_images = [] 350 | cams = { 351 | "poses": {}, 352 | "intrs": {} 353 | } 354 | for n in range(N): 355 | for frame_idx in range(K): 356 | # save the predictions using their original filenames 357 | file_name = batch["file_names"][frame_idx][n] 358 | sequence = os.path.basename(batch["root"][n]) 359 | key = f"step_{step:04d}_seq_{sequence}_file_{file_name}_frame_{frame_idx:04d}" 360 | if io_config.save.pred_files: 361 | file_patterns.append("pred_file_") 362 | if io_config.save.cond_files or "cond_" not in file_name: 363 | # convert to pil 364 | img = torch_to_pil(output.images[n, frame_idx].detach().cpu()) 365 | pil_images.append(img) 366 | 367 | # save image as file 368 | with open( 369 | os.path.join(root_output_path, f"{prefix}pred_file_{key}.png"), 370 | "wb", 371 | ) as f: 372 | img.save(f) 373 | 374 | # save cams in dict 375 | cams["poses"][key] = batch["pose"][n, frame_idx] 376 | cams["intrs"][key] = batch["K"][n, frame_idx] 377 | 378 | if io_config.save.pred_video: 379 | # save filename predictions as video 380 | file_patterns.append("pred_video_") 381 | output_path = os.path.join(root_output_path, f"{prefix}pred_video_{step:04d}.mp4") 382 | video_args = Namespace( 383 | **{ 384 | "image_folder": root_output_path, 385 | "file_name_pattern_glob": f"{prefix}pred_file_step_{step:04d}_*.png", 386 | "output_path": output_path, 387 | "framerate": 5, 388 | } 389 | ) 390 | create_video_from_image_folder(video_args) 391 | 392 | if io_config.save.pred_gif: 393 | file_patterns.append("pred_gif_") 394 | output_path = os.path.join(root_output_path, f"{prefix}pred_gif_{step:04d}.gif") 395 | with imageio.get_writer(output_path, mode='I', duration=1.0 / 15.0) as writer: 396 | for im in pil_images: 397 | writer.append_data(np.array(im)) 398 | 399 | # save poses/intrinsics 400 | if io_config.save.cams: 401 | file_patterns.append("cams_") 402 | with open( 403 | os.path.join(root_output_path, f"{prefix}cams_{step:04d}.pt"), 404 | "wb", 405 | ) as f: 406 | torch.save(cams, f) 407 | 408 | # save prompts 409 | if io_config.save.prompts: 410 | file_patterns.append("prompts_") 411 | with open( 412 | os.path.join(root_output_path, f"{prefix}prompts_{step:04d}.txt"), 413 | "w", 414 | ) as f: 415 | for p in batch["prompt"]: 416 | f.write(f"{p}\n") 417 | 418 | # save image metrics 419 | if io_config.save.image_metrics and hasattr(output, "image_metrics"): 420 | with open(os.path.join(io_config.stats_out_dir, f"image_metrics_{step}.json"), "w") as f: 421 | json.dump(output.image_metrics, f, indent=4) 422 | 423 | # save rendered_depth and rendered_mask 424 | if ( 425 | io_config.save.rendered_depth 426 | and hasattr(output, "rendered_depth") 427 | and output.rendered_depth is not None 428 | and hasattr(output, "rendered_mask") 429 | and output.rendered_mask is not None 430 | ): 431 | root_out = io_config.rendered_depth_out_dir 432 | 433 | # t goes over the timesteps, have one list of "rendered_depth/mask per layer" per timestep 434 | per_layer_image_list = {} 435 | for t, (depth_per_layer, mask_per_layer) in enumerate(zip(output.rendered_depth, output.rendered_mask)): 436 | if depth_per_layer is None or mask_per_layer is None: 437 | continue 438 | # i goes over the layers, have one "rendered_depth/mask" per layer 439 | for i, (d, m) in enumerate(zip(depth_per_layer, mask_per_layer)): 440 | # get grid 441 | rendered_depth_mask_grid = make_image_grid(norm_0_1(d), norm_0_1(m)) 442 | 443 | # write to tb 444 | writer.add_image(f"{tb_prefix}Rendered-Depth-Mask/{i}/{t}", rendered_depth_mask_grid, global_step=step) 445 | 446 | # convert to pil 447 | rendered_depth_mask_grid_pil = torch_to_pil(rendered_depth_mask_grid) 448 | 449 | # save for video 450 | if i not in per_layer_image_list: 451 | per_layer_image_list[i] = [] 452 | per_layer_image_list[i].append(rendered_depth_mask_grid_pil) 453 | 454 | # save individual files to disk 455 | with open( 456 | os.path.join( 457 | root_out, 458 | f"{prefix}rendered_depth_mask_grid_{step:04d}_layer_{i:04d}_step_{t:04d}.png", 459 | ), 460 | "wb", 461 | ) as f: 462 | rendered_depth_mask_grid_pil.save(f) 463 | 464 | # save video per layer 465 | for frame_idx, image_list in per_layer_image_list.items(): 466 | video_out = os.path.join(root_out, f"diff_process_{step:04d}_{frame_idx}.mp4") 467 | 468 | # save all files locally 469 | for t, img in enumerate(image_list): 470 | file_out = os.path.join( 471 | root_out, f"{prefix}rendered_depth_mask_grid_{step:04d}_layer_{frame_idx:04d}_step_{t:04d}.png" 472 | ) 473 | if not os.path.exists(file_out): 474 | with open(file_out, "wb") as f: 475 | img.save(f) 476 | 477 | # create video locally 478 | video_args = Namespace( 479 | **{ 480 | "image_folder": root_out, 481 | "file_name_pattern_glob": f"{prefix}rendered_depth_mask_grid_{step:04d}_layer_{frame_idx:04d}_step_*.png", 482 | "output_path": video_out, 483 | "framerate": 5, 484 | } 485 | ) 486 | create_video_from_image_folder(video_args) 487 | 488 | 489 | def create_videos(io: IOConfig, prefix: str = None): 490 | if prefix is not None: 491 | prefix = f"{prefix}_" 492 | else: 493 | prefix = "" 494 | 495 | root_output_path = io.image_out_dir 496 | output_path = os.path.join(root_output_path, "image_grid.mp4") 497 | image_folder = io.image_out_dir 498 | if image_folder[-1] != "/": 499 | image_folder += "/" 500 | video_args = Namespace( 501 | **{ 502 | "image_folder": image_folder, 503 | "file_name_pattern_glob": f"{prefix}image_grid_*.png", 504 | "output_path": output_path, 505 | "framerate": 15, 506 | } 507 | ) 508 | create_video_from_image_folder(video_args) 509 | video_args.file_name_pattern_glob = f"{prefix}pred_*.png" 510 | video_args.output_path = os.path.join(root_output_path, "pred.mp4") 511 | create_video_from_image_folder(video_args) 512 | -------------------------------------------------------------------------------- /viewdiff/model/projection/voxel_proj.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | from typing import Literal, Tuple 4 | import torch 5 | from torch import nn 6 | import torch.nn.functional as F 7 | 8 | from ..custom_attention_processor import collapse_batch, expand_batch 9 | from .util import screen_to_ndc, project_batch 10 | from ...data.co3d.util import scale_bbox, scale_camera_center 11 | 12 | 13 | torch._C._jit_set_profiling_executor(False) 14 | torch._C._jit_set_profiling_mode(False) 15 | 16 | 17 | @torch.jit.script 18 | def fused_mean_variance(x, weight): 19 | mean = torch.sum(x * weight, dim=1, keepdim=True) 20 | var = torch.sum(weight * (x - mean) ** 2, dim=1, keepdim=True) 21 | return mean, var 22 | 23 | 24 | # adapted from: https://github.com/googleinterns/IBRNet/blob/master/ibrnet/mlp_network.py 25 | class IBRNet_Aggregator(nn.Module): 26 | def __init__(self, feature_dim: int = 32, anti_alias_pooling: bool = False, kernel_size: int = 1, padding: int = 0, use_temb: bool = False, temb_dim: int = 1280): 27 | super().__init__() 28 | 29 | self.anti_alias_pooling = anti_alias_pooling 30 | if self.anti_alias_pooling: 31 | self.s = nn.Parameter(torch.tensor(0.2), requires_grad=True) 32 | activation_func = nn.ELU(inplace=True) 33 | 34 | # turn voxel_depth, rays into encoding to add to features 35 | self.ray_depth_encoder = nn.Sequential( 36 | nn.Conv3d(4, feature_dim // 2, kernel_size=kernel_size, padding=padding), 37 | activation_func, 38 | nn.Conv3d(feature_dim // 2, feature_dim, kernel_size=kernel_size, padding=padding), 39 | activation_func, 40 | ) 41 | 42 | # turn time embedding into encoding to add to features 43 | self.use_temb = use_temb 44 | if use_temb: 45 | self.temb_encoder = nn.Sequential( 46 | nn.Linear(temb_dim, temb_dim // 2), 47 | activation_func, 48 | nn.Linear(temb_dim // 2, feature_dim), 49 | activation_func, 50 | ) 51 | 52 | # shared part of feature/weight encoding 53 | self.base_fc = nn.Sequential( 54 | nn.Conv3d(feature_dim * 3, feature_dim * 2, kernel_size=kernel_size, padding=padding), 55 | activation_func, 56 | nn.Conv3d(feature_dim * 2, feature_dim, kernel_size=kernel_size, padding=padding), 57 | activation_func, 58 | ) 59 | 60 | # compute first part of averaging weights, final features 61 | self.vis_fc = nn.Sequential( 62 | nn.Conv3d(feature_dim, feature_dim, kernel_size=kernel_size, padding=padding), 63 | activation_func, 64 | nn.Conv3d(feature_dim, feature_dim + 1, kernel_size=kernel_size, padding=padding), 65 | activation_func, 66 | ) 67 | 68 | # compute second part of averaging weights 69 | self.vis_fc2 = nn.Sequential( 70 | nn.Conv3d(feature_dim, feature_dim, kernel_size=kernel_size, padding=padding), 71 | activation_func, 72 | nn.Conv3d(feature_dim, 1, kernel_size=kernel_size, padding=padding), 73 | nn.Sigmoid(), 74 | ) 75 | 76 | # combine (weight, mean, var) into final grid 77 | self.statistics_out = nn.Sequential( 78 | nn.Conv3d(2 * feature_dim + 1, feature_dim, kernel_size=kernel_size, padding=padding), 79 | activation_func, 80 | ) 81 | 82 | def forward(self, features: torch.Tensor, mask: torch.Tensor, voxel_depth: torch.Tensor, voxel_dir: torch.Tensor, temb: torch.Tensor = None): 83 | """ 84 | 85 | Args: 86 | features (torch.Tensor): tensor of shape (batch_size, num_images, feature_dim, grid_dim, grid_dim, grid_dim) 87 | ray_diff (torch.Tensor): tensor of shape (batch_size, num_images, 3, grid_dim, grid_dim, grid_dim) 88 | mask (torch.Tensor): tensor of shape (batch_size, num_images, 1, grid_dim, grid_dim, grid_dim) 89 | temb (torch.Tensor) tensor of shape (batch_size, num_images, temb_dim) 90 | """ 91 | num_images = features.shape[1] 92 | 93 | # add ray encoding and depth 94 | ray_depth_enc = torch.cat([voxel_dir, voxel_depth], dim=2) 95 | ray_depth_enc = collapse_batch(ray_depth_enc) 96 | ray_depth_enc = self.ray_depth_encoder(ray_depth_enc) 97 | ray_depth_enc = expand_batch(ray_depth_enc, num_images) 98 | features = features + ray_depth_enc 99 | 100 | # add temb encoding 101 | if self.use_temb: 102 | temb_enc = self.temb_encoder(temb) 103 | temb_enc = temb_enc.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) 104 | temb_enc = temb_enc.repeat(1, 1, 1, *features.shape[3:]).contiguous() 105 | features = features + temb_enc 106 | 107 | if self.anti_alias_pooling: 108 | raise NotImplementedError() 109 | else: 110 | weight = mask / (torch.sum(mask, dim=1, keepdim=True) + 1e-8) 111 | 112 | # compute mean and variance across different views for each voxel (== same as aggregate_cost_volume(agg_fn="mean")) 113 | mean, var = fused_mean_variance( 114 | features, weight 115 | ) # (batch_size, 1, feature_dim, grid_dim, grid_dim, grid_dim) 116 | globalfeat = torch.cat([mean, var], dim=2) # (batch_size, 1, 2*feature_dim, grid_dim, grid_dim, grid_dim) 117 | 118 | # combine each voxel with the globalfeat across all views 119 | # (batch_size, num_images, 3*feature_dim, grid_dim, grid_dim, grid_dim) 120 | x = torch.cat([globalfeat.expand(-1, num_images, -1, -1, -1, -1), features], dim=2) 121 | 122 | # encode base_fc: shared part of feature/weight encoding 123 | x = collapse_batch(x) # (batch_size * num_images, 3*feature_dim, grid_dim, grid_dim, grid_dim) 124 | weight = collapse_batch(weight) 125 | mask = collapse_batch(mask) 126 | x = self.base_fc(x) 127 | 128 | # get averaging weights, final features 129 | x_vis = self.vis_fc(x * weight) 130 | x_res, vis = torch.split(x_vis, [x_vis.shape[1] - 1, 1], dim=1) 131 | vis = F.sigmoid(vis) * mask 132 | x = x + x_res 133 | vis = self.vis_fc2(x * vis) * mask 134 | 135 | # compute weighted average 136 | vis = expand_batch(vis, num_images) 137 | x = expand_batch(x, num_images) 138 | weight = vis / (torch.sum(vis, dim=1, keepdim=True) + 1e-8) 139 | mean, var = fused_mean_variance(x, weight) 140 | 141 | # combine (mean, var, weight) and let a final custom layer transform it into the feature grid 142 | # (batch_size, 2*feature_dim + 1, grid_dim, grid_dim, grid_dim) 143 | globalfeat = torch.cat([mean.squeeze(1), var.squeeze(1), weight.mean(dim=1)], dim=1) 144 | globalfeat = self.statistics_out(globalfeat) 145 | 146 | return globalfeat 147 | 148 | 149 | def get_grid_to_world_space_matrix(bbox: torch.Tensor) -> torch.Tensor: 150 | """Calculates the matrix that transforms grid-space coordinates to world-space coordinates. 151 | World-space is defined through the bbox. We calculate a matrix that translates and uniformly scales the grid-space to fit into the bbox. 152 | Grid-space is defined in [-1, 1] where the extrema refer to the corners of the grids, e.g. voxels refer to the centers shifted by half-pixel coordinate. 153 | 154 | Args: 155 | bbox (torch.Tensor): tensor of shape (B, 2, 3) giving the min-xyz and max-xyz bbox corners. 156 | 157 | Returns: 158 | torch.Tensor: the grid_to_world_space_matrix of shape (4, 4) 159 | """ 160 | bbox_min = bbox[:, 0] 161 | bbox_max = bbox[:, 1] 162 | 163 | # scale from 2.0 to largest_side_length 164 | largest_side_length = (bbox_max - bbox_min).max(dim=1).values 165 | uniform_scale = largest_side_length / 2.0 166 | 167 | # translate from center = (0, 0) to bbox_center 168 | bbox_center = (bbox_min + bbox_max) / 2 169 | 170 | # build final matrix combining scale and translation 171 | grid2world = torch.zeros(bbox.shape[0], 4, 4, device=bbox.device) 172 | grid2world[:, 0, 0] = uniform_scale 173 | grid2world[:, 1, 1] = uniform_scale 174 | grid2world[:, 2, 2] = uniform_scale 175 | grid2world[:, 3, 3] = 1 176 | grid2world[:, :3, 3] = bbox_center 177 | 178 | return grid2world 179 | 180 | 181 | def _contract_pi_inv(x, perc_foreground: float = 0.5): 182 | max_index = torch.argmax(x.abs(), dim=1, keepdim=True) 183 | n = torch.gather(x, dim=1, index=max_index) 184 | p = 1.0 / perc_foreground 185 | n_inv = torch.where(n > 0, -(p - 1) / (n - p), -(p - 1) / (n + p)) 186 | x_inv = torch.where((x.abs() - n).abs() <= 1e-7, n_inv.repeat(1, x.shape[1]), n_inv.abs() * x) 187 | x_inv = torch.where(n.abs() <= 1.0, x, x_inv) 188 | return x_inv 189 | 190 | 191 | @torch.autocast("cuda", enabled=False) 192 | def build_cost_volume( 193 | features: torch.Tensor, 194 | poses_world2cam: torch.Tensor, 195 | intrinsics: torch.Tensor, 196 | bbox: torch.Tensor, 197 | grid_dim: int = 128, 198 | feature_sampling_mode: Literal["bilinear", "nearest", "bicubic"] = "bilinear", 199 | depth: torch.Tensor = None, 200 | depth_threshold: float = 1e-8, 201 | contract_background: bool = False, 202 | contract_background_percentage: float = 0.5, 203 | ) -> Tuple[torch.Tensor, torch.Tensor]: 204 | """Create a per-frame cost-volume of the features. 205 | 206 | Args: 207 | features (torch.Tensor): (batch_size, num_images, feature_dim, height, width) 208 | poses_world2cam (torch.Tensor): (batch_size, num_images, 4, 4) 209 | intrinsics (torch.Tensor): (batch_size, num_images, 4, 4) - already downsampled to respect (height, width). 210 | bbox (torch.Tensor): (batch_size, 2, 3) 211 | world_space_transform (torch.Tensor): (batch_size, 3, 3) 212 | grid_dim (int): voxel-grid dimension in xyz 213 | feature_sampling_mode (Literal["bilinear", "nearest", "bicubic"], optional): sampling mode for interpolation of features. Defaults to "bilinear". 214 | depth (torch.Tensor, optional): (batch_size, num_images, height, width). depth map to build up the cost-volume. Defaults to None. 215 | depth_threshold (float, optional): if depth is given, uses this threshold to filter out voxels that are not close enough to GT depth. Defaults to 1e-3. 216 | 217 | Returns: 218 | (torch.Tensor: cost-volume per frame, shape (batch_size, num_images, feature_dim, grid_dim, grid_dim, grid_dim), 219 | torch.Tensor: weights of cost-volume per frame, shape (batch_size, num_images, 1, grid_dim, grid_dim, grid_dim)) 220 | """ 221 | # get shape info 222 | batch_size = features.shape[0] 223 | num_images = features.shape[1] 224 | feature_dim = features.shape[2] 225 | height = features.shape[3] 226 | width = features.shape[4] 227 | 228 | # Generate voxel indices. --> xyz coordinates in [0...grid_dim - 1] 229 | x = torch.arange(grid_dim, dtype=poses_world2cam.dtype, device=poses_world2cam.device) 230 | y = torch.arange(grid_dim, dtype=poses_world2cam.dtype, device=poses_world2cam.device) 231 | z = torch.arange(grid_dim, dtype=poses_world2cam.dtype, device=poses_world2cam.device) 232 | 233 | grid_x, grid_y, grid_z = torch.meshgrid(x, y, z, indexing="xy") 234 | grid_xyz = torch.cat( 235 | [ 236 | grid_x.view(grid_dim, grid_dim, grid_dim, 1), 237 | grid_y.view(grid_dim, grid_dim, grid_dim, 1), 238 | grid_z.view(grid_dim, grid_dim, grid_dim, 1), 239 | ], 240 | dim=3, 241 | ) 242 | 243 | grid_xyz = grid_xyz.view(grid_dim * grid_dim * grid_dim, 3).contiguous() 244 | num_voxels = grid_xyz.shape[0] 245 | 246 | # convert grid coordinates to [-1, 1]. 247 | # convention: coordinates refer to voxel centers and the centers are at half-pixel coordinates 248 | grid_xyz = (grid_xyz + 0.5) / grid_dim * 2.0 - 1.0 249 | 250 | if contract_background: 251 | # what is the valid range of coordinates in foreground/background separ`1ation 252 | # e.g. if we have 50% foreground, 50% background, then it goes from [-2, 2] and the values in [-1, 1] are foreground 253 | # e.g. if we have 80% foreground, 20% background, then it goes from [-1.25, 1.25] and the values in [-1, 1] are foreground 254 | # we invert this step here 255 | grid_xyz = grid_xyz * (1.0 / contract_background_percentage) 256 | 257 | # see MERF (https://arxiv.org/pdf/2302.12249.pdf): we do not want to store features at the ill-defined regions of the contraction function 258 | invalid_contract_voxels_mask = (grid_xyz > 1.0).sum(dim=-1) > 1 259 | invalid_contract_voxels_mask = invalid_contract_voxels_mask.view(grid_dim, grid_dim, grid_dim).contiguous() 260 | 261 | # invert the contraction --> store image features for voxels at their true world-space coordinates 262 | grid_xyz = _contract_pi_inv(grid_xyz, contract_background_percentage) 263 | 264 | # get grid2world matrices 265 | grid2world = get_grid_to_world_space_matrix(bbox).to(poses_world2cam.device) 266 | 267 | # convert grid-space points to world-space points 268 | world_xyz = grid_xyz[None].repeat(batch_size, 1, 1).transpose(1, 2) # (batch_size, 3, num_voxels) 269 | world_xyz = grid2world[:, :3, :3].bmm(world_xyz) + grid2world[:, :3, 3:4] 270 | 271 | # We process all samples simultaneously 272 | intrinsics = collapse_batch(intrinsics) 273 | poses_world2cam = collapse_batch(poses_world2cam) 274 | features = collapse_batch(features) 275 | grid2world = grid2world.repeat_interleave(num_images, dim=0) 276 | grid_xyz = ( 277 | grid_xyz[None].repeat(batch_size * num_images, 1, 1).transpose(1, 2) 278 | ) # (batch_size * num_images, 3, num_voxels) 279 | 280 | # Project voxels to screen/ndc space. 281 | grid2cam = poses_world2cam.bmm(grid2world) 282 | sampler = project_batch(grid_xyz, intrinsics, grid2cam) 283 | sampler = sampler.permute(0, 2, 1) 284 | sampler = screen_to_ndc(sampler, height, width) 285 | 286 | # Mark valid pixels. 287 | valid_pixels = ( 288 | (sampler[..., 0:1] >= -1) 289 | & (sampler[..., 0:1] <= 1) 290 | & (sampler[..., 1:2] >= -1) 291 | & (sampler[..., 1:2] <= 1) 292 | & (sampler[..., 2:3] > 0) 293 | ) 294 | valid_pixels = valid_pixels.repeat(1, 1, 3) 295 | sampler[~valid_pixels] = -10 296 | sampler[~valid_pixels] = -10 297 | 298 | # Interpolate features. 299 | def make_query_pixels(): 300 | query_pixels = torch.stack([sampler[..., 0], sampler[..., 1]], dim=-1) 301 | query_pixels = query_pixels.view(batch_size * num_images, 1, num_voxels, 2).contiguous() 302 | return query_pixels 303 | 304 | # Mark valid pixels based on GT depth 305 | if depth is not None: 306 | depth = collapse_batch(depth) 307 | queried_depth = torch.nn.functional.grid_sample( 308 | depth.view(batch_size * num_images, 1, height, width), 309 | make_query_pixels(), 310 | mode="nearest", 311 | padding_mode="zeros", 312 | align_corners=False, 313 | ).squeeze() 314 | diff = (queried_depth - sampler[..., 2]) ** 2 315 | invalid_depth_pixels = (diff < depth_threshold)[..., None].repeat(1, 1, 3) 316 | valid_pixels &= invalid_depth_pixels 317 | sampler[~valid_pixels] = -10 318 | sampler[~valid_pixels] = -10 319 | 320 | # Sample features at computed pixel locations. 321 | query_pixels = make_query_pixels() 322 | queried_features = torch.nn.functional.grid_sample( 323 | features, 324 | query_pixels, 325 | mode=feature_sampling_mode, 326 | padding_mode="zeros", 327 | align_corners=False, 328 | ) 329 | queried_features = queried_features.view(batch_size * num_images, feature_dim, num_voxels).contiguous() 330 | 331 | # Set invalid values. 332 | valid_pixels = valid_pixels[..., 0] 333 | queried_weights = valid_pixels.float() 334 | queried_features = queried_features.permute(1, 0, 2) 335 | queried_features[:, ~valid_pixels] = 0 336 | 337 | # Unflatten to xyz grid_dim shape. 338 | queried_weights = queried_weights.view(batch_size * num_images, grid_dim, grid_dim, grid_dim).contiguous() 339 | queried_weights = queried_weights.unsqueeze(1) 340 | 341 | voxel_depth = sampler[..., 2].view(batch_size * num_images, grid_dim, grid_dim, grid_dim).contiguous() 342 | voxel_depth = voxel_depth.unsqueeze(1) 343 | 344 | queried_features = queried_features.view(feature_dim, batch_size * num_images, grid_dim, grid_dim, grid_dim).contiguous() 345 | queried_features = queried_features.transpose(0, 1) 346 | 347 | # Expand to batch_size, num_images. 348 | queried_weights = expand_batch(queried_weights, num_images) 349 | voxel_depth = expand_batch(voxel_depth, num_images) 350 | queried_features = expand_batch(queried_features, num_images) 351 | 352 | # mask out queried_features for those voxels that are invalid in the contraction 353 | if contract_background: 354 | queried_features[:, :, :, invalid_contract_voxels_mask] = 0 355 | 356 | return queried_features, queried_weights, world_xyz, voxel_depth 357 | 358 | 359 | def sparsify_cost_volume(features: torch.Tensor, weights: torch.Tensor, points: torch.Tensor): 360 | points = points.reshape( 361 | points.shape[0], points.shape[1], *features.shape[2:] 362 | ) # (batch_size, 3, grid_dim, grid_dim, grid_dim) 363 | 364 | feature_list = [] 365 | points_list = [] 366 | remaining_points = 0 367 | for i in range(features.shape[0]): 368 | # sparsify features 369 | m = weights[i].bool().squeeze() # (grid_dim, grid_dim, grid_dim) 370 | feature_list.append(features[i, :, m]) # (C, P) 371 | 372 | # sparsify world_points 373 | sparse_points = points[i, :, m] 374 | points_list.append(sparse_points) # (3, P) 375 | remaining_points += sparse_points.shape[1] 376 | 377 | return feature_list, points_list, remaining_points 378 | 379 | 380 | def mean_aggregate_cost_volumes( 381 | features: torch.Tensor, 382 | weights: torch.Tensor, 383 | ) -> torch.Tensor: 384 | """Aggregates multiple per-frame feature_grids into a single feature_grid using the specified aggregation function. 385 | 386 | Args: 387 | features (torch.Tensor): tensor of shape (batch_size, num_images, feature_dim, grid_dim, grid_dim, grid_dim) 388 | weights (torch.Tensor): tensor of shape (batch_size, num_images, 1, grid_dim, grid_dim, grid_dim) 389 | 390 | Returns: 391 | torch.Tensor: the aggregated feature_grid of shape (batch_size, feature_dim, grid_dim, grid_dim, grid_dim) 392 | """ 393 | features = features.sum(dim=1) # (batch_size, C, grid_dim, grid_dim, grid_dim) 394 | weights = weights.sum(dim=1) # (batch_size, 1, grid_dim, grid_dim, grid_dim) 395 | features = features / (weights + 1e-8) 396 | return features, weights 397 | 398 | 399 | def get_rays_for_view(world2cam: torch.Tensor, K: torch.Tensor, canonical_rays: torch.Tensor): 400 | rays = canonical_rays[None].repeat(world2cam.shape[0], 1, 1) # (batch_size, 3, h*w) 401 | 402 | R_inv = world2cam[..., :3, :3].transpose(1, 2) # (batch_size, 3, 3) 403 | t_inv = -R_inv @ world2cam[..., :3, 3:4] # (batch_size, 3, 1) 404 | 405 | rays = K.inverse().bmm(rays) 406 | rays = R_inv.bmm(rays) + t_inv 407 | centers = t_inv.expand(rays.shape) # (batch_size, 3, h*w) 408 | rays = rays - centers 409 | 410 | rays = rays.permute(0, 2, 1) # (batch_size, h*w, 3) 411 | centers = centers.permute(0, 2, 1) # (batch_size, h*w, 3) 412 | 413 | return rays, centers 414 | 415 | 416 | def ray_aabb_intersection( 417 | ray_d: torch.Tensor, ray_o: torch.Tensor, bbox: torch.Tensor 418 | ) -> Tuple[torch.Tensor, torch.Tensor]: 419 | # ray_d, ray_o: (batch_size, h*w, 3) 420 | # bbox: (batch_size, 2, 3) 421 | # t_min, t_max: (batch_size, h*w,) 422 | vec = torch.where(ray_d == 0, torch.full_like(ray_d, 1e-6), ray_d) 423 | bbox_min = bbox[:, 0][:, None] 424 | bbox_max = bbox[:, 1][:, None] 425 | rate_a = (bbox_min - ray_o) / vec 426 | rate_b = (bbox_max - ray_o) / vec 427 | t_min = torch.minimum(rate_a, rate_b).amax(-1).clamp(min=-1e5, max=1e5) 428 | t_max = torch.maximum(rate_a, rate_b).amin(-1).clamp(min=-1e5, max=1e5) 429 | 430 | return t_min, t_max 431 | 432 | 433 | @torch.autocast("cuda", enabled=False) 434 | def get_rays_in_unit_cube( 435 | bbox: torch.Tensor, 436 | pose: torch.Tensor, 437 | K: torch.Tensor, 438 | canonical_rays: torch.Tensor, 439 | use_ray_aabb: bool = True, 440 | default_near: float = 0.5, 441 | default_far: float = 5.0, 442 | ): 443 | """Always do it in high-precision, e.g. disable fp16 here. 444 | 445 | Args: 446 | bbox (torch.Tensor): _description_ 447 | pose (torch.Tensor): _description_ 448 | K (torch.Tensor): _description_ 449 | canonical_rays (torch.Tensor): _description_ 450 | default_near (float, optional): _description_. Defaults to 0.5. 451 | default_far (float, optional): _description_. Defaults to 5.0. 452 | """ 453 | num_images = pose.shape[1] 454 | 455 | # our volume-renderer assumes that the world-space is [-1, 1]^3 and voxel-centers are at half-grid coordinates 456 | # our build_cost_volume() returns voxels within the bbox (sampled such that bbox-extrema are the boundaries of the voxels, e.g. voxel-centers are at half-grid coordinates) 457 | # we need to scale the bbox and poses to lie in (-1, 1). we can then re-use the voxel-features as they are! 458 | bbox, scale = scale_bbox(bbox) 459 | scale_pose = scale[:, None].repeat(1, num_images).view(-1).contiguous() 460 | pose = scale_camera_center(collapse_batch(pose), scale_pose) 461 | bbox = bbox.unsqueeze(1).repeat(1, num_images, 1, 1) 462 | bbox = collapse_batch(bbox) 463 | 464 | # get rays 465 | rays, centers = get_rays_for_view(pose, collapse_batch(K), canonical_rays) 466 | 467 | # get near/far sampling points along the ray 468 | if use_ray_aabb: 469 | near_t, far_t = ray_aabb_intersection(rays, centers, bbox) 470 | mask_out_of_bbox = far_t <= near_t 471 | near_t[mask_out_of_bbox] = default_near 472 | far_t[mask_out_of_bbox] = default_far 473 | else: 474 | R = pose[:, :3, :3] 475 | T = pose[:, :3, 3:4] 476 | camera_center = -R.transpose(-2, -1).bmm(T) 477 | dist_camera_center_origin = torch.linalg.vector_norm(camera_center, dim=1) 478 | near_t = (dist_camera_center_origin - 4.0).clamp_min(0.0).repeat(1, rays.shape[1]) 479 | far_t = (dist_camera_center_origin + 4.0).repeat(1, rays.shape[1]) 480 | 481 | rays = expand_batch(rays, num_images) 482 | centers = expand_batch(centers, num_images) 483 | near_t = expand_batch(near_t, num_images) 484 | far_t = expand_batch(far_t, num_images) 485 | 486 | return rays, centers, near_t, far_t, scale 487 | -------------------------------------------------------------------------------- /viewdiff/model/custom_attention_processor.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This file is partially based on the diffusers library, which licensed the code under the following license: 3 | 4 | # Copyright 2023 The HuggingFace Team. All rights reserved. 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | import random 18 | 19 | import torch 20 | import torch.nn.functional as F 21 | 22 | from diffusers.models.attention import Attention 23 | from diffusers.models.attention_processor import LoRALinearLayer 24 | 25 | 26 | def expand_batch(x: torch.Tensor, frames_per_batch: int) -> torch.Tensor: 27 | n = x.shape[0] 28 | other_dims = x.shape[1:] 29 | return x.reshape(n // frames_per_batch, frames_per_batch, *other_dims) 30 | 31 | 32 | def collapse_batch(x: torch.Tensor) -> torch.Tensor: 33 | n, k = x.shape[:2] 34 | other_dims = x.shape[2:] 35 | return x.reshape(n * k, *other_dims) 36 | 37 | 38 | class CrossFrameAttentionProcessor2_0(torch.nn.Module): 39 | """ 40 | Processor for implementing scaled dot-product attention between multiple images within each batch (enabled by default if you're using PyTorch 2.0). 41 | """ 42 | 43 | def __init__( 44 | self, 45 | n_input_images: int = 5, 46 | to_k_other_frames: int = 0, 47 | with_self_attention: bool = True, 48 | random_others: bool = False, 49 | use_lora_in_cfa: bool = False, 50 | use_temb_in_lora: bool = False, 51 | temb_size: int = 1280, 52 | temb_out_size: int = 8, 53 | hidden_size: int = 320, 54 | pose_cond_dim=8, 55 | rank=4, 56 | network_alpha=None, 57 | ): 58 | """ 59 | Args: 60 | n_input_images (int, optional): How many images are in one batch. Defaults to 5. 61 | to_k_other_frames (int, optional): How many of the other images in a batch to use as key/value. Defaults to 0. 62 | with_self_attention (bool, optional): If the key/value of the query image should be appended. Defaults to True. 63 | random_others (bool, optional): If True, will select the k_other_frames randomly, otherwise sequentially. Defaults to False. 64 | """ 65 | super().__init__() 66 | if not hasattr(F, "scaled_dot_product_attention"): 67 | raise ImportError( 68 | "CrossFrameAttentionProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." 69 | ) 70 | 71 | if not 0 <= to_k_other_frames + with_self_attention <= n_input_images: 72 | raise ValueError( 73 | f"Need 0 <= to_k_other_frames + with_self_attention <= n_input_images, but got: to_k_other_frames={to_k_other_frames}, with_self_attention={with_self_attention}, n_input_images={n_input_images}" 74 | ) 75 | 76 | self.set_config(n_input_images, to_k_other_frames, with_self_attention, random_others) 77 | 78 | self.set_save_attention_matrix(False) 79 | 80 | # init lora-specific layers 81 | self.use_lora_in_cfa = use_lora_in_cfa 82 | self.use_temb_in_lora = use_temb_in_lora 83 | self.hidden_size = hidden_size 84 | self.pose_cond_dim = pose_cond_dim 85 | self.rank = rank 86 | self.network_alpha = network_alpha 87 | 88 | if use_lora_in_cfa: 89 | lora_in_size = hidden_size + pose_cond_dim 90 | if use_temb_in_lora: 91 | self.temb_proj = torch.nn.Sequential( 92 | torch.nn.Linear(temb_size, temb_size // 2), 93 | torch.nn.ELU(inplace=True), 94 | torch.nn.Linear(temb_size // 2, temb_out_size), 95 | torch.nn.ELU(inplace=True), 96 | ) 97 | lora_in_size += temb_out_size 98 | 99 | self.to_q_lora = LoRALinearLayer(lora_in_size, hidden_size, rank, network_alpha) 100 | self.to_k_lora = LoRALinearLayer(lora_in_size, hidden_size, rank, network_alpha) 101 | self.to_v_lora = LoRALinearLayer(lora_in_size, hidden_size, rank, network_alpha) 102 | self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) 103 | 104 | def set_config(self, 105 | n_input_images: int, 106 | to_k_other_frames: int = 0, 107 | with_self_attention: bool = True, 108 | random_others: bool = False, 109 | ): 110 | self.n = n_input_images 111 | 112 | # [[1, 2, 3, 4], [0, 2, 3, 4], [0, 1, 3, 4], [0, 1, 2, 4], [0, 1, 2, 3]] 113 | self.ids = [[k for k in range(self.n) if k != i] for i in range(self.n)] 114 | for i in range(self.n): 115 | if random_others: 116 | random.shuffle(self.ids[i]) 117 | self.ids[i] = self.ids[i][:to_k_other_frames] 118 | self.ids = torch.tensor(self.ids, dtype=torch.long) 119 | 120 | if with_self_attention: 121 | # [0, 1, 2, 3, 4] 122 | self_attention_ids = torch.tensor([i for i in range(self.n)], dtype=torch.long) 123 | 124 | # [[0, 1, 2, 3, 4], [1, 0, 2, 3, 4], [2, 0, 1, 3, 4], [3, 0, 1, 2, 4], [4, 0, 1, 2, 3]] 125 | self.ids = torch.cat([self_attention_ids[..., None], self.ids], dim=-1) 126 | 127 | self.k = self.ids.shape[1] # how many of the frames in a batch to attend to. 1 < k <= n 128 | 129 | def set_save_attention_matrix(self, save: bool, on_cpu: bool = False, only_uncond: bool = True): 130 | self.do_save_attention_matrix = save 131 | self.save_attention_matrix_on_cpu = on_cpu 132 | self.save_only_uncond_batch = only_uncond 133 | 134 | @torch.no_grad() 135 | def save_attention_matrix( 136 | self, attn: Attention, query: torch.Tensor, key: torch.Tensor, attention_mask: torch.Tensor 137 | ): 138 | if self.do_save_attention_matrix: 139 | if self.save_only_uncond_batch: 140 | batches_of_frames = query.shape[0] // self.n 141 | assert ( 142 | batches_of_frames == 2 143 | ), "only support save_attention_matrix for batch_size=1 with classifier-free-guidance" 144 | query = query.reshape(batches_of_frames, self.n, *query.shape[1:]) 145 | key = key.reshape(batches_of_frames, self.n, *key.shape[1:]) 146 | # Filter out unconditional TODO FIXME or should we use [0]? 147 | query = query[1] 148 | key = key[1] 149 | 150 | N, heads = query.shape[:2] 151 | query = query.reshape(N * heads, *query.shape[2:]) 152 | key = key.reshape(N * heads, *key.shape[2:]) 153 | 154 | if self.save_attention_matrix_on_cpu: 155 | query = query.cpu() 156 | key = key.cpu() 157 | attention_mask = attention_mask.cpu() if attention_mask is not None else None 158 | 159 | self.attention_probs = attn.get_attention_scores(query, key, attention_mask) 160 | self.attention_probs = self.attention_probs.reshape( 161 | N, heads, *self.attention_probs.shape[1:] 162 | ) # (N, attn.heads, query_dim, key_dim) 163 | 164 | def __call__( 165 | self, 166 | attn: Attention, 167 | hidden_states, 168 | encoder_hidden_states=None, 169 | attention_mask=None, 170 | temb=None, 171 | scale=1.0, 172 | pose_cond=None, 173 | ): 174 | residual = hidden_states 175 | 176 | if attn.spatial_norm is not None: 177 | hidden_states = attn.spatial_norm(hidden_states, temb) 178 | 179 | input_ndim = hidden_states.ndim 180 | 181 | if input_ndim == 4: 182 | batch_size, channel, height, width = hidden_states.shape 183 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) 184 | 185 | batch_size, sequence_length, _ = ( 186 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape 187 | ) 188 | 189 | if attention_mask is not None: 190 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) 191 | # scaled_dot_product_attention expects attention_mask shape to be 192 | # (batch, heads, source_length, target_length) 193 | attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) 194 | 195 | if attn.group_norm is not None: 196 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) 197 | 198 | # prepare encoder_hidden_states 199 | is_self_attention = encoder_hidden_states is None 200 | use_lora_in_cfa = self.use_lora_in_cfa and pose_cond is not None and is_self_attention 201 | if is_self_attention: 202 | encoder_hidden_states = hidden_states 203 | elif attn.norm_cross: 204 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) 205 | 206 | if use_lora_in_cfa: 207 | # prepare pose_cond for lora 208 | pose_cond = ( 209 | pose_cond[:, None, :] 210 | .repeat(1, hidden_states.shape[1], 1) 211 | .to(device=hidden_states.device, dtype=hidden_states.dtype) 212 | ) 213 | lora_cond = [hidden_states, pose_cond] 214 | 215 | # prepare temb for lora 216 | if self.use_temb_in_lora: 217 | temb = self.temb_proj(temb) 218 | temb = ( 219 | temb[:, None, :] 220 | .repeat(1, hidden_states.shape[1], 1) 221 | .to(device=hidden_states.device, dtype=hidden_states.dtype) 222 | ) 223 | lora_cond.append(temb) 224 | 225 | # construct final lora_cond tensor 226 | lora_cond = torch.cat(lora_cond, dim=-1) 227 | 228 | # encode with lora -- encoder_hidden_states is the same as hidden_states 229 | query = attn.to_q(hidden_states) + scale * self.to_q_lora(lora_cond) 230 | key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(lora_cond) 231 | value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(lora_cond) 232 | else: 233 | # encode without lora 234 | query = attn.to_q(hidden_states) 235 | key = attn.to_k(encoder_hidden_states) 236 | value = attn.to_v(encoder_hidden_states) 237 | 238 | # we want to change key/value only in case of self_attention. cross_attention refers to text-conditioning which should remain unchanged. 239 | if is_self_attention: 240 | # update key/value to contain the values of other frames within the same batch 241 | batches_of_frames = key.size()[0] // self.n 242 | 243 | # rearrange keys to have batch and frames in the 1st and 2nd dims respectively 244 | key = expand_batch(key, self.n) 245 | key = key[:, self.ids] # (batches_of_frames, self.n, self.k, sequence_length, inner_dim) 246 | key = key.view( 247 | batches_of_frames, self.n, -1, key.shape[-1] 248 | ).contiguous() # (batches_of_frames, self.n, self.k * sequence_length, inner_dim) 249 | 250 | # rearrange values to have batch and frames in the 1st and 2nd dims respectively 251 | value = expand_batch(value, self.n) 252 | value = value[:, self.ids] 253 | value = value.view(batches_of_frames, self.n, -1, value.shape[-1]).contiguous() 254 | 255 | # rearrange back to original shape 256 | key = collapse_batch(key) 257 | value = collapse_batch(value) 258 | 259 | inner_dim = key.shape[-1] 260 | head_dim = inner_dim // attn.heads 261 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2).contiguous() 262 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2).contiguous() 263 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2).contiguous() 264 | 265 | # the output of sdp = (batch, num_heads, seq_len, head_dim) 266 | # TODO: add support for attn.scale when we move to Torch 2.1 267 | self.save_attention_matrix(attn, query, key, attention_mask) 268 | hidden_states = F.scaled_dot_product_attention( 269 | query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False 270 | ) 271 | 272 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) 273 | hidden_states = hidden_states.to(query.dtype) 274 | 275 | # linear proj 276 | if use_lora_in_cfa: 277 | hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states) 278 | else: 279 | hidden_states = attn.to_out[0](hidden_states) 280 | # dropout 281 | hidden_states = attn.to_out[1](hidden_states) 282 | 283 | if input_ndim == 4: 284 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) 285 | 286 | if attn.residual_connection: 287 | hidden_states = hidden_states + residual 288 | 289 | hidden_states = hidden_states / attn.rescale_output_factor 290 | 291 | return hidden_states 292 | 293 | 294 | class PoseCondLoRAAttnProcessor2_0(torch.nn.Module): 295 | r""" 296 | Processor for implementing the pose-conditioned LoRA attention mechanism using PyTorch 2.0's memory-efficient scaled dot-product 297 | attention. 298 | Can be used for self-attention and cross-attention as an alternative to Zero-123 style of pose conditioning (which requires finetuning the whole model). 299 | 300 | Args: 301 | hidden_size (`int`): 302 | The hidden size of the attention layer. 303 | cross_attention_dim (`int`, *optional*): 304 | The number of channels in the `encoder_hidden_states`. 305 | pose_cond_dim (`int`, *optional*): 306 | The number of channels in the pose_conditioning. 307 | rank (`int`, defaults to 4): 308 | The dimension of the LoRA update matrices. 309 | network_alpha (`int`, *optional*): 310 | Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs. 311 | """ 312 | 313 | def __init__(self, hidden_size, cross_attention_dim=None, pose_cond_dim=8, rank=4, network_alpha=None): 314 | super().__init__() 315 | if not hasattr(F, "scaled_dot_product_attention"): 316 | raise ImportError( 317 | "PoseCondLoRAAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." 318 | ) 319 | 320 | self.hidden_size = hidden_size 321 | self.cross_attention_dim = cross_attention_dim 322 | self.pose_cond_dim = pose_cond_dim 323 | self.rank = rank 324 | 325 | self.to_q_lora = LoRALinearLayer(hidden_size + pose_cond_dim, hidden_size, rank, network_alpha) 326 | self.to_k_lora = LoRALinearLayer( 327 | (cross_attention_dim or hidden_size) + pose_cond_dim, hidden_size, rank, network_alpha 328 | ) 329 | self.to_v_lora = LoRALinearLayer( 330 | (cross_attention_dim or hidden_size) + pose_cond_dim, hidden_size, rank, network_alpha 331 | ) 332 | self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) 333 | 334 | def __call__( 335 | self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None, scale=1.0, pose_cond=None 336 | ): 337 | if pose_cond is None: 338 | raise ValueError("pose_cond cannot be None") 339 | 340 | residual = hidden_states 341 | 342 | if attn.spatial_norm is not None: 343 | hidden_states = attn.spatial_norm(hidden_states, temb) 344 | 345 | input_ndim = hidden_states.ndim 346 | 347 | if input_ndim == 4: 348 | batch_size, channel, height, width = hidden_states.shape 349 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) 350 | 351 | batch_size, sequence_length, _ = ( 352 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape 353 | ) 354 | 355 | if attention_mask is not None: 356 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) 357 | # scaled_dot_product_attention expects attention_mask shape to be 358 | # (batch, heads, source_length, target_length) 359 | attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) 360 | 361 | if attn.group_norm is not None: 362 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) 363 | 364 | # prepare pose_cond for query 365 | q_pose_cond = ( 366 | pose_cond[:, None, :] 367 | .repeat(1, hidden_states.shape[1], 1) 368 | .to(device=hidden_states.device, dtype=hidden_states.dtype) 369 | ) 370 | q_lora_cond = torch.cat([hidden_states, q_pose_cond], dim=-1) 371 | 372 | query = attn.to_q(hidden_states) + scale * self.to_q_lora(q_lora_cond) 373 | 374 | if encoder_hidden_states is None: 375 | encoder_hidden_states = hidden_states 376 | elif attn.norm_cross: 377 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) 378 | 379 | # prepare pose_cond for key/value 380 | kv_pose_cond = ( 381 | pose_cond[:, None, :] 382 | .repeat(1, encoder_hidden_states.shape[1], 1) 383 | .to(device=encoder_hidden_states.device, dtype=encoder_hidden_states.dtype) 384 | ) 385 | kv_lora_cond = torch.cat([encoder_hidden_states, kv_pose_cond], dim=-1) 386 | 387 | key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(kv_lora_cond) 388 | value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(kv_lora_cond) 389 | 390 | inner_dim = key.shape[-1] 391 | head_dim = inner_dim // attn.heads 392 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2).contiguous() 393 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2).contiguous() 394 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2).contiguous() 395 | 396 | # TODO: add support for attn.scale when we move to Torch 2.1 397 | hidden_states = F.scaled_dot_product_attention( 398 | query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False 399 | ) 400 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) 401 | hidden_states = hidden_states.to(query.dtype) 402 | 403 | # linear proj 404 | hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states) 405 | # dropout 406 | hidden_states = attn.to_out[1](hidden_states) 407 | 408 | if input_ndim == 4: 409 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) 410 | 411 | if attn.residual_connection: 412 | hidden_states = hidden_states + residual 413 | 414 | hidden_states = hidden_states / attn.rescale_output_factor 415 | 416 | return hidden_states 417 | 418 | 419 | class CustomAttnProcessor2_0: 420 | r""" 421 | Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). 422 | """ 423 | 424 | def __init__(self): 425 | if not hasattr(F, "scaled_dot_product_attention"): 426 | raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") 427 | 428 | def __call__( 429 | self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None, **kwargs 430 | ): 431 | residual = hidden_states 432 | 433 | if attn.spatial_norm is not None: 434 | hidden_states = attn.spatial_norm(hidden_states, temb) 435 | 436 | input_ndim = hidden_states.ndim 437 | 438 | if input_ndim == 4: 439 | batch_size, channel, height, width = hidden_states.shape 440 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) 441 | 442 | batch_size, sequence_length, _ = ( 443 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape 444 | ) 445 | 446 | if attention_mask is not None: 447 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) 448 | # scaled_dot_product_attention expects attention_mask shape to be 449 | # (batch, heads, source_length, target_length) 450 | attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) 451 | 452 | if attn.group_norm is not None: 453 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) 454 | 455 | query = attn.to_q(hidden_states) 456 | 457 | if encoder_hidden_states is None: 458 | encoder_hidden_states = hidden_states 459 | elif attn.norm_cross: 460 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) 461 | 462 | key = attn.to_k(encoder_hidden_states) 463 | value = attn.to_v(encoder_hidden_states) 464 | 465 | inner_dim = key.shape[-1] 466 | head_dim = inner_dim // attn.heads 467 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2).contiguous() 468 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2).contiguous() 469 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2).contiguous() 470 | 471 | # the output of sdp = (batch, num_heads, seq_len, head_dim) 472 | # TODO: add support for attn.scale when we move to Torch 2.1 473 | hidden_states = F.scaled_dot_product_attention( 474 | query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False 475 | ) 476 | 477 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) 478 | hidden_states = hidden_states.to(query.dtype) 479 | 480 | # linear proj 481 | hidden_states = attn.to_out[0](hidden_states) 482 | # dropout 483 | hidden_states = attn.to_out[1](hidden_states) 484 | 485 | if input_ndim == 4: 486 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) 487 | 488 | if attn.residual_connection: 489 | hidden_states = hidden_states + residual 490 | 491 | hidden_states = hidden_states / attn.rescale_output_factor 492 | 493 | return hidden_states 494 | --------------------------------------------------------------------------------