├── .gitignore ├── README.md ├── app.py ├── assets ├── examples │ ├── blue_cat.png │ ├── bulldog.png │ ├── ceramic.png │ ├── chair_watermelon.png │ ├── cup_rgba.png │ ├── cute_horse.jpg │ ├── earphone.jpg │ ├── fox.jpg │ ├── fruit_elephant.jpg │ ├── hatsune_miku.png │ ├── ikun_rgba.png │ ├── mailbox.png │ ├── mario.png │ ├── mei_ling_panda.png │ ├── mushroom_teapot.jpg │ ├── pikachu.png │ ├── potplant_rgba.png │ ├── seed_frog.png │ ├── shuai_panda_notail.png │ └── yellow_duck.png ├── hdri │ ├── golden_bay_1k.hdr │ ├── gym_entrance_1k.hdr │ ├── metro_noord_1k.hdr │ └── spaichingen_hill_1k.hdr ├── representaion_evaluation.jpg ├── teaser.webp └── valid_fitting_2048.txt ├── configs ├── inference_dit.yml ├── inference_dit_text.yml ├── train_dit.yml ├── train_dit_text.yml ├── train_fitting.yml └── train_vae.yml ├── datasets ├── prim_volume.py ├── prim_volume_caption.py └── sample_glb.py ├── dva ├── __init__.py ├── attr_dict.py ├── geom.py ├── io.py ├── layers.py ├── losses.py ├── mvp │ ├── extensions │ │ ├── mvpraymarch │ │ │ ├── bvh.cu │ │ │ ├── cudadispatch.h │ │ │ ├── helper_math.h │ │ │ ├── makefile │ │ │ ├── mvpraymarch.cpp │ │ │ ├── mvpraymarch.py │ │ │ ├── mvpraymarch_kernel.cu │ │ │ ├── mvpraymarch_subset_kernel.h │ │ │ ├── primaccum.h │ │ │ ├── primsampler.h │ │ │ ├── primtransf.h │ │ │ ├── setup.py │ │ │ └── utils.h │ │ └── utils │ │ │ ├── helper_math.h │ │ │ ├── makefile │ │ │ ├── setup.py │ │ │ ├── utils.cpp │ │ │ ├── utils.py │ │ │ └── utils_kernel.cu │ └── models │ │ ├── bg │ │ ├── lap.py │ │ └── mlp.py │ │ ├── colorcals │ │ └── colorcal.py │ │ ├── decoders │ │ ├── mvp.py │ │ └── nv.py │ │ ├── encoders │ │ ├── geotex.py │ │ └── image.py │ │ ├── raymarchers │ │ ├── mvpraymarcher.py │ │ └── stepraymarcher.py │ │ ├── utils.py │ │ └── volumetric.py ├── ray_marcher.py ├── scheduler.py ├── utils.py ├── vgg.py └── visualize.py ├── inference.py ├── install.sh ├── models ├── __init__.py ├── attention.py ├── conditioner │ ├── dinov2 │ │ ├── __init__.py │ │ ├── hub │ │ │ ├── __init__.py │ │ │ ├── backbones.py │ │ │ ├── classifiers.py │ │ │ ├── depth │ │ │ │ ├── __init__.py │ │ │ │ ├── decode_heads.py │ │ │ │ ├── encoder_decoder.py │ │ │ │ └── ops.py │ │ │ ├── depthers.py │ │ │ └── utils.py │ │ ├── layers │ │ │ ├── __init__.py │ │ │ ├── attention.py │ │ │ ├── block.py │ │ │ ├── dino_head.py │ │ │ ├── drop_path.py │ │ │ ├── layer_scale.py │ │ │ ├── mlp.py │ │ │ ├── patch_embed.py │ │ │ └── swiglu_ffn.py │ │ └── models │ │ │ ├── __init__.py │ │ │ └── vision_transformer.py │ ├── image.py │ ├── image_dinov2.py │ └── text.py ├── diffusion │ ├── __init__.py │ ├── diffusion_utils.py │ ├── gaussian_diffusion.py │ ├── respace.py │ └── timestep_sampler.py ├── dit_crossattn.py ├── primsdf.py ├── utils.py └── vae3d_dib.py ├── requirements.txt ├── scripts ├── cache_conditioner.py └── cache_vae.py ├── simple-knn ├── ext.cpp ├── setup.py ├── simple_knn.cu ├── simple_knn.h ├── simple_knn │ └── .gitkeep ├── spatial.cu └── spatial.h ├── train_dit.py ├── train_fitting.py ├── train_vae.py └── utils ├── mesh.py ├── meshutils.py ├── op.py ├── typing.py └── uv_unwrap.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | build 3 | *.so 4 | runs -------------------------------------------------------------------------------- /assets/examples/blue_cat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/3DTopia/3DTopia-XL/af602b003e9d137e37e1d883c18b50ed65c06f26/assets/examples/blue_cat.png -------------------------------------------------------------------------------- /assets/examples/bulldog.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/3DTopia/3DTopia-XL/af602b003e9d137e37e1d883c18b50ed65c06f26/assets/examples/bulldog.png -------------------------------------------------------------------------------- /assets/examples/ceramic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/3DTopia/3DTopia-XL/af602b003e9d137e37e1d883c18b50ed65c06f26/assets/examples/ceramic.png -------------------------------------------------------------------------------- /assets/examples/chair_watermelon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/3DTopia/3DTopia-XL/af602b003e9d137e37e1d883c18b50ed65c06f26/assets/examples/chair_watermelon.png -------------------------------------------------------------------------------- /assets/examples/cup_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/3DTopia/3DTopia-XL/af602b003e9d137e37e1d883c18b50ed65c06f26/assets/examples/cup_rgba.png -------------------------------------------------------------------------------- /assets/examples/cute_horse.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/3DTopia/3DTopia-XL/af602b003e9d137e37e1d883c18b50ed65c06f26/assets/examples/cute_horse.jpg -------------------------------------------------------------------------------- /assets/examples/earphone.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/3DTopia/3DTopia-XL/af602b003e9d137e37e1d883c18b50ed65c06f26/assets/examples/earphone.jpg -------------------------------------------------------------------------------- /assets/examples/fox.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/3DTopia/3DTopia-XL/af602b003e9d137e37e1d883c18b50ed65c06f26/assets/examples/fox.jpg -------------------------------------------------------------------------------- /assets/examples/fruit_elephant.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/3DTopia/3DTopia-XL/af602b003e9d137e37e1d883c18b50ed65c06f26/assets/examples/fruit_elephant.jpg -------------------------------------------------------------------------------- /assets/examples/hatsune_miku.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/3DTopia/3DTopia-XL/af602b003e9d137e37e1d883c18b50ed65c06f26/assets/examples/hatsune_miku.png -------------------------------------------------------------------------------- /assets/examples/ikun_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/3DTopia/3DTopia-XL/af602b003e9d137e37e1d883c18b50ed65c06f26/assets/examples/ikun_rgba.png -------------------------------------------------------------------------------- /assets/examples/mailbox.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/3DTopia/3DTopia-XL/af602b003e9d137e37e1d883c18b50ed65c06f26/assets/examples/mailbox.png -------------------------------------------------------------------------------- /assets/examples/mario.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/3DTopia/3DTopia-XL/af602b003e9d137e37e1d883c18b50ed65c06f26/assets/examples/mario.png -------------------------------------------------------------------------------- /assets/examples/mei_ling_panda.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/3DTopia/3DTopia-XL/af602b003e9d137e37e1d883c18b50ed65c06f26/assets/examples/mei_ling_panda.png -------------------------------------------------------------------------------- /assets/examples/mushroom_teapot.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/3DTopia/3DTopia-XL/af602b003e9d137e37e1d883c18b50ed65c06f26/assets/examples/mushroom_teapot.jpg -------------------------------------------------------------------------------- /assets/examples/pikachu.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/3DTopia/3DTopia-XL/af602b003e9d137e37e1d883c18b50ed65c06f26/assets/examples/pikachu.png -------------------------------------------------------------------------------- /assets/examples/potplant_rgba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/3DTopia/3DTopia-XL/af602b003e9d137e37e1d883c18b50ed65c06f26/assets/examples/potplant_rgba.png -------------------------------------------------------------------------------- /assets/examples/seed_frog.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/3DTopia/3DTopia-XL/af602b003e9d137e37e1d883c18b50ed65c06f26/assets/examples/seed_frog.png -------------------------------------------------------------------------------- /assets/examples/shuai_panda_notail.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/3DTopia/3DTopia-XL/af602b003e9d137e37e1d883c18b50ed65c06f26/assets/examples/shuai_panda_notail.png -------------------------------------------------------------------------------- /assets/examples/yellow_duck.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/3DTopia/3DTopia-XL/af602b003e9d137e37e1d883c18b50ed65c06f26/assets/examples/yellow_duck.png -------------------------------------------------------------------------------- /assets/hdri/golden_bay_1k.hdr: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/3DTopia/3DTopia-XL/af602b003e9d137e37e1d883c18b50ed65c06f26/assets/hdri/golden_bay_1k.hdr -------------------------------------------------------------------------------- /assets/hdri/gym_entrance_1k.hdr: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/3DTopia/3DTopia-XL/af602b003e9d137e37e1d883c18b50ed65c06f26/assets/hdri/gym_entrance_1k.hdr -------------------------------------------------------------------------------- /assets/hdri/metro_noord_1k.hdr: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/3DTopia/3DTopia-XL/af602b003e9d137e37e1d883c18b50ed65c06f26/assets/hdri/metro_noord_1k.hdr -------------------------------------------------------------------------------- /assets/hdri/spaichingen_hill_1k.hdr: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/3DTopia/3DTopia-XL/af602b003e9d137e37e1d883c18b50ed65c06f26/assets/hdri/spaichingen_hill_1k.hdr -------------------------------------------------------------------------------- /assets/representaion_evaluation.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/3DTopia/3DTopia-XL/af602b003e9d137e37e1d883c18b50ed65c06f26/assets/representaion_evaluation.jpg -------------------------------------------------------------------------------- /assets/teaser.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/3DTopia/3DTopia-XL/af602b003e9d137e37e1d883c18b50ed65c06f26/assets/teaser.webp -------------------------------------------------------------------------------- /configs/inference_dit.yml: -------------------------------------------------------------------------------- 1 | debug: False 2 | root_data_dir: ./runs 3 | checkpoint_path: ./pretrained/model_sview_dit_fp16.pt 4 | global_seed: 42 5 | 6 | inference: 7 | input_dir: ./assets/examples 8 | ddim: 25 9 | cfg: 6 10 | seed: ${global_seed} 11 | precision: fp16 12 | export_glb: True 13 | fast_unwrap: False 14 | decimate: 100000 15 | mc_resolution: 256 16 | batch_size: 8192 17 | remesh: False 18 | 19 | image_height: 518 20 | image_width: 518 21 | 22 | model: 23 | class_name: models.primsdf.PrimSDF 24 | num_prims: 2048 25 | dim_feat: 6 26 | prim_shape: 8 27 | init_scale: 0.05 # useless if auto_scale_init == True 28 | sdf2alpha_var: 0.005 29 | auto_scale_init: True 30 | init_sampling: uniform 31 | vae: 32 | class_name: models.vae3d_dib.VAE 33 | in_channels: ${model.dim_feat} 34 | latent_channels: 1 35 | out_channels: ${model.vae.in_channels} 36 | down_channels: [32, 256] 37 | mid_attention: True 38 | up_channels: [256, 32] 39 | layers_per_block: 2 40 | gradient_checkpointing: False 41 | vae_checkpoint_path: ./pretrained/model_vae_fp16.pt 42 | conditioner: 43 | class_name: models.conditioner.image.ImageConditioner 44 | num_prims: ${model.num_prims} 45 | dim_feat: ${model.dim_feat} 46 | prim_shape: ${model.prim_shape} 47 | sample_view: False 48 | encoder_config: 49 | class_name: models.conditioner.image_dinov2.Dinov2Wrapper 50 | model_name: dinov2_vitb14_reg 51 | freeze: True 52 | generator: 53 | class_name: models.dit_crossattn.DiT 54 | seq_length: ${model.num_prims} 55 | in_channels: 68 # equals to model.vae.latent_channels * latent_dim^3 56 | condition_channels: 768 57 | hidden_size: 1152 58 | depth: 28 59 | num_heads: 16 60 | attn_proj_bias: True 61 | cond_drop_prob: 0.1 62 | gradient_checkpointing: False 63 | latent_nf: 1.0 64 | latent_mean: [ 0.0442, -0.0029, -0.0425, -0.0043, -0.4086, -0.2906, -0.7002, -0.0852, -0.4446, -0.6896, -0.7344, -0.3524, -0.5488, -0.4313, -1.1715, -0.0875, -0.6131, -0.3924, -0.7335, -0.3749, 0.4658, -0.0236, 0.8362, 0.3388, 0.0188, 0.5988, -0.1853, 1.1579, 0.6240, 0.0758, 0.9641, 0.6586, 0.6260, 0.2384, 0.7798, 0.8297, -0.6543, -0.4441, -1.3887, -0.0393, -0.9008, -0.8616, -1.7434, -0.1328, -0.8119, -0.8225, -1.8533, -0.0444, -1.0510, -0.5158, -1.1907, -0.5265, 0.2832, 0.6037, 0.5981, 0.5461, 0.4366, 0.4144, 0.7219, 0.5722, 0.5937, 0.5598, 0.9414, 0.7419, 0.2102, 0.3388, 0.4501, 0.5166] 65 | latent_std: [0.0219, 0.3707, 0.3911, 0.3610, 0.7549, 0.7909, 0.9691, 0.9193, 0.8218, 0.9389, 1.1785, 1.0254, 0.6376, 0.6568, 0.7892, 0.8468, 0.8775, 0.7920, 0.9037, 0.9329, 0.9196, 1.1123, 1.3041, 1.0955, 1.2727, 1.6565, 1.8502, 1.7006, 0.8973, 1.0408, 1.2034, 1.2703, 1.0373, 1.0486, 1.0716, 0.9746, 0.7088, 0.8685, 1.0030, 0.9504, 1.0410, 1.3033, 1.5368, 1.4386, 0.6142, 0.6887, 0.9085, 0.9903, 1.0190, 0.9302, 1.0121, 0.9964, 1.1474, 1.2729, 1.4627, 1.1404, 1.3713, 1.6692, 1.8424, 1.5047, 1.1356, 1.2369, 1.3554, 1.1848, 1.1319, 1.0822, 1.1972, 0.9916] 66 | 67 | diffusion: 68 | timestep_respacing: 69 | noise_schedule: squaredcos_cap_v2 70 | diffusion_steps: 1000 71 | parameterization: v 72 | 73 | rm: 74 | volradius: 10000.0 75 | dt: 1.0 76 | 77 | optimizer: 78 | class_name: torch.optim.AdamW 79 | lr: 0.0001 80 | weight_decay: 0 81 | 82 | scheduler: 83 | class_name: dva.scheduler.CosineWarmupScheduler 84 | warmup_iters: 3000 85 | max_iters: 200000 86 | 87 | train: 88 | batch_size: 8 89 | n_workers: 4 90 | n_epochs: 1000 91 | log_every_n_steps: 50 92 | summary_every_n_steps: 10000 93 | ckpt_every_n_steps: 10000 94 | amp: False 95 | precision: tf32 96 | 97 | tag: 3dtopia-xl-sview 98 | output_dir: ${root_data_dir}/inference/${tag} 99 | -------------------------------------------------------------------------------- /configs/inference_dit_text.yml: -------------------------------------------------------------------------------- 1 | debug: False 2 | root_data_dir: ./runs 3 | checkpoint_path: ./pretrained/scaleup_text_ckpt_backup_fp16.pt 4 | global_seed: 42 5 | 6 | inference: 7 | input_dir: ./assets/examples 8 | ddim: 25 9 | cfg: 6 10 | seed: ${global_seed} 11 | precision: fp16 12 | export_glb: True 13 | fast_unwrap: False 14 | decimate: 100000 15 | mc_resolution: 256 16 | batch_size: 8192 17 | remesh: False 18 | 19 | image_height: 518 20 | image_width: 518 21 | 22 | model: 23 | class_name: models.primsdf.PrimSDF 24 | num_prims: 2048 25 | dim_feat: 6 26 | prim_shape: 8 27 | init_scale: 0.05 # useless if auto_scale_init == True 28 | sdf2alpha_var: 0.005 29 | auto_scale_init: True 30 | init_sampling: uniform 31 | vae: 32 | class_name: models.vae3d_dib.VAE 33 | in_channels: ${model.dim_feat} 34 | latent_channels: 1 35 | out_channels: ${model.vae.in_channels} 36 | down_channels: [32, 256] 37 | mid_attention: True 38 | up_channels: [256, 32] 39 | layers_per_block: 2 40 | gradient_checkpointing: False 41 | vae_checkpoint_path: ./pretrained/model_vae_fp16.pt 42 | conditioner: 43 | class_name: models.conditioner.text.TextConditioner 44 | encoder_config: 45 | class_name: models.conditioner.text.CLIPTextEncoder 46 | pretrained_path: ./pretrained/open_clip_pytorch_model.bin 47 | model_spec: ViT-L-14 48 | generator: 49 | class_name: models.dit_crossattn.DiT 50 | seq_length: ${model.num_prims} 51 | in_channels: 68 # equals to model.vae.latent_channels * latent_dim^3 52 | condition_channels: 768 53 | hidden_size: 1152 54 | depth: 28 55 | num_heads: 16 56 | attn_proj_bias: True 57 | cond_drop_prob: 0.1 58 | gradient_checkpointing: False 59 | latent_nf: 1.0 60 | latent_mean: [ 0.0442, -0.0029, -0.0425, -0.0043, -0.4086, -0.2906, -0.7002, -0.0852, -0.4446, -0.6896, -0.7344, -0.3524, -0.5488, -0.4313, -1.1715, -0.0875, -0.6131, -0.3924, -0.7335, -0.3749, 0.4658, -0.0236, 0.8362, 0.3388, 0.0188, 0.5988, -0.1853, 1.1579, 0.6240, 0.0758, 0.9641, 0.6586, 0.6260, 0.2384, 0.7798, 0.8297, -0.6543, -0.4441, -1.3887, -0.0393, -0.9008, -0.8616, -1.7434, -0.1328, -0.8119, -0.8225, -1.8533, -0.0444, -1.0510, -0.5158, -1.1907, -0.5265, 0.2832, 0.6037, 0.5981, 0.5461, 0.4366, 0.4144, 0.7219, 0.5722, 0.5937, 0.5598, 0.9414, 0.7419, 0.2102, 0.3388, 0.4501, 0.5166] 61 | latent_std: [0.0219, 0.3707, 0.3911, 0.3610, 0.7549, 0.7909, 0.9691, 0.9193, 0.8218, 0.9389, 1.1785, 1.0254, 0.6376, 0.6568, 0.7892, 0.8468, 0.8775, 0.7920, 0.9037, 0.9329, 0.9196, 1.1123, 1.3041, 1.0955, 1.2727, 1.6565, 1.8502, 1.7006, 0.8973, 1.0408, 1.2034, 1.2703, 1.0373, 1.0486, 1.0716, 0.9746, 0.7088, 0.8685, 1.0030, 0.9504, 1.0410, 1.3033, 1.5368, 1.4386, 0.6142, 0.6887, 0.9085, 0.9903, 1.0190, 0.9302, 1.0121, 0.9964, 1.1474, 1.2729, 1.4627, 1.1404, 1.3713, 1.6692, 1.8424, 1.5047, 1.1356, 1.2369, 1.3554, 1.1848, 1.1319, 1.0822, 1.1972, 0.9916] 62 | 63 | diffusion: 64 | timestep_respacing: 65 | noise_schedule: squaredcos_cap_v2 66 | diffusion_steps: 1000 67 | parameterization: v 68 | 69 | rm: 70 | volradius: 10000.0 71 | dt: 1.0 72 | 73 | optimizer: 74 | class_name: torch.optim.AdamW 75 | lr: 0.0001 76 | weight_decay: 0 77 | 78 | scheduler: 79 | class_name: dva.scheduler.CosineWarmupScheduler 80 | warmup_iters: 3000 81 | max_iters: 200000 82 | 83 | train: 84 | batch_size: 8 85 | n_workers: 4 86 | n_epochs: 1000 87 | log_every_n_steps: 50 88 | summary_every_n_steps: 10000 89 | ckpt_every_n_steps: 10000 90 | amp: False 91 | precision: tf32 92 | 93 | tag: 3dtopia-xl-text 94 | output_dir: ${root_data_dir}/inference/${tag} 95 | -------------------------------------------------------------------------------- /configs/train_dit.yml: -------------------------------------------------------------------------------- 1 | debug: False 2 | root_data_dir: ./runs 3 | checkpoint_path: 4 | global_seed: 42 5 | 6 | image_height: 518 7 | image_width: 518 8 | 9 | model: 10 | class_name: models.primsdf.PrimSDF 11 | num_prims: 2048 12 | dim_feat: 6 13 | prim_shape: 8 14 | init_scale: 0.05 # useless if auto_scale_init == True 15 | sdf2alpha_var: 0.005 16 | auto_scale_init: True 17 | init_sampling: uniform 18 | vae: 19 | class_name: models.vae3d_dib.VAE 20 | in_channels: ${model.dim_feat} 21 | latent_channels: 1 22 | out_channels: ${model.vae.in_channels} 23 | down_channels: [32, 256] 24 | mid_attention: True 25 | up_channels: [256, 32] 26 | layers_per_block: 2 27 | gradient_checkpointing: False 28 | vae_checkpoint_path: ./pretrained/model_vae_fp16.pt 29 | conditioner: 30 | class_name: models.conditioner.image.DummyImageConditioner 31 | num_prims: ${model.num_prims} 32 | dim_feat: ${model.dim_feat} 33 | prim_shape: ${model.prim_shape} 34 | sample_view: False 35 | encoder_config: 36 | class_name: models.conditioner.image_dinov2.Dinov2Wrapper 37 | model_name: dinov2_vitb14_reg 38 | freeze: True 39 | generator: 40 | class_name: models.dit_crossattn.DiT 41 | seq_length: ${model.num_prims} 42 | in_channels: 68 # equals to model.vae.latent_channels * latent_dim^3 43 | condition_channels: 768 44 | hidden_size: 1152 45 | depth: 28 46 | num_heads: 16 47 | attn_proj_bias: True 48 | cond_drop_prob: 0.1 49 | gradient_checkpointing: False 50 | latent_nf: 1.0 51 | latent_mean: [ 0.0442, -0.0029, -0.0425, -0.0043, -0.4086, -0.2906, -0.7002, -0.0852, -0.4446, -0.6896, -0.7344, -0.3524, -0.5488, -0.4313, -1.1715, -0.0875, -0.6131, -0.3924, -0.7335, -0.3749, 0.4658, -0.0236, 0.8362, 0.3388, 0.0188, 0.5988, -0.1853, 1.1579, 0.6240, 0.0758, 0.9641, 0.6586, 0.6260, 0.2384, 0.7798, 0.8297, -0.6543, -0.4441, -1.3887, -0.0393, -0.9008, -0.8616, -1.7434, -0.1328, -0.8119, -0.8225, -1.8533, -0.0444, -1.0510, -0.5158, -1.1907, -0.5265, 0.2832, 0.6037, 0.5981, 0.5461, 0.4366, 0.4144, 0.7219, 0.5722, 0.5937, 0.5598, 0.9414, 0.7419, 0.2102, 0.3388, 0.4501, 0.5166] 52 | latent_std: [0.0219, 0.3707, 0.3911, 0.3610, 0.7549, 0.7909, 0.9691, 0.9193, 0.8218, 0.9389, 1.1785, 1.0254, 0.6376, 0.6568, 0.7892, 0.8468, 0.8775, 0.7920, 0.9037, 0.9329, 0.9196, 1.1123, 1.3041, 1.0955, 1.2727, 1.6565, 1.8502, 1.7006, 0.8973, 1.0408, 1.2034, 1.2703, 1.0373, 1.0486, 1.0716, 0.9746, 0.7088, 0.8685, 1.0030, 0.9504, 1.0410, 1.3033, 1.5368, 1.4386, 0.6142, 0.6887, 0.9085, 0.9903, 1.0190, 0.9302, 1.0121, 0.9964, 1.1474, 1.2729, 1.4627, 1.1404, 1.3713, 1.6692, 1.8424, 1.5047, 1.1356, 1.2369, 1.3554, 1.1848, 1.1319, 1.0822, 1.1972, 0.9916] 53 | 54 | diffusion: 55 | timestep_respacing: 56 | noise_schedule: squaredcos_cap_v2 57 | diffusion_steps: 1000 58 | parameterization: v 59 | 60 | rm: 61 | volradius: 10000.0 62 | dt: 1.0 63 | 64 | optimizer: 65 | class_name: torch.optim.AdamW 66 | lr: 0.0001 67 | weight_decay: 0 68 | 69 | scheduler: 70 | class_name: dva.scheduler.CosineWarmupScheduler 71 | warmup_iters: 3000 72 | max_iters: 200000 73 | 74 | dataset: 75 | class_name: datasets.prim_volume.AllCacheManifoldDataset 76 | manifold_url_template: ./data/obj-psdf-2048-scaleup-fitting/{folder}{key}.pt 77 | vaecache_url_template: ./data/klvae_2048_scaleup_cache/vae-{folder}{key}.pt 78 | cond_url_template: ./data/obj-2048-518reso-dino-cond/{folder}{key}.pt 79 | obj_name_list_path: ./assets/valid_fitting_2048.txt 80 | num_prims: ${model.num_prims} 81 | dim_feat: ${model.dim_feat} 82 | prim_shape: ${model.prim_shape} 83 | incl_srt: False 84 | 85 | train: 86 | batch_size: 8 87 | n_workers: 4 88 | n_epochs: 1000 89 | log_every_n_steps: 50 90 | summary_every_n_steps: 10000 91 | ckpt_every_n_steps: 10000 92 | amp: False 93 | precision: tf32 94 | 95 | tag: 3dtopia-xl-sview 96 | output_dir: ${root_data_dir}/train/${tag} 97 | -------------------------------------------------------------------------------- /configs/train_dit_text.yml: -------------------------------------------------------------------------------- 1 | debug: False 2 | root_data_dir: ./runs 3 | checkpoint_path: 4 | global_seed: 42 5 | 6 | image_height: 518 7 | image_width: 518 8 | 9 | model: 10 | class_name: models.primsdf.PrimSDF 11 | num_prims: 2048 12 | dim_feat: 6 13 | prim_shape: 8 14 | init_scale: 0.05 # useless if auto_scale_init == True 15 | sdf2alpha_var: 0.005 16 | auto_scale_init: True 17 | init_sampling: uniform 18 | vae: 19 | class_name: models.vae3d_dib.VAE 20 | in_channels: ${model.dim_feat} 21 | latent_channels: 1 22 | out_channels: ${model.vae.in_channels} 23 | down_channels: [32, 256] 24 | mid_attention: True 25 | up_channels: [256, 32] 26 | layers_per_block: 2 27 | gradient_checkpointing: False 28 | vae_checkpoint_path: ./pretrained/model_vae_fp16.pt 29 | conditioner: 30 | class_name: models.conditioner.text.TextConditioner 31 | encoder_config: 32 | class_name: models.conditioner.text.CLIPTextEncoder 33 | pretrained_path: ./pretrained/open_clip_pytorch_model.bin 34 | model_spec: ViT-L-14 35 | generator: 36 | class_name: models.dit_crossattn.DiT 37 | seq_length: ${model.num_prims} 38 | in_channels: 68 # equals to model.vae.latent_channels * latent_dim^3 39 | condition_channels: 768 40 | hidden_size: 1152 41 | depth: 28 42 | num_heads: 16 43 | attn_proj_bias: True 44 | cond_drop_prob: 0.1 45 | gradient_checkpointing: False 46 | latent_nf: 1.0 47 | latent_mean: [ 0.0442, -0.0029, -0.0425, -0.0043, -0.4086, -0.2906, -0.7002, -0.0852, -0.4446, -0.6896, -0.7344, -0.3524, -0.5488, -0.4313, -1.1715, -0.0875, -0.6131, -0.3924, -0.7335, -0.3749, 0.4658, -0.0236, 0.8362, 0.3388, 0.0188, 0.5988, -0.1853, 1.1579, 0.6240, 0.0758, 0.9641, 0.6586, 0.6260, 0.2384, 0.7798, 0.8297, -0.6543, -0.4441, -1.3887, -0.0393, -0.9008, -0.8616, -1.7434, -0.1328, -0.8119, -0.8225, -1.8533, -0.0444, -1.0510, -0.5158, -1.1907, -0.5265, 0.2832, 0.6037, 0.5981, 0.5461, 0.4366, 0.4144, 0.7219, 0.5722, 0.5937, 0.5598, 0.9414, 0.7419, 0.2102, 0.3388, 0.4501, 0.5166] 48 | latent_std: [0.0219, 0.3707, 0.3911, 0.3610, 0.7549, 0.7909, 0.9691, 0.9193, 0.8218, 0.9389, 1.1785, 1.0254, 0.6376, 0.6568, 0.7892, 0.8468, 0.8775, 0.7920, 0.9037, 0.9329, 0.9196, 1.1123, 1.3041, 1.0955, 1.2727, 1.6565, 1.8502, 1.7006, 0.8973, 1.0408, 1.2034, 1.2703, 1.0373, 1.0486, 1.0716, 0.9746, 0.7088, 0.8685, 1.0030, 0.9504, 1.0410, 1.3033, 1.5368, 1.4386, 0.6142, 0.6887, 0.9085, 0.9903, 1.0190, 0.9302, 1.0121, 0.9964, 1.1474, 1.2729, 1.4627, 1.1404, 1.3713, 1.6692, 1.8424, 1.5047, 1.1356, 1.2369, 1.3554, 1.1848, 1.1319, 1.0822, 1.1972, 0.9916] 49 | 50 | diffusion: 51 | timestep_respacing: 52 | noise_schedule: squaredcos_cap_v2 53 | diffusion_steps: 1000 54 | parameterization: v 55 | 56 | rm: 57 | volradius: 10000.0 58 | dt: 1.0 59 | 60 | optimizer: 61 | class_name: torch.optim.AdamW 62 | lr: 0.0001 63 | weight_decay: 0 64 | 65 | scheduler: 66 | class_name: dva.scheduler.CosineWarmupScheduler 67 | warmup_iters: 3000 68 | max_iters: 200000 69 | 70 | dataset: 71 | class_name: datasets.prim_volume.AllCacheManifoldDataset 72 | manifold_url_template: ./data/obj-psdf-2048-scaleup-fitting/{folder}{key}.pt 73 | vaecache_url_template: ./data/klvae_2048_scaleup_cache/vae-{folder}{key}.pt 74 | cond_url_template: ./data/obj-2048-518reso-dino-cond/{folder}{key}.pt 75 | obj_name_list_path: ./assets/valid_fitting_2048.txt 76 | num_prims: ${model.num_prims} 77 | dim_feat: ${model.dim_feat} 78 | prim_shape: ${model.prim_shape} 79 | incl_srt: False 80 | 81 | train: 82 | batch_size: 8 83 | n_workers: 4 84 | n_epochs: 1000 85 | log_every_n_steps: 50 86 | summary_every_n_steps: 10000 87 | ckpt_every_n_steps: 10000 88 | amp: False 89 | precision: tf32 90 | 91 | tag: 3dtopia-xl-text 92 | output_dir: ${root_data_dir}/train/${tag} 93 | -------------------------------------------------------------------------------- /configs/train_fitting.yml: -------------------------------------------------------------------------------- 1 | debug: False 2 | root_data_dir: ./runs 3 | 4 | image_height: 1024 5 | image_width: 1024 6 | 7 | model: 8 | class_name: models.primsdf.PrimSDF 9 | num_prims: 2048 10 | dim_feat: 6 11 | prim_shape: 8 12 | init_scale: 0.05 # useless if auto_scale_init == True 13 | sdf2alpha_var: 0.005 14 | auto_scale_init: True 15 | init_sampling: uniform 16 | 17 | rm: 18 | volradius: 10000.0 19 | dt: 1.0 20 | 21 | optimizer: 22 | class_name: torch.optim.Adam 23 | lr: 0.0001 24 | 25 | loss: 26 | class_name: dva.losses.PrimSDFLoss 27 | weights: 28 | sdf_l1: 10 29 | rgb_l1: 1 30 | mat_l1: 1 31 | shape_opt_steps: ${train.shape_fit_steps} 32 | tex_opt_steps: ${train.tex_fit_steps} 33 | 34 | dataset: 35 | class_name: datasets.sample_glb.SampleSDFTexMatMesh 36 | mesh_file_path: ./data/old_school_drill.glb 37 | num_surface_samples: 300000 38 | num_near_samples: 200000 39 | num_rand_samples: 100000 40 | use_rand_sample: False 41 | sample_std: 0.01 42 | chunk_size: 16000 43 | 44 | train: 45 | batch_size: 1 46 | shape_fit_steps: 1000 47 | tex_fit_steps: 2000 48 | n_workers: 8 49 | n_epochs: 200 50 | n_max_iters: 10000 51 | log_every_n_steps: 1000 52 | summary_every_n_steps: 2000 53 | ckpt_every_n_steps: 2000 54 | gradient_clip_value: 5.0 55 | save_fp16: True 56 | 57 | tag: primx 58 | output_dir: ${root_data_dir}/training-fitting/${tag} 59 | -------------------------------------------------------------------------------- /configs/train_vae.yml: -------------------------------------------------------------------------------- 1 | debug: False 2 | root_data_dir: ./runs 3 | checkpoint_path: 4 | global_seed: 42 5 | 6 | image_height: 1024 7 | image_width: 1024 8 | 9 | model: 10 | class_name: models.primsdf.PrimSDF 11 | num_prims: 2048 12 | dim_feat: 6 13 | prim_shape: 8 14 | init_scale: 0.05 # useless if auto_scale_init == True 15 | sdf2alpha_var: 0.005 16 | auto_scale_init: True 17 | init_sampling: uniform 18 | vae: 19 | class_name: models.vae3d_dib.VAE 20 | in_channels: ${model.dim_feat} 21 | latent_channels: 1 22 | out_channels: ${model.vae.in_channels} 23 | down_channels: [32, 256] 24 | mid_attention: True 25 | up_channels: [256, 32] 26 | layers_per_block: 2 27 | gradient_checkpointing: False 28 | 29 | rm: 30 | volradius: 10000.0 31 | dt: 1.0 32 | 33 | optimizer: 34 | class_name: torch.optim.Adam 35 | lr: 0.0001 36 | 37 | loss: 38 | class_name: dva.losses.VAESepLoss 39 | weights: 40 | sdf: 1 41 | rgb: 1 42 | mat: 1 43 | kl: 0.0005 44 | 45 | dataset: 46 | class_name: datasets.prim_volume.ManifoldDataset 47 | manifold_url_template: ./data/obj-psdf-2048-scaleup-fitting/{folder}{key}.pt 48 | obj_name_list_path: ./assets/valid_fitting_2048.txt 49 | num_prims: ${model.num_prims} 50 | dim_feat: ${model.dim_feat} 51 | prim_shape: ${model.prim_shape} 52 | incl_srt: False 53 | 54 | train: 55 | batch_size: 4 56 | n_workers: 8 57 | n_epochs: 200 58 | log_every_n_steps: 50 59 | summary_every_n_steps: 5000 60 | ckpt_every_n_steps: 5000 61 | amp: True 62 | 63 | tag: vae 64 | output_dir: ${root_data_dir}/train/${tag} 65 | -------------------------------------------------------------------------------- /datasets/prim_volume_caption.py: -------------------------------------------------------------------------------- 1 | import io 2 | import os 3 | import torch 4 | import numpy as np 5 | from torch.utils.data.dataset import Dataset 6 | import open_clip 7 | 8 | 9 | import logging 10 | logger = logging.getLogger(__name__) 11 | 12 | class VAECacheManifoldDataset(Dataset): 13 | def __init__( 14 | self, 15 | manifold_url_template, 16 | vaecache_url_template, 17 | obj_name_list_path, 18 | caption_list_path, 19 | num_prims, 20 | dim_feat, 21 | prim_shape, 22 | lan_model_spec="ViT-L-14", 23 | incl_srt=False, 24 | device='cpu', 25 | **kwargs, 26 | ): 27 | super().__init__() 28 | assert os.path.exists(obj_name_list_path) 29 | assert os.path.exists(caption_list_path) 30 | with open(obj_name_list_path, 'r') as f: 31 | obj_name_list = f.readlines() 32 | with open(caption_list_path, 'r') as f: 33 | caption_list = f.readlines() 34 | self.manifold_url_template = manifold_url_template 35 | self.vaecache_url_template = vaecache_url_template 36 | self.obj_list = obj_name_list 37 | self.caption_list = caption_list 38 | # we assume the order of object is same in obj_list and caption_list 39 | assert len(self.obj_list) == len(self.caption_list), "len(obj_list)={} is not equal to len(caption_list)={}".format(len(self.obj_list), len(self.caption_list)) 40 | self.num_prims = num_prims 41 | self.dim_feat = dim_feat 42 | self.prim_shape = prim_shape 43 | self.incl_srt = incl_srt 44 | self.tokenizer = open_clip.get_tokenizer(lan_model_spec) 45 | self.device = device 46 | 47 | def __len__(self): 48 | return len(self.obj_list) 49 | 50 | def __getitem__(self, index): 51 | sample = {} 52 | obj_meta = self.obj_list[index] 53 | caption_meta = self.caption_list[index] 54 | folder, key = obj_meta[:-1].split("/") 55 | caption_key, caption = caption_meta[:-1].split("@", 1) 56 | assert caption_key == key 57 | sample['folder'] = folder 58 | sample['key'] = key 59 | sample['caption_raw'] = caption 60 | sample['caption_token'] = self.tokenizer([caption])[0] 61 | manifold_obj_path = self.manifold_url_template.format(folder=folder, key=key) 62 | vaecache_path = self.vaecache_url_template.format(folder=folder, key=key) 63 | try: 64 | ckpt = torch.load(manifold_obj_path, map_location=self.device) 65 | weights_dict = ckpt['model_state_dict'] 66 | srt_param = weights_dict['srt_param'] 67 | feat_param = weights_dict['feat_param'] 68 | if torch.isnan(srt_param).any() or torch.isnan(feat_param).any(): 69 | raise ValueError 70 | 71 | vae_ckpt = torch.load(vaecache_path, map_location=self.device) 72 | if torch.isnan(vae_ckpt).any(): 73 | raise ValueError 74 | except: 75 | srt_param = torch.zeros(self.num_prims, 4) 76 | feat_param = torch.zeros(self.num_prims, self.dim_feat * self.prim_shape ** 3) 77 | vae_ckpt = torch.zeros(self.num_prims, 2, 4, 4, 4) 78 | sample['vae_cache'] = vae_ckpt.float() 79 | srt_param = srt_param.float() 80 | feat_param = feat_param.float() 81 | feat_param = feat_param.reshape(self.num_prims, self.dim_feat, self.prim_shape ** 3) 82 | feat_param[:, 1:, :] = torch.clip(feat_param[:, 1:, :], min=0.0, max=1.0) 83 | feat_param = feat_param.reshape(self.num_prims, self.dim_feat * self.prim_shape ** 3) 84 | # sample['srt_param'] = srt_param 85 | # sample['feat_param'] = feat_param 86 | sample['input_param'] = torch.concat([srt_param, feat_param], dim = -1) 87 | 88 | normalized_srt_param = srt_param.clone() 89 | normalized_srt_param[:, 0:1] = (normalized_srt_param[:, 0:1] - 0.05) * 10 # heuristic normalization 90 | normalized_srt_param = normalized_srt_param[..., None, None, None].repeat(1, 1, self.prim_shape, self.prim_shape, self.prim_shape) 91 | 92 | # [nprims, 6, 8, 8, 8] 93 | normalized_feat_param = feat_param.clone().reshape(self.num_prims, self.dim_feat, self.prim_shape, self.prim_shape, self.prim_shape) 94 | # sdf heuristic normalization 95 | normalized_feat_param[:, 0:1, ...] *= 5 96 | # color, mat normalization [0, 1] -> [-1, 1] 97 | normalized_feat_param[:, 1:, ...] = normalized_feat_param[:, 1:, ...] * 2. - 1. 98 | 99 | # [nprims, 10, 8, 8, 8] 100 | if self.incl_srt: 101 | sample['gt'] = torch.concat([normalized_srt_param, normalized_feat_param], dim = 1) 102 | else: 103 | sample['gt'] = normalized_feat_param 104 | return sample -------------------------------------------------------------------------------- /datasets/sample_glb.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import trimesh 3 | import numpy as np 4 | from torch.utils.data.dataset import Dataset 5 | from pytorch3d.structures import Meshes 6 | from pytorch3d.ops import sample_points_from_meshes 7 | 8 | from utils.mesh import Mesh 9 | from utils.meshutils import scale_to_unit_cube, rotation_matrix 10 | from dva.geom import GeometryModule 11 | import cubvh 12 | import os 13 | 14 | import logging 15 | 16 | logger = logging.getLogger(__name__) 17 | 18 | class SampleSDFTexMatMesh(Dataset): 19 | def __init__( 20 | self, 21 | mesh_file_path, 22 | glb_f=None, 23 | num_surface_samples=300000, 24 | num_near_samples=200000, 25 | sample_std=0.01, 26 | chunk_size=1024, 27 | is_train=True, 28 | **kwargs, 29 | ): 30 | super().__init__() 31 | if isinstance(mesh_file_path, str) and os.path.exists(mesh_file_path): 32 | assert mesh_file_path.endswith("glb") 33 | if glb_f is not None: 34 | _data = trimesh.load(glb_f, file_type='glb') 35 | else: 36 | _data = trimesh.load(mesh_file_path) 37 | self.chunk_size = chunk_size 38 | device = "cpu" 39 | # always convert scene to mesh, and apply all transforms... 40 | if isinstance(_data, trimesh.Scene): 41 | # print(f"[INFO] load trimesh: concatenating {len(_data.geometry)} meshes.") 42 | _concat = [] 43 | # loop the scene graph and apply transform to each mesh 44 | scene_graph = _data.graph.to_flattened() # dict {name: {transform: 4x4 mat, geometry: str}} 45 | for k, v in scene_graph.items(): 46 | name = v['geometry'] 47 | if name in _data.geometry and isinstance(_data.geometry[name], trimesh.Trimesh): 48 | transform = v['transform'] 49 | _concat.append(_data.geometry[name].apply_transform(transform)) 50 | # we do not concatenate here 51 | _mesh_list = _concat 52 | else: 53 | _mesh_list = [_data] 54 | 55 | _asset_list = [] 56 | _valid_mesh_list = [] 57 | max_xyz_list = [] 58 | min_xyz_list = [] 59 | sampling_weights = [] 60 | for each_mesh in _mesh_list: 61 | _asset = Mesh.parse_trimesh_data(each_mesh, device=device) 62 | # clean faces have less than 3 connected components 63 | tmp_mesh = trimesh.Trimesh(_asset.v, _asset.f, process=False) 64 | cc = trimesh.graph.connected_components(tmp_mesh.face_adjacency, min_len=3) 65 | if not len(cc) > 0: 66 | if _asset.v.shape[0] > 5: 67 | _valid_mesh_list.append(each_mesh) 68 | _asset_list.append(_asset) 69 | sampling_weights.append(tmp_mesh.area.item()) 70 | max_xyz_list.append(_asset.v.max(0)[0]) 71 | min_xyz_list.append(_asset.v.min(0)[0]) 72 | else: 73 | logger.info(f"Less than 3 connected components found! Drop trimesh element with vertices shape:{tmp_mesh.vertices.shape}") 74 | continue 75 | _valid_mesh_list.append(each_mesh) 76 | cc_mask = np.zeros(len(tmp_mesh.faces), dtype=np.bool_) 77 | cc_mask[np.concatenate(cc)] = True 78 | tmp_mesh.update_faces(cc_mask) 79 | # remove unreferenced vertices, update vertices and texture coordinates accordingly 80 | referenced = np.zeros(len(tmp_mesh.vertices), dtype=bool) 81 | referenced[tmp_mesh.faces] = True 82 | inverse = np.zeros(len(tmp_mesh.vertices), dtype=np.int64) 83 | inverse[referenced] = np.arange(referenced.sum()) 84 | tmp_mesh.update_vertices(mask=referenced, inverse=inverse) 85 | # update texture coordinates 86 | updated_vt = _asset.vt[referenced, :].clone() 87 | _asset.vt = updated_vt 88 | # renormalize vertices to unit cube after outliers removal 89 | _asset.v = torch.from_numpy(tmp_mesh.vertices).float() 90 | _asset.f = torch.from_numpy(tmp_mesh.faces).long() 91 | _asset.ft = torch.from_numpy(tmp_mesh.faces).long() 92 | _asset_list.append(_asset) 93 | # use sum of face area as weights 94 | sampling_weights.append(tmp_mesh.area.item()) 95 | max_xyz_list.append(_asset.v.max(0)[0]) 96 | min_xyz_list.append(_asset.v.min(0)[0]) 97 | # scale to unit cube 98 | global_max_xyz, _ = torch.stack(max_xyz_list).max(0) 99 | global_min_xyz, _ = torch.stack(min_xyz_list).min(0) 100 | bb_centroid = (global_max_xyz + global_min_xyz) / 2. 101 | global_scale_max = (global_max_xyz - global_min_xyz).max() 102 | for ast in _asset_list: 103 | zero_mean_pts = ast.v.clone() - bb_centroid 104 | ast.v = zero_mean_pts * (1.8 / global_scale_max) 105 | 106 | self.asset_list = _asset_list 107 | _merged_mesh = trimesh.util.concatenate(_valid_mesh_list) 108 | _merged_vertices = torch.from_numpy(_merged_mesh.vertices).to(torch.float32) 109 | _merged_vertices = scale_to_unit_cube(_merged_vertices) 110 | _merged_mesh.vertices = _merged_vertices 111 | _merged_faces = torch.from_numpy(_merged_mesh.faces).to(torch.long) 112 | self.mesh_obj = _merged_mesh 113 | self.f_sdf = cubvh.cuBVH(_merged_vertices.cuda(), _merged_faces.cuda()) 114 | self.mesh_p3d_obj = Meshes([_merged_vertices], [_merged_faces]) 115 | surface_samples = sample_points_from_meshes(self.mesh_p3d_obj, num_surface_samples) 116 | near_samples = sample_points_from_meshes(self.mesh_p3d_obj, num_near_samples) 117 | near_samples = near_samples + torch.rand_like(near_samples) * sample_std 118 | self.sampled_points = torch.concat([surface_samples, near_samples], dim=1)[0] 119 | self.sampled_sdf = self.f_sdf.signed_distance(self.sampled_points, return_uvw=False, mode='raystab')[0].cpu()[..., None] * (-1) 120 | 121 | # instantiation of geometry function 122 | self.geo_fn_list = [] 123 | sampling_weights = torch.Tensor(sampling_weights) 124 | self.sampling_weights = sampling_weights / torch.sum(sampling_weights) 125 | self.num_sampled_pts = (self.sampling_weights * (num_surface_samples + num_near_samples)).to(torch.int) 126 | if not torch.sum(self.num_sampled_pts) == (num_surface_samples + num_near_samples): 127 | diff = num_surface_samples + num_near_samples - torch.sum(self.num_sampled_pts) 128 | self.num_sampled_pts[-1] += diff 129 | sampled_gt_tex = [] 130 | sampled_gt_pts = [] 131 | sampled_gt_mat = [] 132 | for idx, ast in enumerate(_asset_list): 133 | topology = { 134 | "v": ast.v.to(torch.float32), 135 | "vi": ast.f.to(torch.long), 136 | "vti": ast.ft.to(torch.long), 137 | "vt": ast.vt.to(torch.float32), 138 | "n_verts": ast.v.shape[0], 139 | } 140 | # assert ast.albedo.shape[:2] == ast.metallicRoughness.shape[:2] 141 | geo_fn = GeometryModule( 142 | v=topology['v'], 143 | vi=topology['vi'], 144 | vt=topology['vt'], 145 | vti=topology['vti'], 146 | impaint=False, 147 | uv_size=ast.albedo.shape[:2], 148 | ) 149 | self.geo_fn_list.append(geo_fn) 150 | num_sampled_pts = self.num_sampled_pts[idx] 151 | if num_sampled_pts == 0: 152 | continue 153 | sampled_texture, sampled_pts = geo_fn.rand_sample_3d_uv(num_sampled_pts, ast.albedo) 154 | sampled_material, _ = geo_fn.sample_uv_from_3dpts(sampled_pts, ast.metallicRoughness) 155 | sampled_gt_tex.append(torch.from_numpy(sampled_texture)) 156 | sampled_gt_pts.append(torch.from_numpy(np.array(sampled_pts, dtype=np.float32))) 157 | sampled_gt_mat.append(torch.from_numpy(sampled_material[..., -2:])) 158 | 159 | self.sampled_tex = torch.concat(sampled_gt_tex, dim=0) 160 | self.sampled_tex_points = torch.concat(sampled_gt_pts, dim=0) 161 | self.sampled_mat = torch.concat(sampled_gt_mat, dim=0) 162 | self.idx_list = np.arange(self.sampled_sdf.shape[0]) 163 | assert self.sampled_tex.shape[0] == self.sampled_tex_points.shape[0] 164 | assert self.sampled_sdf.shape[0] == self.sampled_points.shape[0] 165 | assert self.sampled_points.shape[0] == self.sampled_tex_points.shape[0] 166 | 167 | def __len__(self): 168 | return self.sampled_sdf.shape[0] 169 | 170 | def __getitem__(self, index): 171 | idxs = np.random.choice(self.idx_list, self.chunk_size) 172 | sample = {} 173 | sample['pts'] = self.sampled_points[idxs, :] 174 | sample['sdf'] = self.sampled_sdf[idxs, :] 175 | sample['tex_pts'] = self.sampled_tex_points[idxs, :] 176 | sample['tex'] = self.sampled_tex[idxs, :] 177 | sample['mat'] = self.sampled_mat[idxs, :] 178 | return sample -------------------------------------------------------------------------------- /dva/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /dva/attr_dict.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import json 8 | 9 | 10 | class AttrDict: 11 | def __init__(self, entries): 12 | self.add_entries_(entries) 13 | 14 | def keys(self): 15 | return self.__dict__.keys() 16 | 17 | def values(self): 18 | return self.__dict__.values() 19 | 20 | def __getitem__(self, key): 21 | return self.__dict__[key] 22 | 23 | def __setitem__(self, key, value): 24 | self.__dict__[key] = value 25 | 26 | def __delitem__(self, key): 27 | return self.__dict__.__delitem__(key) 28 | 29 | def __contains__(self, key): 30 | return key in self.__dict__ 31 | 32 | def __repr__(self): 33 | return self.__dict__.__repr__() 34 | 35 | def __getattr__(self, attr): 36 | if attr.startswith("__"): 37 | return self.__getattribute__(attr) 38 | return self.__dict__[attr] 39 | 40 | def items(self): 41 | return self.__dict__.items() 42 | 43 | def __iter__(self): 44 | return iter(self.items()) 45 | 46 | def add_entries_(self, entries, overwrite=True): 47 | for key, value in entries.items(): 48 | if key not in self.__dict__: 49 | if isinstance(value, dict): 50 | self.__dict__[key] = AttrDict(value) 51 | else: 52 | self.__dict__[key] = value 53 | else: 54 | if isinstance(value, dict): 55 | self.__dict__[key].add_entries_(entries=value, overwrite=overwrite) 56 | elif overwrite or self.__dict__[key] is None: 57 | self.__dict__[key] = value 58 | 59 | def serialize(self): 60 | return json.dumps(self, default=self.obj_to_dict, indent=4) 61 | 62 | def obj_to_dict(self, obj): 63 | return obj.__dict__ 64 | 65 | def get(self, key, default=None): 66 | return self.__dict__.get(key, default) 67 | -------------------------------------------------------------------------------- /dva/io.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import json 8 | import cv2 9 | import numpy as np 10 | import copy 11 | import importlib 12 | from typing import Any, Dict 13 | 14 | def load_module(module_name, class_name=None, silent: bool = False): 15 | module = importlib.import_module(module_name) 16 | return getattr(module, class_name) if class_name else module 17 | 18 | 19 | def load_class(class_name): 20 | return load_module(*class_name.rsplit(".", 1)) 21 | 22 | 23 | def load_from_config(config, **kwargs): 24 | """Instantiate an object given a config and arguments.""" 25 | assert "class_name" in config and "module_name" not in config 26 | config = copy.deepcopy(config) 27 | class_name = config.pop("class_name") 28 | object_class = load_class(class_name) 29 | return object_class(**config, **kwargs) 30 | 31 | 32 | def load_opencv_calib(extrin_path, intrin_path): 33 | cameras = {} 34 | 35 | fse = cv2.FileStorage() 36 | fse.open(extrin_path, cv2.FileStorage_READ) 37 | 38 | fsi = cv2.FileStorage() 39 | fsi.open(intrin_path, cv2.FileStorage_READ) 40 | 41 | names = [ 42 | fse.getNode("names").at(c).string() for c in range(fse.getNode("names").size()) 43 | ] 44 | 45 | for camera in names: 46 | rot = fse.getNode(f"R_{camera}").mat() 47 | R = fse.getNode(f"Rot_{camera}").mat() 48 | T = fse.getNode(f"T_{camera}").mat() 49 | R_pred = cv2.Rodrigues(rot)[0] 50 | assert np.all(np.isclose(R_pred, R)) 51 | K = fsi.getNode(f"K_{camera}").mat() 52 | cameras[camera] = { 53 | "Rt": np.concatenate([R, T], axis=1).astype(np.float32), 54 | "K": K.astype(np.float32), 55 | } 56 | return cameras -------------------------------------------------------------------------------- /dva/layers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch as th 8 | import torch.nn as nn 9 | 10 | import numpy as np 11 | 12 | from dva.mvp.models.utils import Conv2dWN, Conv2dWNUB, ConvTranspose2dWNUB, initmod 13 | 14 | 15 | class ConvBlock(nn.Module): 16 | def __init__( 17 | self, 18 | in_channels, 19 | out_channels, 20 | size, 21 | lrelu_slope=0.2, 22 | kernel_size=3, 23 | padding=1, 24 | wnorm_dim=0, 25 | ): 26 | super().__init__() 27 | 28 | self.conv_resize = Conv2dWN(in_channels, out_channels, kernel_size=1) 29 | self.conv1 = Conv2dWNUB( 30 | in_channels, 31 | in_channels, 32 | kernel_size=kernel_size, 33 | padding=padding, 34 | height=size, 35 | width=size, 36 | ) 37 | 38 | self.lrelu1 = nn.LeakyReLU(lrelu_slope) 39 | self.conv2 = Conv2dWNUB( 40 | in_channels, 41 | out_channels, 42 | kernel_size=kernel_size, 43 | padding=padding, 44 | height=size, 45 | width=size, 46 | ) 47 | self.lrelu2 = nn.LeakyReLU(lrelu_slope) 48 | 49 | def forward(self, x): 50 | x_skip = self.conv_resize(x) 51 | x = self.conv1(x) 52 | x = self.lrelu1(x) 53 | x = self.conv2(x) 54 | x = self.lrelu2(x) 55 | return x + x_skip 56 | 57 | 58 | def tile2d(x, size: int): 59 | """Tile a given set of features into a convolutional map. 60 | 61 | Args: 62 | x: float tensor of shape [N, F] 63 | size: int or a tuple 64 | 65 | Returns: 66 | a feature map [N, F, size[0], size[1]] 67 | """ 68 | # size = size if isinstance(size, tuple) else (size, size) 69 | # NOTE: expecting only int here (!!!) 70 | return x[:, :, np.newaxis, np.newaxis].expand(-1, -1, size, size) 71 | 72 | 73 | def weights_initializer(m, alpha: float = 1.0): 74 | return initmod(m, nn.init.calculate_gain("leaky_relu", alpha)) 75 | 76 | 77 | class UNetWB(nn.Module): 78 | def __init__( 79 | self, 80 | in_channels, 81 | out_channels, 82 | size, 83 | n_init_ftrs=8, 84 | out_scale=0.1, 85 | ): 86 | # super().__init__(*args, **kwargs) 87 | super().__init__() 88 | 89 | self.out_scale = 0.1 90 | 91 | F = n_init_ftrs 92 | 93 | # TODO: allow changing the size? 94 | self.size = size 95 | 96 | self.down1 = nn.Sequential( 97 | Conv2dWNUB(in_channels, F, self.size // 2, self.size // 2, 4, 2, 1), 98 | nn.LeakyReLU(0.2), 99 | ) 100 | self.down2 = nn.Sequential( 101 | Conv2dWNUB(F, 2 * F, self.size // 4, self.size // 4, 4, 2, 1), 102 | nn.LeakyReLU(0.2), 103 | ) 104 | self.down3 = nn.Sequential( 105 | Conv2dWNUB(2 * F, 4 * F, self.size // 8, self.size // 8, 4, 2, 1), 106 | nn.LeakyReLU(0.2), 107 | ) 108 | self.down4 = nn.Sequential( 109 | Conv2dWNUB(4 * F, 8 * F, self.size // 16, self.size // 16, 4, 2, 1), 110 | nn.LeakyReLU(0.2), 111 | ) 112 | self.down5 = nn.Sequential( 113 | Conv2dWNUB(8 * F, 16 * F, self.size // 32, self.size // 32, 4, 2, 1), 114 | nn.LeakyReLU(0.2), 115 | ) 116 | self.up1 = nn.Sequential( 117 | ConvTranspose2dWNUB( 118 | 16 * F, 8 * F, self.size // 16, self.size // 16, 4, 2, 1 119 | ), 120 | nn.LeakyReLU(0.2), 121 | ) 122 | self.up2 = nn.Sequential( 123 | ConvTranspose2dWNUB(8 * F, 4 * F, self.size // 8, self.size // 8, 4, 2, 1), 124 | nn.LeakyReLU(0.2), 125 | ) 126 | self.up3 = nn.Sequential( 127 | ConvTranspose2dWNUB(4 * F, 2 * F, self.size // 4, self.size // 4, 4, 2, 1), 128 | nn.LeakyReLU(0.2), 129 | ) 130 | self.up4 = nn.Sequential( 131 | ConvTranspose2dWNUB(2 * F, F, self.size // 2, self.size // 2, 4, 2, 1), 132 | nn.LeakyReLU(0.2), 133 | ) 134 | self.up5 = nn.Sequential( 135 | ConvTranspose2dWNUB(F, F, self.size, self.size, 4, 2, 1), nn.LeakyReLU(0.2) 136 | ) 137 | self.out = Conv2dWNUB( 138 | F + in_channels, out_channels, self.size, self.size, kernel_size=1 139 | ) 140 | self.apply(lambda x: initmod(x, 0.2)) 141 | initmod(self.out, 1.0) 142 | 143 | def forward(self, x): 144 | x1 = x 145 | x2 = self.down1(x1) 146 | x3 = self.down2(x2) 147 | x4 = self.down3(x3) 148 | x5 = self.down4(x4) 149 | x6 = self.down5(x5) 150 | # TODO: switch to concat? 151 | x = self.up1(x6) + x5 152 | x = self.up2(x) + x4 153 | x = self.up3(x) + x3 154 | x = self.up4(x) + x2 155 | x = self.up5(x) 156 | x = th.cat([x, x1], dim=1) 157 | return self.out(x) * self.out_scale 158 | -------------------------------------------------------------------------------- /dva/mvp/extensions/mvpraymarch/cudadispatch.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #ifndef cudadispatch_h_ 8 | #define cudadispatch_h_ 9 | 10 | #include 11 | #include 12 | #include 13 | 14 | template 15 | struct get_base { 16 | typedef T type; 17 | }; 18 | 19 | template 20 | struct get_base::value>::type> { 21 | typedef std::shared_ptr type; 22 | }; 23 | 24 | template struct is_shared_ptr : std::false_type {}; 25 | template struct is_shared_ptr> : std::true_type {}; 26 | 27 | template 28 | auto convert_shptr_impl2(std::shared_ptr t) { 29 | return *static_cast(t.get()); 30 | } 31 | 32 | template 33 | auto convert_shptr_impl(T&& t, std::false_type) { 34 | return convert_shptr_impl2(t); 35 | } 36 | 37 | template 38 | auto convert_shptr_impl(T&& t, std::true_type) { 39 | return std::forward(t); 40 | } 41 | 42 | template 43 | auto convert_shptr(T&& t) { 44 | return convert_shptr_impl(std::forward(t), std::is_same{}); 45 | } 46 | 47 | template 48 | struct cudacall { 49 | struct functbase { 50 | virtual ~functbase() {} 51 | virtual void call(dim3, dim3, cudaStream_t, ArgsIn...) const = 0; 52 | }; 53 | 54 | template 55 | struct funct : public functbase { 56 | std::function fn; 57 | funct(void(*fn_)(ArgsOut...)) : fn(fn_) { } 58 | void call(dim3 gridsize, dim3 blocksize, cudaStream_t stream, ArgsIn... args) const { 59 | void (*const*kfunc)(ArgsOut...) = fn.template target(); 60 | (*kfunc)<<>>( 61 | std::forward(convert_shptr(std::forward(args)))...); 62 | } 63 | }; 64 | 65 | std::shared_ptr fn; 66 | 67 | template 68 | cudacall(void(*fn_)(ArgsOut...)) : fn(std::make_shared>(fn_)) { } 69 | 70 | template 71 | void call(dim3 gridsize, dim3 blocksize, cudaStream_t stream, ArgsTmp&&... args) const { 72 | fn->call(gridsize, blocksize, stream, std::forward(args)...); 73 | } 74 | }; 75 | 76 | template 77 | struct binder { 78 | F f; T t; 79 | template 80 | auto operator()(Args&&... args) const 81 | -> decltype(f(t, std::forward(args)...)) { 82 | return f(t, std::forward(args)...); 83 | } 84 | }; 85 | 86 | template 87 | binder::type 88 | , typename std::decay::type> BindFirst(F&& f, T&& t) { 89 | return { std::forward(f), std::forward(t) }; 90 | } 91 | 92 | template 93 | auto make_cudacall_(void(*fn)(ArgsOut...)) { 94 | return BindFirst( 95 | std::mem_fn(&cudacall::type...>::template call::type...>), 96 | cudacall::type...>(fn)); 97 | } 98 | 99 | template 100 | std::function::type...)> make_cudacall(void(*fn)(ArgsOut...)) { 101 | return std::function::type...)>(make_cudacall_(fn)); 102 | } 103 | 104 | #endif 105 | -------------------------------------------------------------------------------- /dva/mvp/extensions/mvpraymarch/makefile: -------------------------------------------------------------------------------- 1 | all: 2 | python setup.py build_ext --inplace 3 | -------------------------------------------------------------------------------- /dva/mvp/extensions/mvpraymarch/mvpraymarch_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | 15 | #include "helper_math.h" 16 | 17 | #include "cudadispatch.h" 18 | 19 | #include "utils.h" 20 | 21 | #include "primtransf.h" 22 | #include "primsampler.h" 23 | #include "primaccum.h" 24 | 25 | #include "mvpraymarch_subset_kernel.h" 26 | 27 | typedef std::shared_ptr PrimTransfDataBase_ptr; 28 | typedef std::shared_ptr PrimSamplerDataBase_ptr; 29 | typedef std::shared_ptr PrimAccumDataBase_ptr; 30 | typedef std::function mapfn_t; 34 | typedef RaySubsetFixedBVH raysubset_t; 35 | 36 | void raymarch_forward_cuda( 37 | int N, int H, int W, int K, 38 | float * rayposim, 39 | float * raydirim, 40 | float stepsize, 41 | float * tminmaxim, 42 | 43 | int * sortedobjid, 44 | int * nodechildren, 45 | float * nodeaabb, 46 | float * primpos, 47 | float * primrot, 48 | float * primscale, 49 | 50 | int TD, int TH, int TW, 51 | float * tplate, 52 | int WD, int WH, int WW, 53 | float * warp, 54 | 55 | float * rayrgbaim, 56 | float * raysatim, 57 | int * raytermim, 58 | 59 | int algorithm, 60 | bool sortboxes, 61 | int maxhitboxes, 62 | bool synchitboxes, 63 | bool chlast, 64 | float fadescale, 65 | float fadeexp, 66 | int accum, 67 | float termthresh, 68 | int griddim, int blocksizex, int blocksizey, 69 | cudaStream_t stream) { 70 | dim3 blocksize(blocksizex, blocksizey); 71 | dim3 gridsize; 72 | gridsize = dim3( 73 | (W + blocksize.x - 1) / blocksize.x, 74 | (H + blocksize.y - 1) / blocksize.y, 75 | N); 76 | 77 | std::shared_ptr primtransf_data; 78 | primtransf_data = std::make_shared(PrimTransfSRT::Data{ 79 | PrimTransfDataBase{}, 80 | K, (float3*)primpos, nullptr, 81 | K * 3, (float3*)primrot, nullptr, 82 | K, (float3*)primscale, nullptr}); 83 | std::shared_ptr primsampler_data; 84 | if (algorithm == 1) { 85 | primsampler_data = std::make_shared::Data>(PrimSamplerTW::Data{ 86 | PrimSamplerDataBase{}, 87 | fadescale, fadeexp, 88 | K * TD * TH * TW * 4, TD, TH, TW, tplate, nullptr, 89 | K * WD * WH * WW * 3, WD, WH, WW, warp, nullptr}); 90 | } else { 91 | primsampler_data = std::make_shared::Data>(PrimSamplerTW::Data{ 92 | PrimSamplerDataBase{}, 93 | fadescale, fadeexp, 94 | K * TD * TH * TW * 4, TD, TH, TW, tplate, nullptr, 95 | 0, 0, 0, 0, nullptr, nullptr}); 96 | } 97 | std::shared_ptr primaccum_data = std::make_shared(PrimAccumAdditive::Data{ 98 | PrimAccumDataBase{}, 99 | termthresh, H * W, W, 1, (float4*)rayrgbaim, nullptr, (float3*)raysatim}); 100 | 101 | std::map dispatcher = { 102 | {0, make_cudacall(raymarch_subset_forward_kernel<512, 4, raysubset_t, PrimTransfSRT, PrimSamplerTW, PrimAccumAdditive>)}, 103 | {1, make_cudacall(raymarch_subset_forward_kernel<512, 4, raysubset_t, PrimTransfSRT, PrimSamplerTW, PrimAccumAdditive>)}}; 104 | 105 | auto iter = dispatcher.find(algorithm); 106 | if (iter != dispatcher.end()) { 107 | (iter->second)( 108 | gridsize, blocksize, stream, 109 | N, H, W, K, 110 | reinterpret_cast(rayposim), 111 | reinterpret_cast(raydirim), 112 | stepsize, 113 | reinterpret_cast(tminmaxim), 114 | reinterpret_cast(sortedobjid), 115 | reinterpret_cast(nodechildren), 116 | reinterpret_cast(nodeaabb), 117 | primtransf_data, 118 | primsampler_data, 119 | primaccum_data); 120 | } 121 | } 122 | 123 | void raymarch_backward_cuda( 124 | int N, int H, int W, int K, 125 | float * rayposim, 126 | float * raydirim, 127 | float stepsize, 128 | float * tminmaxim, 129 | int * sortedobjid, 130 | int * nodechildren, 131 | float * nodeaabb, 132 | 133 | float * primpos, 134 | float * grad_primpos, 135 | float * primrot, 136 | float * grad_primrot, 137 | float * primscale, 138 | float * grad_primscale, 139 | 140 | int TD, int TH, int TW, 141 | float * tplate, 142 | float * grad_tplate, 143 | int WD, int WH, int WW, 144 | float * warp, 145 | float * grad_warp, 146 | 147 | float * rayrgbaim, 148 | float * grad_rayrgba, 149 | float * raysatim, 150 | int * raytermim, 151 | 152 | int algorithm, bool sortboxes, int maxhitboxes, bool synchitboxes, 153 | bool chlast, float fadescale, float fadeexp, int accum, float termthresh, 154 | int griddim, int blocksizex, int blocksizey, 155 | 156 | cudaStream_t stream) { 157 | dim3 blocksize(blocksizex, blocksizey); 158 | dim3 gridsize; 159 | gridsize = dim3( 160 | (W + blocksize.x - 1) / blocksize.x, 161 | (H + blocksize.y - 1) / blocksize.y, 162 | N); 163 | 164 | std::shared_ptr primtransf_data; 165 | primtransf_data = std::make_shared(PrimTransfSRT::Data{ 166 | PrimTransfDataBase{}, 167 | K, (float3*)primpos, (float3*)grad_primpos, 168 | K * 3, (float3*)primrot, (float3*)grad_primrot, 169 | K, (float3*)primscale, (float3*)grad_primscale}); 170 | std::shared_ptr primsampler_data; 171 | if (algorithm == 1) { 172 | primsampler_data = std::make_shared::Data>(PrimSamplerTW::Data{ 173 | PrimSamplerDataBase{}, 174 | fadescale, fadeexp, 175 | K * TD * TH * TW * 4, TD, TH, TW, tplate, grad_tplate, 176 | K * WD * WH * WW * 3, WD, WH, WW, warp, grad_warp}); 177 | } else { 178 | primsampler_data = std::make_shared::Data>(PrimSamplerTW::Data{ 179 | PrimSamplerDataBase{}, 180 | fadescale, fadeexp, 181 | K * TD * TH * TW * 4, TD, TH, TW, tplate, grad_tplate, 182 | 0, 0, 0, 0, nullptr, nullptr}); 183 | } 184 | std::shared_ptr primaccum_data = std::make_shared(PrimAccumAdditive::Data{ 185 | PrimAccumDataBase{}, 186 | termthresh, H * W, W, 1, (float4*)rayrgbaim, (float4*)grad_rayrgba, (float3*)raysatim}); 187 | 188 | std::map dispatcher = { 189 | {0, make_cudacall(raymarch_subset_backward_kernel, PrimAccumAdditive>)}, 190 | {1, make_cudacall(raymarch_subset_backward_kernel, PrimAccumAdditive>)}}; 191 | 192 | auto iter = dispatcher.find(algorithm); 193 | if (iter != dispatcher.end()) { 194 | (iter->second)( 195 | gridsize, blocksize, stream, 196 | N, H, W, K, 197 | reinterpret_cast(rayposim), 198 | reinterpret_cast(raydirim), 199 | stepsize, 200 | reinterpret_cast(tminmaxim), 201 | reinterpret_cast(sortedobjid), 202 | reinterpret_cast(nodechildren), 203 | reinterpret_cast(nodeaabb), 204 | primtransf_data, 205 | primsampler_data, 206 | primaccum_data); 207 | } 208 | } 209 | -------------------------------------------------------------------------------- /dva/mvp/extensions/mvpraymarch/mvpraymarch_subset_kernel.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | template< 8 | int maxhitboxes, 9 | int nwarps, 10 | class RaySubsetT=RaySubsetFixedBVH, 11 | class PrimTransfT=PrimTransfSRT, 12 | class PrimSamplerT=PrimSamplerTW, 13 | class PrimAccumT=PrimAccumAdditive> 14 | __global__ void raymarch_subset_forward_kernel( 15 | int N, int H, int W, int K, 16 | float3 * rayposim, 17 | float3 * raydirim, 18 | float stepsize, 19 | float2 * tminmaxim, 20 | int * sortedobjid, 21 | int2 * nodechildren, 22 | float3 * nodeaabb, 23 | typename PrimTransfT::Data primtransf_data, 24 | typename PrimSamplerT::Data primsampler_data, 25 | typename PrimAccumT::Data primaccum_data 26 | ) { 27 | int w = blockIdx.x * blockDim.x + threadIdx.x; 28 | int h = blockIdx.y * blockDim.y + threadIdx.y; 29 | int n = blockIdx.z; 30 | bool validthread = (w < W) && (h < H) && (n 0 ? 1 : maxhitboxes]; 55 | __shared__ int hitboxes_sh[nwarps > 0 ? maxhitboxes * nwarps : 1]; 56 | int * hitboxes_ptr = nwarps > 0 ? hitboxes_sh + maxhitboxes * warpid : hitboxes; 57 | int nhitboxes = 0; 58 | 59 | // find raytminmax 60 | float2 rtminmax = make_float2(std::numeric_limits::infinity(), -std::numeric_limits::infinity()); 61 | RaySubsetT::forward(warpmask, K, raypos, raydir, tminmax, rtminmax, 62 | sortedobjid, nodechildren, nodeaabb, 63 | primtransf_data, hitboxes_ptr, nhitboxes); 64 | rtminmax.x = max(rtminmax.x, tminmax.x); 65 | rtminmax.y = min(rtminmax.y, tminmax.y); 66 | __syncwarp(warpmask); 67 | 68 | float t = tminmax.x; 69 | raypos = raypos + raydir * tminmax.x; 70 | 71 | int incs = floor((rtminmax.x - t) / stepsize); 72 | t += incs * stepsize; 73 | raypos += raydir * incs * stepsize; 74 | 75 | PrimAccumT pa; 76 | 77 | while (!__all_sync(warpmask, t > rtminmax.y + 1e-5f || pa.is_done())) { 78 | for (int ks = 0; ks < nhitboxes; ++ks) { 79 | int k = hitboxes_ptr[ks]; 80 | 81 | // compute primitive-relative coordinate 82 | PrimTransfT pt; 83 | float3 samplepos = pt.forward(primtransf_data, k, raypos); 84 | 85 | if (pt.valid(samplepos) && !pa.is_done() && t < rtminmax.y + 1e-5f) { 86 | // sample 87 | PrimSamplerT ps; 88 | float4 sample = ps.forward(primsampler_data, k, samplepos); 89 | 90 | // accumulate 91 | pa.forward_prim(primaccum_data, sample, stepsize); 92 | } 93 | } 94 | 95 | // update position 96 | t += stepsize; 97 | raypos += raydir * stepsize; 98 | } 99 | 100 | pa.write(primaccum_data); 101 | } 102 | 103 | template < 104 | bool forwarddir, 105 | int maxhitboxes, 106 | int nwarps, 107 | class RaySubsetT=RaySubsetFixedBVH, 108 | class PrimTransfT=PrimTransfSRT, 109 | class PrimSamplerT=PrimSamplerTW, 110 | class PrimAccumT=PrimAccumAdditive> 111 | __global__ void raymarch_subset_backward_kernel( 112 | int N, int H, int W, int K, 113 | float3 * rayposim, 114 | float3 * raydirim, 115 | float stepsize, 116 | float2 * tminmaxim, 117 | int * sortedobjid, 118 | int2 * nodechildren, 119 | float3 * nodeaabb, 120 | typename PrimTransfT::Data primtransf_data, 121 | typename PrimSamplerT::Data primsampler_data, 122 | typename PrimAccumT::Data primaccum_data 123 | ) { 124 | int w = blockIdx.x * blockDim.x + threadIdx.x; 125 | int h = blockIdx.y * blockDim.y + threadIdx.y; 126 | int n = blockIdx.z; 127 | bool validthread = (w < W) && (h < H) && (n 0 ? 1 : maxhitboxes]; 155 | __shared__ int hitboxes_sh[nwarps > 0 ? maxhitboxes * nwarps : 1]; 156 | int * hitboxes_ptr = nwarps > 0 ? hitboxes_sh + maxhitboxes * warpid : hitboxes; 157 | int nhitboxes = 0; 158 | 159 | // find raytminmax 160 | float2 rtminmax = make_float2(std::numeric_limits::infinity(), -std::numeric_limits::infinity()); 161 | RaySubsetT::forward(warpmask, K, raypos, raydir, tminmax, rtminmax, 162 | sortedobjid, nodechildren, nodeaabb, 163 | primtransf_data, hitboxes_ptr, nhitboxes); 164 | rtminmax.x = max(rtminmax.x, tminmax.x); 165 | rtminmax.y = min(rtminmax.y, tminmax.y); 166 | __syncwarp(warpmask); 167 | 168 | // set up raymarching position 169 | float t = tminmax.x; 170 | raypos = raypos + raydir * tminmax.x; 171 | 172 | int incs = floor((rtminmax.x - t) / stepsize); 173 | t += incs * stepsize; 174 | raypos += raydir * incs * stepsize; 175 | 176 | if (!forwarddir) { 177 | int nsteps = pa.get_nsteps(); 178 | t += nsteps * stepsize; 179 | raypos += raydir * nsteps * stepsize; 180 | } 181 | 182 | while (__any_sync(warpmask, ( 183 | (forwarddir && t < rtminmax.y + 1e-5f || 184 | !forwarddir && t > rtminmax.x - 1e-5f) && 185 | !pa.is_done()))) { 186 | for (int ks = 0; ks < nhitboxes; ++ks) { 187 | int k = hitboxes_ptr[forwarddir ? ks : nhitboxes - ks - 1]; 188 | 189 | PrimTransfT pt; 190 | float3 samplepos = pt.forward(primtransf_data, k, raypos); 191 | 192 | bool evalprim = pt.valid(samplepos) && !pa.is_done() && t < rtminmax.y + 1e-5f; 193 | 194 | float3 dL_samplepos = make_float3(0.f); 195 | if (evalprim) { 196 | PrimSamplerT ps; 197 | float4 sample = ps.forward(primsampler_data, k, samplepos); 198 | 199 | float4 dL_sample = pa.forwardbackward_prim(primaccum_data, sample, stepsize); 200 | 201 | dL_samplepos = ps.backward(primsampler_data, k, samplepos, sample, dL_sample, validthread); 202 | } 203 | 204 | if (__any_sync(warpmask, evalprim)) { 205 | pt.backward(primtransf_data, k, samplepos, dL_samplepos, validthread && evalprim); 206 | } 207 | } 208 | 209 | if (forwarddir) { 210 | t += stepsize; 211 | raypos += raydir * stepsize; 212 | } else { 213 | t -= stepsize; 214 | raypos -= raydir * stepsize; 215 | } 216 | } 217 | } 218 | 219 | -------------------------------------------------------------------------------- /dva/mvp/extensions/mvpraymarch/primaccum.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #ifndef MVPRAYMARCHER_PRIMACCUM_H_ 8 | #define MVPRAYMARCHER_PRIMACCUM_H_ 9 | 10 | struct PrimAccumDataBase { 11 | typedef PrimAccumDataBase base; 12 | }; 13 | 14 | struct PrimAccumAdditive { 15 | struct Data : public PrimAccumDataBase { 16 | float termthresh; 17 | 18 | int nstride, hstride, wstride; 19 | float4 * rayrgbaim; 20 | float4 * grad_rayrgbaim; 21 | float3 * raysatim; 22 | 23 | __forceinline__ __device__ void n_stride(int n, int h, int w) { 24 | rayrgbaim += n * nstride + h * hstride + w * wstride; 25 | grad_rayrgbaim += n * nstride + h * hstride + w * wstride; 26 | if (raysatim) { 27 | raysatim += n * nstride + h * hstride + w * wstride; 28 | } 29 | } 30 | }; 31 | 32 | float4 rayrgba; 33 | float3 raysat; 34 | bool sat; 35 | float4 dL_rayrgba; 36 | 37 | __forceinline__ __device__ PrimAccumAdditive() : 38 | rayrgba(make_float4(0.f)), 39 | raysat(make_float3(-1.f)), 40 | sat(false) { 41 | } 42 | 43 | __forceinline__ __device__ bool is_done() const { 44 | return sat; 45 | } 46 | 47 | __forceinline__ __device__ int get_nsteps() const { 48 | return 0; 49 | } 50 | 51 | __forceinline__ __device__ void write(const Data & data) { 52 | *data.rayrgbaim = rayrgba; 53 | if (data.raysatim) { 54 | *data.raysatim = raysat; 55 | } 56 | } 57 | 58 | __forceinline__ __device__ void read(const Data & data) { 59 | dL_rayrgba = *data.grad_rayrgbaim; 60 | raysat = *data.raysatim; 61 | } 62 | 63 | __forceinline__ __device__ void forward_prim(const Data & data, float4 sample, float stepsize) { 64 | // accumulate 65 | float3 rgb = make_float3(sample); 66 | float alpha = sample.w; 67 | float newalpha = rayrgba.w + alpha * stepsize; 68 | float contrib = fminf(newalpha, 1.f) - rayrgba.w; 69 | 70 | rayrgba += make_float4(rgb, 1.f) * contrib; 71 | 72 | if (newalpha >= 1.f) { 73 | // save saturation point 74 | if (!sat) { 75 | raysat = rgb; 76 | } 77 | sat = true; 78 | } 79 | } 80 | 81 | __forceinline__ __device__ float4 forwardbackward_prim(const Data & data, float4 sample, float stepsize) { 82 | float3 rgb = make_float3(sample); 83 | float4 rgb1 = make_float4(rgb, 1.f); 84 | sample.w *= stepsize; 85 | 86 | bool thissat = rayrgba.w + sample.w >= 1.f; 87 | sat = sat || thissat; 88 | 89 | float weight = sat ? (1.f - rayrgba.w) : sample.w; 90 | 91 | float3 dL_rgb = weight * make_float3(dL_rayrgba); 92 | float dL_alpha = sat ? 0.f : 93 | stepsize * dot(rgb1 - (raysat.x > -1.f ? make_float4(raysat, 1.f) : make_float4(0.f)), dL_rayrgba); 94 | 95 | rayrgba += make_float4(rgb, 1.f) * weight; 96 | 97 | return make_float4(dL_rgb, dL_alpha); 98 | } 99 | }; 100 | 101 | #endif 102 | -------------------------------------------------------------------------------- /dva/mvp/extensions/mvpraymarch/primsampler.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #ifndef MVPRAYMARCHER_PRIMSAMPLER_H_ 8 | #define MVPRAYMARCHER_PRIMSAMPLER_H_ 9 | 10 | struct PrimSamplerDataBase { 11 | typedef PrimSamplerDataBase base; 12 | }; 13 | 14 | template< 15 | bool dowarp, 16 | template class GridSamplerT=GridSamplerChlast> 17 | struct PrimSamplerTW { 18 | struct Data : public PrimSamplerDataBase { 19 | float fadescale, fadeexp; 20 | 21 | int tplate_nstride; 22 | int TD, TH, TW; 23 | float * tplate; 24 | float * grad_tplate; 25 | 26 | int warp_nstride; 27 | int WD, WH, WW; 28 | float * warp; 29 | float * grad_warp; 30 | 31 | __forceinline__ __device__ void n_stride(int n) { 32 | tplate += n * tplate_nstride; 33 | grad_tplate += n * tplate_nstride; 34 | warp += n * warp_nstride; 35 | grad_warp += n * warp_nstride; 36 | } 37 | }; 38 | 39 | float fade; 40 | float * tplate_ptr; 41 | float * warp_ptr; 42 | float3 yy1; 43 | 44 | __forceinline__ __device__ float4 forward( 45 | const Data & data, 46 | int k, 47 | float3 y0) { 48 | fade = __expf(-data.fadescale * ( 49 | __powf(abs(y0.x), data.fadeexp) + 50 | __powf(abs(y0.y), data.fadeexp) + 51 | __powf(abs(y0.z), data.fadeexp))); 52 | 53 | if (dowarp) { 54 | warp_ptr = data.warp + (k * 3 * data.WD * data.WH * data.WW); 55 | yy1 = GridSamplerT::forward(3, data.WD, data.WH, data.WW, warp_ptr, y0, false); 56 | } else { 57 | yy1 = y0; 58 | } 59 | 60 | tplate_ptr = data.tplate + (k * 4 * data.TD * data.TH * data.TW); 61 | float4 sample = GridSamplerT::forward(4, data.TD, data.TH, data.TW, tplate_ptr, yy1, false); 62 | 63 | sample.w *= fade; 64 | 65 | return sample; 66 | } 67 | 68 | __forceinline__ __device__ float3 backward(const Data & data, int k, float3 y0, 69 | float4 sample, float4 dL_sample, bool validthread) { 70 | float3 dfade_y0 = -(data.fadescale * data.fadeexp) * make_float3( 71 | __powf(abs(y0.x), data.fadeexp - 1.f) * (y0.x > 0.f ? 1.f : -1.f), 72 | __powf(abs(y0.y), data.fadeexp - 1.f) * (y0.y > 0.f ? 1.f : -1.f), 73 | __powf(abs(y0.z), data.fadeexp - 1.f) * (y0.z > 0.f ? 1.f : -1.f)); 74 | float3 dL_y0 = dfade_y0 * sample.w * dL_sample.w; 75 | 76 | dL_sample.w *= fade; 77 | 78 | float * grad_tplate_ptr = data.grad_tplate + (k * 4 * data.TD * data.TH * data.TW); 79 | float3 dL_y1 = GridSamplerT::backward(4, data.TD, data.TH, data.TW, 80 | tplate_ptr, grad_tplate_ptr, yy1, validthread ? dL_sample : make_float4(0.f), false); 81 | 82 | if (dowarp) { 83 | float * grad_warp_ptr = data.grad_warp + (k * 3 * data.WD * data.WH * data.WW); 84 | dL_y0 += GridSamplerT::backward(3, data.WD, data.WH, data.WW, 85 | warp_ptr, grad_warp_ptr, y0, validthread ? dL_y1 : make_float3(0.f), false); 86 | } else { 87 | dL_y0 += dL_y1; 88 | } 89 | 90 | return dL_y0; 91 | } 92 | }; 93 | 94 | #endif 95 | -------------------------------------------------------------------------------- /dva/mvp/extensions/mvpraymarch/primtransf.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #ifndef MVPRAYMARCHER_PRIMTRANSF_H_ 8 | #define MVPRAYMARCHER_PRIMTRANSF_H_ 9 | 10 | #include "utils.h" 11 | 12 | __forceinline__ __device__ void compute_aabb_srt( 13 | float3 pt, float3 pr0, float3 pr1, float3 pr2, float3 ps, 14 | float3 & pmin, float3 & pmax) { 15 | float3 p; 16 | p = make_float3(-1.f, -1.f, -1.f) / ps; 17 | p = make_float3(dot(p, pr0), dot(p, pr1), dot(p, pr2)) + pt; 18 | 19 | pmin = p; 20 | pmax = p; 21 | 22 | p = make_float3(1.f, -1.f, -1.f) / ps; 23 | p = make_float3(dot(p, pr0), dot(p, pr1), dot(p, pr2)) + pt; 24 | 25 | pmin = fminf(pmin, p); 26 | pmax = fmaxf(pmax, p); 27 | 28 | p = make_float3(-1.f, 1.f, -1.f) / ps; 29 | p = make_float3(dot(p, pr0), dot(p, pr1), dot(p, pr2)) + pt; 30 | 31 | pmin = fminf(pmin, p); 32 | pmax = fmaxf(pmax, p); 33 | 34 | p = make_float3(1.f, 1.f, -1.f) / ps; 35 | p = make_float3(dot(p, pr0), dot(p, pr1), dot(p, pr2)) + pt; 36 | 37 | pmin = fminf(pmin, p); 38 | pmax = fmaxf(pmax, p); 39 | 40 | p = make_float3(-1.f, -1.f, 1.f) / ps; 41 | p = make_float3(dot(p, pr0), dot(p, pr1), dot(p, pr2)) + pt; 42 | 43 | pmin = fminf(pmin, p); 44 | pmax = fmaxf(pmax, p); 45 | 46 | p = make_float3(1.f, -1.f, 1.f) / ps; 47 | p = make_float3(dot(p, pr0), dot(p, pr1), dot(p, pr2)) + pt; 48 | 49 | pmin = fminf(pmin, p); 50 | pmax = fmaxf(pmax, p); 51 | 52 | p = make_float3(-1.f, 1.f, 1.f) / ps; 53 | p = make_float3(dot(p, pr0), dot(p, pr1), dot(p, pr2)) + pt; 54 | 55 | pmin = fminf(pmin, p); 56 | pmax = fmaxf(pmax, p); 57 | 58 | p = make_float3(1.f, 1.f, 1.f) / ps; 59 | p = make_float3(dot(p, pr0), dot(p, pr1), dot(p, pr2)) + pt; 60 | 61 | pmin = fminf(pmin, p); 62 | pmax = fmaxf(pmax, p); 63 | } 64 | 65 | struct PrimTransfDataBase { 66 | typedef PrimTransfDataBase base; 67 | }; 68 | 69 | struct PrimTransfSRT { 70 | struct Data : public PrimTransfDataBase { 71 | int primpos_nstride; 72 | float3 * primpos; 73 | float3 * grad_primpos; 74 | int primrot_nstride; 75 | float3 * primrot; 76 | float3 * grad_primrot; 77 | int primscale_nstride; 78 | float3 * primscale; 79 | float3 * grad_primscale; 80 | 81 | __forceinline__ __device__ void n_stride(int n) { 82 | primpos += n * primpos_nstride; 83 | grad_primpos += n * primpos_nstride; 84 | primrot += n * primrot_nstride; 85 | grad_primrot += n * primrot_nstride; 86 | primscale += n * primscale_nstride; 87 | grad_primscale += n * primscale_nstride; 88 | } 89 | 90 | __forceinline__ __device__ float3 get_center(int n, int k) { 91 | return primpos[n * primpos_nstride + k]; 92 | } 93 | 94 | __forceinline__ __device__ void compute_aabb(int n, int k, float3 & pmin, float3 & pmax) { 95 | float3 pt = primpos[n * primpos_nstride + k]; 96 | float3 pr0 = primrot[n * primrot_nstride + k * 3 + 0]; 97 | float3 pr1 = primrot[n * primrot_nstride + k * 3 + 1]; 98 | float3 pr2 = primrot[n * primrot_nstride + k * 3 + 2]; 99 | float3 ps = primscale[n * primscale_nstride + k]; 100 | 101 | compute_aabb_srt(pt, pr0, pr1, pr2, ps, pmin, pmax); 102 | } 103 | }; 104 | 105 | float3 xmt; 106 | float3 pr0; 107 | float3 pr1; 108 | float3 pr2; 109 | float3 rxmt; 110 | float3 ps; 111 | 112 | static __forceinline__ __device__ bool valid(float3 pos) { 113 | return ( 114 | pos.x > -1.f && pos.x < 1.f && 115 | pos.y > -1.f && pos.y < 1.f && 116 | pos.z > -1.f && pos.z < 1.f); 117 | } 118 | 119 | __forceinline__ __device__ float3 forward( 120 | const Data & data, 121 | int k, 122 | float3 x) { 123 | float3 pt = data.primpos[k]; 124 | pr0 = data.primrot[(k) * 3 + 0]; 125 | pr1 = data.primrot[(k) * 3 + 1]; 126 | pr2 = data.primrot[(k) * 3 + 2]; 127 | ps = data.primscale[k]; 128 | xmt = x - pt; 129 | rxmt = pr0 * xmt.x + pr1 * xmt.y + pr2 * xmt.z; 130 | float3 y0 = rxmt * ps; 131 | return y0; 132 | } 133 | 134 | static __forceinline__ __device__ void forward2( 135 | const Data & data, 136 | int k, 137 | float3 r, float3 d, float3 & rout, float3 & dout) { 138 | float3 pt = data.primpos[k]; 139 | float3 pr0 = data.primrot[k * 3 + 0]; 140 | float3 pr1 = data.primrot[k * 3 + 1]; 141 | float3 pr2 = data.primrot[k * 3 + 2]; 142 | float3 ps = data.primscale[k]; 143 | float3 xmt = r - pt; 144 | float3 dmt = d; 145 | float3 rxmt = pr0 * xmt.x; 146 | float3 rdmt = pr0 * dmt.x; 147 | rxmt += pr1 * xmt.y; 148 | rdmt += pr1 * dmt.y; 149 | rxmt += pr2 * xmt.z; 150 | rdmt += pr2 * dmt.z; 151 | rout = rxmt * ps; 152 | dout = rdmt * ps; 153 | } 154 | 155 | __forceinline__ __device__ void backward(const Data & data, int k, float3 x, float3 dL_y0, bool validthread) { 156 | fastAtomicAdd((float*)data.grad_primscale + k * 3 + 0, validthread ? rxmt.x * dL_y0.x : 0.f); 157 | fastAtomicAdd((float*)data.grad_primscale + k * 3 + 1, validthread ? rxmt.y * dL_y0.y : 0.f); 158 | fastAtomicAdd((float*)data.grad_primscale + k * 3 + 2, validthread ? rxmt.z * dL_y0.z : 0.f); 159 | 160 | dL_y0 *= ps; 161 | float3 gpr0 = xmt.x * dL_y0; 162 | fastAtomicAdd((float*)data.grad_primrot + (k * 3 + 0) * 3 + 0, validthread ? gpr0.x : 0.f); 163 | fastAtomicAdd((float*)data.grad_primrot + (k * 3 + 0) * 3 + 1, validthread ? gpr0.y : 0.f); 164 | fastAtomicAdd((float*)data.grad_primrot + (k * 3 + 0) * 3 + 2, validthread ? gpr0.z : 0.f); 165 | 166 | float3 gpr1 = xmt.y * dL_y0; 167 | fastAtomicAdd((float*)data.grad_primrot + (k * 3 + 1) * 3 + 0, validthread ? gpr1.x : 0.f); 168 | fastAtomicAdd((float*)data.grad_primrot + (k * 3 + 1) * 3 + 1, validthread ? gpr1.y : 0.f); 169 | fastAtomicAdd((float*)data.grad_primrot + (k * 3 + 1) * 3 + 2, validthread ? gpr1.z : 0.f); 170 | 171 | float3 gpr2 = xmt.z * dL_y0; 172 | fastAtomicAdd((float*)data.grad_primrot + (k * 3 + 2) * 3 + 0, validthread ? gpr2.x : 0.f); 173 | fastAtomicAdd((float*)data.grad_primrot + (k * 3 + 2) * 3 + 1, validthread ? gpr2.y : 0.f); 174 | fastAtomicAdd((float*)data.grad_primrot + (k * 3 + 2) * 3 + 2, validthread ? gpr2.z : 0.f); 175 | 176 | fastAtomicAdd((float*)data.grad_primpos + k * 3 + 0, validthread ? -dot(pr0, dL_y0) : 0.f); 177 | fastAtomicAdd((float*)data.grad_primpos + k * 3 + 1, validthread ? -dot(pr1, dL_y0) : 0.f); 178 | fastAtomicAdd((float*)data.grad_primpos + k * 3 + 2, validthread ? -dot(pr2, dL_y0) : 0.f); 179 | } 180 | }; 181 | 182 | #endif 183 | -------------------------------------------------------------------------------- /dva/mvp/extensions/mvpraymarch/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from setuptools import setup 8 | 9 | from torch.utils.cpp_extension import CUDAExtension, BuildExtension 10 | 11 | if __name__ == "__main__": 12 | import torch 13 | setup( 14 | name="mvpraymarch", 15 | ext_modules=[ 16 | CUDAExtension( 17 | "mvpraymarchlib", 18 | sources=["mvpraymarch.cpp", "mvpraymarch_kernel.cu", "bvh.cu"], 19 | extra_compile_args={ 20 | "nvcc": [ 21 | "-use_fast_math", 22 | "-arch=sm_70", 23 | "-std=c++17", 24 | "-lineinfo", 25 | ] 26 | } 27 | ) 28 | ], 29 | cmdclass={"build_ext": BuildExtension} 30 | ) 31 | -------------------------------------------------------------------------------- /dva/mvp/extensions/utils/makefile: -------------------------------------------------------------------------------- 1 | all: 2 | python setup.py build_ext --inplace 3 | -------------------------------------------------------------------------------- /dva/mvp/extensions/utils/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from setuptools import setup 8 | 9 | from torch.utils.cpp_extension import CUDAExtension, BuildExtension 10 | 11 | if __name__ == "__main__": 12 | import torch 13 | setup( 14 | name="utils", 15 | ext_modules=[ 16 | CUDAExtension( 17 | "utilslib", 18 | sources=["utils.cpp", "utils_kernel.cu"], 19 | extra_compile_args={ 20 | "nvcc": [ 21 | "-arch=sm_70", 22 | "-std=c++14", 23 | "-lineinfo", 24 | ] 25 | } 26 | ) 27 | ], 28 | cmdclass={"build_ext": BuildExtension} 29 | ) 30 | -------------------------------------------------------------------------------- /dva/mvp/extensions/utils/utils.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #include 8 | #include 9 | 10 | #include 11 | 12 | void compute_raydirs_forward_cuda( 13 | int N, int H, int W, 14 | float * viewposim, 15 | float * viewrotim, 16 | float * focalim, 17 | float * princptim, 18 | float * pixelcoordsim, 19 | float volradius, 20 | float * raypos, 21 | float * raydir, 22 | float * tminmax, 23 | cudaStream_t stream); 24 | 25 | void compute_raydirs_backward_cuda( 26 | int N, int H, int W, 27 | float * viewposim, 28 | float * viewrotim, 29 | float * focalim, 30 | float * princptim, 31 | float * pixelcoordsim, 32 | float volradius, 33 | float * raypos, 34 | float * raydir, 35 | float * tminmax, 36 | float * grad_viewposim, 37 | float * grad_viewrotim, 38 | float * grad_focalim, 39 | float * grad_princptim, 40 | cudaStream_t stream); 41 | 42 | #define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x " must be a CUDA tensor") 43 | #define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") 44 | #define CHECK_INPUT(x) CHECK_CUDA((x)); CHECK_CONTIGUOUS((x)) 45 | 46 | std::vector compute_raydirs_forward( 47 | torch::Tensor viewposim, 48 | torch::Tensor viewrotim, 49 | torch::Tensor focalim, 50 | torch::Tensor princptim, 51 | torch::optional pixelcoordsim, 52 | int W, int H, 53 | float volradius, 54 | torch::Tensor rayposim, 55 | torch::Tensor raydirim, 56 | torch::Tensor tminmaxim) { 57 | CHECK_INPUT(viewposim); 58 | CHECK_INPUT(viewrotim); 59 | CHECK_INPUT(focalim); 60 | CHECK_INPUT(princptim); 61 | if (pixelcoordsim) { CHECK_INPUT(*pixelcoordsim); } 62 | CHECK_INPUT(rayposim); 63 | CHECK_INPUT(raydirim); 64 | CHECK_INPUT(tminmaxim); 65 | 66 | int N = viewposim.size(0); 67 | assert(!pixelcoordsim || (pixelcoordsim.size(1) == H && pixelcoordsim.size(2) == W)); 68 | 69 | compute_raydirs_forward_cuda(N, H, W, 70 | reinterpret_cast(viewposim.data_ptr()), 71 | reinterpret_cast(viewrotim.data_ptr()), 72 | reinterpret_cast(focalim.data_ptr()), 73 | reinterpret_cast(princptim.data_ptr()), 74 | pixelcoordsim ? reinterpret_cast(pixelcoordsim->data_ptr()) : nullptr, 75 | volradius, 76 | reinterpret_cast(rayposim.data_ptr()), 77 | reinterpret_cast(raydirim.data_ptr()), 78 | reinterpret_cast(tminmaxim.data_ptr()), 79 | 0); 80 | 81 | return {}; 82 | } 83 | 84 | std::vector compute_raydirs_backward( 85 | torch::Tensor viewposim, 86 | torch::Tensor viewrotim, 87 | torch::Tensor focalim, 88 | torch::Tensor princptim, 89 | torch::optional pixelcoordsim, 90 | int W, int H, 91 | float volradius, 92 | torch::Tensor rayposim, 93 | torch::Tensor raydirim, 94 | torch::Tensor tminmaxim, 95 | torch::Tensor grad_viewpos, 96 | torch::Tensor grad_viewrot, 97 | torch::Tensor grad_focal, 98 | torch::Tensor grad_princpt) { 99 | CHECK_INPUT(viewposim); 100 | CHECK_INPUT(viewrotim); 101 | CHECK_INPUT(focalim); 102 | CHECK_INPUT(princptim); 103 | if (pixelcoordsim) { CHECK_INPUT(*pixelcoordsim); } 104 | CHECK_INPUT(rayposim); 105 | CHECK_INPUT(raydirim); 106 | CHECK_INPUT(tminmaxim); 107 | CHECK_INPUT(grad_viewpos); 108 | CHECK_INPUT(grad_viewrot); 109 | CHECK_INPUT(grad_focal); 110 | CHECK_INPUT(grad_princpt); 111 | 112 | int N = viewposim.size(0); 113 | assert(!pixelcoordsim || (pixelcoordsim.size(1) == H && pixelcoordsim.size(2) == W)); 114 | 115 | compute_raydirs_backward_cuda(N, H, W, 116 | reinterpret_cast(viewposim.data_ptr()), 117 | reinterpret_cast(viewrotim.data_ptr()), 118 | reinterpret_cast(focalim.data_ptr()), 119 | reinterpret_cast(princptim.data_ptr()), 120 | pixelcoordsim ? reinterpret_cast(pixelcoordsim->data_ptr()) : nullptr, 121 | volradius, 122 | reinterpret_cast(rayposim.data_ptr()), 123 | reinterpret_cast(raydirim.data_ptr()), 124 | reinterpret_cast(tminmaxim.data_ptr()), 125 | reinterpret_cast(grad_viewpos.data_ptr()), 126 | reinterpret_cast(grad_viewrot.data_ptr()), 127 | reinterpret_cast(grad_focal.data_ptr()), 128 | reinterpret_cast(grad_princpt.data_ptr()), 129 | 0); 130 | 131 | return {}; 132 | } 133 | 134 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 135 | m.def("compute_raydirs_forward", &compute_raydirs_forward, "raydirs forward (CUDA)"); 136 | m.def("compute_raydirs_backward", &compute_raydirs_backward, "raydirs backward (CUDA)"); 137 | } 138 | -------------------------------------------------------------------------------- /dva/mvp/extensions/utils/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import time 9 | 10 | import torch 11 | import torch.nn as nn 12 | from torch.autograd import Function 13 | import torch.nn.functional as F 14 | 15 | try: 16 | from . import utilslib 17 | except: 18 | import utilslib 19 | 20 | class ComputeRaydirs(Function): 21 | @staticmethod 22 | def forward(self, viewpos, viewrot, focal, princpt, pixelcoords, volradius): 23 | for tensor in [viewpos, viewrot, focal, princpt, pixelcoords]: 24 | assert tensor.is_contiguous() 25 | 26 | N = viewpos.size(0) 27 | if isinstance(pixelcoords, tuple): 28 | W, H = pixelcoords 29 | pixelcoords = None 30 | else: 31 | H = pixelcoords.size(1) 32 | W = pixelcoords.size(2) 33 | 34 | raypos = torch.empty((N, H, W, 3), device=viewpos.device) 35 | raydirs = torch.empty((N, H, W, 3), device=viewpos.device) 36 | tminmax = torch.empty((N, H, W, 2), device=viewpos.device) 37 | utilslib.compute_raydirs_forward(viewpos, viewrot, focal, princpt, 38 | pixelcoords, W, H, volradius, raypos, raydirs, tminmax) 39 | 40 | return raypos, raydirs, tminmax 41 | 42 | @staticmethod 43 | def backward(self, grad_raydirs, grad_tminmax): 44 | return None, None, None, None, None, None 45 | 46 | def compute_raydirs(viewpos, viewrot, focal, princpt, pixelcoords, volradius): 47 | raypos, raydirs, tminmax = ComputeRaydirs.apply(viewpos, viewrot, focal, princpt, pixelcoords, volradius) 48 | return raypos, raydirs, tminmax 49 | 50 | class Rodrigues(nn.Module): 51 | def __init__(self): 52 | super(Rodrigues, self).__init__() 53 | 54 | def forward(self, rvec): 55 | theta = torch.sqrt(1e-5 + torch.sum(rvec ** 2, dim=1)) 56 | rvec = rvec / theta[:, None] 57 | costh = torch.cos(theta) 58 | sinth = torch.sin(theta) 59 | return torch.stack(( 60 | rvec[:, 0] ** 2 + (1. - rvec[:, 0] ** 2) * costh, 61 | rvec[:, 0] * rvec[:, 1] * (1. - costh) - rvec[:, 2] * sinth, 62 | rvec[:, 0] * rvec[:, 2] * (1. - costh) + rvec[:, 1] * sinth, 63 | 64 | rvec[:, 0] * rvec[:, 1] * (1. - costh) + rvec[:, 2] * sinth, 65 | rvec[:, 1] ** 2 + (1. - rvec[:, 1] ** 2) * costh, 66 | rvec[:, 1] * rvec[:, 2] * (1. - costh) - rvec[:, 0] * sinth, 67 | 68 | rvec[:, 0] * rvec[:, 2] * (1. - costh) - rvec[:, 1] * sinth, 69 | rvec[:, 1] * rvec[:, 2] * (1. - costh) + rvec[:, 0] * sinth, 70 | rvec[:, 2] ** 2 + (1. - rvec[:, 2] ** 2) * costh), dim=1).view(-1, 3, 3) 71 | 72 | def gradcheck(): 73 | N = 2 74 | H = 64 75 | W = 64 76 | k3 = 4 77 | K = k3*k3*k3 78 | 79 | M = 32 80 | volradius = 1. 81 | 82 | # generate random inputs 83 | torch.manual_seed(1113) 84 | 85 | rodrigues = Rodrigues() 86 | 87 | _viewpos = torch.tensor([[-0.0, 0.0, -4.] for n in range(N)], device="cuda") + torch.randn(N, 3, device="cuda") * 0.1 88 | viewrvec = torch.randn(N, 3, device="cuda") * 0.01 89 | _viewrot = rodrigues(viewrvec) 90 | 91 | _focal = torch.tensor([[W*4.0, W*4.0] for n in range(N)], device="cuda") 92 | _princpt = torch.tensor([[W*0.5, H*0.5] for n in range(N)], device="cuda") 93 | pixely, pixelx = torch.meshgrid(torch.arange(H, device="cuda").float(), torch.arange(W, device="cuda").float()) 94 | _pixelcoords = torch.stack([pixelx, pixely], dim=-1)[None, :, :, :].repeat(N, 1, 1, 1) 95 | 96 | _viewpos = _viewpos.contiguous().detach().clone() 97 | _viewpos.requires_grad = True 98 | _viewrot = _viewrot.contiguous().detach().clone() 99 | _viewrot.requires_grad = True 100 | _focal = _focal.contiguous().detach().clone() 101 | _focal.requires_grad = True 102 | _princpt = _princpt.contiguous().detach().clone() 103 | _princpt.requires_grad = True 104 | _pixelcoords = _pixelcoords.contiguous().detach().clone() 105 | _pixelcoords.requires_grad = True 106 | 107 | max_len = 6.0 108 | _stepsize = max_len / 15.5 109 | 110 | params = [_viewpos, _viewrot, _focal, _princpt] 111 | paramnames = ["viewpos", "viewrot", "focal", "princpt"] 112 | 113 | ########################### run pytorch version ########################### 114 | 115 | viewpos = _viewpos 116 | viewrot = _viewrot 117 | focal = _focal 118 | princpt = _princpt 119 | pixelcoords = _pixelcoords 120 | 121 | raypos = viewpos[:, None, None, :].repeat(1, H, W, 1) 122 | 123 | raydir = (pixelcoords - princpt[:, None, None, :]) / focal[:, None, None, :] 124 | raydir = torch.cat([raydir, torch.ones_like(raydir[:, :, :, 0:1])], dim=-1) 125 | raydir = torch.sum(viewrot[:, None, None, :, :] * raydir[:, :, :, :, None], dim=-2) 126 | raydir = raydir / torch.sqrt(torch.sum(raydir ** 2, dim=-1, keepdim=True)) 127 | 128 | t1 = (-1. - viewpos[:, None, None, :]) / raydir 129 | t2 = ( 1. - viewpos[:, None, None, :]) / raydir 130 | tmin = torch.max(torch.min(t1[..., 0], t2[..., 0]), 131 | torch.max(torch.min(t1[..., 1], t2[..., 1]), 132 | torch.min(t1[..., 2], t2[..., 2]))).clamp(min=0.) 133 | tmax = torch.min(torch.max(t1[..., 0], t2[..., 0]), 134 | torch.min(torch.max(t1[..., 1], t2[..., 1]), 135 | torch.max(t1[..., 2], t2[..., 2]))) 136 | 137 | tminmax = torch.stack([tmin, tmax], dim=-1) 138 | 139 | sample0 = raydir 140 | 141 | torch.cuda.synchronize() 142 | time1 = time.time() 143 | 144 | sample0.backward(torch.ones_like(sample0)) 145 | 146 | torch.cuda.synchronize() 147 | time2 = time.time() 148 | 149 | grads0 = [p.grad.detach().clone() if p.grad is not None else None for p in params] 150 | 151 | for p in params: 152 | if p.grad is not None: 153 | p.grad.detach_() 154 | p.grad.zero_() 155 | 156 | ############################## run cuda version ########################### 157 | 158 | viewpos = _viewpos 159 | viewrot = _viewrot 160 | focal = _focal 161 | princpt = _princpt 162 | pixelcoords = _pixelcoords 163 | 164 | niter = 1 165 | 166 | for p in params: 167 | if p.grad is not None: 168 | p.grad.detach_() 169 | p.grad.zero_() 170 | t0 = time.time() 171 | torch.cuda.synchronize() 172 | 173 | sample1 = compute_raydirs(viewpos, viewrot, focal, princpt, pixelcoords, volradius)[1] 174 | 175 | t1 = time.time() 176 | torch.cuda.synchronize() 177 | 178 | print("-----------------------------------------------------------------") 179 | print("{:>10} {:>10} {:>10} {:>10} {:>10} {:>10}".format("", "maxabsdiff", "dp", "index", "py", "cuda")) 180 | ind = torch.argmax(torch.abs(sample0 - sample1)) 181 | print("{:<10} {:>10.5} {:>10.5} {:>10} {:>10.5} {:>10.5}".format( 182 | "fwd", 183 | torch.max(torch.abs(sample0 - sample1)).item(), 184 | (torch.sum(sample0 * sample1) / torch.sqrt(torch.sum(sample0 * sample0) * torch.sum(sample1 * sample1))).item(), 185 | ind.item(), 186 | sample0.view(-1)[ind].item(), 187 | sample1.view(-1)[ind].item())) 188 | 189 | sample1.backward(torch.ones_like(sample1), retain_graph=True) 190 | 191 | torch.cuda.synchronize() 192 | t2 = time.time() 193 | 194 | 195 | print("{:<10} {:10.5} {:10.5} {:10.5}".format("time", tf / niter, tb / niter, (tf + tb) / niter)) 196 | grads1 = [p.grad.detach().clone() if p.grad is not None else None for p in params] 197 | 198 | ############# compare results ############# 199 | 200 | for p, g0, g1 in zip(paramnames, grads0, grads1): 201 | ind = torch.argmax(torch.abs(g0 - g1)) 202 | print("{:<10} {:>10.5} {:>10.5} {:>10} {:>10.5} {:>10.5}".format( 203 | p, 204 | torch.max(torch.abs(g0 - g1)).item(), 205 | (torch.sum(g0 * g1) / torch.sqrt(torch.sum(g0 * g0) * torch.sum(g1 * g1))).item(), 206 | ind.item(), 207 | g0.view(-1)[ind].item(), 208 | g1.view(-1)[ind].item())) 209 | 210 | if __name__ == "__main__": 211 | gradcheck() 212 | -------------------------------------------------------------------------------- /dva/mvp/extensions/utils/utils_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | 13 | #include "helper_math.h" 14 | 15 | __global__ void compute_raydirs_forward_kernel( 16 | int N, int H, int W, 17 | float3 * viewposim, 18 | float3 * viewrotim, 19 | float2 * focalim, 20 | float2 * princptim, 21 | float2 * pixelcoordsim, 22 | float volradius, 23 | float3 * rayposim, 24 | float3 * raydirim, 25 | float2 * tminmaxim 26 | ) { 27 | bool validthread = false; 28 | int w, h, n; 29 | w = blockIdx.x * blockDim.x + threadIdx.x; 30 | h = (blockIdx.y * blockDim.y + threadIdx.y)%H; 31 | n = (blockIdx.y * blockDim.y + threadIdx.y)/H; 32 | validthread = (w < W) && (h < H) && (n>>( 122 | N, H, W, 123 | reinterpret_cast(viewposim), 124 | reinterpret_cast(viewrotim), 125 | reinterpret_cast(focalim), 126 | reinterpret_cast(princptim), 127 | reinterpret_cast(pixelcoordsim), 128 | volradius, 129 | reinterpret_cast(rayposim), 130 | reinterpret_cast(raydirim), 131 | reinterpret_cast(tminmaxim)); 132 | } 133 | 134 | void compute_raydirs_backward_cuda( 135 | int N, int H, int W, 136 | float * viewposim, 137 | float * viewrotim, 138 | float * focalim, 139 | float * princptim, 140 | float * pixelcoordsim, 141 | float volradius, 142 | float * rayposim, 143 | float * raydirim, 144 | float * tminmaxim, 145 | float * grad_viewposim, 146 | float * grad_viewrotim, 147 | float * grad_focalim, 148 | float * grad_princptim, 149 | cudaStream_t stream) { 150 | int blocksizex = 16; 151 | int blocksizey = 16; 152 | dim3 blocksize(blocksizex, blocksizey); 153 | dim3 gridsize; 154 | gridsize = dim3( 155 | (W + blocksize.x - 1) / blocksize.x, 156 | (N*H + blocksize.y - 1) / blocksize.y); 157 | 158 | auto fn = compute_raydirs_backward_kernel; 159 | fn<<>>( 160 | N, H, W, 161 | reinterpret_cast(viewposim), 162 | reinterpret_cast(viewrotim), 163 | reinterpret_cast(focalim), 164 | reinterpret_cast(princptim), 165 | reinterpret_cast(pixelcoordsim), 166 | volradius, 167 | reinterpret_cast(rayposim), 168 | reinterpret_cast(raydirim), 169 | reinterpret_cast(tminmaxim), 170 | reinterpret_cast(grad_viewposim), 171 | reinterpret_cast(grad_viewrotim), 172 | reinterpret_cast(grad_focalim), 173 | reinterpret_cast(grad_princptim)); 174 | } 175 | -------------------------------------------------------------------------------- /dva/mvp/models/bg/lap.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | from typing import Optional 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | import models.utils 13 | 14 | class ImageMod(nn.Module): 15 | def __init__(self, width, height, depth, buf=False): 16 | super(ImageMod, self).__init__() 17 | 18 | if buf: 19 | self.register_buffer("image", torch.randn(1, 3, depth, height, width) * 0.001, persistent=False) 20 | else: 21 | self.image = nn.Parameter(torch.randn(1, 3, depth, height, width) * 0.001) 22 | 23 | def forward(self, samplecoords): 24 | image = self.image.expand(samplecoords.size(0), -1, -1, -1, -1) 25 | return F.grid_sample(image, samplecoords, align_corners=True) 26 | 27 | class LapImage(nn.Module): 28 | def __init__(self, width, height, depth, levels, startlevel=0, buftop=False, align_corners=True): 29 | super(LapImage, self).__init__() 30 | 31 | self.width : int = int(width) 32 | self.height : int = int(height) 33 | self.levels = levels 34 | self.startlevel = startlevel 35 | self.align_corners = align_corners 36 | 37 | self.pyr = nn.ModuleList( 38 | [ImageMod(self.width // 2 ** i, self.height // 2 ** i, depth) 39 | for i in list(range(startlevel, levels - 1))[::-1]] + 40 | ([ImageMod(self.width, self.height, depth, buf=True)] if buftop else [])) 41 | self.pyr0 = ImageMod(self.width // 2 ** (levels - 1), self.height // 2 ** (levels - 1), depth) 42 | 43 | def forward(self, samplecoords): 44 | image = self.pyr0(samplecoords) 45 | 46 | for i, layer in enumerate(self.pyr): 47 | image = image + layer(samplecoords) 48 | 49 | return image 50 | 51 | class BGModel(nn.Module): 52 | def __init__(self, width, height, allcameras, bgdict=True, trainstart=0, 53 | levels=5, startlevel=0, buftop=False, align_corners=True): 54 | super(BGModel, self).__init__() 55 | 56 | self.allcameras = allcameras 57 | self.trainstart = trainstart 58 | 59 | if trainstart > -1: 60 | self.lap = LapImage(width, height, len(allcameras), levels=levels, 61 | startlevel=startlevel, buftop=buftop, 62 | align_corners=align_corners) 63 | 64 | def forward( 65 | self, 66 | bg : Optional[torch.Tensor]=None, 67 | camindex : Optional[torch.Tensor]=None, 68 | raypos : Optional[torch.Tensor]=None, 69 | rayposend : Optional[torch.Tensor]=None, 70 | raydir : Optional[torch.Tensor]=None, 71 | samplecoords : Optional[torch.Tensor]=None, 72 | trainiter : float=-1): 73 | if self.trainstart > -1 and trainiter >= self.trainstart and camindex is not None: 74 | assert samplecoords is not None 75 | assert camindex is not None 76 | 77 | samplecoordscam = torch.cat([ 78 | samplecoords[:, None, :, :, :], # [B, 1, H, W, 2] 79 | ((camindex[:, None, None, None, None] * 2.) / (len(self.allcameras) - 1.) - 1.) 80 | .expand(-1, -1, samplecoords.size(1), samplecoords.size(2), -1)], 81 | dim=-1) # [B, 1, H, W, 3] 82 | lap = self.lap(samplecoordscam)[:, :, 0, :, :] 83 | else: 84 | lap = None 85 | 86 | if lap is None: 87 | return None 88 | else: 89 | return F.softplus(lap) 90 | -------------------------------------------------------------------------------- /dva/mvp/models/bg/mlp.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import numpy as np 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | from models.utils import BufferDict, Conv2dELR 13 | 14 | class BGModel(nn.Module): 15 | def __init__(self, width, height, allcameras, bgdict=True, demod=True, trainstart=0): 16 | super(BGModel, self).__init__() 17 | 18 | self.allcameras = allcameras 19 | self.trainstart = trainstart 20 | 21 | if bgdict: 22 | self.bg = BufferDict({k: torch.ones(3, height, width) for k in allcameras}) 23 | else: 24 | self.bg = None 25 | 26 | if trainstart > -1: 27 | self.mlp1 = nn.Sequential( 28 | Conv2dELR(60+24, 256, 1, 1, 0, demod="demod" if demod else None), nn.LeakyReLU(0.2), 29 | Conv2dELR( 256, 256, 1, 1, 0, demod="demod" if demod else None), nn.LeakyReLU(0.2), 30 | Conv2dELR( 256, 256, 1, 1, 0, demod="demod" if demod else None), nn.LeakyReLU(0.2), 31 | Conv2dELR( 256, 256, 1, 1, 0, demod="demod" if demod else None), nn.LeakyReLU(0.2), 32 | Conv2dELR( 256, 256, 1, 1, 0, demod="demod" if demod else None)) 33 | 34 | self.mlp2 = nn.Sequential( 35 | Conv2dELR(60+24+256, 256, 1, 1, 0, demod="demod" if demod else None), nn.LeakyReLU(0.2), 36 | Conv2dELR( 256, 256, 1, 1, 0, demod="demod" if demod else None), nn.LeakyReLU(0.2), 37 | Conv2dELR( 256, 256, 1, 1, 0, demod="demod" if demod else None), nn.LeakyReLU(0.2), 38 | Conv2dELR( 256, 3, 1, 1, 0, demod=False)) 39 | 40 | def forward(self, bg=None, camindex=None, raypos=None, rayposend=None, 41 | raydir=None, samplecoords=None, trainiter=-1, **kwargs): 42 | if self.trainstart > -1 and trainiter >= self.trainstart:# and camindex is not None: 43 | # generate position encoding 44 | posenc = torch.cat([ 45 | torch.sin(2 ** i * np.pi * rayposend[:, :, :, :]) 46 | for i in range(10)] + [ 47 | torch.cos(2 ** i * np.pi * rayposend[:, :, :, :]) 48 | for i in range(10)], dim=-1).permute(0, 3, 1, 2) 49 | 50 | direnc = torch.cat([ 51 | torch.sin(2 ** i * np.pi * raydir[:, :, :, :]) 52 | for i in range(4)] + [ 53 | torch.cos(2 ** i * np.pi * raydir[:, :, :, :]) 54 | for i in range(4)], dim=-1).permute(0, 3, 1, 2) 55 | 56 | decout = torch.cat([posenc, direnc], dim=1) 57 | decout = self.mlp1(decout) 58 | 59 | decout = torch.cat([posenc, direnc, decout], dim=1) 60 | decout = self.mlp2(decout) 61 | else: 62 | decout = None 63 | 64 | if bg is None and self.bg is not None and camindex is not None: 65 | bg = torch.stack([self.bg[self.allcameras[camindex[i].item()]] for i in range(camindex.size(0))], dim=0) 66 | else: 67 | bg = None 68 | 69 | if bg is not None and samplecoords is not None: 70 | if samplecoords.size()[1:3] != bg.size()[2:4]: 71 | bg = F.grid_sample(bg, samplecoords, align_corners=False) 72 | 73 | if decout is not None: 74 | if bg is not None: 75 | return F.softplus(bg + decout) 76 | else: 77 | return F.softplus(decout) 78 | else: 79 | if bg is not None: 80 | return F.softplus(bg) 81 | else: 82 | return None 83 | -------------------------------------------------------------------------------- /dva/mvp/models/colorcals/colorcal.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | class Colorcal(nn.Module): 11 | def __init__(self, allcameras): 12 | super(Colorcal, self).__init__() 13 | 14 | self.allcameras = allcameras 15 | 16 | self.weight = nn.Parameter( 17 | torch.ones(len(self.allcameras), 3)) 18 | self.bias = nn.Parameter( 19 | torch.zeros(len(self.allcameras), 3)) 20 | 21 | def forward(self, image, camindex): 22 | # collect weights 23 | weight = self.weight[camindex] 24 | bias = self.bias[camindex] 25 | 26 | # reshape 27 | b = image.size(0) 28 | groups = b * 3 29 | image = image.view(1, -1, image.size(2), image.size(3)) 30 | weight = weight.view(-1, 1, 1, 1) 31 | bias = bias.view(-1) 32 | 33 | # conv 34 | result = F.conv2d(image, weight, bias, groups=groups) 35 | 36 | # unshape 37 | result = result.view(b, 3, image.size(2), image.size(3)) 38 | return result 39 | 40 | def parameters(self): 41 | for p in super(Colorcal, self).parameters(): 42 | if p.requires_grad: 43 | yield p 44 | -------------------------------------------------------------------------------- /dva/mvp/models/encoders/geotex.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | from typing import Optional, List 7 | 8 | import numpy as np 9 | 10 | import torch 11 | import torch.nn as nn 12 | 13 | from models.utils import LinearELR, Conv2dELR 14 | 15 | class Encoder(torch.nn.Module): 16 | def __init__(self, latentdim=256, hiq=True, texin=True, 17 | conv=Conv2dELR, lin=LinearELR, 18 | demod=True, texsize=1024, vertsize=21918): 19 | super(Encoder, self).__init__() 20 | 21 | self.latentdim = latentdim 22 | 23 | self.vertbranch = lin(vertsize, 256, norm="demod", act=nn.LeakyReLU(0.2)) 24 | if texin: 25 | cm = 2 if hiq else 1 26 | 27 | layers = [] 28 | chout = 128*cm 29 | chin = 128*cm 30 | nlayers = int(np.log2(texsize)) - 2 31 | for i in range(nlayers): 32 | if i == nlayers - 1: 33 | chin = 3 34 | layers.append( 35 | conv(chin, chout, 4, 2, 1, norm="demod" if demod else None, act=nn.LeakyReLU(0.2))) 36 | if chin == chout: 37 | chin = chout // 2 38 | else: 39 | chout = chin 40 | 41 | self.texbranch1 = nn.Sequential(*(layers[::-1])) 42 | 43 | self.texbranch2 = lin(cm*128*4*4, 256, norm="demod", act=nn.LeakyReLU(0.2)) 44 | self.mu = lin(512, self.latentdim) 45 | self.logstd = lin(512, self.latentdim) 46 | else: 47 | self.mu = lin(256, self.latentdim) 48 | self.logstd = lin(256, self.latentdim) 49 | 50 | def forward(self, verts, texture : Optional[torch.Tensor]=None, losslist : Optional[List[str]]=None): 51 | assert losslist is not None 52 | 53 | x = self.vertbranch(verts.view(verts.size(0), -1)) 54 | if texture is not None: 55 | texture = self.texbranch1(texture).reshape(verts.size(0), -1) 56 | texture = self.texbranch2(texture) 57 | x = torch.cat([x, texture], dim=1) 58 | 59 | mu, logstd = self.mu(x) * 0.1, self.logstd(x) * 0.01 60 | if self.training: 61 | z = mu + torch.exp(logstd) * torch.randn_like(logstd) 62 | else: 63 | z = mu 64 | 65 | losses = {} 66 | if "kldiv" in losslist: 67 | losses["kldiv"] = torch.mean(-0.5 - logstd + 0.5 * mu ** 2 + 0.5 * torch.exp(2 * logstd), dim=-1) 68 | 69 | return {"encoding": z}, losses 70 | -------------------------------------------------------------------------------- /dva/mvp/models/encoders/image.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | from typing import Optional, List 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | from models.utils import LinearELR, Conv2dELR, Downsample2d 12 | 13 | class Encoder(torch.nn.Module): 14 | def __init__(self, ninputs, size, nlayers=7, conv=Conv2dELR, lin=LinearELR): 15 | super(Encoder, self).__init__() 16 | 17 | self.ninputs = ninputs 18 | height, width = size 19 | self.nlayers = nlayers 20 | 21 | ypad = ((height + 2 ** nlayers - 1) // 2 ** nlayers) * 2 ** nlayers - height 22 | xpad = ((width + 2 ** nlayers - 1) // 2 ** nlayers) * 2 ** nlayers - width 23 | self.pad = nn.ZeroPad2d((xpad // 2, xpad - xpad // 2, ypad // 2, ypad - ypad // 2)) 24 | 25 | self.downwidth = ((width + 2 ** nlayers - 1) // 2 ** nlayers) 26 | self.downheight = ((height + 2 ** nlayers - 1) // 2 ** nlayers) 27 | 28 | # compile layers 29 | layers = [] 30 | inch, outch = 3, 64 31 | for i in range(nlayers): 32 | layers.append(conv(inch, outch, 4, 2, 1, norm="demod", act=nn.LeakyReLU(0.2))) 33 | 34 | if inch == outch: 35 | outch = inch * 2 36 | else: 37 | inch = outch 38 | if outch > 256: 39 | outch = 256 40 | 41 | self.down1 = nn.ModuleList([nn.Sequential(*layers) 42 | for i in range(self.ninputs)]) 43 | self.down2 = lin(256 * self.ninputs * self.downwidth * self.downheight, 512, norm="demod", act=nn.LeakyReLU(0.2)) 44 | self.mu = lin(512, 256) 45 | self.logstd = lin(512, 256) 46 | 47 | def forward(self, x, losslist : Optional[List[str]]=None): 48 | assert losslist is not None 49 | 50 | x = self.pad(x) 51 | x = [self.down1[i](x[:, i*3:(i+1)*3, :, :]).view(x.size(0), 256 * self.downwidth * self.downheight) 52 | for i in range(self.ninputs)] 53 | x = torch.cat(x, dim=1) 54 | x = self.down2(x) 55 | 56 | mu, logstd = self.mu(x) * 0.1, self.logstd(x) * 0.01 57 | if self.training: 58 | z = mu + torch.exp(logstd) * torch.randn(*logstd.size(), device=logstd.device) 59 | else: 60 | z = mu 61 | 62 | losses = {} 63 | if "kldiv" in losslist: 64 | losses["kldiv"] = torch.mean(-0.5 - logstd + 0.5 * mu ** 2 + 0.5 * torch.exp(2 * logstd), dim=-1) 65 | 66 | return {"encoding": z}, losses 67 | -------------------------------------------------------------------------------- /dva/mvp/models/raymarchers/mvpraymarcher.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """ Raymarcher for a mixture of volumetric primitives """ 7 | import os 8 | import itertools 9 | import time 10 | import numpy as np 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | 16 | from extensions.mvpraymarch.mvpraymarch import mvpraymarch 17 | 18 | class Raymarcher(nn.Module): 19 | def __init__(self, volradius): 20 | super(Raymarcher, self).__init__() 21 | 22 | self.volradius = volradius 23 | 24 | def forward(self, raypos, raydir, tminmax, decout, 25 | encoding=None, renderoptions={}, trainiter=-1, evaliter=-1, 26 | rayterm=None, 27 | **kwargs): 28 | 29 | # rescale world 30 | dt = renderoptions["dt"] / self.volradius 31 | 32 | rayrgba = mvpraymarch(raypos, raydir, dt, tminmax, 33 | (decout["primpos"], decout["primrot"], decout["primscale"]), 34 | template=decout["template"], 35 | warp=decout["warp"] if "warp" in decout else None, 36 | rayterm=rayterm, 37 | **{k:v for k, v in renderoptions.items() if k in mvpraymarch.__code__.co_varnames}) 38 | 39 | return rayrgba.permute(0, 3, 1, 2), {} 40 | -------------------------------------------------------------------------------- /dva/mvp/models/raymarchers/stepraymarcher.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """ Raymarching in pure pytorch """ 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | class Raymarcher(nn.Module): 12 | def __init__(self, volradius): 13 | super(Raymarcher, self).__init__() 14 | 15 | self.volradius = volradius 16 | 17 | def forward(self, raypos, raydir, tminmax, decout, 18 | encoding=None, renderoptions={}, **kwargs): 19 | 20 | dt = renderoptions["dt"] / self.volradius 21 | 22 | tminmax = torch.floor(tminmax / dt) * dt 23 | 24 | t = tminmax[..., 0] + 0. 25 | raypos = raypos + raydir * t[..., None] 26 | 27 | rayrgb = torch.zeros_like(raypos.permute(0, 3, 1, 2)) # NCHW 28 | if "multaccum" in renderoptions and renderoptions["multaccum"]: 29 | lograyalpha = torch.zeros_like(rayrgb[:, 0:1, :, :]) # NCHW 30 | else: 31 | rayalpha = torch.zeros_like(rayrgb[:, 0:1, :, :]) # NCHW 32 | 33 | # raymarch 34 | done = torch.zeros_like(t).bool() 35 | while not done.all(): 36 | valid = torch.prod((raypos > -1.) * (raypos < 1.), dim=-1).float() 37 | samplepos = F.grid_sample(decout["warp"][:, 0], raypos[:, None, :, :, :], align_corners=True).permute(0, 2, 3, 4, 1) 38 | val = F.grid_sample(decout["template"][:, 0], samplepos, align_corners=True)[:, :, 0, :, :] 39 | val = val * valid[:, None, :, :] 40 | sample_rgb, sample_alpha = val[:, :3, :, :], val[:, 3:, :, :] 41 | 42 | done = done | ((t + dt) >= tminmax[..., 1]) 43 | 44 | if "multaccum" in renderoptions and renderoptions["multaccum"]: 45 | contrib = torch.exp(-lograyalpha) * (1. - torch.exp(-sample_alpha * dt)) 46 | 47 | rayrgb = rayrgb + sample_rgb * contrib 48 | lograyalpha = lograyalpha + sample_alpha * dt 49 | else: 50 | contrib = ((rayalpha + sample_alpha * dt).clamp(max=1.) - rayalpha) 51 | 52 | rayrgb = rayrgb + sample_rgb * contrib 53 | rayalpha = rayalpha + contrib 54 | 55 | raypos = raypos + raydir * dt 56 | t = t + dt 57 | 58 | if "multaccum" in renderoptions and renderoptions["multaccum"]: 59 | rayalpha = 1. - torch.exp(-lograyalpha) 60 | 61 | rayrgba = torch.cat([rayrgb, rayalpha], dim=1) 62 | return rayrgba, {} 63 | -------------------------------------------------------------------------------- /dva/scheduler.py: -------------------------------------------------------------------------------- 1 | import math 2 | from torch.optim.lr_scheduler import LRScheduler 3 | 4 | class CosineWarmupScheduler(LRScheduler): 5 | def __init__(self, optimizer, warmup_iters: int, max_iters: int, initial_lr: float = 1e-10, last_iter: int = -1): 6 | self.warmup_iters = warmup_iters 7 | self.max_iters = max_iters 8 | self.initial_lr = initial_lr 9 | super().__init__(optimizer, last_iter) 10 | 11 | def get_lr(self): 12 | if self._step_count <= self.warmup_iters: 13 | return [ 14 | self.initial_lr + (base_lr - self.initial_lr) * self._step_count / self.warmup_iters 15 | for base_lr in self.base_lrs] 16 | else: 17 | cos_iter = self._step_count - self.warmup_iters 18 | cos_max_iter = self.max_iters - self.warmup_iters 19 | cos_theta = cos_iter / cos_max_iter * math.pi 20 | cos_lr = [base_lr * (1 + math.cos(cos_theta)) / 2 for base_lr in self.base_lrs] 21 | return cos_lr 22 | -------------------------------------------------------------------------------- /dva/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | 9 | import torch as th 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | import cv2 14 | 15 | 16 | def label_image( 17 | image, 18 | label, 19 | font_scale=1.0, 20 | font_thickness=1, 21 | label_origin=(10, 64), 22 | font_color=(255, 255, 255), 23 | font=cv2.FONT_HERSHEY_SIMPLEX, 24 | ): 25 | text_size, baseline = cv2.getTextSize(label, font, font_scale, font_thickness) 26 | image[ 27 | label_origin[1] - text_size[1] : label_origin[1] + baseline, 28 | label_origin[0] : label_origin[0] + text_size[0], 29 | ] = (255 - font_color[0], 255 - font_color[1], 255 - font_color[2]) 30 | cv2.putText( 31 | image, label, label_origin, font, font_scale, font_color, font_thickness 32 | ) 33 | return image 34 | 35 | 36 | def to_device(values, device=None, non_blocking=True): 37 | """Transfer a set of values to the device. 38 | Args: 39 | values: a nested dict/list/tuple of tensors 40 | device: argument to `to()` for the underlying vector 41 | NOTE: 42 | if the device is not specified, using `th.cuda()` 43 | """ 44 | if device is None: 45 | device = th.device("cuda") 46 | 47 | if isinstance(values, dict): 48 | return {k: to_device(v, device=device) for k, v in values.items()} 49 | elif isinstance(values, tuple): 50 | return tuple(to_device(v, device=device) for v in values) 51 | elif isinstance(values, list): 52 | return [to_device(v, device=device) for v in values] 53 | elif isinstance(values, th.Tensor): 54 | return values.to(device, non_blocking=non_blocking) 55 | elif isinstance(values, nn.Module): 56 | return values.to(device) 57 | elif isinstance(values, np.ndarray): 58 | return th.from_numpy(values).to(device) 59 | else: 60 | return values 61 | -------------------------------------------------------------------------------- /dva/vgg.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch as th 8 | import torch.nn as nn 9 | from torchvision.models import vgg19 10 | import torch.nn.functional as F 11 | import logging 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | class Vgg19(nn.Module): 17 | def __init__(self, requires_grad=False): 18 | super(Vgg19, self).__init__() 19 | vgg19_network = vgg19(pretrained=True) 20 | # vgg19_network.load_state_dict(state_dict) 21 | vgg_pretrained_features = vgg19_network.features 22 | self.slice1 = nn.Sequential() 23 | self.slice2 = nn.Sequential() 24 | self.slice3 = nn.Sequential() 25 | self.slice4 = nn.Sequential() 26 | self.slice5 = nn.Sequential() 27 | for x in range(2): 28 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 29 | for x in range(2, 7): 30 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 31 | for x in range(7, 12): 32 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 33 | for x in range(12, 21): 34 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 35 | for x in range(21, 30): 36 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 37 | if not requires_grad: 38 | for param in self.parameters(): 39 | param.requires_grad = False 40 | 41 | def forward(self, X): 42 | h_relu1 = self.slice1(X) 43 | h_relu2 = self.slice2(h_relu1) 44 | h_relu3 = self.slice3(h_relu2) 45 | h_relu4 = self.slice4(h_relu3) 46 | h_relu5 = self.slice5(h_relu4) 47 | out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5] 48 | return out 49 | 50 | 51 | class VGGLossMasked(nn.Module): 52 | def __init__(self, weights=None): 53 | super().__init__() 54 | self.vgg = Vgg19() 55 | if weights is None: 56 | # self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0] 57 | self.weights = [20.0, 5.0, 0.9, 0.5, 0.5] 58 | else: 59 | self.weights = weights 60 | 61 | def normalize(self, batch): 62 | mean = batch.new_tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1) 63 | std = batch.new_tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1) 64 | return ((batch / 255.0).clamp(0.0, 1.0) - mean) / std 65 | 66 | def forward(self, x_rgb, y_rgb, mask): 67 | 68 | x_norm = self.normalize(x_rgb) 69 | y_norm = self.normalize(y_rgb) 70 | 71 | x_vgg = self.vgg(x_norm) 72 | y_vgg = self.vgg(y_norm) 73 | loss = 0 74 | for i in range(len(x_vgg)): 75 | if isinstance(mask, th.Tensor): 76 | m = F.interpolate( 77 | mask, size=(x_vgg[i].shape[-2], x_vgg[i].shape[-1]), mode="bilinear" 78 | ).detach() 79 | else: 80 | m = mask 81 | 82 | vx = x_vgg[i] * m 83 | vy = y_vgg[i] * m 84 | 85 | loss += self.weights[i] * (vx - vy).abs().mean() 86 | 87 | # logger.info( 88 | # f"loss for {i}, {loss.item()} vx={vx.shape} vy={vy.shape} {vx.max()} {vy.max()}" 89 | # ) 90 | return loss 91 | -------------------------------------------------------------------------------- /install.sh: -------------------------------------------------------------------------------- 1 | CURRENT=$(pwd) 2 | cd dva/mvp/extensions/mvpraymarch 3 | make -j4 4 | cd ../utils 5 | make -j4 6 | cd ${CURRENT} 7 | pip install ./simple-knn 8 | git clone https://github.com/ashawkey/cubvh 9 | cd cubvh 10 | pip install . 11 | cd ${CURRENT} -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/3DTopia/3DTopia-XL/af602b003e9d137e37e1d883c18b50ed65c06f26/models/__init__.py -------------------------------------------------------------------------------- /models/attention.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | # References: 7 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 8 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py 9 | 10 | import os 11 | import warnings 12 | 13 | import torch 14 | from torch import nn 15 | from torch.utils.checkpoint import checkpoint 16 | 17 | from xformers.ops import memory_efficient_attention, unbind 18 | 19 | 20 | class MemEffAttention(nn.Module): 21 | def __init__( 22 | self, 23 | dim: int, 24 | num_heads: int = 8, 25 | qkv_bias: bool = False, 26 | proj_bias: bool = True, 27 | attn_drop: float = 0.0, 28 | proj_drop: float = 0.0, 29 | gradient_checkpointing: bool = False, 30 | ) -> None: 31 | super().__init__() 32 | self.num_heads = num_heads 33 | head_dim = dim // num_heads 34 | self.scale = head_dim**-0.5 35 | self.gradient_checkpointing = gradient_checkpointing 36 | 37 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 38 | self.attn_drop = nn.Dropout(attn_drop) 39 | self.proj = nn.Linear(dim, dim, bias=proj_bias) 40 | self.proj_drop = nn.Dropout(proj_drop) 41 | 42 | def forward(self, x: torch.Tensor, attn_bias=None) -> torch.Tensor: 43 | if self.training and self.gradient_checkpointing: 44 | return checkpoint(self._forward, x, attn_bias, use_reentrant=False) 45 | else: 46 | return self._forward(x, attn_bias) 47 | 48 | def _forward(self, x: torch.Tensor, attn_bias=None) -> torch.Tensor: 49 | B, N, C = x.shape 50 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) 51 | 52 | q, k, v = unbind(qkv, 2) 53 | 54 | x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) 55 | x = x.reshape([B, N, C]) 56 | 57 | x = self.proj(x) 58 | x = self.proj_drop(x) 59 | return x 60 | 61 | 62 | class MemEffCrossAttention(nn.Module): 63 | def __init__( 64 | self, 65 | dim: int, 66 | dim_q: int, 67 | dim_k: int, 68 | dim_v: int, 69 | num_heads: int = 8, 70 | qkv_bias: bool = False, 71 | proj_bias: bool = True, 72 | attn_drop: float = 0.0, 73 | proj_drop: float = 0.0, 74 | gradient_checkpointing: bool = False, 75 | ) -> None: 76 | super().__init__() 77 | self.dim = dim 78 | self.num_heads = num_heads 79 | head_dim = dim // num_heads 80 | self.scale = head_dim**-0.5 81 | self.gradient_checkpointing = gradient_checkpointing 82 | 83 | self.to_q = nn.Linear(dim_q, dim, bias=qkv_bias) 84 | self.to_k = nn.Linear(dim_k, dim, bias=qkv_bias) 85 | self.to_v = nn.Linear(dim_v, dim, bias=qkv_bias) 86 | self.attn_drop = nn.Dropout(attn_drop) 87 | self.proj = nn.Linear(dim, dim, bias=proj_bias) 88 | self.proj_drop = nn.Dropout(proj_drop) 89 | 90 | def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attn_bias=None) -> torch.Tensor: 91 | if self.training and self.gradient_checkpointing: 92 | return checkpoint(self._forward, q, k, v, attn_bias, use_reentrant=False) 93 | else: 94 | return self._forward(q, k, v, attn_bias) 95 | 96 | def _forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attn_bias=None) -> torch.Tensor: 97 | # q: [B, N, Cq] 98 | # k: [B, M, Ck] 99 | # v: [B, M, Cv] 100 | # return: [B, N, C] 101 | 102 | B, N, _ = q.shape 103 | M = k.shape[1] 104 | 105 | q = self.scale * self.to_q(q).reshape(B, N, self.num_heads, self.dim // self.num_heads) # [B, N, nh, C/nh] 106 | k = self.to_k(k).reshape(B, M, self.num_heads, self.dim // self.num_heads) # [B, M, nh, C/nh] 107 | v = self.to_v(v).reshape(B, M, self.num_heads, self.dim // self.num_heads) # [B, M, nh, C/nh] 108 | 109 | x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) 110 | x = x.reshape(B, N, -1) 111 | 112 | x = self.proj(x) 113 | x = self.proj_drop(x) 114 | return x -------------------------------------------------------------------------------- /models/conditioner/dinov2/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/3DTopia/3DTopia-XL/af602b003e9d137e37e1d883c18b50ed65c06f26/models/conditioner/dinov2/__init__.py -------------------------------------------------------------------------------- /models/conditioner/dinov2/hub/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | -------------------------------------------------------------------------------- /models/conditioner/dinov2/hub/backbones.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | from enum import Enum 7 | from typing import Union 8 | 9 | import torch 10 | 11 | from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name 12 | 13 | 14 | class Weights(Enum): 15 | LVD142M = "LVD142M" 16 | 17 | 18 | def _make_dinov2_model( 19 | *, 20 | arch_name: str = "vit_large", 21 | img_size: int = 518, 22 | patch_size: int = 14, 23 | init_values: float = 1.0, 24 | ffn_layer: str = "mlp", 25 | block_chunks: int = 0, 26 | num_register_tokens: int = 0, 27 | interpolate_antialias: bool = False, 28 | interpolate_offset: float = 0.1, 29 | pretrained: bool = True, 30 | weights: Union[Weights, str] = Weights.LVD142M, 31 | **kwargs, 32 | ): 33 | from ..models import vision_transformer as vits 34 | 35 | if isinstance(weights, str): 36 | try: 37 | weights = Weights[weights] 38 | except KeyError: 39 | raise AssertionError(f"Unsupported weights: {weights}") 40 | 41 | model_base_name = _make_dinov2_model_name(arch_name, patch_size) 42 | vit_kwargs = dict( 43 | img_size=img_size, 44 | patch_size=patch_size, 45 | init_values=init_values, 46 | ffn_layer=ffn_layer, 47 | block_chunks=block_chunks, 48 | num_register_tokens=num_register_tokens, 49 | interpolate_antialias=interpolate_antialias, 50 | interpolate_offset=interpolate_offset, 51 | ) 52 | vit_kwargs.update(**kwargs) 53 | model = vits.__dict__[arch_name](**vit_kwargs) 54 | 55 | if pretrained: 56 | model_full_name = _make_dinov2_model_name(arch_name, patch_size, num_register_tokens) 57 | url = _DINOV2_BASE_URL + f"/{model_base_name}/{model_full_name}_pretrain.pth" 58 | state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu") 59 | 60 | state_dict = {k: v for k, v in state_dict.items() if 'mask_token' not in k} # DDP concern 61 | if vit_kwargs.get("modulation_dim") is not None: 62 | state_dict = { 63 | k.replace('norm1', 'norm1.norm').replace('norm2', 'norm2.norm'): v 64 | for k, v in state_dict.items() 65 | } 66 | model.load_state_dict(state_dict, strict=False) 67 | else: 68 | model.load_state_dict(state_dict, strict=True) 69 | # ******************************************************** 70 | 71 | return model 72 | 73 | 74 | def dinov2_vits14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): 75 | """ 76 | DINOv2 ViT-S/14 model (optionally) pretrained on the LVD-142M dataset. 77 | """ 78 | return _make_dinov2_model(arch_name="vit_small", pretrained=pretrained, weights=weights, **kwargs) 79 | 80 | 81 | def dinov2_vitb14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): 82 | """ 83 | DINOv2 ViT-B/14 model (optionally) pretrained on the LVD-142M dataset. 84 | """ 85 | return _make_dinov2_model(arch_name="vit_base", pretrained=pretrained, weights=weights, **kwargs) 86 | 87 | 88 | def dinov2_vitl14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): 89 | """ 90 | DINOv2 ViT-L/14 model (optionally) pretrained on the LVD-142M dataset. 91 | """ 92 | return _make_dinov2_model(arch_name="vit_large", pretrained=pretrained, weights=weights, **kwargs) 93 | 94 | 95 | def dinov2_vitg14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): 96 | """ 97 | DINOv2 ViT-g/14 model (optionally) pretrained on the LVD-142M dataset. 98 | """ 99 | return _make_dinov2_model( 100 | arch_name="vit_giant2", 101 | ffn_layer="swiglufused", 102 | weights=weights, 103 | pretrained=pretrained, 104 | **kwargs, 105 | ) 106 | 107 | 108 | def dinov2_vits14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): 109 | """ 110 | DINOv2 ViT-S/14 model with registers (optionally) pretrained on the LVD-142M dataset. 111 | """ 112 | return _make_dinov2_model( 113 | arch_name="vit_small", 114 | pretrained=pretrained, 115 | weights=weights, 116 | num_register_tokens=4, 117 | interpolate_antialias=True, 118 | interpolate_offset=0.0, 119 | **kwargs, 120 | ) 121 | 122 | 123 | def dinov2_vitb14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): 124 | """ 125 | DINOv2 ViT-B/14 model with registers (optionally) pretrained on the LVD-142M dataset. 126 | """ 127 | return _make_dinov2_model( 128 | arch_name="vit_base", 129 | pretrained=pretrained, 130 | weights=weights, 131 | num_register_tokens=4, 132 | interpolate_antialias=True, 133 | interpolate_offset=0.0, 134 | **kwargs, 135 | ) 136 | 137 | 138 | def dinov2_vitl14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): 139 | """ 140 | DINOv2 ViT-L/14 model with registers (optionally) pretrained on the LVD-142M dataset. 141 | """ 142 | return _make_dinov2_model( 143 | arch_name="vit_large", 144 | pretrained=pretrained, 145 | weights=weights, 146 | num_register_tokens=4, 147 | interpolate_antialias=True, 148 | interpolate_offset=0.0, 149 | **kwargs, 150 | ) 151 | 152 | 153 | def dinov2_vitg14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): 154 | """ 155 | DINOv2 ViT-g/14 model with registers (optionally) pretrained on the LVD-142M dataset. 156 | """ 157 | return _make_dinov2_model( 158 | arch_name="vit_giant2", 159 | ffn_layer="swiglufused", 160 | weights=weights, 161 | pretrained=pretrained, 162 | num_register_tokens=4, 163 | interpolate_antialias=True, 164 | interpolate_offset=0.0, 165 | **kwargs, 166 | ) 167 | -------------------------------------------------------------------------------- /models/conditioner/dinov2/hub/depth/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | from .decode_heads import BNHead, DPTHead 7 | from .encoder_decoder import DepthEncoderDecoder 8 | -------------------------------------------------------------------------------- /models/conditioner/dinov2/hub/depth/ops.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | import warnings 7 | 8 | import torch.nn.functional as F 9 | 10 | 11 | def resize(input, size=None, scale_factor=None, mode="nearest", align_corners=None, warning=False): 12 | if warning: 13 | if size is not None and align_corners: 14 | input_h, input_w = tuple(int(x) for x in input.shape[2:]) 15 | output_h, output_w = tuple(int(x) for x in size) 16 | if output_h > input_h or output_w > output_h: 17 | if ( 18 | (output_h > 1 and output_w > 1 and input_h > 1 and input_w > 1) 19 | and (output_h - 1) % (input_h - 1) 20 | and (output_w - 1) % (input_w - 1) 21 | ): 22 | warnings.warn( 23 | f"When align_corners={align_corners}, " 24 | "the output would more aligned if " 25 | f"input size {(input_h, input_w)} is `x+1` and " 26 | f"out size {(output_h, output_w)} is `nx+1`" 27 | ) 28 | return F.interpolate(input, size, scale_factor, mode, align_corners) 29 | -------------------------------------------------------------------------------- /models/conditioner/dinov2/hub/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | import itertools 7 | import math 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | 14 | _DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2" 15 | 16 | 17 | def _make_dinov2_model_name(arch_name: str, patch_size: int, num_register_tokens: int = 0) -> str: 18 | compact_arch_name = arch_name.replace("_", "")[:4] 19 | registers_suffix = f"_reg{num_register_tokens}" if num_register_tokens else "" 20 | return f"dinov2_{compact_arch_name}{patch_size}{registers_suffix}" 21 | 22 | 23 | class CenterPadding(nn.Module): 24 | def __init__(self, multiple): 25 | super().__init__() 26 | self.multiple = multiple 27 | 28 | def _get_pad(self, size): 29 | new_size = math.ceil(size / self.multiple) * self.multiple 30 | pad_size = new_size - size 31 | pad_size_left = pad_size // 2 32 | pad_size_right = pad_size - pad_size_left 33 | return pad_size_left, pad_size_right 34 | 35 | @torch.inference_mode() 36 | def forward(self, x): 37 | pads = list(itertools.chain.from_iterable(self._get_pad(m) for m in x.shape[:1:-1])) 38 | output = F.pad(x, pads) 39 | return output 40 | -------------------------------------------------------------------------------- /models/conditioner/dinov2/layers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | from .dino_head import DINOHead 7 | from .mlp import Mlp 8 | from .patch_embed import PatchEmbed 9 | from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused 10 | 11 | from .block import Block, BlockWithModulation 12 | from .attention import MemEffAttention 13 | -------------------------------------------------------------------------------- /models/conditioner/dinov2/layers/attention.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | # References: 7 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 8 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py 9 | 10 | import logging 11 | import os 12 | import warnings 13 | 14 | from torch import Tensor 15 | from torch import nn 16 | 17 | 18 | logger = logging.getLogger("dinov2") 19 | 20 | 21 | XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None 22 | try: 23 | if XFORMERS_ENABLED: 24 | from xformers.ops import memory_efficient_attention, unbind 25 | 26 | XFORMERS_AVAILABLE = True 27 | warnings.warn("xFormers is available (Attention)") 28 | else: 29 | warnings.warn("xFormers is disabled (Attention)") 30 | raise ImportError 31 | except ImportError: 32 | XFORMERS_AVAILABLE = False 33 | warnings.warn("xFormers is not available (Attention)") 34 | 35 | 36 | class Attention(nn.Module): 37 | def __init__( 38 | self, 39 | dim: int, 40 | num_heads: int = 8, 41 | qkv_bias: bool = False, 42 | proj_bias: bool = True, 43 | attn_drop: float = 0.0, 44 | proj_drop: float = 0.0, 45 | ) -> None: 46 | super().__init__() 47 | self.num_heads = num_heads 48 | head_dim = dim // num_heads 49 | self.scale = head_dim**-0.5 50 | 51 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 52 | self.attn_drop = nn.Dropout(attn_drop) 53 | self.proj = nn.Linear(dim, dim, bias=proj_bias) 54 | self.proj_drop = nn.Dropout(proj_drop) 55 | 56 | def forward(self, x: Tensor) -> Tensor: 57 | B, N, C = x.shape 58 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 59 | 60 | q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] 61 | attn = q @ k.transpose(-2, -1) 62 | 63 | attn = attn.softmax(dim=-1) 64 | attn = self.attn_drop(attn) 65 | 66 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 67 | x = self.proj(x) 68 | x = self.proj_drop(x) 69 | return x 70 | 71 | 72 | class MemEffAttention(Attention): 73 | def forward(self, x: Tensor, attn_bias=None) -> Tensor: 74 | if not XFORMERS_AVAILABLE: 75 | if attn_bias is not None: 76 | raise AssertionError("xFormers is required for using nested tensors") 77 | return super().forward(x) 78 | 79 | B, N, C = x.shape 80 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) 81 | 82 | q, k, v = unbind(qkv, 2) 83 | 84 | x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) 85 | x = x.reshape([B, N, C]) 86 | 87 | x = self.proj(x) 88 | x = self.proj_drop(x) 89 | return x 90 | -------------------------------------------------------------------------------- /models/conditioner/dinov2/layers/dino_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | import torch.nn as nn 8 | from torch.nn.init import trunc_normal_ 9 | from torch.nn.utils import weight_norm 10 | 11 | 12 | class DINOHead(nn.Module): 13 | def __init__( 14 | self, 15 | in_dim, 16 | out_dim, 17 | use_bn=False, 18 | nlayers=3, 19 | hidden_dim=2048, 20 | bottleneck_dim=256, 21 | mlp_bias=True, 22 | ): 23 | super().__init__() 24 | nlayers = max(nlayers, 1) 25 | self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias) 26 | self.apply(self._init_weights) 27 | self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False)) 28 | self.last_layer.weight_g.data.fill_(1) 29 | 30 | def _init_weights(self, m): 31 | if isinstance(m, nn.Linear): 32 | trunc_normal_(m.weight, std=0.02) 33 | if isinstance(m, nn.Linear) and m.bias is not None: 34 | nn.init.constant_(m.bias, 0) 35 | 36 | def forward(self, x): 37 | x = self.mlp(x) 38 | eps = 1e-6 if x.dtype == torch.float16 else 1e-12 39 | x = nn.functional.normalize(x, dim=-1, p=2, eps=eps) 40 | x = self.last_layer(x) 41 | return x 42 | 43 | 44 | def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True): 45 | if nlayers == 1: 46 | return nn.Linear(in_dim, bottleneck_dim, bias=bias) 47 | else: 48 | layers = [nn.Linear(in_dim, hidden_dim, bias=bias)] 49 | if use_bn: 50 | layers.append(nn.BatchNorm1d(hidden_dim)) 51 | layers.append(nn.GELU()) 52 | for _ in range(nlayers - 2): 53 | layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias)) 54 | if use_bn: 55 | layers.append(nn.BatchNorm1d(hidden_dim)) 56 | layers.append(nn.GELU()) 57 | layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias)) 58 | return nn.Sequential(*layers) 59 | -------------------------------------------------------------------------------- /models/conditioner/dinov2/layers/drop_path.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | # References: 7 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 8 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py 9 | 10 | 11 | from torch import nn 12 | 13 | 14 | def drop_path(x, drop_prob: float = 0.0, training: bool = False): 15 | if drop_prob == 0.0 or not training: 16 | return x 17 | keep_prob = 1 - drop_prob 18 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 19 | random_tensor = x.new_empty(shape).bernoulli_(keep_prob) 20 | if keep_prob > 0.0: 21 | random_tensor.div_(keep_prob) 22 | output = x * random_tensor 23 | return output 24 | 25 | 26 | class DropPath(nn.Module): 27 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" 28 | 29 | def __init__(self, drop_prob=None): 30 | super(DropPath, self).__init__() 31 | self.drop_prob = drop_prob 32 | 33 | def forward(self, x): 34 | return drop_path(x, self.drop_prob, self.training) 35 | -------------------------------------------------------------------------------- /models/conditioner/dinov2/layers/layer_scale.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | # Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110 7 | 8 | from typing import Union 9 | 10 | import torch 11 | from torch import Tensor 12 | from torch import nn 13 | 14 | 15 | class LayerScale(nn.Module): 16 | def __init__( 17 | self, 18 | dim: int, 19 | init_values: Union[float, Tensor] = 1e-5, 20 | inplace: bool = False, 21 | ) -> None: 22 | super().__init__() 23 | self.inplace = inplace 24 | self.gamma = nn.Parameter(init_values * torch.ones(dim)) 25 | 26 | def forward(self, x: Tensor) -> Tensor: 27 | return x.mul_(self.gamma) if self.inplace else x * self.gamma 28 | -------------------------------------------------------------------------------- /models/conditioner/dinov2/layers/mlp.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | # References: 7 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 8 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py 9 | 10 | 11 | from typing import Callable, Optional 12 | 13 | from torch import Tensor, nn 14 | 15 | 16 | class Mlp(nn.Module): 17 | def __init__( 18 | self, 19 | in_features: int, 20 | hidden_features: Optional[int] = None, 21 | out_features: Optional[int] = None, 22 | act_layer: Callable[..., nn.Module] = nn.GELU, 23 | drop: float = 0.0, 24 | bias: bool = True, 25 | ) -> None: 26 | super().__init__() 27 | out_features = out_features or in_features 28 | hidden_features = hidden_features or in_features 29 | self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) 30 | self.act = act_layer() 31 | self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) 32 | self.drop = nn.Dropout(drop) 33 | 34 | def forward(self, x: Tensor) -> Tensor: 35 | x = self.fc1(x) 36 | x = self.act(x) 37 | x = self.drop(x) 38 | x = self.fc2(x) 39 | x = self.drop(x) 40 | return x 41 | -------------------------------------------------------------------------------- /models/conditioner/dinov2/layers/patch_embed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | # References: 7 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 8 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py 9 | 10 | from typing import Callable, Optional, Tuple, Union 11 | 12 | from torch import Tensor 13 | import torch.nn as nn 14 | 15 | 16 | def make_2tuple(x): 17 | if isinstance(x, tuple): 18 | assert len(x) == 2 19 | return x 20 | 21 | assert isinstance(x, int) 22 | return (x, x) 23 | 24 | 25 | class PatchEmbed(nn.Module): 26 | """ 27 | 2D image to patch embedding: (B,C,H,W) -> (B,N,D) 28 | 29 | Args: 30 | img_size: Image size. 31 | patch_size: Patch token size. 32 | in_chans: Number of input image channels. 33 | embed_dim: Number of linear projection output channels. 34 | norm_layer: Normalization layer. 35 | """ 36 | 37 | def __init__( 38 | self, 39 | img_size: Union[int, Tuple[int, int]] = 224, 40 | patch_size: Union[int, Tuple[int, int]] = 16, 41 | in_chans: int = 3, 42 | embed_dim: int = 768, 43 | norm_layer: Optional[Callable] = None, 44 | flatten_embedding: bool = True, 45 | ) -> None: 46 | super().__init__() 47 | 48 | image_HW = make_2tuple(img_size) 49 | patch_HW = make_2tuple(patch_size) 50 | patch_grid_size = ( 51 | image_HW[0] // patch_HW[0], 52 | image_HW[1] // patch_HW[1], 53 | ) 54 | 55 | self.img_size = image_HW 56 | self.patch_size = patch_HW 57 | self.patches_resolution = patch_grid_size 58 | self.num_patches = patch_grid_size[0] * patch_grid_size[1] 59 | 60 | self.in_chans = in_chans 61 | self.embed_dim = embed_dim 62 | 63 | self.flatten_embedding = flatten_embedding 64 | 65 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW) 66 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() 67 | 68 | def forward(self, x: Tensor) -> Tensor: 69 | _, _, H, W = x.shape 70 | patch_H, patch_W = self.patch_size 71 | 72 | assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}" 73 | assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}" 74 | 75 | x = self.proj(x) # B C H W 76 | H, W = x.size(2), x.size(3) 77 | x = x.flatten(2).transpose(1, 2) # B HW C 78 | x = self.norm(x) 79 | if not self.flatten_embedding: 80 | x = x.reshape(-1, H, W, self.embed_dim) # B H W C 81 | return x 82 | 83 | def flops(self) -> float: 84 | Ho, Wo = self.patches_resolution 85 | flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) 86 | if self.norm is not None: 87 | flops += Ho * Wo * self.embed_dim 88 | return flops 89 | -------------------------------------------------------------------------------- /models/conditioner/dinov2/layers/swiglu_ffn.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | import os 7 | from typing import Callable, Optional 8 | import warnings 9 | 10 | from torch import Tensor, nn 11 | import torch.nn.functional as F 12 | 13 | 14 | class SwiGLUFFN(nn.Module): 15 | def __init__( 16 | self, 17 | in_features: int, 18 | hidden_features: Optional[int] = None, 19 | out_features: Optional[int] = None, 20 | act_layer: Callable[..., nn.Module] = None, 21 | drop: float = 0.0, 22 | bias: bool = True, 23 | ) -> None: 24 | super().__init__() 25 | out_features = out_features or in_features 26 | hidden_features = hidden_features or in_features 27 | self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) 28 | self.w3 = nn.Linear(hidden_features, out_features, bias=bias) 29 | 30 | def forward(self, x: Tensor) -> Tensor: 31 | x12 = self.w12(x) 32 | x1, x2 = x12.chunk(2, dim=-1) 33 | hidden = F.silu(x1) * x2 34 | return self.w3(hidden) 35 | 36 | 37 | XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None 38 | try: 39 | if XFORMERS_ENABLED: 40 | from xformers.ops import SwiGLU 41 | 42 | XFORMERS_AVAILABLE = True 43 | warnings.warn("xFormers is available (SwiGLU)") 44 | else: 45 | warnings.warn("xFormers is disabled (SwiGLU)") 46 | raise ImportError 47 | except ImportError: 48 | SwiGLU = SwiGLUFFN 49 | XFORMERS_AVAILABLE = False 50 | 51 | warnings.warn("xFormers is not available (SwiGLU)") 52 | 53 | 54 | class SwiGLUFFNFused(SwiGLU): 55 | def __init__( 56 | self, 57 | in_features: int, 58 | hidden_features: Optional[int] = None, 59 | out_features: Optional[int] = None, 60 | act_layer: Callable[..., nn.Module] = None, 61 | drop: float = 0.0, 62 | bias: bool = True, 63 | ) -> None: 64 | out_features = out_features or in_features 65 | hidden_features = hidden_features or in_features 66 | hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 67 | super().__init__( 68 | in_features=in_features, 69 | hidden_features=hidden_features, 70 | out_features=out_features, 71 | bias=bias, 72 | ) 73 | -------------------------------------------------------------------------------- /models/conditioner/dinov2/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | import logging 7 | 8 | from . import vision_transformer as vits 9 | 10 | 11 | logger = logging.getLogger("dinov2") 12 | 13 | 14 | def build_model(args, only_teacher=False, img_size=224): 15 | args.arch = args.arch.removesuffix("_memeff") 16 | if "vit" in args.arch: 17 | vit_kwargs = dict( 18 | img_size=img_size, 19 | patch_size=args.patch_size, 20 | init_values=args.layerscale, 21 | ffn_layer=args.ffn_layer, 22 | block_chunks=args.block_chunks, 23 | qkv_bias=args.qkv_bias, 24 | proj_bias=args.proj_bias, 25 | ffn_bias=args.ffn_bias, 26 | num_register_tokens=args.num_register_tokens, 27 | interpolate_offset=args.interpolate_offset, 28 | interpolate_antialias=args.interpolate_antialias, 29 | ) 30 | teacher = vits.__dict__[args.arch](**vit_kwargs) 31 | if only_teacher: 32 | return teacher, teacher.embed_dim 33 | student = vits.__dict__[args.arch]( 34 | **vit_kwargs, 35 | drop_path_rate=args.drop_path_rate, 36 | drop_path_uniform=args.drop_path_uniform, 37 | ) 38 | embed_dim = student.embed_dim 39 | return student, teacher, embed_dim 40 | 41 | 42 | def build_model_from_cfg(cfg, only_teacher=False): 43 | return build_model(cfg.student, only_teacher=only_teacher, img_size=cfg.crops.global_crops_size) 44 | -------------------------------------------------------------------------------- /models/conditioner/image_dinov2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torchvision.transforms import Compose, Resize, InterpolationMode, Normalize 4 | 5 | import logging 6 | logger = logging.getLogger(__name__) 7 | 8 | 9 | 10 | 11 | class Dinov2Wrapper(nn.Module): 12 | """ 13 | Dino v2 wrapper using original implementation, hacked with modulation. 14 | """ 15 | def __init__(self, model_name: str, modulation_dim: int = None, freeze: bool = True): 16 | super().__init__() 17 | self.modulation_dim = modulation_dim 18 | self.model = self._build_dinov2(model_name, modulation_dim=modulation_dim) 19 | self.preprocess = Compose([ 20 | Resize(self.model.patch_embed.img_size[0], interpolation=InterpolationMode.BICUBIC), 21 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 22 | ]) 23 | if freeze: 24 | if modulation_dim is not None: 25 | raise ValueError("Modulated Dinov2 requires training, freezing is not allowed.") 26 | self._freeze() 27 | 28 | def _freeze(self): 29 | logger.warning(f"======== Freezing Dinov2Wrapper ========") 30 | self.model.eval() 31 | for name, param in self.model.named_parameters(): 32 | param.requires_grad = False 33 | 34 | @staticmethod 35 | def _build_dinov2(model_name: str, modulation_dim: int = None, pretrained: bool = True): 36 | from importlib import import_module 37 | dinov2_hub = import_module(".dinov2.hub.backbones", package=__package__) 38 | model_fn = getattr(dinov2_hub, model_name) 39 | logger.info(f"Modulation dim for Dinov2 is {modulation_dim}.") 40 | model = model_fn(modulation_dim=modulation_dim, pretrained=pretrained) 41 | return model 42 | 43 | # @torch.compile 44 | def forward(self, image: torch.Tensor, mod: torch.Tensor = None): 45 | # image: [N, H, W, C] -- need to be permuted!!! 46 | # mod: [N, D] or None 47 | assert image.shape[-1] == 3 48 | image = image.permute(0, 3, 1, 2) / 255. 49 | image = self.preprocess(image) 50 | if self.modulation_dim is None: 51 | assert mod is None, "Unexpected modulation input in dinov2 forward." 52 | outs = self.model(image, is_training=True) 53 | else: 54 | assert mod is not None, "Modulation input is required in modulated dinov2 forward." 55 | outs = self.model(image, mod=mod, is_training=True) 56 | ret = torch.cat([ 57 | outs["x_norm_clstoken"].unsqueeze(dim=1), 58 | outs["x_norm_patchtokens"], 59 | ], dim=1) 60 | # ret in [B, 1370, 384] 61 | return ret 62 | -------------------------------------------------------------------------------- /models/conditioner/text.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import open_clip 6 | from dva.io import load_from_config 7 | 8 | class TextConditioner(nn.Module): 9 | def __init__( 10 | self, 11 | encoder_config, 12 | ): 13 | super().__init__() 14 | self.encoder = load_from_config(encoder_config) 15 | 16 | @torch.no_grad() 17 | def forward(self, batch, rm, amp=False, precision_dtype=torch.float32): 18 | assert 'caption_token' in batch, "No tokenized caption in current batch for text conditions" 19 | caption_token = batch['caption_token'] 20 | with torch.autocast(device_type='cuda', dtype=precision_dtype, enabled=amp): 21 | results = self.encoder(caption_token) 22 | return results 23 | 24 | class CLIPTextEncoder(nn.Module): 25 | def __init__( 26 | self, 27 | pretrained_path: str, 28 | model_spec: str = 'ViT-L-14', 29 | ): 30 | super().__init__() 31 | self.model, _, _ = open_clip.create_model_and_transforms(model_spec, pretrained=pretrained_path) 32 | self.model.eval() 33 | 34 | @torch.no_grad() 35 | def forward(self, text): 36 | text_features = self.model.encode_text(text) 37 | text_features /= text_features.norm(dim=-1, keepdim=True) 38 | return text_features[:, None, :] 39 | -------------------------------------------------------------------------------- /models/diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenAI's diffusion repos 2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 5 | 6 | from . import gaussian_diffusion as gd 7 | from .respace import SpacedDiffusion, space_timesteps 8 | 9 | 10 | def create_diffusion( 11 | timestep_respacing, 12 | noise_schedule="linear", 13 | use_kl=False, 14 | sigma_small=False, 15 | parameterization="eps", 16 | learn_sigma=True, 17 | rescale_learned_sigmas=False, 18 | diffusion_steps=1000 19 | ): 20 | betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps) 21 | if use_kl: 22 | loss_type = gd.LossType.RESCALED_KL 23 | elif rescale_learned_sigmas: 24 | loss_type = gd.LossType.RESCALED_MSE 25 | else: 26 | loss_type = gd.LossType.MSE 27 | if timestep_respacing is None or timestep_respacing == "": 28 | timestep_respacing = [diffusion_steps] 29 | if parameterization == "eps": 30 | model_mean_type = gd.ModelMeanType.EPSILON 31 | elif parameterization == "xstart": 32 | model_mean_type = gd.ModelMeanType.START_X 33 | elif parameterization == "v": 34 | model_mean_type = gd.ModelMeanType.VELOCITY 35 | else: 36 | raise NotImplementedError("Model Mean Type {} is not supported!".format(parameterization)) 37 | return SpacedDiffusion( 38 | use_timesteps=space_timesteps(diffusion_steps, timestep_respacing), 39 | betas=betas, 40 | model_mean_type=model_mean_type, 41 | model_var_type=( 42 | ( 43 | gd.ModelVarType.FIXED_LARGE 44 | if not sigma_small 45 | else gd.ModelVarType.FIXED_SMALL 46 | ) 47 | if not learn_sigma 48 | else gd.ModelVarType.LEARNED_RANGE 49 | ), 50 | loss_type=loss_type 51 | # rescale_timesteps=rescale_timesteps, 52 | ) 53 | -------------------------------------------------------------------------------- /models/diffusion/diffusion_utils.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenAI's diffusion repos 2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 5 | 6 | import torch as th 7 | import numpy as np 8 | 9 | 10 | def normal_kl(mean1, logvar1, mean2, logvar2): 11 | """ 12 | Compute the KL divergence between two gaussians. 13 | Shapes are automatically broadcasted, so batches can be compared to 14 | scalars, among other use cases. 15 | """ 16 | tensor = None 17 | for obj in (mean1, logvar1, mean2, logvar2): 18 | if isinstance(obj, th.Tensor): 19 | tensor = obj 20 | break 21 | assert tensor is not None, "at least one argument must be a Tensor" 22 | 23 | # Force variances to be Tensors. Broadcasting helps convert scalars to 24 | # Tensors, but it does not work for th.exp(). 25 | logvar1, logvar2 = [ 26 | x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) 27 | for x in (logvar1, logvar2) 28 | ] 29 | 30 | return 0.5 * ( 31 | -1.0 32 | + logvar2 33 | - logvar1 34 | + th.exp(logvar1 - logvar2) 35 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2) 36 | ) 37 | 38 | 39 | def approx_standard_normal_cdf(x): 40 | """ 41 | A fast approximation of the cumulative distribution function of the 42 | standard normal. 43 | """ 44 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) 45 | 46 | 47 | def continuous_gaussian_log_likelihood(x, *, means, log_scales): 48 | """ 49 | Compute the log-likelihood of a continuous Gaussian distribution. 50 | :param x: the targets 51 | :param means: the Gaussian mean Tensor. 52 | :param log_scales: the Gaussian log stddev Tensor. 53 | :return: a tensor like x of log probabilities (in nats). 54 | """ 55 | centered_x = x - means 56 | inv_stdv = th.exp(-log_scales) 57 | normalized_x = centered_x * inv_stdv 58 | log_probs = th.distributions.Normal(th.zeros_like(x), th.ones_like(x)).log_prob(normalized_x) 59 | return log_probs 60 | 61 | 62 | def discretized_gaussian_log_likelihood(x, *, means, log_scales): 63 | """ 64 | Compute the log-likelihood of a Gaussian distribution discretizing to a 65 | given image. 66 | :param x: the target images. It is assumed that this was uint8 values, 67 | rescaled to the range [-1, 1]. 68 | :param means: the Gaussian mean Tensor. 69 | :param log_scales: the Gaussian log stddev Tensor. 70 | :return: a tensor like x of log probabilities (in nats). 71 | """ 72 | assert x.shape == means.shape == log_scales.shape 73 | centered_x = x - means 74 | inv_stdv = th.exp(-log_scales) 75 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0) 76 | cdf_plus = approx_standard_normal_cdf(plus_in) 77 | min_in = inv_stdv * (centered_x - 1.0 / 255.0) 78 | cdf_min = approx_standard_normal_cdf(min_in) 79 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) 80 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) 81 | cdf_delta = cdf_plus - cdf_min 82 | log_probs = th.where( 83 | x < -0.999, 84 | log_cdf_plus, 85 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), 86 | ) 87 | assert log_probs.shape == x.shape 88 | return log_probs 89 | -------------------------------------------------------------------------------- /models/diffusion/respace.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenAI's diffusion repos 2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 5 | 6 | import numpy as np 7 | import torch as th 8 | 9 | from .gaussian_diffusion import GaussianDiffusion 10 | 11 | 12 | def space_timesteps(num_timesteps, section_counts): 13 | """ 14 | Create a list of timesteps to use from an original diffusion process, 15 | given the number of timesteps we want to take from equally-sized portions 16 | of the original process. 17 | For example, if there's 300 timesteps and the section counts are [10,15,20] 18 | then the first 100 timesteps are strided to be 10 timesteps, the second 100 19 | are strided to be 15 timesteps, and the final 100 are strided to be 20. 20 | If the stride is a string starting with "ddim", then the fixed striding 21 | from the DDIM paper is used, and only one section is allowed. 22 | :param num_timesteps: the number of diffusion steps in the original 23 | process to divide up. 24 | :param section_counts: either a list of numbers, or a string containing 25 | comma-separated numbers, indicating the step count 26 | per section. As a special case, use "ddimN" where N 27 | is a number of steps to use the striding from the 28 | DDIM paper. 29 | :return: a set of diffusion steps from the original process to use. 30 | """ 31 | if isinstance(section_counts, str): 32 | if section_counts.startswith("ddim"): 33 | desired_count = int(section_counts[len("ddim") :]) 34 | for i in range(1, num_timesteps): 35 | if len(range(0, num_timesteps, i)) == desired_count: 36 | return set(range(0, num_timesteps, i)) 37 | raise ValueError( 38 | f"cannot create exactly {num_timesteps} steps with an integer stride" 39 | ) 40 | section_counts = [int(x) for x in section_counts.split(",")] 41 | size_per = num_timesteps // len(section_counts) 42 | extra = num_timesteps % len(section_counts) 43 | start_idx = 0 44 | all_steps = [] 45 | for i, section_count in enumerate(section_counts): 46 | size = size_per + (1 if i < extra else 0) 47 | if size < section_count: 48 | raise ValueError( 49 | f"cannot divide section of {size} steps into {section_count}" 50 | ) 51 | if section_count <= 1: 52 | frac_stride = 1 53 | else: 54 | frac_stride = (size - 1) / (section_count - 1) 55 | cur_idx = 0.0 56 | taken_steps = [] 57 | for _ in range(section_count): 58 | taken_steps.append(start_idx + round(cur_idx)) 59 | cur_idx += frac_stride 60 | all_steps += taken_steps 61 | start_idx += size 62 | return set(all_steps) 63 | 64 | 65 | class SpacedDiffusion(GaussianDiffusion): 66 | """ 67 | A diffusion process which can skip steps in a base diffusion process. 68 | :param use_timesteps: a collection (sequence or set) of timesteps from the 69 | original diffusion process to retain. 70 | :param kwargs: the kwargs to create the base diffusion process. 71 | """ 72 | 73 | def __init__(self, use_timesteps, **kwargs): 74 | self.use_timesteps = set(use_timesteps) 75 | self.timestep_map = [] 76 | self.original_num_steps = len(kwargs["betas"]) 77 | 78 | base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa 79 | last_alpha_cumprod = 1.0 80 | new_betas = [] 81 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): 82 | if i in self.use_timesteps: 83 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) 84 | last_alpha_cumprod = alpha_cumprod 85 | self.timestep_map.append(i) 86 | kwargs["betas"] = np.array(new_betas) 87 | super().__init__(**kwargs) 88 | 89 | def p_mean_variance( 90 | self, model, *args, **kwargs 91 | ): # pylint: disable=signature-differs 92 | return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) 93 | 94 | def training_losses( 95 | self, model, *args, **kwargs 96 | ): # pylint: disable=signature-differs 97 | return super().training_losses(self._wrap_model(model), *args, **kwargs) 98 | 99 | def condition_mean(self, cond_fn, *args, **kwargs): 100 | return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs) 101 | 102 | def condition_score(self, cond_fn, *args, **kwargs): 103 | return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs) 104 | 105 | def _wrap_model(self, model): 106 | if isinstance(model, _WrappedModel): 107 | return model 108 | return _WrappedModel( 109 | model, self.timestep_map, self.original_num_steps 110 | ) 111 | 112 | def _scale_timesteps(self, t): 113 | # Scaling is done by the wrapped model. 114 | return t 115 | 116 | 117 | class _WrappedModel: 118 | def __init__(self, model, timestep_map, original_num_steps): 119 | self.model = model 120 | self.timestep_map = timestep_map 121 | # self.rescale_timesteps = rescale_timesteps 122 | self.original_num_steps = original_num_steps 123 | 124 | def __call__(self, x, ts, **kwargs): 125 | map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) 126 | new_ts = map_tensor[ts] 127 | # if self.rescale_timesteps: 128 | # new_ts = new_ts.float() * (1000.0 / self.original_num_steps) 129 | return self.model(x, new_ts, **kwargs) 130 | -------------------------------------------------------------------------------- /models/diffusion/timestep_sampler.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenAI's diffusion repos 2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 5 | 6 | from abc import ABC, abstractmethod 7 | 8 | import numpy as np 9 | import torch as th 10 | import torch.distributed as dist 11 | 12 | 13 | def create_named_schedule_sampler(name, diffusion): 14 | """ 15 | Create a ScheduleSampler from a library of pre-defined samplers. 16 | :param name: the name of the sampler. 17 | :param diffusion: the diffusion object to sample for. 18 | """ 19 | if name == "uniform": 20 | return UniformSampler(diffusion) 21 | elif name == "loss-second-moment": 22 | return LossSecondMomentResampler(diffusion) 23 | else: 24 | raise NotImplementedError(f"unknown schedule sampler: {name}") 25 | 26 | 27 | class ScheduleSampler(ABC): 28 | """ 29 | A distribution over timesteps in the diffusion process, intended to reduce 30 | variance of the objective. 31 | By default, samplers perform unbiased importance sampling, in which the 32 | objective's mean is unchanged. 33 | However, subclasses may override sample() to change how the resampled 34 | terms are reweighted, allowing for actual changes in the objective. 35 | """ 36 | 37 | @abstractmethod 38 | def weights(self): 39 | """ 40 | Get a numpy array of weights, one per diffusion step. 41 | The weights needn't be normalized, but must be positive. 42 | """ 43 | 44 | def sample(self, batch_size, device): 45 | """ 46 | Importance-sample timesteps for a batch. 47 | :param batch_size: the number of timesteps. 48 | :param device: the torch device to save to. 49 | :return: a tuple (timesteps, weights): 50 | - timesteps: a tensor of timestep indices. 51 | - weights: a tensor of weights to scale the resulting losses. 52 | """ 53 | w = self.weights() 54 | p = w / np.sum(w) 55 | indices_np = np.random.choice(len(p), size=(batch_size,), p=p) 56 | indices = th.from_numpy(indices_np).long().to(device) 57 | weights_np = 1 / (len(p) * p[indices_np]) 58 | weights = th.from_numpy(weights_np).float().to(device) 59 | return indices, weights 60 | 61 | 62 | class UniformSampler(ScheduleSampler): 63 | def __init__(self, diffusion): 64 | self.diffusion = diffusion 65 | self._weights = np.ones([diffusion.num_timesteps]) 66 | 67 | def weights(self): 68 | return self._weights 69 | 70 | 71 | class LossAwareSampler(ScheduleSampler): 72 | def update_with_local_losses(self, local_ts, local_losses): 73 | """ 74 | Update the reweighting using losses from a model. 75 | Call this method from each rank with a batch of timesteps and the 76 | corresponding losses for each of those timesteps. 77 | This method will perform synchronization to make sure all of the ranks 78 | maintain the exact same reweighting. 79 | :param local_ts: an integer Tensor of timesteps. 80 | :param local_losses: a 1D Tensor of losses. 81 | """ 82 | batch_sizes = [ 83 | th.tensor([0], dtype=th.int32, device=local_ts.device) 84 | for _ in range(dist.get_world_size()) 85 | ] 86 | dist.all_gather( 87 | batch_sizes, 88 | th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device), 89 | ) 90 | 91 | # Pad all_gather batches to be the maximum batch size. 92 | batch_sizes = [x.item() for x in batch_sizes] 93 | max_bs = max(batch_sizes) 94 | 95 | timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes] 96 | loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes] 97 | dist.all_gather(timestep_batches, local_ts) 98 | dist.all_gather(loss_batches, local_losses) 99 | timesteps = [ 100 | x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs] 101 | ] 102 | losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]] 103 | self.update_with_all_losses(timesteps, losses) 104 | 105 | @abstractmethod 106 | def update_with_all_losses(self, ts, losses): 107 | """ 108 | Update the reweighting using losses from a model. 109 | Sub-classes should override this method to update the reweighting 110 | using losses from the model. 111 | This method directly updates the reweighting without synchronizing 112 | between workers. It is called by update_with_local_losses from all 113 | ranks with identical arguments. Thus, it should have deterministic 114 | behavior to maintain state across workers. 115 | :param ts: a list of int timesteps. 116 | :param losses: a list of float losses, one per timestep. 117 | """ 118 | 119 | 120 | class LossSecondMomentResampler(LossAwareSampler): 121 | def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001): 122 | self.diffusion = diffusion 123 | self.history_per_term = history_per_term 124 | self.uniform_prob = uniform_prob 125 | self._loss_history = np.zeros( 126 | [diffusion.num_timesteps, history_per_term], dtype=np.float64 127 | ) 128 | self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int) 129 | 130 | def weights(self): 131 | if not self._warmed_up(): 132 | return np.ones([self.diffusion.num_timesteps], dtype=np.float64) 133 | weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1)) 134 | weights /= np.sum(weights) 135 | weights *= 1 - self.uniform_prob 136 | weights += self.uniform_prob / len(weights) 137 | return weights 138 | 139 | def update_with_all_losses(self, ts, losses): 140 | for t, loss in zip(ts, losses): 141 | if self._loss_counts[t] == self.history_per_term: 142 | # Shift out the oldest loss term. 143 | self._loss_history[t, :-1] = self._loss_history[t, 1:] 144 | self._loss_history[t, -1] = loss 145 | else: 146 | self._loss_history[t, self._loss_counts[t]] = loss 147 | self._loss_counts[t] += 1 148 | 149 | def _warmed_up(self): 150 | return (self._loss_counts == self.history_per_term).all() 151 | -------------------------------------------------------------------------------- /models/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.utils.checkpoint import checkpoint 4 | import numpy as np 5 | import math 6 | from itertools import repeat 7 | import collections.abc 8 | from .attention import MemEffAttention 9 | 10 | # From PyTorch internals 11 | def _ntuple(n): 12 | def parse(x): 13 | if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): 14 | return tuple(x) 15 | return tuple(repeat(x, n)) 16 | return parse 17 | to_2tuple = _ntuple(2) 18 | 19 | def modulate(x, shift, scale): 20 | return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) 21 | 22 | 23 | ################################################################################# 24 | # Embedding Layers for Timesteps and Class Labels # 25 | ################################################################################# 26 | 27 | class TimestepEmbedder(nn.Module): 28 | """ 29 | Embeds scalar timesteps into vector representations. 30 | """ 31 | def __init__(self, hidden_size, frequency_embedding_size=256): 32 | super().__init__() 33 | self.mlp = nn.Sequential( 34 | nn.Linear(frequency_embedding_size, hidden_size, bias=True), 35 | nn.SiLU(), 36 | nn.Linear(hidden_size, hidden_size, bias=True), 37 | ) 38 | self.frequency_embedding_size = frequency_embedding_size 39 | 40 | @staticmethod 41 | def timestep_embedding(t, dim, max_period=10000): 42 | """ 43 | Create sinusoidal timestep embeddings. 44 | :param t: a 1-D Tensor of N indices, one per batch element. 45 | These may be fractional. 46 | :param dim: the dimension of the output. 47 | :param max_period: controls the minimum frequency of the embeddings. 48 | :return: an (N, D) Tensor of positional embeddings. 49 | """ 50 | # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py 51 | half = dim // 2 52 | freqs = torch.exp( 53 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 54 | ).to(device=t.device) 55 | args = t[:, None].float() * freqs[None] 56 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 57 | if dim % 2: 58 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 59 | return embedding 60 | 61 | def forward(self, t): 62 | t_freq = self.timestep_embedding(t, self.frequency_embedding_size) 63 | t_emb = self.mlp(t_freq) 64 | return t_emb 65 | 66 | class Mlp(nn.Module): 67 | """ MLP as used in Vision Transformer, MLP-Mixer and related networks 68 | """ 69 | def __init__( 70 | self, 71 | in_features, 72 | hidden_features=None, 73 | out_features=None, 74 | act_layer=nn.GELU, 75 | norm_layer=None, 76 | bias=True, 77 | drop=0., 78 | use_conv=False, 79 | ): 80 | super().__init__() 81 | out_features = out_features or in_features 82 | hidden_features = hidden_features or in_features 83 | bias = to_2tuple(bias) 84 | drop_probs = to_2tuple(drop) 85 | linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear 86 | 87 | self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0]) 88 | self.act = act_layer() 89 | self.drop1 = nn.Dropout(drop_probs[0]) 90 | self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity() 91 | self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1]) 92 | self.drop2 = nn.Dropout(drop_probs[1]) 93 | 94 | def forward(self, x): 95 | x = self.fc1(x) 96 | x = self.act(x) 97 | x = self.drop1(x) 98 | x = self.norm(x) 99 | x = self.fc2(x) 100 | x = self.drop2(x) 101 | return x -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | einops 2 | omegaconf 3 | opencv-python 4 | libigl 5 | trimesh==4.2.0 6 | pygltflib 7 | pymeshlab==0.2 8 | PyMCubes 9 | xatlas 10 | git+https://github.com/NVlabs/nvdiffrast/ 11 | scikit-learn 12 | open_clip_torch 13 | triton==2.1.0 14 | rembg 15 | gradio 16 | tqdm 17 | transformers==4.40.1 18 | diffusers==0.19.3 19 | ninja 20 | imageio 21 | imageio-ffmpeg 22 | gradio-litmodel3d==0.0.1 23 | jaxtyping==0.2.31 -------------------------------------------------------------------------------- /scripts/cache_conditioner.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import io 4 | 5 | import torch 6 | import numpy as np 7 | from omegaconf import OmegaConf 8 | 9 | from torch.utils.data import DataLoader 10 | 11 | from dva.io import load_from_config 12 | from dva.ray_marcher import RayMarcher 13 | from dva.utils import to_device 14 | 15 | import logging 16 | from time import time 17 | logger = logging.getLogger("cache_conditioner.py") 18 | 19 | def main(config): 20 | logging.basicConfig(level=logging.INFO) 21 | 22 | rank = 0 23 | seed = 42 24 | device = 0 25 | torch.manual_seed(seed) 26 | torch.cuda.set_device(device) 27 | is_master = rank == 0 28 | 29 | dataset = load_from_config(config.dataset) 30 | conditioner = load_from_config(config.model.conditioner) 31 | conditioner = conditioner.to(device) 32 | 33 | # computing values for the given viewpoints 34 | rm = RayMarcher( 35 | config.image_height, 36 | config.image_width, 37 | **config.rm, 38 | ).to(device) 39 | 40 | loader = DataLoader( 41 | dataset, 42 | batch_size=config.train.get("batch_size", 4), 43 | pin_memory=True, 44 | num_workers=config.train.get("n_workers", 1), 45 | drop_last=False, 46 | ) 47 | 48 | conditioner.eval() 49 | iteration = 0 50 | for b, batch in enumerate(loader): 51 | logger.info(f"Iteration {iteration}") 52 | batch = to_device(batch, device) 53 | bs = batch['gt'].shape[0] 54 | with torch.no_grad(): 55 | y = conditioner(batch, rm, amp=False, precision_dtype=torch.float32) 56 | for bidx in range(bs): 57 | fitted_param = y[bidx, ...].clone() 58 | folder = batch['folder'][bidx] 59 | key = batch['key'][bidx] 60 | fitted_param_url = "./data/obj-2048-518reso-dino-cond/{}{}.pt".format(folder, key) 61 | torch.save(fitted_param, fitted_param_url) 62 | iteration += 1 63 | 64 | 65 | if __name__ == "__main__": 66 | torch.backends.cudnn.benchmark = True 67 | # manually enable tf32 to get speedup on A100 GPUs 68 | torch.backends.cuda.matmul.allow_tf32 = True 69 | torch.backends.cudnn.allow_tf32 = True 70 | # set config 71 | config = OmegaConf.load(str(sys.argv[1])) 72 | config_cli = OmegaConf.from_cli(args_list=sys.argv[2:]) 73 | if config_cli: 74 | logger.info("overriding with following values from args:") 75 | logger.info(OmegaConf.to_yaml(config_cli)) 76 | config = OmegaConf.merge(config, config_cli) 77 | 78 | main(config) 79 | -------------------------------------------------------------------------------- /scripts/cache_vae.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import io 4 | 5 | import torch 6 | import numpy as np 7 | from omegaconf import OmegaConf 8 | 9 | from torch.utils.data import DataLoader 10 | 11 | from dva.io import load_from_config 12 | from dva.utils import to_device 13 | 14 | import logging 15 | from time import time 16 | logger = logging.getLogger("cache_vae.py") 17 | 18 | def main(config): 19 | logging.basicConfig(level=logging.INFO) 20 | 21 | rank = 0 22 | seed = 42 23 | device = 0 24 | torch.manual_seed(seed) 25 | torch.cuda.set_device(device) 26 | is_master = rank == 0 27 | 28 | dataset = load_from_config(config.dataset) 29 | vae = load_from_config(config.model.vae) 30 | vae_state_dict = torch.load(config.model.vae_checkpoint_path, map_location='cpu') 31 | vae.load_state_dict(vae_state_dict['model_state_dict']) 32 | 33 | vae = vae.to(device) 34 | loader = DataLoader( 35 | dataset, 36 | batch_size=config.train.get("batch_size", 4), 37 | pin_memory=True, 38 | num_workers=config.train.get("n_workers", 1), 39 | drop_last=False, 40 | ) 41 | 42 | vae.eval() 43 | iteration = 0 44 | for b, batch in enumerate(loader): 45 | logger.info(f"Iteration {iteration}") 46 | batch = to_device(batch, device) 47 | bs = batch['gt'].shape[0] 48 | with torch.no_grad(): 49 | latent = vae.encode(batch['gt'].reshape(bs*config.model.num_prims, config.model.dim_feat, config.model.prim_shape, config.model.prim_shape, config.model.prim_shape)).parameters 50 | latent = latent.reshape(bs, config.model.num_prims, *latent.shape[1:]).detach() 51 | 52 | for bidx in range(bs): 53 | fitted_param = latent[bidx, ...].clone() 54 | folder = batch['folder'][bidx] 55 | key = batch['key'][bidx] 56 | fitted_param_url = "./data/klvae_2048_scaleup_cache/vae-{}{}.pt".format(folder, key) 57 | torch.save(fitted_param, fitted_param_url) 58 | iteration += 1 59 | 60 | 61 | if __name__ == "__main__": 62 | torch.backends.cudnn.benchmark = True 63 | # manually enable tf32 to get speedup on A100 GPUs 64 | torch.backends.cuda.matmul.allow_tf32 = True 65 | torch.backends.cudnn.allow_tf32 = True 66 | # set config 67 | config = OmegaConf.load(str(sys.argv[1])) 68 | config_cli = OmegaConf.from_cli(args_list=sys.argv[2:]) 69 | if config_cli: 70 | logger.info("overriding with following values from args:") 71 | logger.info(OmegaConf.to_yaml(config_cli)) 72 | config = OmegaConf.merge(config, config_cli) 73 | 74 | main(config) 75 | -------------------------------------------------------------------------------- /simple-knn/ext.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #include 13 | #include "spatial.h" 14 | 15 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 16 | m.def("distCUDA2", &distCUDA2); 17 | } 18 | -------------------------------------------------------------------------------- /simple-knn/setup.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | from setuptools import setup 13 | from torch.utils.cpp_extension import CUDAExtension, BuildExtension 14 | import os 15 | 16 | cxx_compiler_flags = [] 17 | 18 | if os.name == 'nt': 19 | cxx_compiler_flags.append("/wd4624") 20 | 21 | setup( 22 | name="simple_knn", 23 | ext_modules=[ 24 | CUDAExtension( 25 | name="simple_knn._C", 26 | sources=[ 27 | "spatial.cu", 28 | "simple_knn.cu", 29 | "ext.cpp"], 30 | extra_compile_args={"nvcc": [ 31 | "-use_fast_math", 32 | "-arch=sm_70", 33 | "-std=c++17", 34 | "-lineinfo", 35 | ], "cxx": cxx_compiler_flags}) 36 | ], 37 | cmdclass={ 38 | 'build_ext': BuildExtension 39 | } 40 | ) 41 | -------------------------------------------------------------------------------- /simple-knn/simple_knn.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #define BOX_SIZE 1024 13 | 14 | #include "cuda_runtime.h" 15 | #include "device_launch_parameters.h" 16 | #include "simple_knn.h" 17 | #include 18 | #include 19 | #include 20 | #include 21 | #include 22 | #include 23 | #define __CUDACC__ 24 | #include 25 | #include 26 | 27 | namespace cg = cooperative_groups; 28 | 29 | struct CustomMin 30 | { 31 | __device__ __forceinline__ 32 | float3 operator()(const float3& a, const float3& b) const { 33 | return { min(a.x, b.x), min(a.y, b.y), min(a.z, b.z) }; 34 | } 35 | }; 36 | 37 | struct CustomMax 38 | { 39 | __device__ __forceinline__ 40 | float3 operator()(const float3& a, const float3& b) const { 41 | return { max(a.x, b.x), max(a.y, b.y), max(a.z, b.z) }; 42 | } 43 | }; 44 | 45 | __host__ __device__ uint32_t prepMorton(uint32_t x) 46 | { 47 | x = (x | (x << 16)) & 0x030000FF; 48 | x = (x | (x << 8)) & 0x0300F00F; 49 | x = (x | (x << 4)) & 0x030C30C3; 50 | x = (x | (x << 2)) & 0x09249249; 51 | return x; 52 | } 53 | 54 | __host__ __device__ uint32_t coord2Morton(float3 coord, float3 minn, float3 maxx) 55 | { 56 | uint32_t x = prepMorton(((coord.x - minn.x) / (maxx.x - minn.x)) * ((1 << 10) - 1)); 57 | uint32_t y = prepMorton(((coord.y - minn.y) / (maxx.y - minn.y)) * ((1 << 10) - 1)); 58 | uint32_t z = prepMorton(((coord.z - minn.z) / (maxx.z - minn.z)) * ((1 << 10) - 1)); 59 | 60 | return x | (y << 1) | (z << 2); 61 | } 62 | 63 | __global__ void coord2Morton(int P, const float3* points, float3 minn, float3 maxx, uint32_t* codes) 64 | { 65 | auto idx = cg::this_grid().thread_rank(); 66 | if (idx >= P) 67 | return; 68 | 69 | codes[idx] = coord2Morton(points[idx], minn, maxx); 70 | } 71 | 72 | struct MinMax 73 | { 74 | float3 minn; 75 | float3 maxx; 76 | }; 77 | 78 | __global__ void boxMinMax(uint32_t P, float3* points, uint32_t* indices, MinMax* boxes) 79 | { 80 | auto idx = cg::this_grid().thread_rank(); 81 | 82 | MinMax me; 83 | if (idx < P) 84 | { 85 | me.minn = points[indices[idx]]; 86 | me.maxx = points[indices[idx]]; 87 | } 88 | else 89 | { 90 | me.minn = { FLT_MAX, FLT_MAX, FLT_MAX }; 91 | me.maxx = { -FLT_MAX,-FLT_MAX,-FLT_MAX }; 92 | } 93 | 94 | __shared__ MinMax redResult[BOX_SIZE]; 95 | 96 | for (int off = BOX_SIZE / 2; off >= 1; off /= 2) 97 | { 98 | if (threadIdx.x < 2 * off) 99 | redResult[threadIdx.x] = me; 100 | __syncthreads(); 101 | 102 | if (threadIdx.x < off) 103 | { 104 | MinMax other = redResult[threadIdx.x + off]; 105 | me.minn.x = min(me.minn.x, other.minn.x); 106 | me.minn.y = min(me.minn.y, other.minn.y); 107 | me.minn.z = min(me.minn.z, other.minn.z); 108 | me.maxx.x = max(me.maxx.x, other.maxx.x); 109 | me.maxx.y = max(me.maxx.y, other.maxx.y); 110 | me.maxx.z = max(me.maxx.z, other.maxx.z); 111 | } 112 | __syncthreads(); 113 | } 114 | 115 | if (threadIdx.x == 0) 116 | boxes[blockIdx.x] = me; 117 | } 118 | 119 | __device__ __host__ float distBoxPoint(const MinMax& box, const float3& p) 120 | { 121 | float3 diff = { 0, 0, 0 }; 122 | if (p.x < box.minn.x || p.x > box.maxx.x) 123 | diff.x = min(abs(p.x - box.minn.x), abs(p.x - box.maxx.x)); 124 | if (p.y < box.minn.y || p.y > box.maxx.y) 125 | diff.y = min(abs(p.y - box.minn.y), abs(p.y - box.maxx.y)); 126 | if (p.z < box.minn.z || p.z > box.maxx.z) 127 | diff.z = min(abs(p.z - box.minn.z), abs(p.z - box.maxx.z)); 128 | return diff.x * diff.x + diff.y * diff.y + diff.z * diff.z; 129 | } 130 | 131 | template 132 | __device__ void updateKBest(const float3& ref, const float3& point, float* knn) 133 | { 134 | float3 d = { point.x - ref.x, point.y - ref.y, point.z - ref.z }; 135 | float dist = d.x * d.x + d.y * d.y + d.z * d.z; 136 | for (int j = 0; j < K; j++) 137 | { 138 | if (knn[j] > dist) 139 | { 140 | float t = knn[j]; 141 | knn[j] = dist; 142 | dist = t; 143 | } 144 | } 145 | } 146 | 147 | __global__ void boxMeanDist(uint32_t P, float3* points, uint32_t* indices, MinMax* boxes, float* dists) 148 | { 149 | int idx = cg::this_grid().thread_rank(); 150 | if (idx >= P) 151 | return; 152 | 153 | float3 point = points[indices[idx]]; 154 | float best[3] = { FLT_MAX, FLT_MAX, FLT_MAX }; 155 | 156 | for (int i = max(0, idx - 3); i <= min(P - 1, idx + 3); i++) 157 | { 158 | if (i == idx) 159 | continue; 160 | updateKBest<3>(point, points[indices[i]], best); 161 | } 162 | 163 | float reject = best[2]; 164 | best[0] = FLT_MAX; 165 | best[1] = FLT_MAX; 166 | best[2] = FLT_MAX; 167 | 168 | for (int b = 0; b < (P + BOX_SIZE - 1) / BOX_SIZE; b++) 169 | { 170 | MinMax box = boxes[b]; 171 | float dist = distBoxPoint(box, point); 172 | if (dist > reject || dist > best[2]) 173 | continue; 174 | 175 | for (int i = b * BOX_SIZE; i < min(P, (b + 1) * BOX_SIZE); i++) 176 | { 177 | if (i == idx) 178 | continue; 179 | updateKBest<3>(point, points[indices[i]], best); 180 | } 181 | } 182 | dists[indices[idx]] = (best[0] + best[1] + best[2]) / 3.0f; 183 | } 184 | 185 | void SimpleKNN::knn(int P, float3* points, float* meanDists) 186 | { 187 | float3* result; 188 | cudaMalloc(&result, sizeof(float3)); 189 | size_t temp_storage_bytes; 190 | 191 | float3 init = { 0, 0, 0 }, minn, maxx; 192 | 193 | cub::DeviceReduce::Reduce(nullptr, temp_storage_bytes, points, result, P, CustomMin(), init); 194 | thrust::device_vector temp_storage(temp_storage_bytes); 195 | 196 | cub::DeviceReduce::Reduce(temp_storage.data().get(), temp_storage_bytes, points, result, P, CustomMin(), init); 197 | cudaMemcpy(&minn, result, sizeof(float3), cudaMemcpyDeviceToHost); 198 | 199 | cub::DeviceReduce::Reduce(temp_storage.data().get(), temp_storage_bytes, points, result, P, CustomMax(), init); 200 | cudaMemcpy(&maxx, result, sizeof(float3), cudaMemcpyDeviceToHost); 201 | 202 | thrust::device_vector morton(P); 203 | thrust::device_vector morton_sorted(P); 204 | coord2Morton << <(P + 255) / 256, 256 >> > (P, points, minn, maxx, morton.data().get()); 205 | 206 | thrust::device_vector indices(P); 207 | thrust::sequence(indices.begin(), indices.end()); 208 | thrust::device_vector indices_sorted(P); 209 | 210 | cub::DeviceRadixSort::SortPairs(nullptr, temp_storage_bytes, morton.data().get(), morton_sorted.data().get(), indices.data().get(), indices_sorted.data().get(), P); 211 | temp_storage.resize(temp_storage_bytes); 212 | 213 | cub::DeviceRadixSort::SortPairs(temp_storage.data().get(), temp_storage_bytes, morton.data().get(), morton_sorted.data().get(), indices.data().get(), indices_sorted.data().get(), P); 214 | 215 | uint32_t num_boxes = (P + BOX_SIZE - 1) / BOX_SIZE; 216 | thrust::device_vector boxes(num_boxes); 217 | boxMinMax << > > (P, points, indices_sorted.data().get(), boxes.data().get()); 218 | boxMeanDist << > > (P, points, indices_sorted.data().get(), boxes.data().get(), meanDists); 219 | 220 | cudaFree(result); 221 | } -------------------------------------------------------------------------------- /simple-knn/simple_knn.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #ifndef SIMPLEKNN_H_INCLUDED 13 | #define SIMPLEKNN_H_INCLUDED 14 | 15 | class SimpleKNN 16 | { 17 | public: 18 | static void knn(int P, float3* points, float* meanDists); 19 | }; 20 | 21 | #endif -------------------------------------------------------------------------------- /simple-knn/simple_knn/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/3DTopia/3DTopia-XL/af602b003e9d137e37e1d883c18b50ed65c06f26/simple-knn/simple_knn/.gitkeep -------------------------------------------------------------------------------- /simple-knn/spatial.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #include "spatial.h" 13 | #include "simple_knn.h" 14 | 15 | torch::Tensor 16 | distCUDA2(const torch::Tensor& points) 17 | { 18 | const int P = points.size(0); 19 | 20 | auto float_opts = points.options().dtype(torch::kFloat32); 21 | torch::Tensor means = torch::full({P}, 0.0, float_opts); 22 | 23 | SimpleKNN::knn(P, (float3*)points.contiguous().data(), means.contiguous().data()); 24 | 25 | return means; 26 | } -------------------------------------------------------------------------------- /simple-knn/spatial.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2023, Inria 3 | * GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | * All rights reserved. 5 | * 6 | * This software is free for non-commercial, research and evaluation use 7 | * under the terms of the LICENSE.md file. 8 | * 9 | * For inquiries contact george.drettakis@inria.fr 10 | */ 11 | 12 | #include 13 | 14 | torch::Tensor distCUDA2(const torch::Tensor& points); -------------------------------------------------------------------------------- /train_fitting.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import torch 5 | import numpy as np 6 | from omegaconf import OmegaConf 7 | 8 | import torch.distributed as dist 9 | from torch.utils.data import DataLoader 10 | 11 | from dva.ray_marcher import RayMarcher 12 | from dva.io import load_from_config 13 | from dva.losses import process_losses 14 | from dva.utils import to_device 15 | from dva.visualize import render_primsdf, visualize_primsdf_box 16 | 17 | import logging 18 | 19 | device = torch.device("cuda") 20 | 21 | logger = logging.getLogger("train_fitting.py") 22 | 23 | def main(config): 24 | dist.init_process_group("nccl") 25 | 26 | logging.basicConfig(level=logging.INFO) 27 | 28 | local_rank = int(os.environ["LOCAL_RANK"]) 29 | device = torch.device(f"cuda:{local_rank}") 30 | torch.cuda.set_device(device) 31 | 32 | os.makedirs(f"{config.output_dir}/checkpoints", exist_ok=True) 33 | OmegaConf.save(config, f"{config.output_dir}/config.yml") 34 | logger.info(f"saving results to {config.output_dir}") 35 | logger.info(f"starting training with the config: {OmegaConf.to_yaml(config)}") 36 | 37 | dataset = load_from_config(config.dataset) 38 | train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) 39 | loader = DataLoader( 40 | dataset, 41 | batch_size=config.train.get("batch_size", 4), 42 | pin_memory=False, 43 | sampler=train_sampler, 44 | num_workers=config.train.get("n_workers", 1), 45 | drop_last=True, 46 | worker_init_fn=lambda _: np.random.seed(), 47 | ) 48 | 49 | model = load_from_config(config.model, mesh_obj=dataset.mesh_obj, f_sdf=dataset.f_sdf, geo_fn=dataset.geo_fn_list, asset_list=dataset.asset_list) 50 | model = model.to(device) 51 | 52 | # computing values for the given viewpoints 53 | rm = RayMarcher( 54 | config.image_height, 55 | config.image_width, 56 | **config.rm, 57 | ).to(device) 58 | 59 | loss_fn = load_from_config(config.loss).to(device) 60 | optimizer = load_from_config(config.optimizer, params=model.parameters()) 61 | iteration = 0 62 | model.train() 63 | # stage 1, optimizing SDF only 64 | while True: 65 | if iteration >= config.train.shape_fit_steps: 66 | model.eval() 67 | visualize_primsdf_box("{}/{:06d}_box_nosampling.png".format(config.output_dir, iteration), model, rm, device) 68 | render_primsdf("{}/{:06d}_rendering.png".format(config.output_dir, iteration), model, rm, device) 69 | model.train() 70 | break 71 | for b, batch in enumerate(loader): 72 | batch = to_device(batch, device) 73 | for k, v in batch.items(): 74 | batch[k] = v.reshape(config.train.batch_size * config.dataset.chunk_size, *v.shape[2:]) 75 | 76 | if local_rank == 0 and batch is None: 77 | logger.info(f"batch {b} is None, skipping") 78 | continue 79 | 80 | if local_rank == 0 and iteration >= config.train.shape_fit_steps: 81 | logger.info(f"stopping after {config.train.shape_fit_steps}") 82 | break 83 | 84 | batch['pts'].requires_grad_(True) 85 | preds = model(batch['pts']) 86 | preds['prim_scale'] = (1 / model.scale.reshape(1, model.num_prims, 1).repeat(1, 1, 3)) 87 | 88 | loss, loss_dict = loss_fn(batch, preds, iteration) 89 | _loss_dict = process_losses(loss_dict) 90 | 91 | if torch.isnan(loss): 92 | loss_str = " ".join([f"{k}={v:.4f}" for k, v in _loss_dict.items()]) 93 | logger.warning(f"some of the losses is NaN, skipping: {loss_str}") 94 | continue 95 | 96 | optimizer.zero_grad() 97 | loss.backward() 98 | optimizer.step() 99 | 100 | if local_rank == 0 and iteration % config.train.log_every_n_steps == 0: 101 | loss_str = " ".join([f"{k}={v:.4f}" for k, v in _loss_dict.items()]) 102 | logger.info(f"iter={iteration}: {loss_str}") 103 | 104 | if ( 105 | local_rank == 0 106 | # and iteration 107 | and iteration % config.train.summary_every_n_steps == 0 108 | ): 109 | logger.info( 110 | f"saving summary to {config.output_dir} after {iteration} steps" 111 | ) 112 | 113 | iteration += 1 114 | 115 | pass 116 | 117 | # stage 2, optimizing texture 118 | optimizer_tex = load_from_config(config.optimizer, params=[model.feat_param]) 119 | while True: 120 | if iteration >= config.train.tex_fit_steps: 121 | if (local_rank == 0): 122 | logger.info(f"Texture Optimization Done: saving checkpoint after {iteration} steps") 123 | model.eval() 124 | visualize_primsdf_box("{}/{:06d}_box_nosampling.png".format(config.output_dir, iteration), model, rm, device) 125 | render_primsdf("{}/{:06d}_rendering.png".format(config.output_dir, iteration), model, rm, device) 126 | model.train() 127 | if config.train.save_fp16: 128 | model = model.half() 129 | params = { 130 | "model_state_dict": model.state_dict(), 131 | } 132 | torch.save(params, f"{config.output_dir}/checkpoints/tex-{iteration:06d}.pt") 133 | break 134 | for b, batch in enumerate(loader): 135 | batch = to_device(batch, device) 136 | for k, v in batch.items(): 137 | batch[k] = v.reshape(config.train.batch_size * config.dataset.chunk_size, *v.shape[2:]) 138 | 139 | if local_rank == 0 and batch is None: 140 | logger.info(f"batch {b} is None, skipping") 141 | continue 142 | 143 | if local_rank == 0 and iteration >= config.train.tex_fit_steps: 144 | logger.info(f"stopping after {config.train.tex_fit_steps}") 145 | break 146 | 147 | preds = model(batch['tex_pts']) 148 | preds['prim_scale'] = (1 / model.scale.reshape(1, model.num_prims, 1).repeat(1, 1, 3)) 149 | 150 | loss, loss_dict = loss_fn(batch, preds, iteration) 151 | _loss_dict = process_losses(loss_dict) 152 | 153 | if torch.isnan(loss): 154 | loss_str = " ".join([f"{k}={v:.4f}" for k, v in _loss_dict.items()]) 155 | logger.warning(f"some of the losses is NaN, skipping: {loss_str}") 156 | continue 157 | 158 | optimizer_tex.zero_grad() 159 | loss.backward() 160 | optimizer_tex.step() 161 | 162 | if local_rank == 0 and iteration % config.train.log_every_n_steps == 0: 163 | loss_str = " ".join([f"{k}={v:.4f}" for k, v in _loss_dict.items()]) 164 | logger.info(f"iter={iteration}: {loss_str}") 165 | iteration += 1 166 | pass 167 | 168 | 169 | if __name__ == "__main__": 170 | torch.backends.cudnn.benchmark = True 171 | # set config 172 | config = OmegaConf.load(str(sys.argv[1])) 173 | config_cli = OmegaConf.from_cli(args_list=sys.argv[2:]) 174 | if config_cli: 175 | logger.info("overriding with following values from args:") 176 | logger.info(OmegaConf.to_yaml(config_cli)) 177 | config = OmegaConf.merge(config, config_cli) 178 | 179 | main(config) 180 | -------------------------------------------------------------------------------- /train_vae.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import io 4 | 5 | import torch 6 | import numpy as np 7 | from omegaconf import OmegaConf 8 | 9 | import torch.distributed as dist 10 | from torch.nn.parallel import DistributedDataParallel as DDP 11 | from torch.utils.data import DataLoader 12 | 13 | from dva.ray_marcher import RayMarcher 14 | from dva.io import load_from_config 15 | from dva.losses import process_losses 16 | from dva.utils import to_device 17 | from dva.visualize import visualize_primvolume 18 | 19 | import logging 20 | import time 21 | logger = logging.getLogger("train_ae.py") 22 | 23 | def main(config): 24 | dist.init_process_group("nccl") 25 | logging.basicConfig(level=logging.INFO) 26 | 27 | rank = int(os.environ["RANK"]) 28 | assert rank == dist.get_rank() 29 | device = int(os.environ["LOCAL_RANK"]) 30 | seed = config.global_seed * dist.get_world_size() + rank 31 | torch.manual_seed(seed) 32 | torch.cuda.set_device(device) 33 | logger.info(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.") 34 | is_master = rank == 0 35 | 36 | if is_master: 37 | os.makedirs(f"{config.output_dir}/checkpoints", exist_ok=True) 38 | OmegaConf.save(config, f"{config.output_dir}/config.yml") 39 | logger.info(f"saving results to {config.output_dir}") 40 | logger.info(f"starting training with the config: {OmegaConf.to_yaml(config)}") 41 | 42 | amp = config.train.amp 43 | scaler = torch.cuda.amp.GradScaler() if amp else None 44 | dataset = load_from_config(config.dataset) 45 | if not config.dataset.incl_srt: 46 | assert config.model.vae.in_channels == config.model.dim_feat 47 | model = load_from_config(config.model.vae) 48 | if config.checkpoint_path: 49 | state_dict = torch.load(config.checkpoint_path, map_location='cpu') 50 | model.load_state_dict(state_dict['model_state_dict']) 51 | iteration = 0 52 | model = DDP(model.to(device), device_ids=[device]) 53 | 54 | # computing values for the given viewpoints 55 | rm = RayMarcher( 56 | config.image_height, 57 | config.image_width, 58 | **config.rm, 59 | ).to(device) 60 | 61 | loss_fn = load_from_config(config.loss).to(device) 62 | optimizer = load_from_config(config.optimizer, params=model.parameters()) 63 | 64 | train_sampler = torch.utils.data.distributed.DistributedSampler( 65 | dataset, 66 | num_replicas=dist.get_world_size(), 67 | rank=rank, 68 | shuffle=True, 69 | seed=config.global_seed, 70 | ) 71 | loader = DataLoader( 72 | dataset, 73 | batch_size=config.train.get("batch_size", 4), 74 | pin_memory=True, 75 | sampler=train_sampler, 76 | num_workers=config.train.get("n_workers", 1), 77 | drop_last=True, 78 | ) 79 | 80 | model.train() 81 | for epoch in range(config.train.n_epochs): 82 | train_sampler.set_epoch(epoch) 83 | if is_master: 84 | ts = time.time() 85 | for b, batch in enumerate(loader): 86 | with torch.cuda.amp.autocast(enabled=amp): 87 | batch = to_device(batch, device) 88 | if is_master: 89 | te = time.time() 90 | data_time = te - ts 91 | ts = te 92 | bs = batch['gt'].shape[0] 93 | batch['gt'] = batch['gt'].reshape(bs * config.model.num_prims, config.model.vae.in_channels, config.model.prim_shape, config.model.prim_shape, config.model.prim_shape) 94 | 95 | preds = {} 96 | recon, posterior = model(batch['gt']) 97 | preds['recon'] = recon 98 | preds['posterior'] = posterior 99 | 100 | loss, loss_dict = loss_fn(batch, preds, iteration) 101 | _loss_dict = process_losses(loss_dict) 102 | 103 | if is_master: 104 | te = time.time() 105 | model_time = te - ts 106 | ts = te 107 | 108 | if torch.isnan(loss): 109 | loss_str = " ".join([f"{k}={v:.4f}" for k, v in _loss_dict.items()]) 110 | logger.warning(f"some of the losses is NaN, skipping: {loss_str}") 111 | continue 112 | 113 | optimizer.zero_grad() 114 | if amp: 115 | assert scaler is not None 116 | scaler.scale(loss).backward() 117 | scaler.step(optimizer) 118 | scaler.update() 119 | else: 120 | loss.backward() 121 | optimizer.step() 122 | 123 | if is_master: 124 | te = time.time() 125 | bp_time = te - ts 126 | ts = te 127 | 128 | if is_master and iteration % config.train.log_every_n_steps == 0: 129 | loss_str = " ".join([f"{k}={v:.4f}" for k, v in _loss_dict.items()]) 130 | logger.info(f"epoch={epoch}, iter={iteration}[data={data_time:.3f}|model={model_time:.3f}|bp={bp_time:.3f}]: {loss_str}") 131 | 132 | if iteration % config.train.summary_every_n_steps == 0: 133 | if is_master: 134 | if config.dataset.incl_srt: 135 | recon_srt_param = torch.mean(recon[:, 0:4, ...], dim=[2,3,4]).reshape(bs, config.model.num_prims, 4) 136 | # invert normalization 137 | recon_srt_param[..., 0:1] = recon_srt_param[..., 0:1] / 10. + 0.05 138 | recon_feat_param = recon[:, 4:, ...] 139 | else: 140 | recon_srt_param = batch['input_param'][:, :, :4].reshape(bs, config.model.num_prims, 4) 141 | recon_feat_param = recon 142 | # invert normalization 143 | recon_feat_param[:, 0:1, ...] /= 5. 144 | recon_feat_param[:, 1:, ...] = (recon_feat_param[:, 1:, ...] + 1) / 2. 145 | recon_feat_param = recon_feat_param.reshape(bs, config.model.num_prims, -1) 146 | recon_param = torch.concat([recon_srt_param, recon_feat_param], dim=-1) 147 | visualize_primvolume("{}/{:06d}_recon.jpg".format(config.output_dir, iteration), batch, recon_param, rm, device) 148 | visualize_primvolume("{}/{:06d}_gt.jpg".format(config.output_dir, iteration), batch, batch['input_param'], rm, device) 149 | logger.info(f"saving checkpoint after {iteration} steps") 150 | params = { 151 | "model_state_dict": model.module.state_dict(), 152 | "epoch": epoch, 153 | "iteration": iteration, 154 | "optimizer": optimizer.state_dict(), 155 | } 156 | torch.save(params, f"{config.output_dir}/checkpoints/latest.pt") 157 | dist.barrier() 158 | 159 | iteration += 1 160 | 161 | if __name__ == "__main__": 162 | torch.backends.cudnn.benchmark = True 163 | torch.backends.cuda.matmul.allow_tf32 = True 164 | torch.backends.cudnn.allow_tf32 = True 165 | # set config 166 | config = OmegaConf.load(str(sys.argv[1])) 167 | config_cli = OmegaConf.from_cli(args_list=sys.argv[2:]) 168 | if config_cli: 169 | logger.info("overriding with following values from args:") 170 | logger.info(OmegaConf.to_yaml(config_cli)) 171 | config = OmegaConf.merge(config, config_cli) 172 | 173 | main(config) 174 | -------------------------------------------------------------------------------- /utils/op.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from .typing import * 4 | 5 | # torch / numpy math utils 6 | def dot(x: Union[Tensor, ndarray], y: Union[Tensor, ndarray]) -> Union[Tensor, ndarray]: 7 | """dot product (along the last dim). 8 | 9 | Args: 10 | x (Union[Tensor, ndarray]): x, [..., C] 11 | y (Union[Tensor, ndarray]): y, [..., C] 12 | 13 | Returns: 14 | Union[Tensor, ndarray]: x dot y, [..., 1] 15 | """ 16 | if isinstance(x, np.ndarray): 17 | return np.sum(x * y, -1, keepdims=True) 18 | else: 19 | return torch.sum(x * y, -1, keepdim=True) 20 | 21 | def length(x: Union[Tensor, ndarray], eps=1e-20) -> Union[Tensor, ndarray]: 22 | """length of an array (along the last dim). 23 | 24 | Args: 25 | x (Union[Tensor, ndarray]): x, [..., C] 26 | eps (float, optional): eps. Defaults to 1e-20. 27 | 28 | Returns: 29 | Union[Tensor, ndarray]: length, [..., 1] 30 | """ 31 | if isinstance(x, np.ndarray): 32 | return np.sqrt(np.maximum(np.sum(x * x, axis=-1, keepdims=True), eps)) 33 | else: 34 | return torch.sqrt(torch.clamp(dot(x, x), min=eps)) 35 | 36 | def safe_normalize(x: Union[Tensor, ndarray], eps=1e-20) -> Union[Tensor, ndarray]: 37 | """normalize an array (along the last dim). 38 | 39 | Args: 40 | x (Union[Tensor, ndarray]): x, [..., C] 41 | eps (float, optional): eps. Defaults to 1e-20. 42 | 43 | Returns: 44 | Union[Tensor, ndarray]: normalized x, [..., C] 45 | """ 46 | 47 | return x / length(x, eps) -------------------------------------------------------------------------------- /utils/typing.py: -------------------------------------------------------------------------------- 1 | # ref: https://mypy.readthedocs.io/en/stable/cheat_sheet_py3.html 2 | from typing import Sequence, List, Tuple, Dict, Any, Optional, Union, Literal, Callable 3 | from torch import Tensor 4 | from numpy import ndarray --------------------------------------------------------------------------------