├── .gitignore ├── README.md ├── configs ├── callbacks │ └── default.yaml ├── category │ ├── airplane.yaml │ ├── chair.yaml │ └── table.yaml ├── debug │ └── default.yaml ├── hydra │ └── default.yaml ├── logger │ ├── multi_loggers.yaml │ └── tensorboard.yaml ├── model │ ├── default.yaml │ ├── lang_default.yaml │ ├── lang_phase1.yaml │ ├── lang_phase2.yaml │ ├── phase1.yaml │ └── phase2.yaml ├── paths │ └── default.yaml ├── train.yaml └── trainer │ └── default.yaml ├── docs └── images │ └── salad_teaser.png ├── environment.yml ├── notebooks ├── part_completion_demo.ipynb ├── text_generation_demo.ipynb ├── uncleaned_demo │ ├── part_mixing.py │ ├── pvd │ │ ├── __init__.py │ │ ├── configs │ │ │ ├── callbacks │ │ │ │ ├── default.yaml │ │ │ │ └── fast.yaml │ │ │ ├── debug │ │ │ │ ├── default.yaml │ │ │ │ ├── fast.yaml │ │ │ │ └── super_fast.yaml │ │ │ ├── experiment │ │ │ │ ├── airplane_phase1.yaml │ │ │ │ ├── airplane_phase1_sym.yaml │ │ │ │ ├── airplane_phase2.yaml │ │ │ │ ├── airplane_single_phase.yaml │ │ │ │ ├── airplane_za.yaml │ │ │ │ ├── free_guidance.yaml │ │ │ │ ├── lang_phase1.yaml │ │ │ │ ├── lang_phase1_zc.yaml │ │ │ │ ├── lang_phase2.yaml │ │ │ │ ├── lang_single_phase.yaml │ │ │ │ ├── phase1-airplane.yaml │ │ │ │ ├── phase1-pos_enc-airplane.yaml │ │ │ │ ├── phase1-pos_enc.yaml │ │ │ │ ├── phase1-scaled-vectors-sym-airplane.yaml │ │ │ │ ├── phase1-scaled-vectors-sym.yaml │ │ │ │ ├── phase1-scaled-vectors.yaml │ │ │ │ ├── phase1-sym-airplane.yaml │ │ │ │ ├── phase1-sym.yaml │ │ │ │ ├── phase1-table.yaml │ │ │ │ ├── phase1.yaml │ │ │ │ ├── phase1_sym.yaml │ │ │ │ ├── phase2-airplane.yaml │ │ │ │ ├── phase2-scaled-vectors.yaml │ │ │ │ ├── phase2-table.yaml │ │ │ │ ├── phase2.yaml │ │ │ │ ├── reverse_phase.yaml │ │ │ │ ├── simclr-chair.yaml │ │ │ │ ├── single_phase-airplane.yaml │ │ │ │ ├── single_phase-pos_enc-airplane.yaml │ │ │ │ ├── single_phase-pos_enc.yaml │ │ │ │ ├── single_phase-table.yaml │ │ │ │ ├── single_phase.yaml │ │ │ │ ├── single_phase_zb-airplane.yaml │ │ │ │ ├── single_phase_zb.yaml │ │ │ │ ├── spaghetti_condition_self_attention.yaml │ │ │ │ ├── za-airplane.yaml │ │ │ │ └── za.yaml │ │ │ ├── hydra │ │ │ │ └── default.yaml │ │ │ ├── logger │ │ │ │ ├── multi_loggers.yaml │ │ │ │ └── tensorboard.yaml │ │ │ ├── model │ │ │ │ ├── ldm_big_phase1.yaml │ │ │ │ ├── ldm_big_phase1_language.yaml │ │ │ │ ├── ldm_big_phase1_pos_enc.yaml │ │ │ │ ├── ldm_big_phase2.yaml │ │ │ │ ├── ldm_big_phase2_language.yaml │ │ │ │ ├── ldm_big_single_phase-pos_enc.yaml │ │ │ │ ├── ldm_big_single_phase.yaml │ │ │ │ ├── ldm_big_single_phase_language.yaml │ │ │ │ ├── ldm_big_single_phase_zb.yaml │ │ │ │ ├── ldm_big_za.yaml │ │ │ │ ├── ldm_default.yaml │ │ │ │ ├── ldm_lang_default.yaml │ │ │ │ ├── simclr.yaml │ │ │ │ └── vae.yaml │ │ │ ├── paths │ │ │ │ └── default.yaml │ │ │ ├── simclr.yaml │ │ │ ├── train.yaml │ │ │ └── trainer │ │ │ │ └── default.yaml │ │ ├── dataset.py │ │ ├── diffusion │ │ │ ├── __init__.py │ │ │ ├── baselines │ │ │ │ └── single_phase.py │ │ │ ├── common.py │ │ │ ├── language_cond.py │ │ │ ├── ldm.py │ │ │ ├── network.py │ │ │ ├── phase1.py │ │ │ ├── phase1_sym.py │ │ │ ├── reverse_model.py │ │ │ └── single_phase.py │ │ ├── embedding │ │ │ └── model.py │ │ ├── misc │ │ │ └── extract_gmms.py │ │ ├── modules │ │ │ ├── __init__.py │ │ │ ├── pointnet2_ops_lib │ │ │ │ ├── MANIFEST.in │ │ │ │ ├── pointnet2_ops │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── _ext-src │ │ │ │ │ │ ├── include │ │ │ │ │ │ │ ├── ball_query.h │ │ │ │ │ │ │ ├── cuda_utils.h │ │ │ │ │ │ │ ├── group_points.h │ │ │ │ │ │ │ ├── interpolate.h │ │ │ │ │ │ │ ├── sampling.h │ │ │ │ │ │ │ └── utils.h │ │ │ │ │ │ └── src │ │ │ │ │ │ │ ├── ball_query.cpp │ │ │ │ │ │ │ ├── ball_query_gpu.cu │ │ │ │ │ │ │ ├── bindings.cpp │ │ │ │ │ │ │ ├── group_points.cpp │ │ │ │ │ │ │ ├── group_points_gpu.cu │ │ │ │ │ │ │ ├── interpolate.cpp │ │ │ │ │ │ │ ├── interpolate_gpu.cu │ │ │ │ │ │ │ ├── sampling.cpp │ │ │ │ │ │ │ └── sampling_gpu.cu │ │ │ │ │ ├── _version.py │ │ │ │ │ ├── pointnet2_modules.py │ │ │ │ │ └── pointnet2_utils.py │ │ │ │ └── setup.py │ │ │ ├── simple_module.py │ │ │ └── transformer.py │ │ ├── spaghetti_github │ │ │ ├── __init__.py │ │ │ ├── constants.py │ │ │ ├── custom_types.py │ │ │ ├── models │ │ │ │ ├── __init__.py │ │ │ │ ├── gm_utils.py │ │ │ │ ├── mlp_models.py │ │ │ │ ├── models_utils.py │ │ │ │ ├── occ_gmm.py │ │ │ │ ├── occ_loss.py │ │ │ │ └── transformer.py │ │ │ ├── options.py │ │ │ ├── ui │ │ │ │ ├── __init__.py │ │ │ │ ├── gaussian_status.py │ │ │ │ ├── inference_processing.py │ │ │ │ ├── interactors.py │ │ │ │ ├── mock_keyboard.py │ │ │ │ ├── occ_inference.py │ │ │ │ ├── ui_controllers.py │ │ │ │ ├── ui_main.py │ │ │ │ └── ui_utils.py │ │ │ └── utils │ │ │ │ ├── files_utils.py │ │ │ │ ├── mcubes_meshing.py │ │ │ │ ├── mesh_utils.py │ │ │ │ ├── myparse.py │ │ │ │ ├── rotation_utils.py │ │ │ │ └── train_utils.py │ │ ├── train_clr.py │ │ ├── train_ldm.py │ │ ├── train_vae.py │ │ ├── utils │ │ │ ├── __init__.py │ │ │ ├── contrastive_utils.py │ │ │ ├── ldm_utils.py │ │ │ ├── sde_utils.py │ │ │ ├── spaghetti_utils.py │ │ │ ├── spaghetti_v2_utils.py │ │ │ ├── train_utils.py │ │ │ └── transform.py │ │ └── vae │ │ │ ├── __init__.py │ │ │ ├── ae.py │ │ │ ├── pointnet2_utils.py │ │ │ └── vae.py │ ├── sdedit.py │ ├── shape_editing.py │ └── utils │ │ ├── __init__.py │ │ ├── common.py │ │ ├── gmm_edit.py │ │ ├── io_utils.py │ │ ├── mixing.py │ │ └── vis.py └── unconditional_generation_demo.ipynb ├── salad ├── data │ └── dataset.py ├── model_components │ ├── lstm.py │ ├── network.py │ ├── simple_module.py │ ├── transformer.py │ └── variance_schedule.py ├── models │ ├── __init__.py │ ├── base_model.py │ ├── language_phase1.py │ ├── language_phase2.py │ ├── phase1.py │ └── phase2.py ├── spaghetti │ ├── .gitignore │ ├── LICENSE │ ├── README.md │ ├── __init__.py │ ├── assets │ │ ├── checkpoints │ │ │ ├── .gitkeep │ │ │ └── README.md │ │ ├── mesh │ │ │ └── example.obj │ │ ├── readme_resources │ │ │ └── teaser-01.png │ │ └── ui_resources │ │ │ ├── icons-03.png │ │ │ ├── icons-04.png │ │ │ ├── icons-05.png │ │ │ ├── icons-06.png │ │ │ ├── icons-10.png │ │ │ ├── icons-11.png │ │ │ ├── icons-12.png │ │ │ ├── icons-13.png │ │ │ └── icons-14.png │ ├── constants.py │ ├── custom_types.py │ ├── models │ │ ├── __init__.py │ │ ├── gm_utils.py │ │ ├── mlp_models.py │ │ ├── models_utils.py │ │ ├── occ_gmm.py │ │ ├── occ_loss.py │ │ └── transformer.py │ ├── options.py │ ├── ui │ │ ├── __init__.py │ │ ├── gaussian_status.py │ │ ├── inference_processing.py │ │ ├── interactors.py │ │ ├── mock_keyboard.py │ │ ├── occ_inference.py │ │ ├── ui_controllers.py │ │ ├── ui_main.py │ │ └── ui_utils.py │ └── utils │ │ ├── __init__.py │ │ ├── files_utils.py │ │ ├── mcubes_meshing.py │ │ ├── mesh_utils.py │ │ ├── myparse.py │ │ ├── rotation_utils.py │ │ └── train_utils.py └── utils │ ├── __init__.py │ ├── fresnelvis.py │ ├── imageutil.py │ ├── logutil.py │ ├── meshutil.py │ ├── nputil.py │ ├── paths.py │ ├── shapeglot_util.py │ ├── spaghetti_util.py │ ├── sysutil.py │ ├── thutil.py │ ├── train_util.py │ └── visutil.py ├── setup.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | salad.egg-info/ 2 | checkpoints/ 3 | results/ 4 | data/ 5 | **/__pycache__/ 6 | .DS_Store 7 | **/*.pyc 8 | **/.ipynb_checkpoints 9 | !salad/data/ 10 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 🥗 SALAD: Part-Level Latent Diffusion for 3D Shape Generation and Manipulation, ICCV 2023 2 | 3 | ![teaser](./docs/images/salad_teaser.png) 4 | 5 | 6 | [**Arxiv**](https://arxiv.org/abs/2303.12236) | [**Project Page**](https://salad3d.github.io/)
7 | 8 | [Juil Koo\*](https://63days.github.io/), [Seungwoo Yoo\*](https://dvelopery0115.github.io/), [Minh Hieu Nguyen\*](https://min-hieu.github.io/), [Minhyuk Sung](https://mhsung.github.io/)
9 | \* denotes equal contribution. 10 | 11 | ### 🎉 This paper got accepted to ICCV 2023! 12 | 13 | # Introduction 14 | This repository contains the official implementation of 🥗 **SALAD: Part-Level Latent Diffusion for 3D Shape Generation and Manipulation**.
15 | **SALAD** is a 3D shape diffusion model which generates high-quality shapes and hints its zero-shot capability in various shape manipulation applications. More results can be viewed with a 3D viewer on the [project webpage](https://salad3d.github.io). 16 | 17 | [//]: # (### Abstract) 18 | > We present a cascaded diffusion model based on a part-level implicit 3D representation. Our model achieves state-of-the-art generation quality and also enables part-level shape editing and manipulation without any additional training in conditional setup. Diffusion models have demonstrated impressive capabilities in data generation as well as zero-shot completion and editing via a guided reverse process. Recent research on 3D diffusion models has focused on improving their generation capabilities with various data representations, while the absence of structural information has limited their capability in completion and editing tasks. We thus propose our novel diffusion model using a part-level implicit representation. To effectively learn diffusion with high-dimensional embedding vectors of parts, we propose a cascaded framework, learning diffusion first on a low-dimensional subspace encoding extrinsic parameters of parts and then on the other high-dimensional subspace encoding intrinsic attributes. In the experiments, we demonstrate the outperformance of our method compared with the previous ones both in generation and part-level completion and manipulation tasks. 19 | 20 | # UPDATES 21 | 22 | - [x] Training code of SALAD with Chair class. (May 8th 2023) 23 | - [x] Training code of *text-conditioned* SALAD. (May 8th 2023) 24 | - [x] Demo of *text-guided* shape generation. (May 8th 2023) 25 | - [x] Code of generation of more classes. (Sep. 17th 2023) 26 | - [x] Demo of part completion. 27 | - [ ] Demo of part mixing and refinement. 28 | - [ ] Demo of text-guided part completion. 29 | 30 | 31 | # Get Started 32 | 33 | ## Installation 34 | 35 | For main requirements, we have tested the code with Python 3.9, CUDA 11.3 and Pytorch 1.12.1+cu113. 36 | 37 | ``` 38 | git clone https://github.com/63days/salad/ 39 | cd SALAD 40 | conda env create -f environment.yml 41 | conda activate salad 42 | pip install -e . 43 | ``` 44 | 45 | ## Data and Model Checkpoint 46 | We provide ready-to-use data and pre-trained SALAD and SPAGHETTI checkpoints [here](http://onedrive.live.com/?cid=0ce615b143fc4bdc&id=0CE615B143FC4BDC%212507&resid=0CE615B143FC4BDC%212507&ithint=folder&migratedtospo=true&redeem=aHR0cHM6Ly8xZHJ2Lm1zL2YvYy8wY2U2MTViMTQzZmM0YmRjL1F0eExfRU94RmVZZ2dBekxDUUFBQUFBQU42dzI0U1RrV2VpRXd3&v=validatepermission). 47 | Unzip files and put them under the corresponding directories. SPAGHETTI checkpoint directoriess should be under the `SALAD/salad/spaghetti/assets/checkpoints/` directory. 48 | 49 | 50 | # Training 51 | ## Unconditional Generation 52 | 53 | You can train models for unconditional generation with airplanes, chairs or tables. 54 | ``` 55 | python train.py model={phase1, phase2} category={airplane, chair, table} 56 | ``` 57 | 58 | ## Text-Guided Generation 59 | 60 | For text-guided generation, train a first-phase model and a second-phase model: 61 | ``` 62 | python train.py model={lang_phase1, lang_phase2} 63 | ``` 64 | 65 | # Demo 66 | ## Unconditional Shape Generation 67 | 68 | We provide the demo code of unconditional generation with airplanes, chairs and tables at `notebooks/unconditional_generation_demo.ipynb`. 69 | 70 | ## Text-Guided Shape Generation 71 | 72 | We also provide text-guided shape generation demo code at `notebooks/text_generation_demo.ipynb`. 73 | 74 | 75 | # Citation 76 | If you find our work useful, please consider citing: 77 | 78 | ```bibtex 79 | @article{koo2023salad, 80 | title={{SALAD}: Part-Level Latent Diffusion for 3D Shape Generation and Manipulation}, 81 | author={Koo, Juil and Yoo, Seungwoo and Nguyen, Minh Hieu and Sung, Minhyuk}, 82 | year={2023}, 83 | journal={arXiv preprint arXiv:2303.12236}, 84 | } 85 | ``` 86 | 87 | # Reference 88 | We borrow [SPAGHETTI](https://amirhertz.github.io/spaghetti/) code from the [official implementation](https://github.com/amirhertz/spaghetti). 89 | -------------------------------------------------------------------------------- /configs/callbacks/default.yaml: -------------------------------------------------------------------------------- 1 | model_summary: 2 | _target_: pytorch_lightning.callbacks.RichModelSummary 3 | max_depth: 1 4 | 5 | rich_progress_bar: 6 | _target_: pytorch_lightning.callbacks.RichProgressBar 7 | 8 | model_checkpoint: 9 | _target_: pytorch_lightning.callbacks.ModelCheckpoint 10 | monitor: "train/loss" 11 | mode: "min" 12 | save_top_k: 1 13 | save_last: true 14 | verbose: true 15 | dirpath: "checkpoints/" 16 | filename: "epoch={epoch:02d}-train_loss={train/loss:.4f}" 17 | #every_n_epochs: 200 18 | auto_insert_metric_name: false 19 | 20 | #lr_monitor: 21 | # _target_: pytorch_lightning.callbacks.LearningRateMonitor 22 | # logging_interval: "step" 23 | 24 | -------------------------------------------------------------------------------- /configs/category/airplane.yaml: -------------------------------------------------------------------------------- 1 | latent_path: spaghetti_airplane_latents.hdf5 # file path from the salad.utils.paths.DATA_DIR directory. 2 | spaghetti_tag: "airplanes" 3 | -------------------------------------------------------------------------------- /configs/category/chair.yaml: -------------------------------------------------------------------------------- 1 | latent_path: spaghetti_chair_latents.hdf5 # file path from the salad.utils.paths.DATA_DIR directory. 2 | spaghetti_tag: "chairs_large" 3 | -------------------------------------------------------------------------------- /configs/category/table.yaml: -------------------------------------------------------------------------------- 1 | latent_path: spaghetti_table_latents.hdf5 # file path from the salad.utils.paths.DATA_DIR directory. 2 | spaghetti_tag: "tables" 3 | -------------------------------------------------------------------------------- /configs/debug/default.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | # 3 | callbacks: null 4 | logger: null 5 | 6 | hydra: 7 | job_logging: 8 | root: 9 | level: DEBUG 10 | 11 | trainer: 12 | max_epochs: 1 13 | -------------------------------------------------------------------------------- /configs/hydra/default.yaml: -------------------------------------------------------------------------------- 1 | run: 2 | dir: ${paths.save_dir} 3 | 4 | sweep: 5 | dir: ${paths.save_dir} 6 | subdir: ${hydra.job.num} 7 | -------------------------------------------------------------------------------- /configs/logger/multi_loggers.yaml: -------------------------------------------------------------------------------- 1 | wandb: 2 | _target_: pytorch_lightning.loggers.wandb.WandbLogger 3 | project: "spaghetti-ldm" 4 | name: ${name} 5 | save_dir: "." 6 | entity: "geometry" 7 | 8 | tensorboard: 9 | _target_: pytorch_lightning.loggers.tensorboard.TensorBoardLogger 10 | save_dir: "tensorboard/" 11 | name: null 12 | version: ${name} 13 | log_graph: false 14 | default_hp_metric: true 15 | prefix: "" 16 | -------------------------------------------------------------------------------- /configs/logger/tensorboard.yaml: -------------------------------------------------------------------------------- 1 | tensorboard: 2 | _target_: pytorch_lightning.loggers.tensorboard.TensorBoardLogger 3 | save_dir: "tensorboard/" 4 | name: null 5 | #version: ${name} 6 | log_graph: true 7 | default_hp_metric: false 8 | prefix: "" 9 | -------------------------------------------------------------------------------- /configs/model/default.yaml: -------------------------------------------------------------------------------- 1 | network: null 2 | 3 | variance_schedule: 4 | _target_: salad.model_components.variance_schedule.VarianceSchedule 5 | num_steps: &time_steps 1000 6 | beta_1: 1e-4 7 | beta_T: 0.05 8 | mode: linear 9 | 10 | # optimizer 11 | lr: 1e-4 12 | batch_size: ${batch_size} 13 | 14 | # dataset 15 | dataset_kwargs: 16 | data_path: ${category.latent_path} 17 | repeat: 3 18 | data_keys: ["s_j_affine"] 19 | 20 | # model 21 | num_timesteps: *time_steps 22 | faster: true 23 | validation_step: 10 24 | no_run_validation: false 25 | spaghetti_tag: ${category.spaghetti_tag} 26 | -------------------------------------------------------------------------------- /configs/model/lang_default.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default.yaml 3 | 4 | # dataset 5 | dataset_kwargs: 6 | data_path: ${category.latent_path} 7 | repeat: 1 8 | only_easy_context: false 9 | global_normalization: ${model.global_normalization} 10 | 11 | # model 12 | global_normalization: false 13 | text_encoder_freeze: false 14 | use_lstm: true 15 | classifier_free_guidance: true 16 | conditioning_dropout_prob: 0.2 17 | 18 | -------------------------------------------------------------------------------- /configs/model/lang_phase1.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - lang_default.yaml 3 | 4 | _target_: salad.models.language_phase1.LangPhase1Model 5 | network: 6 | _target_: salad.model_components.network.CondDiffNetwork 7 | input_dim: 16 8 | residual: true 9 | context_dim: 768 10 | context_embedding_dim: 1024 11 | embedding_dim: 512 12 | encoder_use_time: false 13 | encoder_type: pointwise 14 | decoder_type: transformer_encoder 15 | enc_num_layers: 2 16 | dec_num_layers: 6 17 | use_timestep_embedder: true 18 | timestep_embedder_dim: 128 19 | 20 | global_normalization: partial 21 | 22 | dataset_kwargs: 23 | data_keys: ["g_js_affine"] 24 | -------------------------------------------------------------------------------- /configs/model/lang_phase2.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - lang_default.yaml 3 | 4 | _target_: salad.models.language_phase2.LangPhase2Model 5 | network: 6 | _target_: salad.model_components.network.CondDiffNetwork 7 | input_dim: 512 8 | residual: true 9 | context_dim: 784 # concat of 768 lang feat and gaussian. 10 | context_embedding_dim: 1024 11 | embedding_dim: 512 12 | encoder_use_time: false 13 | encoder_type: transformer 14 | decoder_type: transformer_encoder 15 | enc_num_layers: 6 16 | dec_num_layers: 6 17 | use_timestep_embedder: true 18 | timestep_embedder_dim: 128 19 | 20 | dataset_kwargs: 21 | data_keys: ["s_j_affine", "g_js_affine"] 22 | -------------------------------------------------------------------------------- /configs/model/phase1.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default.yaml 3 | 4 | _target_: salad.models.phase1.Phase1Model 5 | 6 | network: 7 | _target_: salad.model_components.network.UnCondDiffNetwork 8 | input_dim: 16 9 | embedding_dim: 512 10 | num_heads: 4 11 | use_timestep_embedder: true 12 | timestep_embedder_dim: 128 13 | enc_num_layers: 6 14 | residual: true 15 | encoder_type: transformer 16 | attn_dropout: 0.0 17 | 18 | global_normalization: partial # normalize pi, eigenvalues. 19 | 20 | dataset_kwargs: 21 | data_keys: ["g_js_affine"] 22 | global_normalization: ${model.global_normalization} 23 | -------------------------------------------------------------------------------- /configs/model/phase2.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default.yaml 3 | 4 | _target_: salad.models.phase2.Phase2Model 5 | network: 6 | _target_: salad.model_components.network.CondDiffNetwork 7 | input_dim: 512 8 | residual: true 9 | context_dim: 16 # gaussian condition dim. 10 | context_embedding_dim: 512 11 | embedding_dim: 512 12 | encoder_use_time: false 13 | encoder_type: transformer 14 | decoder_type: transformer_encoder # we don't use cross attention. 15 | enc_num_layers: 6 16 | dec_num_layers: 6 17 | use_timestep_embedder: true 18 | timestep_embedder_dim: 128 19 | 20 | 21 | dataset_kwargs: 22 | data_keys: ["s_j_affine", "g_js_affine"] 23 | -------------------------------------------------------------------------------- /configs/paths/default.yaml: -------------------------------------------------------------------------------- 1 | work_dir: ${hydra:runtime.cwd} 2 | save_dir: ${hydra:runtime.cwd}/results/${name}/${now:%m%d_%H%M%S} 3 | -------------------------------------------------------------------------------- /configs/train.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - _self_ 5 | - category: chair.yaml 6 | - model: phase1.yaml 7 | - callbacks: default.yaml 8 | - logger: tensorboard.yaml # multi_loggers.yaml 9 | - trainer: default.yaml 10 | - paths: default.yaml 11 | - hydra: default.yaml 12 | - debug: null 13 | 14 | name: default 15 | seed: 63 16 | 17 | gpus: [0] 18 | epochs: 5000 19 | batch_size: 32 20 | 21 | resume_from_checkpoint: null 22 | -------------------------------------------------------------------------------- /configs/trainer/default.yaml: -------------------------------------------------------------------------------- 1 | _target_: pytorch_lightning.Trainer 2 | accelerator: gpu 3 | devices: ${gpus} 4 | max_epochs: ${epochs} 5 | -------------------------------------------------------------------------------- /docs/images/salad_teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KAIST-Visual-AI-Group/SALAD/db66b1b2ce57d8a4beb5bd610bc006ccaff43b1e/docs/images/salad_teaser.png -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/__init__.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | pvd_chair_path=Path("/home/juil/projects/3D_CRISPR/baselines/PVD/output/test_generation/chair/syn/samples.pth") 4 | pvd_airplane_path=Path("/home/juil/projects/3D_CRISPR/baselines/PVD/output/test_generation/airplane/syn/samples.pth") 5 | dpm_chair_path=Path("/home/hieuristics/docker_home/project/baselines/evaluation_wavelet/metrics/data/chair/dpm/o3d_all_pointclouds.hdf5") 6 | dpm_airplane_path=Path("/home/hieuristics/docker_home/project/baselines/evaluation_wavelet/metrics/data/airplane/dpm/o3d_all_pointclouds.hdf5") 7 | lion_chair_path=Path("/home/juil/projects/3D_CRISPR/baselines/LION/data_pretrained/LION_chair.pt") 8 | lion_airplane_path=Path("/home/juil/projects/3D_CRISPR/baselines/LION/data_pretrained/LION_airplane.pt") 9 | voxel_gan_chair_path=Path("/home/juil/projects/3D_CRISPR/baselines/shapegan/generated_objects/chair") 10 | voxel_gan_airplane_path=Path("/home/juil/projects/3D_CRISPR/baselines/shapegan/generated_objects/airplane") 11 | wavelet_chair_path=Path("/home/hieuristics/docker_home/project/baselines/Wavelet-Generation/debug/2023-02-14_14-47-22_Wavelet-New-chair") 12 | wavelet_airplane_path=Path("/home/hieuristics/docker_home/project/baselines/Wavelet-Generation/debug/2023-02-06_23-49-06_Network-Marching-Cubes-Diffusion-Gen-Airplane") 13 | spaghetti_chair_path=Path("/scratch/juil/iccv2023/data/spaghetti_chair_za_random_sample/meshes") 14 | spaghetti_airplane_path=Path("/home/juil/projects/3D_CRISPR/data/spaghetti_airplane_random_generation/meshes") 15 | 16 | our_chair_path = Path("/home/juil/projects/3D_CRISPR/pointnet-vae-diffusion/pvd/results/phase2/augment_final_0214/0214_202607/eval-02-23-143029/meshes") 17 | our_airplane_path = Path("/home/juil/projects/3D_CRISPR/pointnet-vae-diffusion/pvd/results/phase2-airplane/augment_final_0214/0214_202550/eval-02-27-013138/meshes") 18 | -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/configs/callbacks/default.yaml: -------------------------------------------------------------------------------- 1 | model_summary: 2 | _target_: pytorch_lightning.callbacks.RichModelSummary 3 | max_depth: 1 4 | 5 | rich_progress_bar: 6 | _target_: pytorch_lightning.callbacks.RichProgressBar 7 | 8 | model_checkpoint: 9 | _target_: pytorch_lightning.callbacks.ModelCheckpoint 10 | monitor: "train/loss" 11 | mode: "min" 12 | save_top_k: -1 13 | save_last: true 14 | verbose: true 15 | dirpath: "checkpoints/" 16 | filename: "epoch={epoch:02d}-val_loss={val/loss:.4f}" 17 | every_n_epochs: 200 18 | auto_insert_metric_name: false 19 | 20 | lr_monitor: 21 | _target_: pytorch_lightning.callbacks.LearningRateMonitor 22 | logging_interval: "step" 23 | 24 | -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/configs/callbacks/fast.yaml: -------------------------------------------------------------------------------- 1 | model_summary: 2 | _target_: pytorch_lightning.callbacks.RichModelSummary 3 | max_depth: 2 4 | 5 | rich_progress_bar: 6 | _target_: pytorch_lightning.callbacks.RichProgressBar 7 | 8 | model_checkpoint: 9 | _target_: pytorch_lightning.callbacks.ModelCheckpoint 10 | monitor: "train/loss" 11 | mode: "min" 12 | save_top_k: -1 13 | save_last: true 14 | verbose: true 15 | dirpath: "checkpoints/" 16 | filename: "epoch={epoch:02d}-val_loss={val/loss:.4f}" 17 | every_n_epochs: 200 18 | auto_insert_metric_name: false 19 | -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/configs/debug/default.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | # 3 | callbacks: null 4 | logger: null 5 | 6 | hydra: 7 | job_logging: 8 | root: 9 | level: DEBUG 10 | 11 | trainer: 12 | max_epochs: 1 13 | -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/configs/debug/fast.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | # 3 | model: 4 | no_run_validation: true 5 | -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/configs/debug/super_fast.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | # 3 | defaults: 4 | - override /logger: tensorboard.yaml 5 | - override /callbacks: fast.yaml 6 | logger: null 7 | model: 8 | no_run_validation: true 9 | -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/configs/experiment/airplane_phase1.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: ldm_big_phase1.yaml 4 | - override /paths: default.yaml 5 | 6 | paths: 7 | #### 8 | # save_dir: ${paths.work_dir}/results/phase1/${name}/${now:%m%d_%H%M%S} 9 | save_dir: ${paths.work_dir}/results/airplane-phase1/${name}/${now:%m%d_%H%M%S} 10 | #### 11 | 12 | #### 13 | data: 14 | latent_path: "/home/dreamy1534/Projects/iccv2023-spaghetti/part_ldm/pvd/data/airplane_1775.hdf5" 15 | spaghetti_tag: "airplanes" 16 | #### 17 | 18 | logger: 19 | wandb: 20 | #### 21 | # < Seungwoo > 22 | # Modified project name to prevent overwrite. 23 | # project: "spaghetti-gaus" 24 | # < Seungwoo > 25 | # Modified project name for airplane data 26 | # project: "sw_spaghetti-gaus" 27 | project: "sw_spaghetti-airplane-gaus" 28 | #### -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/configs/experiment/airplane_phase1_sym.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: ldm_big_phase1.yaml 4 | - override /paths: default.yaml 5 | 6 | 7 | #### < Seungwoo > 8 | # Change log directory 9 | paths: 10 | # save_dir: ${paths.work_dir}/results/phase1/${name}/${now:%m%d_%H%M%S} 11 | # save_dir: ${paths.work_dir}/results/phase1_sym/${name}/${now:%m%d_%H%M%S} 12 | save_dir: ${paths.work_dir}/results/airplane-phase1_sym/${name}/${now:%m%d_%H%M%S} 13 | #### 14 | 15 | 16 | #### < Seungwoo > 17 | # Override config in 'data/spaghetti_author' to load the proper data 18 | data: 19 | latent_path: "/home/dreamy1534/Projects/iccv2023-spaghetti/part_ldm/pvd/data/airplane_1775_half.hdf5" 20 | spaghetti_tag: "airplanes" 21 | #### 22 | 23 | 24 | #### < Seungwoo > 25 | # Override configs in 'ldm_big_phase1' to use the symmetric model. 26 | model: 27 | _target_: pvd.diffusion.phase1_sym.GaussianSALDM 28 | eigen_value_clipping: true 29 | #### 30 | 31 | logger: 32 | wandb: 33 | # < Seungwoo > 34 | # Modified project name to prevent overwrite. 35 | # project: "spaghetti-gaus" 36 | project: "sw_spaghetti-gaus-sym" 37 | -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/configs/experiment/airplane_phase2.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: ldm_big_phase2.yaml 4 | - override /paths: default.yaml 5 | 6 | paths: 7 | #### 8 | # save_dir: ${paths.work_dir}/results/phase2/${name}/${now:%m%d_%H%M%S} 9 | save_dir: ${paths.work_dir}/results/airplane-phase2/${name}/${now:%m%d_%H%M%S} 10 | #### 11 | 12 | #### 13 | data: 14 | latent_path: "/home/dreamy1534/Projects/iccv2023-spaghetti/part_ldm/pvd/data/airplane_1775.hdf5" 15 | spaghetti_tag: "airplanes" 16 | #### 17 | 18 | logger: 19 | wandb: 20 | # < Seungwoo > 21 | # Modified project name to prevent overwrite. 22 | # project: "spaghetti-phase2" 23 | # < Seungwoo > 24 | # Modified project name for airplane data 25 | # project: "sw_spaghetti-phase2" 26 | project: "sw_spaghetti-airplane-phase2" 27 | -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/configs/experiment/airplane_single_phase.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: ldm_big_phase1.yaml 4 | - override /paths: default.yaml 5 | 6 | paths: 7 | save_dir: ${paths.work_dir}/results/airplane-single_phase/${name}/${now:%m%d_%H%M%S} 8 | 9 | model: 10 | _target_: pvd.diffusion.baselines.single_phase.SinglePhaseSALDM 11 | network: 12 | _target_: pvd.diffusion.baselines.single_phase.LatentSelfAttentionNetwork 13 | input_dim: 528 14 | 15 | #### < Seungwoo > 16 | # Adjust the number of validation steps 17 | validation_step: 5 18 | #### 19 | 20 | dataset_kwargs: 21 | data_keys: ["concat"] 22 | 23 | data: 24 | #### 25 | # latent_path: "/home/dreamy1534/Projects/iccv2023-spaghetti/part_ldm/pvd/data/chair_6755_concat.hdf5" 26 | latent_path: "/home/dreamy1534/Projects/iccv2023-spaghetti/part_ldm/pvd/data/airplane_1775_concat.hdf5" 27 | spaghetti_tag: "airplanes" 28 | #### 29 | 30 | logger: 31 | wandb: 32 | #### 33 | # project: "spaghetti-single_phase" 34 | project: "spaghetti-airplane-single_phase" 35 | #### 36 | -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/configs/experiment/airplane_za.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: ldm_big_za.yaml 4 | - override /paths: default.yaml 5 | 6 | paths: 7 | #### 8 | # save_dir: ${paths.work_dir}/results/za/${name}/${now:%m%d_%H%M%S} 9 | save_dir: ${paths.work_dir}/results/airplane-za/${name}/${now:%m%d_%H%M%S} 10 | #### 11 | 12 | #### 13 | data: 14 | latent_path: "/home/dreamy1534/Projects/iccv2023-spaghetti/part_ldm/pvd/data/airplane_1775.hdf5" 15 | spaghetti_tag: "airplanes" 16 | #### 17 | 18 | logger: 19 | wandb: 20 | #### 21 | # project: "spaghetti-za" 22 | project: "sw_spaghetti-airplane-za" 23 | #### -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/configs/experiment/free_guidance.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: ldm_big_phase2.yaml 4 | - override /paths: default.yaml 5 | 6 | paths: 7 | save_dir: ${paths.work_dir}/results/free_guidance/${name}/${now:%m%d_%H%M%S} 8 | 9 | model: 10 | classifier_free_guidance: true 11 | draw_tsne: false 12 | -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/configs/experiment/lang_phase1.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: ldm_big_phase1_language.yaml 4 | - override /paths: default.yaml 5 | #- override /data: spaghetti_half_chair.yaml 6 | - override /data: spaghetti_author.yaml 7 | - override /callbacks: default.yaml 8 | 9 | paths: 10 | save_dir: ${paths.work_dir}/results/phase1-lang/${name}/${now:%m%d_%H%M%S} 11 | 12 | logger: 13 | wandb: 14 | project: "spaghetti-phase1-lang" 15 | 16 | 17 | callbacks: 18 | model_checkpoint: 19 | every_n_epochs: 5 20 | 21 | -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/configs/experiment/lang_phase1_zc.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: ldm_big_phase1_language.yaml 4 | - override /paths: default.yaml 5 | #- override /data: spaghetti_half_chair.yaml 6 | - override /data: spaghetti_author.yaml 7 | - override /callbacks: default.yaml 8 | 9 | paths: 10 | save_dir: ${paths.work_dir}/results/phase1-lang/${name}/${now:%m%d_%H%M%S} 11 | 12 | logger: 13 | wandb: 14 | project: "spaghetti-phase1-lang" 15 | 16 | model: 17 | dataset_kwargs: 18 | data_keys: ["Zc"] 19 | use_zc: true 20 | network: 21 | input_dim: 512 22 | 23 | 24 | callbacks: 25 | model_checkpoint: 26 | every_n_epochs: 5 27 | 28 | -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/configs/experiment/lang_phase2.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: ldm_big_phase2_language.yaml 4 | - override /paths: default.yaml 5 | - override /data: spaghetti_author.yaml 6 | - override /callbacks: default.yaml 7 | 8 | paths: 9 | save_dir: ${paths.work_dir}/results/phase2-lang/${name}/${now:%m%d_%H%M%S} 10 | 11 | logger: 12 | wandb: 13 | project: "spaghetti-phase2-lang" 14 | 15 | callbacks: 16 | model_checkpoint: 17 | every_n_epochs: 5 18 | 19 | -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/configs/experiment/lang_single_phase.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: ldm_big_single_phase_language.yaml 4 | - override /paths: default.yaml 5 | - override /callbacks: default.yaml 6 | 7 | paths: 8 | save_dir: ${paths.work_dir}/results/single_phase-lang/${name}/${now:%m%d_%H%M%S} 9 | 10 | logger: 11 | wandb: 12 | project: "spaghetti-single_phase-lang" 13 | 14 | callbacks: 15 | model_checkpoint: 16 | every_n_epochs: 5 17 | 18 | -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/configs/experiment/phase1-airplane.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: ldm_big_phase1.yaml 4 | - override /paths: default.yaml 5 | - override /data: spaghetti_author_airplane.yaml 6 | 7 | paths: 8 | save_dir: ${paths.work_dir}/results/phase1-airplane/${name}/${now:%m%d_%H%M%S} 9 | 10 | logger: 11 | wandb: 12 | project: "spaghetti-gaus-airplane" 13 | 14 | model: 15 | spaghetti_tag: "airplanes" 16 | -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/configs/experiment/phase1-pos_enc-airplane.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: ldm_big_phase1_pos_enc.yaml 4 | - override /paths: default.yaml 5 | - override /data: spaghetti_author_airplane.yaml 6 | 7 | paths: 8 | save_dir: ${paths.work_dir}/results/phase1-pos_enc-airplane/${name}/${now:%m%d_%H%M%S} 9 | 10 | logger: 11 | wandb: 12 | project: "spaghetti-gaus-pos_enc-airplane" 13 | model: 14 | spaghetti_tag: "airplanes" 15 | -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/configs/experiment/phase1-pos_enc.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: ldm_big_phase1_pos_enc.yaml 4 | - override /paths: default.yaml 5 | 6 | paths: 7 | save_dir: ${paths.work_dir}/results/phase1-pos_enc/${name}/${now:%m%d_%H%M%S} 8 | 9 | logger: 10 | wandb: 11 | project: "spaghetti-gaus" 12 | -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/configs/experiment/phase1-scaled-vectors-sym-airplane.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: ldm_big_phase1.yaml 4 | - override /paths: default.yaml 5 | - override /data: spaghetti_half_airplane.yaml 6 | 7 | paths: 8 | save_dir: ${paths.work_dir}/results/phase1-sym-scaled-vectors-airplane/${name}/${now:%m%d_%H%M%S} 9 | 10 | logger: 11 | wandb: 12 | project: "spaghetti-gaus-sym-scaled-vectors-airplane" 13 | 14 | model: 15 | _target_: pvd.diffusion.phase1_sym.GaussianSymSALDM 16 | use_symmetric: &sym true 17 | network: 18 | input_dim: 13 19 | scale_eigenvectors: &scale_vectors true 20 | dataset_kwargs: 21 | scale_eigenvectors: *scale_vectors 22 | use_scaled_eigenvectors: *scale_vectors 23 | spaghetti_tag: "airplanes" 24 | -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/configs/experiment/phase1-scaled-vectors-sym.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: ldm_big_phase1.yaml 4 | - override /paths: default.yaml 5 | - override /data: spaghetti_half_chair.yaml 6 | 7 | paths: 8 | save_dir: ${paths.work_dir}/results/phase1-sym-scaled-vectors/${name}/${now:%m%d_%H%M%S} 9 | 10 | logger: 11 | wandb: 12 | project: "spaghetti-gaus-sym-scaled-vectors" 13 | 14 | model: 15 | _target_: pvd.diffusion.phase1_sym.GaussianSymSALDM 16 | use_symmetric: &sym true 17 | network: 18 | input_dim: 13 19 | scale_eigenvectors: &scale_vectors true 20 | dataset_kwargs: 21 | scale_eigenvectors: *scale_vectors 22 | use_scaled_eigenvectors: *scale_vectors 23 | 24 | -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/configs/experiment/phase1-scaled-vectors.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: ldm_big_phase1.yaml 4 | - override /paths: default.yaml 5 | 6 | paths: 7 | save_dir: ${paths.work_dir}/results/phase1-scaled-vectors/${name}/${now:%m%d_%H%M%S} 8 | 9 | logger: 10 | wandb: 11 | project: "spaghetti-phase1-scaled-vectors" 12 | 13 | model: 14 | #variance_schedule: 15 | #num_steps: 2000 16 | #beta_1: 1e-6 17 | #beta_T: 0.02 18 | #mode: linear 19 | 20 | network: 21 | input_dim: 13 22 | scale_eigenvectors: &scale_vectors true 23 | dataset_kwargs: 24 | scale_eigenvectors: *scale_vectors 25 | 26 | use_scaled_eigenvectors: *scale_vectors 27 | -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/configs/experiment/phase1-sym-airplane.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: ldm_big_phase1.yaml 4 | - override /paths: default.yaml 5 | - override /data: spaghetti_half_airplane.yaml 6 | 7 | paths: 8 | save_dir: ${paths.work_dir}/results/phase1-sym-airplane/${name}/${now:%m%d_%H%M%S} 9 | 10 | logger: 11 | wandb: 12 | project: "spaghetti-gaus-sym-airplane" 13 | 14 | model: 15 | _target_: pvd.diffusion.phase1_sym.GaussianSymSALDM 16 | use_symmetric: &sym true 17 | spaghetti_tag: "airplanes" 18 | 19 | -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/configs/experiment/phase1-sym.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: ldm_big_phase1.yaml 4 | - override /paths: default.yaml 5 | - override /data: spaghetti_half_chair.yaml 6 | 7 | paths: 8 | save_dir: ${paths.work_dir}/results/phase1-sym/${name}/${now:%m%d_%H%M%S} 9 | 10 | logger: 11 | wandb: 12 | project: "spaghetti-gaus-sym" 13 | 14 | model: 15 | _target_: pvd.diffusion.phase1_sym.GaussianSymSALDM 16 | use_symmetric: &sym true 17 | 18 | -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/configs/experiment/phase1-table.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: ldm_big_phase1.yaml 4 | - override /paths: default.yaml 5 | - override /data: spaghetti_table.yaml 6 | 7 | paths: 8 | save_dir: ${paths.work_dir}/results/phase1-airplane/${name}/${now:%m%d_%H%M%S} 9 | 10 | logger: 11 | wandb: 12 | project: "spaghetti-gaus-table" 13 | 14 | model: 15 | spaghetti_tag: "tables" 16 | -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/configs/experiment/phase1.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: ldm_big_phase1.yaml 4 | - override /paths: default.yaml 5 | 6 | paths: 7 | save_dir: ${paths.work_dir}/results/phase1/${name}/${now:%m%d_%H%M%S} 8 | 9 | logger: 10 | wandb: 11 | # < Seungwoo > 12 | # Modified project name to prevent overwrite. 13 | # project: "spaghetti-gaus" 14 | project: "sw_spaghetti-gaus" 15 | -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/configs/experiment/phase1_sym.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: ldm_big_phase1.yaml 4 | - override /paths: default.yaml 5 | 6 | #### < Seungwoo > 7 | paths: 8 | # save_dir: ${paths.work_dir}/results/phase1/${name}/${now:%m%d_%H%M%S} 9 | save_dir: ${paths.work_dir}/results/phase1_sym/${name}/${now:%m%d_%H%M%S} 10 | #### 11 | 12 | 13 | #### < Seungwoo > 14 | # Override config in 'data/spaghetti_author' to load the proper data 15 | data: 16 | latent_path: "/home/dreamy1534/Projects/iccv2023-spaghetti/part_ldm/pvd/data/6755_2023-01-19-15-02-36_half.hdf5" 17 | spaghetti_tag: "chairs_large" 18 | #### 19 | 20 | 21 | #### < Seungwoo > 22 | # Override configs in 'ldm_big_phase1' to use the symmetric model. 23 | model: 24 | _target_: pvd.diffusion.phase1_sym.GaussianSALDM 25 | eigen_value_clipping: true 26 | #### 27 | 28 | logger: 29 | wandb: 30 | # < Seungwoo > 31 | # Modified project name to prevent overwrite. 32 | # project: "spaghetti-gaus" 33 | project: "sw_spaghetti-gaus-sym" 34 | -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/configs/experiment/phase2-airplane.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: ldm_big_phase2.yaml 4 | - override /paths: default.yaml 5 | - override /data: spaghetti_author_airplane.yaml 6 | 7 | paths: 8 | save_dir: ${paths.work_dir}/results/phase2-airplane/${name}/${now:%m%d_%H%M%S} 9 | 10 | logger: 11 | wandb: 12 | project: "spaghetti-phase2-airplane" 13 | 14 | model: 15 | spaghetti_tag: "airplanes" 16 | -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/configs/experiment/phase2-scaled-vectors.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: ldm_big_phase2.yaml 4 | - override /paths: default.yaml 5 | 6 | paths: 7 | save_dir: ${paths.work_dir}/results/phase2-scaled-vectors/${name}/${now:%m%d_%H%M%S} 8 | 9 | logger: 10 | wandb: 11 | project: "spaghetti-phase2-scaled-vectors" 12 | 13 | model: 14 | condition_var_sched: 15 | num_steps: 2000 16 | beta_1: 1e-6 17 | beta_T: 0.02 18 | max_cond_t: 4 19 | 20 | network: 21 | context_dim: 13 22 | scale_eigenvectors: &scale_vectors true 23 | dataset_kwargs: 24 | scale_eigenvectors: *scale_vectors 25 | 26 | use_scaled_eigenvectors: *scale_vectors 27 | augment_condition: true 28 | -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/configs/experiment/phase2-table.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: ldm_big_phase2.yaml 4 | - override /paths: default.yaml 5 | - override /data: spaghetti_table.yaml 6 | 7 | paths: 8 | save_dir: ${paths.work_dir}/results/phase2-airplane/${name}/${now:%m%d_%H%M%S} 9 | 10 | logger: 11 | wandb: 12 | project: "spaghetti-phase2-table" 13 | 14 | model: 15 | spaghetti_tag: "tables" 16 | -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/configs/experiment/phase2.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: ldm_big_phase2.yaml 4 | - override /paths: default.yaml 5 | 6 | paths: 7 | save_dir: ${paths.work_dir}/results/phase2/${name}/${now:%m%d_%H%M%S} 8 | 9 | logger: 10 | wandb: 11 | #### 12 | # < Seungwoo > 13 | # Modified project name to prevent overwrite. 14 | # project: "spaghetti-phase2" 15 | project: "sw_spaghetti-phase2" 16 | #### -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/configs/experiment/reverse_phase.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: ldm_big_phase2.yaml 4 | - override /paths: default.yaml 5 | 6 | paths: 7 | save_dir: ${paths.work_dir}/results/reverse_phase/${name}/${now:%m%d_%H%M%S} 8 | 9 | logger: 10 | wandb: 11 | project: "spaghetti-reverse_phase" 12 | 13 | 14 | model: 15 | _target_: pvd.diffusion.reverse_model.ReverseConditionSALDM 16 | network: 17 | input_dim: 16 18 | residual: true 19 | context_dim: 512 20 | 21 | dataset_kwargs: 22 | data_keys: ["g_js_affine", "s_j_affine"] 23 | 24 | -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/configs/experiment/simclr-chair.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: simclr.yaml 4 | - override /paths: default.yaml 5 | - override /data: spaghetti_author.yaml 6 | 7 | paths: 8 | save_dir: ${paths.work_dir}/results/simclr-chair/${name}/${now:%m%d_%H%M%S} 9 | 10 | logger: 11 | wandb: 12 | project: "simclr-chair" 13 | 14 | callbacks: 15 | model_checkpoint: 16 | every_n_epochs: null 17 | save_top_k: 1 18 | monitor: "val/loss" 19 | -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/configs/experiment/single_phase-airplane.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: ldm_big_single_phase.yaml 4 | - override /paths: default.yaml 5 | - override /data: spaghetti_author_airplane.yaml 6 | 7 | paths: 8 | save_dir: ${paths.work_dir}/results/single_phase-airplane/${name}/${now:%m%d_%H%M%S} 9 | 10 | logger: 11 | wandb: 12 | project: "spaghetti-single-phase-airplane" 13 | 14 | model: 15 | spaghetti_tag: "airplanes" 16 | -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/configs/experiment/single_phase-pos_enc-airplane.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: ldm_big_single_phase-pos_enc.yaml 4 | - override /paths: default.yaml 5 | - override /data: spaghetti_author_airplane.yaml 6 | 7 | paths: 8 | save_dir: ${paths.work_dir}/results/single_phase-airplane-pos_enc/${name}/${now:%m%d_%H%M%S} 9 | 10 | logger: 11 | wandb: 12 | project: "spaghetti-single-phase-airplane-pos_enc" 13 | 14 | model: 15 | spaghetti_tag: "airplanes" 16 | -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/configs/experiment/single_phase-pos_enc.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: ldm_big_single_phase-pos_enc.yaml 4 | - override /paths: default.yaml 5 | 6 | paths: 7 | save_dir: ${paths.work_dir}/results/single_phase-pos_enc/${name}/${now:%m%d_%H%M%S} 8 | 9 | logger: 10 | wandb: 11 | project: "spaghetti-single-phase-pos_enc" 12 | -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/configs/experiment/single_phase-table.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: ldm_big_single_phase.yaml 4 | - override /paths: default.yaml 5 | - override /data: spaghetti_table.yaml 6 | 7 | paths: 8 | save_dir: ${paths.work_dir}/results/single_phase-table/${name}/${now:%m%d_%H%M%S} 9 | 10 | logger: 11 | wandb: 12 | project: "spaghetti-single-phase-table" 13 | 14 | model: 15 | spaghetti_tag: "tables" 16 | -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/configs/experiment/single_phase.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: ldm_big_single_phase.yaml 4 | - override /paths: default.yaml 5 | 6 | paths: 7 | save_dir: ${paths.work_dir}/results/single_phase/${name}/${now:%m%d_%H%M%S} 8 | 9 | logger: 10 | wandb: 11 | project: "spaghetti-single-phase" 12 | -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/configs/experiment/single_phase_zb-airplane.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: ldm_big_single_phase_zb.yaml 4 | - override /paths: default.yaml 5 | - override /data: spaghetti_author_airplane.yaml 6 | 7 | paths: 8 | save_dir: ${paths.work_dir}/results/single_phase-zb-airplane/${name}/${now:%m%d_%H%M%S} 9 | 10 | logger: 11 | wandb: 12 | project: "spaghetti-single-phase-zb-airplane" 13 | 14 | model: 15 | spaghetti_tag: "airplanes" 16 | 17 | data: 18 | latent_path: /home/juil/projects/3D_CRISPR/data/spaghetti_airplane_inversion/3236-03-02-112852-inversion_with_replacement_with_zb.hdf5 19 | -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/configs/experiment/single_phase_zb.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: ldm_big_single_phase_zb.yaml 4 | - override /paths: default.yaml 5 | 6 | paths: 7 | save_dir: ${paths.work_dir}/results/single_phase-zb/${name}/${now:%m%d_%H%M%S} 8 | 9 | logger: 10 | wandb: 11 | project: "spaghetti-single-phase-zb" 12 | -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/configs/experiment/spaghetti_condition_self_attention.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /model: ldm_big_phase2.yaml 5 | - override /paths: default.yaml 6 | 7 | paths: 8 | save_dir: ${paths.work_dir}/results/spaghetti_condition_sa/${name}/${now:%m%d_%H%M%S} 9 | -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/configs/experiment/za-airplane.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: ldm_big_za.yaml 4 | - override /paths: default.yaml 5 | - override /data: spaghetti_author_airplane.yaml 6 | 7 | paths: 8 | save_dir: ${paths.work_dir}/results/za-airplane/${name}/${now:%m%d_%H%M%S} 9 | 10 | logger: 11 | wandb: 12 | project: "spaghetti-za-airplane" 13 | 14 | model: 15 | spaghetti_tag: "airplanes" 16 | -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/configs/experiment/za.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: ldm_big_za.yaml 4 | - override /paths: default.yaml 5 | 6 | paths: 7 | save_dir: ${paths.work_dir}/results/za/${name}/${now:%m%d_%H%M%S} 8 | 9 | logger: 10 | wandb: 11 | project: "spaghetti-za" 12 | -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/configs/hydra/default.yaml: -------------------------------------------------------------------------------- 1 | #defaults: 2 | #- override hydra_logging: colorlog 3 | #- override job_logging: colorlog 4 | 5 | run: 6 | dir: ${paths.save_dir} 7 | 8 | sweep: 9 | dir: ${paths.save_dir} 10 | subdir: ${hydra.job.num} 11 | -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/configs/logger/multi_loggers.yaml: -------------------------------------------------------------------------------- 1 | wandb: 2 | _target_: pytorch_lightning.loggers.wandb.WandbLogger 3 | project: "spaghetti-ldm" 4 | name: ${name} 5 | save_dir: "." 6 | entity: "geometry" 7 | 8 | tensorboard: 9 | _target_: pytorch_lightning.loggers.tensorboard.TensorBoardLogger 10 | save_dir: "tensorboard/" 11 | name: null 12 | version: ${name} 13 | log_graph: false 14 | default_hp_metric: true 15 | prefix: "" 16 | -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/configs/logger/tensorboard.yaml: -------------------------------------------------------------------------------- 1 | tensorboard: 2 | _target_: pytorch_lightning.loggers.tensorboard.TensorBoardLogger 3 | save_dir: "tensorboard/" 4 | name: null 5 | #version: ${name} 6 | log_graph: true 7 | default_hp_metric: false 8 | prefix: "" 9 | -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/configs/model/ldm_big_phase1.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - ldm_default.yaml 3 | 4 | _target_: pvd.diffusion.phase1.GaussianSALDM 5 | use_scaled_eigenvectors: false 6 | network: 7 | _target_: pvd.diffusion.network.LatentSelfAttentionNetwork 8 | input_dim: 16 9 | embedding_dim: 512 10 | num_heads: 4 11 | use_timestep_embedder: true 12 | timestep_embedder_dim: 128 13 | enc_num_layers: 6 14 | residual: true 15 | scale_eigenvectors: ${model.use_scaled_eigenvectors} #*scale_vectors 16 | encoder_type: transformer 17 | 18 | validation_step: 1 19 | eigen_value_clipping: false 20 | eigen_value_clip_val: 1e-4 21 | 22 | global_normalization: false 23 | 24 | dataset_kwargs: 25 | data_keys: ["g_js_affine"] 26 | scale_eigenvectors: ${model.use_scaled_eigenvectors} #*scale_vectors 27 | global_normalization: ${model.global_normalization} 28 | -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/configs/model/ldm_big_phase1_language.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - ldm_lang_default.yaml 3 | 4 | #_target_: pvd.diffusion.language_cond.LanguageGaussianSymSALDM 5 | _target_: pvd.diffusion.language_cond.ShapeglotPhase1 6 | network: 7 | _target_: pvd.diffusion.network.LatentConditionSelfAttentionNetwork 8 | input_dim: 16 9 | residual: true 10 | context_dim: 768 11 | context_embedding_dim: 1024 12 | embedding_dim: 512 13 | encoder_use_time: false 14 | encoder_type: pointwise 15 | decoder_type: transformer_encoder 16 | enc_num_layers: 2 17 | dec_num_layers: 6 18 | use_timestep_embedder: true 19 | timestep_embedder_dim: 128 20 | scale_eigenvectors: &scale_vectors false 21 | use_pos_encoding: false 22 | 23 | classifier_free_guidance: false 24 | conditioning_dropout_prob: 0.2 25 | 26 | dataset_kwargs: 27 | data_keys: ["g_js_affine"] 28 | #lang_data_path: /home/juil/projects/3D_CRISPR/data/text2shape/T5embed2spaghetti_aug_ids.h5 29 | #lang_data_path: /home/juil/projects/3D_CRISPR/data/shapeglot/spaghetti_game_data.csv 30 | -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/configs/model/ldm_big_phase1_pos_enc.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - ldm_default.yaml 3 | 4 | _target_: pvd.diffusion.phase1.GaussianSALDM 5 | use_scaled_eigenvectors: false 6 | network: 7 | _target_: pvd.diffusion.network.PosEncSelfAttentionNetwork 8 | input_dim: 16 9 | embedding_dim: 512 10 | num_heads: 4 11 | use_timestep_embedder: true 12 | timestep_embedder_dim: 128 13 | enc_num_layers: 6 14 | residual: true 15 | scale_eigenvectors: ${model.use_scaled_eigenvectors} #*scale_vectors 16 | 17 | validation_step: 5 18 | eigen_value_clipping: false 19 | eigen_value_clip_val: 1e-4 20 | 21 | global_normalization: false 22 | 23 | dataset_kwargs: 24 | data_keys: ["g_js_affine"] 25 | scale_eigenvectors: ${model.use_scaled_eigenvectors} #*scale_vectors 26 | global_normalization: ${model.global_normalization} 27 | -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/configs/model/ldm_big_phase2.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - ldm_default.yaml 3 | 4 | _target_: pvd.diffusion.ldm.SpaghettiConditionSALDM 5 | network: 6 | _target_: pvd.diffusion.network.LatentConditionSelfAttentionNetwork 7 | input_dim: 512 8 | residual: true 9 | context_dim: 16 10 | context_embedding_dim: 512 11 | embedding_dim: 512 12 | encoder_use_time: false 13 | encoder_type: transformer 14 | decoder_type: transformer_encoder 15 | enc_num_layers: 6 16 | dec_num_layers: 6 17 | use_timestep_embedder: true 18 | timestep_embedder_dim: 128 19 | scale_eigenvectors: &scale_vectors false 20 | 21 | classifier_free_guidance: false 22 | conditioning_dropout_prob: 0.0 23 | conditioning_dropout_level: shape 24 | 25 | augment_condition: false 26 | phase1_ckpt_path: null 27 | augment_timestep: 100 28 | mu_noise_scale: 1 29 | eigenvectors_noise_scale: 1 30 | pi_noise_scale: 1 31 | eigenvalues_noise_scale: 1 32 | 33 | use_scaled_eigenvectors: *scale_vectors 34 | sj_global_normalization: false 35 | 36 | dataset_kwargs: 37 | data_keys: ["s_j_affine", "g_js_affine"] 38 | scale_eigenvectors: *scale_vectors 39 | sj_global_normalization: ${model.sj_global_normalization} 40 | -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/configs/model/ldm_big_phase2_language.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - ldm_lang_default.yaml 3 | 4 | #_target_: pvd.diffusion.language_cond.LanguageConditionIntrinsicLDM 5 | _target_: pvd.diffusion.language_cond.ShapeglotPhase2 6 | global_normalization: false 7 | network: 8 | _target_: pvd.diffusion.network.LatentConditionSelfAttentionNetwork 9 | input_dim: 512 10 | residual: true 11 | context_dim: 784 12 | context_embedding_dim: 1024 13 | embedding_dim: 512 14 | encoder_use_time: false 15 | encoder_type: transformer 16 | decoder_type: transformer_encoder 17 | enc_num_layers: 6 18 | dec_num_layers: 6 19 | use_timestep_embedder: true 20 | timestep_embedder_dim: 128 21 | scale_eigenvectors: &scale_vectors false 22 | 23 | classifier_free_guidance: false 24 | conditioning_dropout_prob: 0.2 25 | 26 | dataset_kwargs: 27 | data_keys: ["s_j_affine", "g_js_affine"] 28 | #lang_data_path: /home/juil/projects/3D_CRISPR/data/text2shape/T5embed2spaghetti_aug_ids.h5 29 | #lang_data_path: /home/juil/projects/3D_CRISPR/data/shapeglot/spaghetti_game_data.csv 30 | -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/configs/model/ldm_big_single_phase-pos_enc.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - ldm_default.yaml 3 | 4 | _target_: pvd.diffusion.single_phase.SingleSALDM 5 | network: 6 | _target_: pvd.diffusion.network.PosEncSelfAttentionNetwork 7 | input_dim: 528 8 | embedding_dim: 512 9 | num_heads: 4 10 | use_timestep_embedder: true 11 | timestep_embedder_dim: 128 12 | enc_num_layers: 6 13 | residual: true 14 | scale_eigenvectors: &scale_vectors false 15 | 16 | validation_step: 10 17 | eigen_value_clipping: false 18 | eigen_value_clip_val: 1e-4 19 | use_scaled_eigenvectors: *scale_vectors 20 | global_normalization: false 21 | 22 | dataset_kwargs: 23 | data_keys: ["s_j_affine", "g_js_affine"] 24 | scale_eigenvectors: *scale_vectors 25 | concat_data: true 26 | global_normalization: ${model.global_normalization} 27 | -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/configs/model/ldm_big_single_phase.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - ldm_default.yaml 3 | 4 | _target_: pvd.diffusion.single_phase.SingleSALDM 5 | network: 6 | _target_: pvd.diffusion.network.LatentSelfAttentionNetwork 7 | input_dim: 528 8 | embedding_dim: 512 9 | num_heads: 4 10 | use_timestep_embedder: true 11 | timestep_embedder_dim: 128 12 | enc_num_layers: 6 13 | residual: true 14 | scale_eigenvectors: &scale_vectors false 15 | 16 | validation_step: 10 17 | eigen_value_clipping: false 18 | eigen_value_clip_val: 1e-4 19 | use_scaled_eigenvectors: *scale_vectors 20 | global_normalization: false 21 | 22 | dataset_kwargs: 23 | data_keys: ["s_j_affine", "g_js_affine"] 24 | scale_eigenvectors: *scale_vectors 25 | concat_data: true 26 | global_normalization: ${model.global_normalization} 27 | -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/configs/model/ldm_big_single_phase_language.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - ldm_lang_default.yaml 3 | 4 | _target_: pvd.diffusion.language_cond.ShapeglotSinglePhase 5 | network: 6 | _target_: pvd.diffusion.network.LatentConditionSelfAttentionNetwork 7 | input_dim: 528 8 | residual: true 9 | context_dim: 768 10 | context_embedding_dim: 1024 11 | embedding_dim: 512 12 | encoder_use_time: false 13 | encoder_type: pointwise 14 | decoder_type: transformer_encoder 15 | enc_num_layers: 2 16 | dec_num_layers: 6 17 | use_timestep_embedder: true 18 | timestep_embedder_dim: 128 19 | use_pos_encoding: false 20 | 21 | classifier_free_guidance: false 22 | conditioning_dropout_prob: 0.2 23 | global_normalization: false 24 | 25 | dataset_kwargs: 26 | data_keys: ["s_j_affine", "g_js_affine"] 27 | #lang_data_path: /home/juil/projects/3D_CRISPR/data/shapeglot/spaghetti_game_data.csv 28 | concat_data: true 29 | global_normalization: ${model.global_normalization} 30 | -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/configs/model/ldm_big_single_phase_zb.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - ldm_default.yaml 3 | 4 | _target_: pvd.diffusion.single_phase.SingleZbSALDM 5 | network: 6 | _target_: pvd.diffusion.network.LatentSelfAttentionNetwork 7 | input_dim: 512 8 | embedding_dim: 512 9 | num_heads: 4 10 | use_timestep_embedder: true 11 | timestep_embedder_dim: 128 12 | enc_num_layers: 6 13 | residual: true 14 | scale_eigenvectors: &scale_vectors false 15 | use_pos_encoding: false 16 | 17 | validation_step: 10 18 | eigen_value_clipping: false 19 | eigen_value_clip_val: 1e-4 20 | use_scaled_eigenvectors: *scale_vectors 21 | global_normalization: false 22 | 23 | dataset_kwargs: 24 | data_keys: ["Zb"] 25 | scale_eigenvectors: *scale_vectors 26 | global_normalization: ${model.global_normalization} 27 | -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/configs/model/ldm_big_za.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - ldm_default.yaml 3 | 4 | _target_: pvd.diffusion.phase1.ZaSALDM 5 | network: 6 | _target_: pvd.diffusion.network.LatentLinearNetwork 7 | input_dim: 256 8 | embedding_dim: 512 9 | use_timestep_embedder: true 10 | timestep_embedder_dim: 128 11 | enc_num_layers: 6 12 | residual: true 13 | 14 | validation_step: 1 15 | 16 | dataset_kwargs: 17 | data_keys: ["za"] 18 | -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/configs/model/ldm_default.yaml: -------------------------------------------------------------------------------- 1 | _target_: pvd.diffusion.ldm.SpaghettiSALDM 2 | network: null 3 | 4 | variance_schedule: 5 | _target_: pvd.diffusion.common.VarianceSchedule 6 | num_steps: &time_steps 1000 7 | beta_1: 1e-4 8 | beta_T: 0.05 9 | mode: linear 10 | 11 | # optimizer 12 | lr: 1e-4 13 | batch_size: ${batch_size} 14 | 15 | # dataset 16 | dataset_kwargs: 17 | data_path: ${data.latent_path} 18 | repeat: 3 19 | data_keys: ["s_j_affine"] 20 | 21 | # model 22 | num_timesteps: *time_steps 23 | faster: true 24 | validation_step: 10 25 | no_run_validation: false 26 | spaghetti_tag: "chairs_large" # or airplanes 27 | -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/configs/model/ldm_lang_default.yaml: -------------------------------------------------------------------------------- 1 | _target_: pvd.diffusion.ldm.SpaghettiSALDM 2 | network: null 3 | 4 | variance_schedule: 5 | _target_: pvd.diffusion.common.VarianceSchedule 6 | num_steps: &time_steps 1000 7 | beta_1: 1e-4 8 | beta_T: 0.05 9 | mode: linear 10 | 11 | # optimizer 12 | lr: 1e-4 13 | batch_size: ${batch_size} 14 | 15 | global_normalization: partial 16 | # dataset 17 | dataset_kwargs: 18 | data_path: ${data.latent_path} 19 | repeat: 1 20 | data_keys: ["s_j_affine"] 21 | lang_data_path: /home/juil/projects/3D_CRISPR/data/shapeglot/autosdf_spaghetti_game_data_v2.csv 22 | only_correct: false 23 | only_easy_context: false 24 | max_dataset_size: 10000000 25 | global_normalization: ${model.global_normalization} 26 | 27 | # model 28 | num_timesteps: *time_steps 29 | faster: true 30 | validation_step: 5 31 | no_run_validation: false 32 | spaghetti_tag: "chairs_large" # or airplanes 33 | text_encoder_freeze: true 34 | text_encoder_return_seq: false 35 | use_lstm: false 36 | use_partglot_data: false 37 | -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/configs/model/simclr.yaml: -------------------------------------------------------------------------------- 1 | _target_: pvd.embedding.model.CLIPEmbedder 2 | ext_embedder: 3 | _target_: pvd.embedding.model.TFEmbedder 4 | dim_in: 16 5 | dim_h: 512 6 | dim_out: 512 7 | num_layers: 6 8 | num_heads: 4 9 | 10 | int_embedder: 11 | _target_: pvd.embedding.model.TFEmbedder 12 | dim_in: 512 13 | dim_h: 512 14 | dim_out: 512 15 | num_layers: 6 16 | num_heads: 4 17 | 18 | temperature: 3 19 | lr: 1e-4 20 | batch_size: ${batch_size} 21 | 22 | dataset_kwargs: 23 | data_path: ${data.latent_path} 24 | repeat: 3 25 | data_keys: ["s_j_affine", "g_js_affine"] 26 | -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/configs/model/vae.yaml: -------------------------------------------------------------------------------- 1 | hydra: 2 | run: 3 | dir: ${save_dir} 4 | 5 | defaults: 6 | - _self_ 7 | 8 | trainer: 9 | _target_: pytorch_lightning.Trainer 10 | accelerator: "gpu" 11 | devices: ${gpus} 12 | max_epochs: ${epochs} 13 | 14 | work_dir: ${hydra:runtime.cwd} 15 | save_dir: ${work_dir}/results/${name}/${now:%m%d_%H%M%S} 16 | 17 | epochs: 200 18 | batch_size: 16 19 | gpus: [1] 20 | seed: 63 21 | name: default 22 | debug: false 23 | 24 | model: 25 | _target_: pvd.models.vae.PointNetVAE 26 | encoder: pointnet 27 | batch_size: ${batch_size} 28 | save_dir: ${save_dir} 29 | kl_weight: 1e-4 30 | 31 | callbacks: 32 | model_summary: 33 | _target_: pytorch_lightning.callbacks.RichModelSummary 34 | max_depth: 1 35 | 36 | rich_progress_bar: 37 | _target_: pytorch_lightning.callbacks.RichProgressBar 38 | 39 | model_checkpoint: 40 | _target_: pytorch_lightning.callbacks.ModelCheckpoint 41 | monitor: "train/loss" 42 | mode: "min" 43 | save_top_k: 1 44 | save_last: true 45 | verbose: true 46 | dirpath: "checkpoints/" 47 | filename: "epoch={epoch:02d}-val_loss={val/loss:.4f}" 48 | auto_insert_metric_name: false 49 | 50 | lr_monitor: 51 | _target_: pytorch_lightning.callbacks.LearningRateMonitor 52 | logging_interval: "step" 53 | 54 | logger: 55 | wandb: 56 | _target_: pytorch_lightning.loggers.wandb.WandbLogger 57 | project: "point-vae" 58 | name: ${name} 59 | save_dir: "." 60 | entity: "geometry" 61 | 62 | tensorboard: 63 | _target_: pytorch_lightning.loggers.tensorboard.TensorBoardLogger 64 | save_dir: "tensorboard/" 65 | name: null 66 | version: ${name} 67 | log_graph: false 68 | default_hp_metric: true 69 | prefix: "" 70 | -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/configs/paths/default.yaml: -------------------------------------------------------------------------------- 1 | work_dir: ${hydra:runtime.cwd} 2 | save_dir: ${work_dir}/results/${name}/${now:%m%d_%H%M%S} 3 | -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/configs/simclr.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - _self_ 5 | - data: spaghetti_author.yaml 6 | - model: simclr.yaml 7 | - callbacks: default.yaml 8 | - logger: multi_loggers.yaml 9 | - trainer: default.yaml 10 | - paths: default.yaml 11 | - hydra: default.yaml 12 | 13 | - experiment: simclr-chair.yaml 14 | - debug: null 15 | 16 | name: default 17 | seed: 63 18 | 19 | gpus: [0] 20 | epochs: 100 21 | batch_size: 1024 22 | -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/configs/train.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - _self_ 5 | - data: spaghetti_author.yaml 6 | - model: ldm_default.yaml 7 | - callbacks: default.yaml 8 | - logger: multi_loggers.yaml 9 | - trainer: default.yaml 10 | - paths: default.yaml 11 | - hydra: default.yaml 12 | 13 | - experiment: phase2.yaml 14 | - debug: null 15 | 16 | name: default 17 | seed: 63 18 | 19 | gpus: [0] 20 | epochs: 5000 21 | batch_size: 32 22 | 23 | resume_from_checkpoint: null 24 | -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/configs/trainer/default.yaml: -------------------------------------------------------------------------------- 1 | _target_: pytorch_lightning.Trainer 2 | accelerator: gpu 3 | devices: ${gpus} 4 | max_epochs: ${epochs} 5 | -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/diffusion/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KAIST-Visual-AI-Group/SALAD/db66b1b2ce57d8a4beb5bd610bc006ccaff43b1e/notebooks/uncleaned_demo/pvd/diffusion/__init__.py -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/diffusion/phase1_sym.py: -------------------------------------------------------------------------------- 1 | import io 2 | from pathlib import Path 3 | from PIL import Image 4 | from pvd.dataset import batch_split_eigens 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import numpy as np 9 | from pvd.diffusion.ldm import SpaghettiSALDM 10 | import jutils 11 | from pvd.utils.spaghetti_utils import * 12 | 13 | 14 | class GaussianSymSALDM(SpaghettiSALDM): 15 | def __init__(self, network, variance_schedule, **kwargs): 16 | super().__init__(network, variance_schedule, **kwargs) 17 | 18 | def forward(self, x): 19 | assert x.shape[1] == 8, f"x.shape: {x.shape}" 20 | return self.get_loss(x) 21 | 22 | @torch.no_grad() 23 | def sample( 24 | self, 25 | batch_size=0, 26 | return_traj=False, 27 | ): 28 | if self.hparams.get("use_scaled_eigenvectors"): 29 | x_T = torch.randn([batch_size, 8, 13]).to(self.device) 30 | else: 31 | x_T = torch.randn([batch_size, 8, 16]).to(self.device) 32 | 33 | traj = {self.var_sched.num_steps: x_T} 34 | for t in range(self.var_sched.num_steps, 0, -1): 35 | z = torch.randn_like(x_T) if t > 1 else torch.zeros_like(x_T) 36 | alpha = self.var_sched.alphas[t] 37 | alpha_bar = self.var_sched.alpha_bars[t] 38 | sigma = self.var_sched.get_sigmas(t, flexibility=0) 39 | 40 | c0 = 1.0 / torch.sqrt(alpha) 41 | c1 = (1 - alpha) / torch.sqrt(1 - alpha_bar) 42 | 43 | x_t = traj[t] 44 | 45 | beta = self.var_sched.betas[[t] * batch_size] 46 | e_theta = self.net(x_t, beta=beta) 47 | 48 | x_next = c0 * (x_t - c1 * e_theta) + sigma * z 49 | traj[t - 1] = x_next.detach() 50 | 51 | traj[t] = traj[t].cpu() 52 | 53 | if not return_traj: 54 | del traj[t] 55 | 56 | if self.hparams.get("eigen_value_clipping"): 57 | traj[0][:, :, 13:16] = torch.clamp_min( 58 | traj[0][:, :, 13:16], min=self.hparams["eigen_value_clip_val"] 59 | ) 60 | 61 | #### 62 | if return_traj: 63 | return traj 64 | else: 65 | return traj[0] 66 | 67 | def sampling_gaussians(self, num_shapes): 68 | """ 69 | Return: 70 | ldm_gaus: np.ndarray 71 | """ 72 | ldm_gaus = self.sample(num_shapes) 73 | 74 | if self.hparams.get("use_scaled_eigenvectors"): 75 | ldm_gaus = jutils.thutil.th2np(batch_split_eigens(jutils.nputil.np2th(ldm_gaus))) 76 | 77 | half = ldm_gaus 78 | full = reflect_and_concat_gmms(half) 79 | ldm_gaus = full 80 | if self.hparams.get("global_normalization"): 81 | if not hasattr(self, "data_val"): 82 | self._build_dataset("val") 83 | print(f"[!] Unnormalize samples as {self.hparams.get('global_normalization')}") 84 | if self.hparams.get("global_normalization") == "partial": 85 | ldm_gaus = self.data_val.unnormalize_global_static(ldm_gaus, slice(12,None)) 86 | elif self.hparams.get("global_normalization") == "all": 87 | ldm_gaus = self.data_val.unnormalize_global_static(ldm_gaus, slice(None)) 88 | 89 | return ldm_gaus 90 | 91 | def validation(self): 92 | vis_num_shapes = 16 93 | ldm_gaus = self.sampling_gaussians(vis_num_shapes) 94 | 95 | pred_images = [] 96 | for i in range(vis_num_shapes): 97 | 98 | def draw_gaus(gaus): 99 | gaus = clip_eigenvalues(gaus) 100 | img = jutils.visutil.render_gaussians(gaus, resolution=(256, 256)) 101 | return img 102 | 103 | try: 104 | pred_img = draw_gaus(ldm_gaus[i]) 105 | pred_images.append(pred_img) 106 | except: 107 | pass 108 | 109 | wandb_logger = self.get_wandb_logger() 110 | 111 | try: 112 | pred_images = jutils.imageutil.merge_images(pred_images) 113 | wandb_logger.log_image("Pred", [pred_images]) 114 | except: 115 | pass 116 | -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/diffusion/reverse_model.py: -------------------------------------------------------------------------------- 1 | import io 2 | from typing import Optional 3 | from pathlib import Path 4 | from PIL import Image 5 | from pvd.utils.sde_utils import add_noise, gaus_denoising_tweedie, tweedie_approach 6 | import pytorch_lightning as pl 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from torch.nn import Module, Parameter, ModuleList 11 | import numpy as np 12 | from torch.utils.data.dataset import Subset 13 | from sklearn.manifold import TSNE 14 | import matplotlib.pyplot as plt 15 | 16 | 17 | from pvd.dataset import LatentDataset, SpaghettiLatentDataset, split_eigens 18 | from pvd.utils.train_utils import PolyDecayScheduler, get_dropout_mask 19 | from pvd.utils.spaghetti_utils import * 20 | from pvd.vae.vae import PointNetVAE 21 | from pvd.diffusion.network import * 22 | from eval3d import Evaluator 23 | from dotmap import DotMap 24 | from pvd.diffusion.common import * 25 | from pvd.diffusion.ldm import * 26 | import jutils 27 | 28 | 29 | class ReverseConditionSALDM(SpaghettiConditionSALDM): 30 | def __init__(self, network, variance_schedule, **kwargs): 31 | super().__init__(network, variance_schedule, **kwargs) 32 | 33 | def forward(self, x, cond): 34 | return self.get_loss(x, cond) 35 | 36 | def sample( 37 | self, 38 | num_samples_or_sjs, 39 | return_traj=False, 40 | return_cond=False 41 | ): 42 | if isinstance(num_samples_or_sjs, int): 43 | batch_size = num_samples_or_sjs 44 | ds = self._build_dataset("val") 45 | cond = torch.stack([ds[i][1] for i in range(batch_size)], 0) 46 | elif isinstance(num_samples_or_sjs, np.ndarray) or isinstance(num_samples_or_sjs, torch.Tensor): 47 | cond = jutils.nputil.np2th(num_samples_or_sjs) 48 | if cond.dim() == 2: 49 | cond = cond[None] 50 | batch_size = len(cond) 51 | x_T = torch.randn([batch_size, 16, 16]).to(self.device) 52 | cond = cond.to(self.device) 53 | 54 | traj = {self.var_sched.num_steps: x_T} 55 | for t in range(self.var_sched.num_steps, 0, -1): 56 | z = torch.randn_like(x_T) if t > 1 else torch.zeros_like(x_T) 57 | alpha = self.var_sched.alphas[t] 58 | alpha_bar = self.var_sched.alpha_bars[t] 59 | sigma = self.var_sched.get_sigmas(t, flexibility=0) 60 | 61 | c0 = 1.0 / torch.sqrt(alpha) 62 | c1 = (1 - alpha) / torch.sqrt(1 - alpha_bar) 63 | 64 | x_t = traj[t] 65 | 66 | beta = self.var_sched.betas[[t] * batch_size] 67 | e_theta = self.net(x_t, beta=beta, context=cond) 68 | 69 | x_next = c0 * (x_t - c1 * e_theta) + sigma * z 70 | traj[t - 1] = x_next.detach() 71 | 72 | traj[t] = traj[t].cpu() 73 | 74 | if not return_traj: 75 | del traj[t] 76 | 77 | if return_traj: 78 | if return_cond: 79 | return traj, cond 80 | return traj 81 | else: 82 | if return_cond: 83 | return traj[0], cond 84 | return traj[0] 85 | 86 | def validation(self): 87 | latent_ds = self._build_dataset("val") 88 | vis_num_shapes = 3 89 | num_variations = 3 90 | jutils.sysutil.clean_gpu() 91 | 92 | if not hasattr(self, "spaghetti"): 93 | spaghetti = load_spaghetti(self.device, self.hparams.spaghetti_tag if self.hparams.get("spaghetti_tag") else "chairs_large") 94 | self.spaghetti = spaghetti 95 | else: 96 | spaghetti = self.spaghetti 97 | 98 | if not hasattr(self, "mesher"): 99 | mesher = load_mesher(self.device) 100 | self.mesher = mesher 101 | else: 102 | mesher = self.mesher 103 | 104 | 105 | gt_gaus, gt_zs = zip(*[latent_ds[i+3] for i in range(vis_num_shapes)]) 106 | gt_gaus, gt_zs = list(map(lambda x : torch.stack(x), [gt_gaus, gt_zs])) 107 | 108 | gt_zs_repeated = gt_zs.repeat_interleave(num_variations, 0) 109 | ldm_gaus, ldm_zs = self.sample(gt_zs_repeated, return_cond=True) 110 | ldm_gaus = clip_eigenvalues(ldm_gaus) 111 | ldm_zcs = generate_zc_from_sj_gaus(spaghetti, ldm_zs, ldm_gaus) 112 | gt_zcs = generate_zc_from_sj_gaus(spaghetti, gt_zs, gt_gaus) 113 | jutils.sysutil.clean_gpu() 114 | 115 | wandb_logger = self.get_wandb_logger() 116 | resolution = (256,256) 117 | for i in range(vis_num_shapes): 118 | img_per_shape = [] 119 | gaus_img = jutils.visutil.render_gaussians(gt_gaus[i], resolution=resolution) 120 | vert, face = get_mesh_from_spaghetti(spaghetti, mesher, gt_zcs[i], res=128) 121 | gt_mesh_img = jutils.visutil.render_mesh(vert, face, resolution=resolution) 122 | gt_img = jutils.imageutil.merge_images([gaus_img, gt_mesh_img]) 123 | gt_img = jutils.imageutil.draw_text(gt_img, "GT", font_size=24) 124 | img_per_shape.append(gt_img) 125 | for j in range(num_variations): 126 | try: 127 | gaus_img = jutils.visutil.render_gaussians(ldm_gaus[i * num_variations + j], resolution=resolution) 128 | vert, face = get_mesh_from_spaghetti(spaghetti, mesher, ldm_zcs[i * num_variations + j], res=128) 129 | mesh_img = jutils.visutil.render_mesh(vert, face, resolution=resolution) 130 | pred_img = jutils.imageutil.merge_images([gaus_img, mesh_img]) 131 | pred_img = jutils.imageutil.draw_text(pred_img, f"{j}-th predicted gaus", font_size=24) 132 | img_per_shape.append(pred_img) 133 | except Exception as e: 134 | print(e) 135 | 136 | try: 137 | image = jutils.imageutil.merge_images(img_per_shape) 138 | wandb_logger.log_image("visualization", [image]) 139 | except Exception as e: 140 | print(e) 141 | 142 | 143 | 144 | -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/diffusion/single_phase.py: -------------------------------------------------------------------------------- 1 | import io 2 | from pathlib import Path 3 | from PIL import Image 4 | from pvd.dataset import batch_split_eigens 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import numpy as np 9 | from pvd.diffusion.ldm import * 10 | 11 | 12 | class SingleSALDM(SpaghettiSALDM): 13 | def __init__(self, network, variance_schedule, **kwargs): 14 | super().__init__(network, variance_schedule, **kwargs) 15 | 16 | @torch.no_grad() 17 | def sample( 18 | self, 19 | batch_size=0, 20 | return_traj=False, 21 | ): 22 | x_T = torch.randn([batch_size, 16, 528]).to(self.device) 23 | 24 | traj = {self.var_sched.num_steps: x_T} 25 | for t in range(self.var_sched.num_steps, 0, -1): 26 | z = torch.randn_like(x_T) if t > 1 else torch.zeros_like(x_T) 27 | alpha = self.var_sched.alphas[t] 28 | alpha_bar = self.var_sched.alpha_bars[t] 29 | sigma = self.var_sched.get_sigmas(t, flexibility=0) 30 | 31 | c0 = 1.0 / torch.sqrt(alpha) 32 | c1 = (1 - alpha) / torch.sqrt(1 - alpha_bar) 33 | 34 | x_t = traj[t] 35 | 36 | beta = self.var_sched.betas[[t] * batch_size] 37 | e_theta = self.net(x_t, beta=beta) 38 | 39 | x_next = c0 * (x_t - c1 * e_theta) + sigma * z 40 | traj[t - 1] = x_next.detach() 41 | 42 | traj[t] = traj[t].cpu() 43 | 44 | if not return_traj: 45 | del traj[t] 46 | 47 | if self.hparams.get("eigen_value_clipping"): 48 | raise NotImplementedError("Not implemented eig clip for single phase.") 49 | # traj[0][:, :, 13:16] = torch.clamp_min( 50 | # traj[0][:, :, 13:16], min=self.hparams["eigen_value_clip_val"] 51 | # ) 52 | if return_traj: 53 | return traj 54 | else: 55 | return traj[0] 56 | 57 | def validation(self): 58 | vis_num_shapes = 4 59 | # ldm_gaus, gt_gaus = self.sampling_gaussians(vis_num_shapes) 60 | ldm_feats = self.sample(vis_num_shapes) 61 | ldm_sj, ldm_gaus = ldm_feats.split(split_size=[512,16], dim=-1) 62 | if self.hparams.get("global_normalization") == "partial": 63 | ldm_gaus = self._build_dataset("val").unnormalize_global_static(ldm_gaus, slice(12,None)) 64 | elif self.hparams.get("global_normalization") == "all": 65 | ldm_gaus = self._build_dataset("val").unnormalize_global_static(ldm_gaus, slice(None)) 66 | 67 | jutils.sysutil.clean_gpu() 68 | """ Spaghetti Decoding """ 69 | if not hasattr(self, "spaghetti"): 70 | spaghetti = load_spaghetti(self.device, self.hparams.spaghetti_tag if self.hparams.get("spaghetti_tag") else "chairs_large") 71 | self.spaghetti = spaghetti 72 | else: 73 | spaghetti = self.spaghetti 74 | 75 | if not hasattr(self, "mesher"): 76 | mesher = load_mesher(self.device) 77 | self.mesher = mesher 78 | else: 79 | mesher = self.mesher 80 | 81 | ldm_zcs = generate_zc_from_sj_gaus(spaghetti, ldm_sj, ldm_gaus) 82 | 83 | camera_kwargs = dict( 84 | camPos = np.array([-2,2,-2]), 85 | camLookat=np.array([0,0,0]), 86 | camUp=np.array([0,1,0]), 87 | resolution=(256,256), 88 | samples=16, 89 | ) 90 | 91 | wandb_logger = self.get_wandb_logger() 92 | for i in range(vis_num_shapes): 93 | try: 94 | vert, face = get_mesh_from_spaghetti(spaghetti, mesher, ldm_zcs[i]) 95 | mesh_img = jutils.visutil.render_mesh(vert, face, **camera_kwargs) 96 | gaus_img = jutils.visutil.render_gaussians(clip_eigenvalues(ldm_gaus[i]), resolution=camera_kwargs["resolution"]) 97 | img = jutils.imageutil.merge_images([gaus_img, mesh_img]) 98 | wandb_logger.log_image("Pred", [img]) 99 | except Exception as e: 100 | print(f"Error occured at visualizing, {e}") 101 | 102 | class SingleZbSALDM(SpaghettiSALDM): 103 | def validation(self): 104 | vis_num_shapes = 4 105 | ldm_zbs = self.sample(vis_num_shapes) 106 | 107 | if not hasattr(self, "spaghetti"): 108 | spaghetti = load_spaghetti(self.device, self.hparams.spaghetti_tag if self.hparams.get("spaghetti_tag") else "chairs_large") 109 | self.spaghetti = spaghetti 110 | else: 111 | spaghetti = self.spaghetti 112 | 113 | if not hasattr(self, "mesher"): 114 | mesher = load_mesher(self.device) 115 | self.mesher = mesher 116 | else: 117 | mesher = self.mesher 118 | 119 | ldm_sjs, ldm_gmms = spaghetti.decomposition_control.forward_mid(ldm_zbs) 120 | ldm_gaus = batch_gmms_to_gaus(ldm_gmms) 121 | ldm_zb_hat = spaghetti.merge_zh_step_a(ldm_sjs, ldm_gmms) 122 | ldm_zcs, _ = spaghetti.mixing_network.forward_with_attention(ldm_zb_hat) 123 | 124 | camera_kwargs = dict( 125 | camPos = np.array([-2,2,-2]), 126 | camLookat=np.array([0,0,0]), 127 | camUp=np.array([0,1,0]), 128 | resolution=(256,256), 129 | samples=16, 130 | ) 131 | 132 | wandb_logger = self.get_wandb_logger() 133 | for i in range(vis_num_shapes): 134 | try: 135 | vert, face = get_mesh_from_spaghetti(spaghetti, mesher, ldm_zcs[i]) 136 | mesh_img = jutils.visutil.render_mesh(vert, face, **camera_kwargs) 137 | gaus_img = jutils.visutil.render_gaussians(clip_eigenvalues(ldm_gaus[i]), resolution=camera_kwargs["resolution"]) 138 | img = jutils.imageutil.merge_images([gaus_img, mesh_img]) 139 | wandb_logger.log_image("Pred", [img]) 140 | except Exception as e: 141 | print(f"Error occured at visualizing, {e}") 142 | -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/embedding/model.py: -------------------------------------------------------------------------------- 1 | from pvd.utils.train_utils import PolyDecayScheduler 2 | from tqdm import tqdm 3 | from rich.progress import track 4 | from pvd.vae.vae import conv_block 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import pytorch_lightning as pl 9 | import numpy as np 10 | from pathlib import Path 11 | from dotmap import DotMap 12 | 13 | import jutils 14 | 15 | from pvd.dataset import * 16 | from pvd.diffusion.network import * 17 | 18 | class ContrastiveLearningDistWrapper: 19 | def __init__(self,**kwargs): 20 | self.__dict__.update(kwargs) 21 | self.hparams = DotMap(self.__dict__) 22 | 23 | def __call__(self, feat_a, feat_b): 24 | """ 25 | Input: 26 | feat_a: [B,D_e] 27 | feat_b: [B,D_e] 28 | Output: 29 | dist_mat: [B,B] 30 | """ 31 | assert feat_a.dim() == feat_b.dim() 32 | assert feat_a.shape[0] == feat_b.shape[0] 33 | if feat_a.dim() == 1: 34 | feat_a = feat_a.unsqueeze(0) 35 | feat_b = feat_b.unsqueeze(0) 36 | 37 | 38 | class TFEmbedder(nn.Module): 39 | def __init__(self, **kwargs): 40 | super().__init__() 41 | self.__dict__.update(kwargs) 42 | self.hparams = DotMap(self.__dict__) 43 | 44 | self.embedding = nn.Linear(self.hparams.dim_in, self.hparams.dim_h) 45 | dim_h = self.hparams.dim_h 46 | self.encoder = TimeTransformerEncoder(dim_h, dim_ctx=None, num_heads=self.hparams.num_heads, use_time=False, num_layers=self.hparams.num_layers, last_fc=True, last_fc_dim_out=self.hparams.dim_out) 47 | 48 | def forward(self, x): 49 | """ 50 | Input: 51 | x: [B,G,D_in] 52 | Output: 53 | out: [B,D_out] 54 | """ 55 | B, G = x.shape[:2] 56 | # res = x 57 | 58 | x = self.embedding(x) 59 | out = self.encoder(x) 60 | out = out.max(1)[0] 61 | return out 62 | 63 | 64 | 65 | class CLIPEmbedder(pl.LightningModule): 66 | def __init__(self, ext_embedder, int_embedder, **kwargs): 67 | super().__init__() 68 | self.save_hyperparameters(logger=False) 69 | self.ext_embedder = ext_embedder 70 | self.int_embedder = int_embedder 71 | 72 | def forward(self, intrinsics, extrinsics): 73 | """ 74 | Input: 75 | intrinsics: [Bi,16,512] 76 | extrinsics: [Be,16,16] 77 | Output: 78 | cossim_mat: [Bi,Be] 79 | """ 80 | 81 | int_f = self.int_embedder(intrinsics) #[B,D_emb] 82 | ext_f = self.ext_embedder(extrinsics) #[B,D_emb] 83 | 84 | int_e = int_f / int_f.norm(dim=-1, keepdim=True) 85 | ext_e = ext_f / ext_f.norm(dim=-1, keepdim=True) 86 | 87 | sim_mat = torch.einsum("nd,md->nm", int_e, ext_e) 88 | 89 | return sim_mat, int_e, ext_e 90 | 91 | def info_nce_loss(self, sim_mat): 92 | """ 93 | Input: 94 | sim_mat: [N,N] 95 | """ 96 | assert sim_mat.dim() == 2 and sim_mat.shape[0] == sim_mat.shape[1] 97 | N = sim_mat.shape[0] 98 | logits = sim_mat * np.exp(self.hparams.temperature) 99 | labels = torch.arange(N).to(logits).long() 100 | loss = F.cross_entropy(logits, labels) 101 | return loss 102 | 103 | def step(self, batch, batch_idx, stage: str): 104 | intrinsics, extrinsics = batch 105 | 106 | sim_mat, int_e, ext_e = self(intrinsics, extrinsics) 107 | loss = self.info_nce_loss(sim_mat) 108 | self.log(f"{stage}/loss", loss, on_step=stage == "train", prog_bar=True) 109 | return loss 110 | 111 | def training_step(self, batch, batch_idx): 112 | return self.step(batch, batch_idx, "train") 113 | 114 | def validation_step(self, batch, batch_idx): 115 | return self.step(batch, batch_idx, "val") 116 | 117 | def get_intrinsic_extrinsic_distance_matrix(self): 118 | """ 119 | [num_i, num_e] 120 | """ 121 | ds = self._build_dataset("val") 122 | intr = ds.data["s_j_affine"] 123 | extr = ds.data["g_js_affine"] 124 | intr, extr = list(map(lambda x: jutils.nputil.np2th(x).to(self.device), [intr, extr])) 125 | dist_mat = self(intr, extr)[0].cpu().numpy() 126 | 127 | return dist_mat 128 | 129 | 130 | 131 | def _build_dataset(self, stage): 132 | if stage == "train": 133 | ds = SpaghettiLatentDataset(**self.hparams.dataset_kwargs) 134 | else: 135 | dataset_kwargs = self.hparams.dataset_kwargs.copy() 136 | dataset_kwargs["repeat"] = 1 137 | ds = SpaghettiLatentDataset(**dataset_kwargs) 138 | setattr(self, f"data_{stage}", ds) 139 | return ds 140 | 141 | def _build_dataloader(self, stage, batch_size=None): 142 | try: 143 | ds = getattr(self, f"data_{stage}") 144 | except: 145 | ds = self._build_dataset(stage) 146 | 147 | return torch.utils.data.DataLoader( 148 | ds, 149 | batch_size=batch_size if batch_size is not None else self.hparams.batch_size, 150 | shuffle=stage == "train", 151 | drop_last=stage == "train", 152 | num_workers=4, 153 | ) 154 | def train_dataloader(self): 155 | return self._build_dataloader("train") 156 | 157 | def val_dataloader(self): 158 | return self._build_dataloader("val") 159 | 160 | def configure_optimizers(self): 161 | optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.lr) 162 | scheduler = PolyDecayScheduler(optimizer, self.hparams.lr, power=0.999) 163 | return [optimizer], [scheduler] 164 | 165 | -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/misc/extract_gmms.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pvd.dataset import * 3 | from pvd.diffusion.ldm import * 4 | from pvd.diffusion.phase1 import * 5 | import torch 6 | import numpy as np 7 | import jutils 8 | import matplotlib.pyplot as plt 9 | from sklearn.manifold import TSNE 10 | from PIL import Image 11 | from eval3d import Evaluator 12 | import io 13 | import os 14 | from PIL import Image 15 | from datetime import datetime 16 | 17 | def load_ldm(path, strict=True): 18 | path = Path(path) 19 | assert path.exists() 20 | 21 | ldm = GaussianSALDM.load_from_checkpoint(path, strict=strict).cuda(0) 22 | ldm.eval() 23 | for p in ldm.parameters(): p.requires_grad_(False) 24 | dataset = ldm._build_dataset("val") 25 | return ldm, dataset 26 | 27 | def sample_gmms(ldm, num_shapes): 28 | ldm_gaus = ldm.sample(num_shapes).to("cpu") 29 | return ldm_gaus 30 | 31 | def main(args): 32 | assert Path(args.save_dir).exists() 33 | now = datetime.now().strftime("%m-%d-%H%M%S") 34 | ldm, dataset = load_ldm(args.ldm_path, strict=False) 35 | 36 | ldm_gaus = sample_gmms(ldm, args.num_shapes) 37 | ldm_gaus = jutils.thutil.th2np(ldm_gaus) 38 | 39 | with h5py.File(f"{args.save_dir}/{len(ldm_gaus)}-{now}.hdf5", "w") as f: 40 | f["data"] = ldm_gaus.astype(np.float32) 41 | 42 | if __name__ =="__main__": 43 | parser = argparse.ArgumentParser() 44 | parser.add_argument("--ldm_path", type=str) 45 | parser.add_argument("--num_shapes", type=int) 46 | parser.add_argument("--save_dir", type=str) 47 | 48 | args = parser.parse_args() 49 | main(args) 50 | -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KAIST-Visual-AI-Group/SALAD/db66b1b2ce57d8a4beb5bd610bc006ccaff43b1e/notebooks/uncleaned_demo/pvd/modules/__init__.py -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/modules/pointnet2_ops_lib/MANIFEST.in: -------------------------------------------------------------------------------- 1 | graft pointnet2_ops/_ext-src 2 | -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/modules/pointnet2_ops_lib/pointnet2_ops/__init__.py: -------------------------------------------------------------------------------- 1 | import pointnet2_ops.pointnet2_modules 2 | import pointnet2_ops.pointnet2_utils 3 | from pointnet2_ops._version import __version__ 4 | -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/modules/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/ball_query.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | at::Tensor ball_query(at::Tensor new_xyz, at::Tensor xyz, const float radius, 5 | const int nsample); 6 | -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/modules/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/cuda_utils.h: -------------------------------------------------------------------------------- 1 | #ifndef _CUDA_UTILS_H 2 | #define _CUDA_UTILS_H 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | #include 9 | #include 10 | 11 | #include 12 | 13 | #define TOTAL_THREADS 512 14 | 15 | inline int opt_n_threads(int work_size) { 16 | const int pow_2 = std::log(static_cast(work_size)) / std::log(2.0); 17 | 18 | return max(min(1 << pow_2, TOTAL_THREADS), 1); 19 | } 20 | 21 | inline dim3 opt_block_config(int x, int y) { 22 | const int x_threads = opt_n_threads(x); 23 | const int y_threads = 24 | max(min(opt_n_threads(y), TOTAL_THREADS / x_threads), 1); 25 | dim3 block_config(x_threads, y_threads, 1); 26 | 27 | return block_config; 28 | } 29 | 30 | #define CUDA_CHECK_ERRORS() \ 31 | do { \ 32 | cudaError_t err = cudaGetLastError(); \ 33 | if (cudaSuccess != err) { \ 34 | fprintf(stderr, "CUDA kernel failed : %s\n%s at L:%d in %s\n", \ 35 | cudaGetErrorString(err), __PRETTY_FUNCTION__, __LINE__, \ 36 | __FILE__); \ 37 | exit(-1); \ 38 | } \ 39 | } while (0) 40 | 41 | #endif 42 | -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/modules/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/group_points.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | at::Tensor group_points(at::Tensor points, at::Tensor idx); 5 | at::Tensor group_points_grad(at::Tensor grad_out, at::Tensor idx, const int n); 6 | -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/modules/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/interpolate.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | std::vector three_nn(at::Tensor unknowns, at::Tensor knows); 7 | at::Tensor three_interpolate(at::Tensor points, at::Tensor idx, 8 | at::Tensor weight); 9 | at::Tensor three_interpolate_grad(at::Tensor grad_out, at::Tensor idx, 10 | at::Tensor weight, const int m); 11 | -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/modules/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/sampling.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | at::Tensor gather_points(at::Tensor points, at::Tensor idx); 5 | at::Tensor gather_points_grad(at::Tensor grad_out, at::Tensor idx, const int n); 6 | at::Tensor furthest_point_sampling(at::Tensor points, const int nsamples); 7 | -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/modules/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/utils.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | 5 | #define CHECK_CUDA(x) \ 6 | do { \ 7 | AT_ASSERT(x.is_cuda(), #x " must be a CUDA tensor"); \ 8 | } while (0) 9 | 10 | #define CHECK_CONTIGUOUS(x) \ 11 | do { \ 12 | AT_ASSERT(x.is_contiguous(), #x " must be a contiguous tensor"); \ 13 | } while (0) 14 | 15 | #define CHECK_IS_INT(x) \ 16 | do { \ 17 | AT_ASSERT(x.scalar_type() == at::ScalarType::Int, \ 18 | #x " must be an int tensor"); \ 19 | } while (0) 20 | 21 | #define CHECK_IS_FLOAT(x) \ 22 | do { \ 23 | AT_ASSERT(x.scalar_type() == at::ScalarType::Float, \ 24 | #x " must be a float tensor"); \ 25 | } while (0) 26 | -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/modules/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/ball_query.cpp: -------------------------------------------------------------------------------- 1 | #include "ball_query.h" 2 | #include "utils.h" 3 | 4 | void query_ball_point_kernel_wrapper(int b, int n, int m, float radius, 5 | int nsample, const float *new_xyz, 6 | const float *xyz, int *idx); 7 | 8 | at::Tensor ball_query(at::Tensor new_xyz, at::Tensor xyz, const float radius, 9 | const int nsample) { 10 | CHECK_CONTIGUOUS(new_xyz); 11 | CHECK_CONTIGUOUS(xyz); 12 | CHECK_IS_FLOAT(new_xyz); 13 | CHECK_IS_FLOAT(xyz); 14 | 15 | if (new_xyz.is_cuda()) { 16 | CHECK_CUDA(xyz); 17 | } 18 | 19 | at::Tensor idx = 20 | torch::zeros({new_xyz.size(0), new_xyz.size(1), nsample}, 21 | at::device(new_xyz.device()).dtype(at::ScalarType::Int)); 22 | 23 | if (new_xyz.is_cuda()) { 24 | query_ball_point_kernel_wrapper(xyz.size(0), xyz.size(1), new_xyz.size(1), 25 | radius, nsample, new_xyz.data_ptr(), 26 | xyz.data_ptr(), idx.data_ptr()); 27 | } else { 28 | AT_ASSERT(false, "CPU not supported"); 29 | } 30 | 31 | return idx; 32 | } 33 | -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/modules/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/ball_query_gpu.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include "cuda_utils.h" 6 | 7 | // input: new_xyz(b, m, 3) xyz(b, n, 3) 8 | // output: idx(b, m, nsample) 9 | __global__ void query_ball_point_kernel(int b, int n, int m, float radius, 10 | int nsample, 11 | const float *__restrict__ new_xyz, 12 | const float *__restrict__ xyz, 13 | int *__restrict__ idx) { 14 | int batch_index = blockIdx.x; 15 | xyz += batch_index * n * 3; 16 | new_xyz += batch_index * m * 3; 17 | idx += m * nsample * batch_index; 18 | 19 | int index = threadIdx.x; 20 | int stride = blockDim.x; 21 | 22 | float radius2 = radius * radius; 23 | for (int j = index; j < m; j += stride) { 24 | float new_x = new_xyz[j * 3 + 0]; 25 | float new_y = new_xyz[j * 3 + 1]; 26 | float new_z = new_xyz[j * 3 + 2]; 27 | for (int k = 0, cnt = 0; k < n && cnt < nsample; ++k) { 28 | float x = xyz[k * 3 + 0]; 29 | float y = xyz[k * 3 + 1]; 30 | float z = xyz[k * 3 + 2]; 31 | float d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) + 32 | (new_z - z) * (new_z - z); 33 | if (d2 < radius2) { 34 | if (cnt == 0) { 35 | for (int l = 0; l < nsample; ++l) { 36 | idx[j * nsample + l] = k; 37 | } 38 | } 39 | idx[j * nsample + cnt] = k; 40 | ++cnt; 41 | } 42 | } 43 | } 44 | } 45 | 46 | void query_ball_point_kernel_wrapper(int b, int n, int m, float radius, 47 | int nsample, const float *new_xyz, 48 | const float *xyz, int *idx) { 49 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 50 | query_ball_point_kernel<<>>( 51 | b, n, m, radius, nsample, new_xyz, xyz, idx); 52 | 53 | CUDA_CHECK_ERRORS(); 54 | } 55 | -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/modules/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/bindings.cpp: -------------------------------------------------------------------------------- 1 | #include "ball_query.h" 2 | #include "group_points.h" 3 | #include "interpolate.h" 4 | #include "sampling.h" 5 | 6 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 7 | m.def("gather_points", &gather_points); 8 | m.def("gather_points_grad", &gather_points_grad); 9 | m.def("furthest_point_sampling", &furthest_point_sampling); 10 | 11 | m.def("three_nn", &three_nn); 12 | m.def("three_interpolate", &three_interpolate); 13 | m.def("three_interpolate_grad", &three_interpolate_grad); 14 | 15 | m.def("ball_query", &ball_query); 16 | 17 | m.def("group_points", &group_points); 18 | m.def("group_points_grad", &group_points_grad); 19 | } 20 | -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/modules/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/group_points.cpp: -------------------------------------------------------------------------------- 1 | #include "group_points.h" 2 | #include "utils.h" 3 | 4 | void group_points_kernel_wrapper(int b, int c, int n, int npoints, int nsample, 5 | const float *points, const int *idx, 6 | float *out); 7 | 8 | void group_points_grad_kernel_wrapper(int b, int c, int n, int npoints, 9 | int nsample, const float *grad_out, 10 | const int *idx, float *grad_points); 11 | 12 | at::Tensor group_points(at::Tensor points, at::Tensor idx) { 13 | CHECK_CONTIGUOUS(points); 14 | CHECK_CONTIGUOUS(idx); 15 | CHECK_IS_FLOAT(points); 16 | CHECK_IS_INT(idx); 17 | 18 | if (points.is_cuda()) { 19 | CHECK_CUDA(idx); 20 | } 21 | 22 | at::Tensor output = 23 | torch::zeros({points.size(0), points.size(1), idx.size(1), idx.size(2)}, 24 | at::device(points.device()).dtype(at::ScalarType::Float)); 25 | 26 | if (points.is_cuda()) { 27 | group_points_kernel_wrapper(points.size(0), points.size(1), points.size(2), 28 | idx.size(1), idx.size(2), 29 | points.data_ptr(), idx.data_ptr(), 30 | output.data_ptr()); 31 | } else { 32 | AT_ASSERT(false, "CPU not supported"); 33 | } 34 | 35 | return output; 36 | } 37 | 38 | at::Tensor group_points_grad(at::Tensor grad_out, at::Tensor idx, const int n) { 39 | CHECK_CONTIGUOUS(grad_out); 40 | CHECK_CONTIGUOUS(idx); 41 | CHECK_IS_FLOAT(grad_out); 42 | CHECK_IS_INT(idx); 43 | 44 | if (grad_out.is_cuda()) { 45 | CHECK_CUDA(idx); 46 | } 47 | 48 | at::Tensor output = 49 | torch::zeros({grad_out.size(0), grad_out.size(1), n}, 50 | at::device(grad_out.device()).dtype(at::ScalarType::Float)); 51 | 52 | if (grad_out.is_cuda()) { 53 | group_points_grad_kernel_wrapper( 54 | grad_out.size(0), grad_out.size(1), n, idx.size(1), idx.size(2), 55 | grad_out.data_ptr(), idx.data_ptr(), 56 | output.data_ptr()); 57 | } else { 58 | AT_ASSERT(false, "CPU not supported"); 59 | } 60 | 61 | return output; 62 | } 63 | -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/modules/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/group_points_gpu.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "cuda_utils.h" 5 | 6 | // input: points(b, c, n) idx(b, npoints, nsample) 7 | // output: out(b, c, npoints, nsample) 8 | __global__ void group_points_kernel(int b, int c, int n, int npoints, 9 | int nsample, 10 | const float *__restrict__ points, 11 | const int *__restrict__ idx, 12 | float *__restrict__ out) { 13 | int batch_index = blockIdx.x; 14 | points += batch_index * n * c; 15 | idx += batch_index * npoints * nsample; 16 | out += batch_index * npoints * nsample * c; 17 | 18 | const int index = threadIdx.y * blockDim.x + threadIdx.x; 19 | const int stride = blockDim.y * blockDim.x; 20 | for (int i = index; i < c * npoints; i += stride) { 21 | const int l = i / npoints; 22 | const int j = i % npoints; 23 | for (int k = 0; k < nsample; ++k) { 24 | int ii = idx[j * nsample + k]; 25 | out[(l * npoints + j) * nsample + k] = points[l * n + ii]; 26 | } 27 | } 28 | } 29 | 30 | void group_points_kernel_wrapper(int b, int c, int n, int npoints, int nsample, 31 | const float *points, const int *idx, 32 | float *out) { 33 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 34 | 35 | group_points_kernel<<>>( 36 | b, c, n, npoints, nsample, points, idx, out); 37 | 38 | CUDA_CHECK_ERRORS(); 39 | } 40 | 41 | // input: grad_out(b, c, npoints, nsample), idx(b, npoints, nsample) 42 | // output: grad_points(b, c, n) 43 | __global__ void group_points_grad_kernel(int b, int c, int n, int npoints, 44 | int nsample, 45 | const float *__restrict__ grad_out, 46 | const int *__restrict__ idx, 47 | float *__restrict__ grad_points) { 48 | int batch_index = blockIdx.x; 49 | grad_out += batch_index * npoints * nsample * c; 50 | idx += batch_index * npoints * nsample; 51 | grad_points += batch_index * n * c; 52 | 53 | const int index = threadIdx.y * blockDim.x + threadIdx.x; 54 | const int stride = blockDim.y * blockDim.x; 55 | for (int i = index; i < c * npoints; i += stride) { 56 | const int l = i / npoints; 57 | const int j = i % npoints; 58 | for (int k = 0; k < nsample; ++k) { 59 | int ii = idx[j * nsample + k]; 60 | atomicAdd(grad_points + l * n + ii, 61 | grad_out[(l * npoints + j) * nsample + k]); 62 | } 63 | } 64 | } 65 | 66 | void group_points_grad_kernel_wrapper(int b, int c, int n, int npoints, 67 | int nsample, const float *grad_out, 68 | const int *idx, float *grad_points) { 69 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 70 | 71 | group_points_grad_kernel<<>>( 72 | b, c, n, npoints, nsample, grad_out, idx, grad_points); 73 | 74 | CUDA_CHECK_ERRORS(); 75 | } 76 | -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/modules/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/interpolate.cpp: -------------------------------------------------------------------------------- 1 | #include "interpolate.h" 2 | #include "utils.h" 3 | 4 | void three_nn_kernel_wrapper(int b, int n, int m, const float *unknown, 5 | const float *known, float *dist2, int *idx); 6 | void three_interpolate_kernel_wrapper(int b, int c, int m, int n, 7 | const float *points, const int *idx, 8 | const float *weight, float *out); 9 | void three_interpolate_grad_kernel_wrapper(int b, int c, int n, int m, 10 | const float *grad_out, 11 | const int *idx, const float *weight, 12 | float *grad_points); 13 | 14 | std::vector three_nn(at::Tensor unknowns, at::Tensor knows) { 15 | CHECK_CONTIGUOUS(unknowns); 16 | CHECK_CONTIGUOUS(knows); 17 | CHECK_IS_FLOAT(unknowns); 18 | CHECK_IS_FLOAT(knows); 19 | 20 | if (unknowns.is_cuda()) { 21 | CHECK_CUDA(knows); 22 | } 23 | 24 | at::Tensor idx = 25 | torch::zeros({unknowns.size(0), unknowns.size(1), 3}, 26 | at::device(unknowns.device()).dtype(at::ScalarType::Int)); 27 | at::Tensor dist2 = 28 | torch::zeros({unknowns.size(0), unknowns.size(1), 3}, 29 | at::device(unknowns.device()).dtype(at::ScalarType::Float)); 30 | 31 | if (unknowns.is_cuda()) { 32 | three_nn_kernel_wrapper(unknowns.size(0), unknowns.size(1), knows.size(1), 33 | unknowns.data_ptr(), knows.data_ptr(), 34 | dist2.data_ptr(), idx.data_ptr()); 35 | } else { 36 | AT_ASSERT(false, "CPU not supported"); 37 | } 38 | 39 | return {dist2, idx}; 40 | } 41 | 42 | at::Tensor three_interpolate(at::Tensor points, at::Tensor idx, 43 | at::Tensor weight) { 44 | CHECK_CONTIGUOUS(points); 45 | CHECK_CONTIGUOUS(idx); 46 | CHECK_CONTIGUOUS(weight); 47 | CHECK_IS_FLOAT(points); 48 | CHECK_IS_INT(idx); 49 | CHECK_IS_FLOAT(weight); 50 | 51 | if (points.is_cuda()) { 52 | CHECK_CUDA(idx); 53 | CHECK_CUDA(weight); 54 | } 55 | 56 | at::Tensor output = 57 | torch::zeros({points.size(0), points.size(1), idx.size(1)}, 58 | at::device(points.device()).dtype(at::ScalarType::Float)); 59 | 60 | if (points.is_cuda()) { 61 | three_interpolate_kernel_wrapper( 62 | points.size(0), points.size(1), points.size(2), idx.size(1), 63 | points.data_ptr(), idx.data_ptr(), weight.data_ptr(), 64 | output.data_ptr()); 65 | } else { 66 | AT_ASSERT(false, "CPU not supported"); 67 | } 68 | 69 | return output; 70 | } 71 | at::Tensor three_interpolate_grad(at::Tensor grad_out, at::Tensor idx, 72 | at::Tensor weight, const int m) { 73 | CHECK_CONTIGUOUS(grad_out); 74 | CHECK_CONTIGUOUS(idx); 75 | CHECK_CONTIGUOUS(weight); 76 | CHECK_IS_FLOAT(grad_out); 77 | CHECK_IS_INT(idx); 78 | CHECK_IS_FLOAT(weight); 79 | 80 | if (grad_out.is_cuda()) { 81 | CHECK_CUDA(idx); 82 | CHECK_CUDA(weight); 83 | } 84 | 85 | at::Tensor output = 86 | torch::zeros({grad_out.size(0), grad_out.size(1), m}, 87 | at::device(grad_out.device()).dtype(at::ScalarType::Float)); 88 | 89 | if (grad_out.is_cuda()) { 90 | three_interpolate_grad_kernel_wrapper( 91 | grad_out.size(0), grad_out.size(1), grad_out.size(2), m, 92 | grad_out.data_ptr(), idx.data_ptr(), 93 | weight.data_ptr(), output.data_ptr()); 94 | } else { 95 | AT_ASSERT(false, "CPU not supported"); 96 | } 97 | 98 | return output; 99 | } 100 | -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/modules/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/interpolate_gpu.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include "cuda_utils.h" 6 | 7 | // input: unknown(b, n, 3) known(b, m, 3) 8 | // output: dist2(b, n, 3), idx(b, n, 3) 9 | __global__ void three_nn_kernel(int b, int n, int m, 10 | const float *__restrict__ unknown, 11 | const float *__restrict__ known, 12 | float *__restrict__ dist2, 13 | int *__restrict__ idx) { 14 | int batch_index = blockIdx.x; 15 | unknown += batch_index * n * 3; 16 | known += batch_index * m * 3; 17 | dist2 += batch_index * n * 3; 18 | idx += batch_index * n * 3; 19 | 20 | int index = threadIdx.x; 21 | int stride = blockDim.x; 22 | for (int j = index; j < n; j += stride) { 23 | float ux = unknown[j * 3 + 0]; 24 | float uy = unknown[j * 3 + 1]; 25 | float uz = unknown[j * 3 + 2]; 26 | 27 | double best1 = 1e40, best2 = 1e40, best3 = 1e40; 28 | int besti1 = 0, besti2 = 0, besti3 = 0; 29 | for (int k = 0; k < m; ++k) { 30 | float x = known[k * 3 + 0]; 31 | float y = known[k * 3 + 1]; 32 | float z = known[k * 3 + 2]; 33 | float d = (ux - x) * (ux - x) + (uy - y) * (uy - y) + (uz - z) * (uz - z); 34 | if (d < best1) { 35 | best3 = best2; 36 | besti3 = besti2; 37 | best2 = best1; 38 | besti2 = besti1; 39 | best1 = d; 40 | besti1 = k; 41 | } else if (d < best2) { 42 | best3 = best2; 43 | besti3 = besti2; 44 | best2 = d; 45 | besti2 = k; 46 | } else if (d < best3) { 47 | best3 = d; 48 | besti3 = k; 49 | } 50 | } 51 | dist2[j * 3 + 0] = best1; 52 | dist2[j * 3 + 1] = best2; 53 | dist2[j * 3 + 2] = best3; 54 | 55 | idx[j * 3 + 0] = besti1; 56 | idx[j * 3 + 1] = besti2; 57 | idx[j * 3 + 2] = besti3; 58 | } 59 | } 60 | 61 | void three_nn_kernel_wrapper(int b, int n, int m, const float *unknown, 62 | const float *known, float *dist2, int *idx) { 63 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 64 | three_nn_kernel<<>>(b, n, m, unknown, known, 65 | dist2, idx); 66 | 67 | CUDA_CHECK_ERRORS(); 68 | } 69 | 70 | // input: points(b, c, m), idx(b, n, 3), weight(b, n, 3) 71 | // output: out(b, c, n) 72 | __global__ void three_interpolate_kernel(int b, int c, int m, int n, 73 | const float *__restrict__ points, 74 | const int *__restrict__ idx, 75 | const float *__restrict__ weight, 76 | float *__restrict__ out) { 77 | int batch_index = blockIdx.x; 78 | points += batch_index * m * c; 79 | 80 | idx += batch_index * n * 3; 81 | weight += batch_index * n * 3; 82 | 83 | out += batch_index * n * c; 84 | 85 | const int index = threadIdx.y * blockDim.x + threadIdx.x; 86 | const int stride = blockDim.y * blockDim.x; 87 | for (int i = index; i < c * n; i += stride) { 88 | const int l = i / n; 89 | const int j = i % n; 90 | float w1 = weight[j * 3 + 0]; 91 | float w2 = weight[j * 3 + 1]; 92 | float w3 = weight[j * 3 + 2]; 93 | 94 | int i1 = idx[j * 3 + 0]; 95 | int i2 = idx[j * 3 + 1]; 96 | int i3 = idx[j * 3 + 2]; 97 | 98 | out[i] = points[l * m + i1] * w1 + points[l * m + i2] * w2 + 99 | points[l * m + i3] * w3; 100 | } 101 | } 102 | 103 | void three_interpolate_kernel_wrapper(int b, int c, int m, int n, 104 | const float *points, const int *idx, 105 | const float *weight, float *out) { 106 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 107 | three_interpolate_kernel<<>>( 108 | b, c, m, n, points, idx, weight, out); 109 | 110 | CUDA_CHECK_ERRORS(); 111 | } 112 | 113 | // input: grad_out(b, c, n), idx(b, n, 3), weight(b, n, 3) 114 | // output: grad_points(b, c, m) 115 | 116 | __global__ void three_interpolate_grad_kernel( 117 | int b, int c, int n, int m, const float *__restrict__ grad_out, 118 | const int *__restrict__ idx, const float *__restrict__ weight, 119 | float *__restrict__ grad_points) { 120 | int batch_index = blockIdx.x; 121 | grad_out += batch_index * n * c; 122 | idx += batch_index * n * 3; 123 | weight += batch_index * n * 3; 124 | grad_points += batch_index * m * c; 125 | 126 | const int index = threadIdx.y * blockDim.x + threadIdx.x; 127 | const int stride = blockDim.y * blockDim.x; 128 | for (int i = index; i < c * n; i += stride) { 129 | const int l = i / n; 130 | const int j = i % n; 131 | float w1 = weight[j * 3 + 0]; 132 | float w2 = weight[j * 3 + 1]; 133 | float w3 = weight[j * 3 + 2]; 134 | 135 | int i1 = idx[j * 3 + 0]; 136 | int i2 = idx[j * 3 + 1]; 137 | int i3 = idx[j * 3 + 2]; 138 | 139 | atomicAdd(grad_points + l * m + i1, grad_out[i] * w1); 140 | atomicAdd(grad_points + l * m + i2, grad_out[i] * w2); 141 | atomicAdd(grad_points + l * m + i3, grad_out[i] * w3); 142 | } 143 | } 144 | 145 | void three_interpolate_grad_kernel_wrapper(int b, int c, int n, int m, 146 | const float *grad_out, 147 | const int *idx, const float *weight, 148 | float *grad_points) { 149 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 150 | three_interpolate_grad_kernel<<>>( 151 | b, c, n, m, grad_out, idx, weight, grad_points); 152 | 153 | CUDA_CHECK_ERRORS(); 154 | } 155 | -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/modules/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/sampling.cpp: -------------------------------------------------------------------------------- 1 | #include "sampling.h" 2 | #include "utils.h" 3 | 4 | void gather_points_kernel_wrapper(int b, int c, int n, int npoints, 5 | const float *points, const int *idx, 6 | float *out); 7 | void gather_points_grad_kernel_wrapper(int b, int c, int n, int npoints, 8 | const float *grad_out, const int *idx, 9 | float *grad_points); 10 | 11 | void furthest_point_sampling_kernel_wrapper(int b, int n, int m, 12 | const float *dataset, float *temp, 13 | int *idxs); 14 | 15 | at::Tensor gather_points(at::Tensor points, at::Tensor idx) { 16 | CHECK_CONTIGUOUS(points); 17 | CHECK_CONTIGUOUS(idx); 18 | CHECK_IS_FLOAT(points); 19 | CHECK_IS_INT(idx); 20 | 21 | if (points.is_cuda()) { 22 | CHECK_CUDA(idx); 23 | } 24 | 25 | at::Tensor output = 26 | torch::zeros({points.size(0), points.size(1), idx.size(1)}, 27 | at::device(points.device()).dtype(at::ScalarType::Float)); 28 | 29 | if (points.is_cuda()) { 30 | gather_points_kernel_wrapper(points.size(0), points.size(1), points.size(2), 31 | idx.size(1), points.data_ptr(), 32 | idx.data_ptr(), output.data_ptr()); 33 | } else { 34 | AT_ASSERT(false, "CPU not supported"); 35 | } 36 | 37 | return output; 38 | } 39 | 40 | at::Tensor gather_points_grad(at::Tensor grad_out, at::Tensor idx, 41 | const int n) { 42 | CHECK_CONTIGUOUS(grad_out); 43 | CHECK_CONTIGUOUS(idx); 44 | CHECK_IS_FLOAT(grad_out); 45 | CHECK_IS_INT(idx); 46 | 47 | if (grad_out.is_cuda()) { 48 | CHECK_CUDA(idx); 49 | } 50 | 51 | at::Tensor output = 52 | torch::zeros({grad_out.size(0), grad_out.size(1), n}, 53 | at::device(grad_out.device()).dtype(at::ScalarType::Float)); 54 | 55 | if (grad_out.is_cuda()) { 56 | gather_points_grad_kernel_wrapper(grad_out.size(0), grad_out.size(1), n, 57 | idx.size(1), grad_out.data_ptr(), 58 | idx.data_ptr(), 59 | output.data_ptr()); 60 | } else { 61 | AT_ASSERT(false, "CPU not supported"); 62 | } 63 | 64 | return output; 65 | } 66 | at::Tensor furthest_point_sampling(at::Tensor points, const int nsamples) { 67 | CHECK_CONTIGUOUS(points); 68 | CHECK_IS_FLOAT(points); 69 | 70 | at::Tensor output = 71 | torch::zeros({points.size(0), nsamples}, 72 | at::device(points.device()).dtype(at::ScalarType::Int)); 73 | 74 | at::Tensor tmp = 75 | torch::full({points.size(0), points.size(1)}, 1e10, 76 | at::device(points.device()).dtype(at::ScalarType::Float)); 77 | 78 | if (points.is_cuda()) { 79 | furthest_point_sampling_kernel_wrapper( 80 | points.size(0), points.size(1), nsamples, points.data_ptr(), 81 | tmp.data_ptr(), output.data_ptr()); 82 | } else { 83 | AT_ASSERT(false, "CPU not supported"); 84 | } 85 | 86 | return output; 87 | } 88 | -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/modules/pointnet2_ops_lib/pointnet2_ops/_version.py: -------------------------------------------------------------------------------- 1 | __version__ = "3.0.0" 2 | -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/modules/pointnet2_ops_lib/setup.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import os.path as osp 4 | 5 | from setuptools import find_packages, setup 6 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 7 | 8 | this_dir = osp.dirname(osp.abspath(__file__)) 9 | _ext_src_root = osp.join("pointnet2_ops", "_ext-src") 10 | _ext_sources = glob.glob(osp.join(_ext_src_root, "src", "*.cpp")) + glob.glob( 11 | osp.join(_ext_src_root, "src", "*.cu") 12 | ) 13 | _ext_headers = glob.glob(osp.join(_ext_src_root, "include", "*")) 14 | 15 | requirements = ["torch>=1.4"] 16 | 17 | exec(open(osp.join("pointnet2_ops", "_version.py")).read()) 18 | 19 | os.environ["TORCH_CUDA_ARCH_LIST"] = "3.7+PTX;5.0;6.0;6.1;6.2;7.0;7.5" 20 | setup( 21 | name="pointnet2_ops", 22 | version=__version__, 23 | author="Erik Wijmans", 24 | packages=find_packages(), 25 | install_requires=requirements, 26 | ext_modules=[ 27 | CUDAExtension( 28 | name="pointnet2_ops._ext", 29 | sources=_ext_sources, 30 | extra_compile_args={ 31 | "cxx": ["-O3"], 32 | "nvcc": ["-O3", "-Xfatbin", "-compress-all"], 33 | }, 34 | include_dirs=[osp.join(this_dir, _ext_src_root, "include")], 35 | ) 36 | ], 37 | cmdclass={"build_ext": BuildExtension}, 38 | include_package_data=True, 39 | ) 40 | -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/modules/simple_module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from pvd.modules.transformer import TimeMLP 5 | import math 6 | 7 | 8 | class ConcatDecoder(nn.Module): 9 | def __init__(self, dim_self, dim_ref, dim_ctx, num_layers=3): 10 | super().__init__() 11 | dim_total = dim_self + dim_ref 12 | 13 | self.layers = nn.ModuleList( 14 | [ 15 | TimeMLP(dim_total, dim_total, dim_total, dim_ctx, use_time=True) 16 | for _ in range(num_layers) 17 | ] 18 | ) 19 | self.fc = nn.Linear(dim_total, dim_self) 20 | 21 | def forward(self, x, y, ctx=None): 22 | out = torch.cat([x, y], -1) 23 | for i, layer in enumerate(self.layers): 24 | out = layer(out, ctx=ctx) 25 | out = self.fc(out) 26 | return out 27 | 28 | 29 | class TimestepEmbedder(nn.Module): 30 | """ 31 | Embeds scalar timesteps into vector representations. 32 | """ 33 | 34 | def __init__(self, hidden_size, frequency_embedding_size=256): 35 | super().__init__() 36 | self.mlp = nn.Sequential( 37 | nn.Linear(frequency_embedding_size, hidden_size, bias=True), 38 | nn.SiLU(), 39 | nn.Linear(hidden_size, hidden_size, bias=True), 40 | ) 41 | self.frequency_embedding_size = frequency_embedding_size 42 | 43 | @staticmethod 44 | def timestep_embedding(t, dim, max_period=10000): 45 | """ 46 | Create sinusoidal timestep embeddings. 47 | :param t: a 1-D Tensor of N indices, one per batch element. 48 | These may be fractional. 49 | :param dim: the dimension of the output. 50 | :param max_period: controls the minimum frequency of the embeddings. 51 | :return: an (N, D) Tensor of positional embeddings. 52 | """ 53 | # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py 54 | half = dim // 2 55 | freqs = torch.exp( 56 | -math.log(max_period) 57 | * torch.arange(start=0, end=half, dtype=torch.float32) 58 | / half 59 | ).to(device=t.device) 60 | args = t[:, None].float() * freqs[None] 61 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 62 | if dim % 2: 63 | embedding = torch.cat( 64 | [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 65 | ) 66 | return embedding 67 | 68 | def forward(self, t): 69 | t_freq = self.timestep_embedding(t, self.frequency_embedding_size) 70 | t_emb = self.mlp(t_freq) 71 | return t_emb 72 | 73 | 74 | class TimePointwiseLayer(nn.Module): 75 | def __init__( 76 | self, 77 | dim_in, 78 | dim_ctx, 79 | mlp_ratio=2, 80 | act=F.leaky_relu, 81 | dropout=0.0, 82 | use_time=False, 83 | ): 84 | super().__init__() 85 | self.use_time = use_time 86 | self.act = act 87 | self.mlp1 = TimeMLP( 88 | dim_in, dim_in * mlp_ratio, dim_in, dim_ctx, use_time=use_time 89 | ) 90 | self.norm1 = nn.LayerNorm(dim_in) 91 | 92 | self.mlp2 = TimeMLP( 93 | dim_in, dim_in * mlp_ratio, dim_in, dim_ctx, use_time=use_time 94 | ) 95 | self.norm2 = nn.LayerNorm(dim_in) 96 | self.dropout = nn.Dropout(dropout) 97 | 98 | def forward(self, x, ctx=None): 99 | res = x 100 | x = self.mlp1(x, ctx=ctx) 101 | x = self.norm1(x + res) 102 | 103 | res = x 104 | x = self.mlp2(x, ctx=ctx) 105 | x = self.norm2(x + res) 106 | return x 107 | 108 | 109 | class TimePointWiseEncoder(nn.Module): 110 | def __init__( 111 | self, 112 | dim_in, 113 | dim_ctx=None, 114 | mlp_ratio=2, 115 | act=F.leaky_relu, 116 | dropout=0.0, 117 | use_time=True, 118 | num_layers=6, 119 | last_fc=False, 120 | last_fc_dim_out=None, 121 | ): 122 | super().__init__() 123 | self.last_fc = last_fc 124 | if last_fc: 125 | self.fc = nn.Linear(dim_in, last_fc_dim_out) 126 | self.layers = nn.ModuleList( 127 | [ 128 | TimePointwiseLayer( 129 | dim_in, 130 | dim_ctx=dim_ctx, 131 | mlp_ratio=mlp_ratio, 132 | act=act, 133 | dropout=dropout, 134 | use_time=use_time, 135 | ) 136 | for _ in range(num_layers) 137 | ] 138 | ) 139 | 140 | def forward(self, x, ctx=None): 141 | for i, layer in enumerate(self.layers): 142 | x = layer(x, ctx=ctx) 143 | if self.last_fc: 144 | x = self.fc(x) 145 | return x 146 | -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/spaghetti_github/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KAIST-Visual-AI-Group/SALAD/db66b1b2ce57d8a4beb5bd610bc006ccaff43b1e/notebooks/uncleaned_demo/pvd/spaghetti_github/__init__.py -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/spaghetti_github/constants.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | IS_WINDOWS = sys.platform == 'win32' 5 | get_trace = getattr(sys, 'gettrace', None) 6 | DEBUG = get_trace is not None and get_trace() is not None 7 | EPSILON = 1e-4 8 | DIM = 3 9 | 10 | #### < Seungwoo > 11 | # Added 'docker_home' directory in the middle. 12 | PROJECT_ROOT = "/home/juil/docker_home/projects/3D_CRISPR/spaghetti_github" 13 | # PROJECT_ROOT = "/home/juil/projects/3D_CRISPR/spaghetti_github" 14 | #### 15 | 16 | DATA_ROOT = f'{PROJECT_ROOT}/assets/' 17 | CHECKPOINTS_ROOT = f'{DATA_ROOT}checkpoints/' 18 | CACHE_ROOT = f'{DATA_ROOT}cache/' 19 | UI_OUT = f'{DATA_ROOT}ui_export/' 20 | UI_RESOURCES = f'{DATA_ROOT}/ui_resources/' 21 | Shapenet_WT = f'{DATA_ROOT}/ShapeNetCore_wt/' 22 | Shapenet = f'{DATA_ROOT}/ShapeNetCore.v2/' 23 | MAX_VS = 100000 24 | MAX_GAUSIANS = 32 25 | 26 | -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/spaghetti_github/custom_types.py: -------------------------------------------------------------------------------- 1 | # import open3d 2 | import enum 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as nnf 7 | from .constants import DEBUG 8 | from typing import Tuple, List, Union, Callable, Type, Iterator, Dict, Set, Optional, Any, Sized, Iterable 9 | from types import DynamicClassAttribute 10 | from enum import Enum, unique 11 | import torch.optim.optimizer 12 | import torch.utils.data 13 | 14 | if DEBUG: 15 | seed = 99 16 | torch.manual_seed(seed) 17 | np.random.seed(seed) 18 | 19 | N = type(None) 20 | V = np.array 21 | ARRAY = np.ndarray 22 | ARRAYS = Union[Tuple[ARRAY, ...], List[ARRAY]] 23 | VS = Union[Tuple[V, ...], List[V]] 24 | VN = Union[V, N] 25 | VNS = Union[VS, N] 26 | T = torch.Tensor 27 | TS = Union[Tuple[T, ...], List[T]] 28 | TN = Optional[T] 29 | TNS = Union[Tuple[TN, ...], List[TN]] 30 | TSN = Optional[TS] 31 | TA = Union[T, ARRAY] 32 | 33 | V_Mesh = Tuple[ARRAY, ARRAY] 34 | T_Mesh = Tuple[T, Optional[T]] 35 | T_Mesh_T = Union[T_Mesh, T] 36 | COLORS = Union[T, ARRAY, Tuple[int, int, int]] 37 | 38 | D = torch.device 39 | CPU = torch.device('cpu') 40 | 41 | 42 | def get_device(device_id: int) -> D: 43 | if not torch.cuda.is_available(): 44 | return CPU 45 | device_id = min(torch.cuda.device_count() - 1, device_id) 46 | return torch.device(f'cuda:{device_id}') 47 | 48 | 49 | CUDA = get_device 50 | Optimizer = torch.optim.Adam 51 | Dataset = torch.utils.data.Dataset 52 | DataLoader = torch.utils.data.DataLoader 53 | Subset = torch.utils.data.Subset 54 | -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/spaghetti_github/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KAIST-Visual-AI-Group/SALAD/db66b1b2ce57d8a4beb5bd610bc006ccaff43b1e/notebooks/uncleaned_demo/pvd/spaghetti_github/models/__init__.py -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/spaghetti_github/models/gm_utils.py: -------------------------------------------------------------------------------- 1 | from ..custom_types import * 2 | 3 | 4 | def get_gm_support(gm, x): 5 | dim = x.shape[-1] 6 | mu, p, phi, eigen = gm 7 | sigma_det = eigen.prod(-1) 8 | eigen_inv = 1 / eigen 9 | sigma_inverse = torch.matmul(p.transpose(3, 4), p * eigen_inv[:, :, :, :, None]).squeeze(1) 10 | phi = torch.softmax(phi, dim=2) 11 | const_1 = phi / torch.sqrt((2 * np.pi) ** dim * sigma_det) 12 | distance = x[:, :, None, :] - mu 13 | mahalanobis_distance = - .5 * torch.einsum('bngd,bgdc,bngc->bng', distance, sigma_inverse, distance) 14 | const_2, _ = mahalanobis_distance.max(dim=2) # for numeric stability 15 | mahalanobis_distance -= const_2[:, :, None] 16 | support = const_1 * torch.exp(mahalanobis_distance) 17 | return support, const_2 18 | 19 | 20 | def gm_log_likelihood_loss(gms: TS, x: T, get_supports: bool = False, 21 | mask: Optional[T] = None, reduction: str = "mean") -> Union[T, Tuple[T, TS]]: 22 | 23 | batch_size, num_points, dim = x.shape 24 | support, const = get_gm_support(gms, x) 25 | probs = torch.log(support.sum(dim=2)) + const 26 | if mask is not None: 27 | probs = probs.masked_select(mask=mask.flatten()) 28 | if reduction == 'none': 29 | likelihood = probs.sum(-1) 30 | loss = - likelihood / num_points 31 | else: 32 | likelihood = probs.sum() 33 | loss = - likelihood / (probs.shape[0] * probs.shape[1]) 34 | if get_supports: 35 | return loss, support 36 | return loss 37 | 38 | 39 | def split_mesh_by_gmm(mesh: T_Mesh, gmm) -> Dict[int, T]: 40 | faces_split = {} 41 | vs, faces = mesh 42 | vs_mid_faces = vs[faces].mean(1) 43 | _, supports = gm_log_likelihood_loss(gmm, vs_mid_faces.unsqueeze(0), get_supports=True) 44 | supports = supports[0] 45 | label = supports.argmax(1) 46 | for i in range(gmm[1].shape[2]): 47 | select = label.eq(i) 48 | if select.any(): 49 | faces_split[i] = faces[select] 50 | else: 51 | faces_split[i] = None 52 | return faces_split 53 | 54 | 55 | def flatten_gmm(gmm: TS) -> T: 56 | b, gp, g, _ = gmm[0].shape 57 | mu, p, phi, eigen = [item.view(b, gp * g, *item.shape[3:]) for item in gmm] 58 | p = p.reshape(*p.shape[:2], -1) 59 | z_gmm = torch.cat((mu, p, phi.unsqueeze(-1), eigen), dim=2) 60 | return z_gmm 61 | -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/spaghetti_github/models/mlp_models.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/spaghetti_github/models/models_utils.py: -------------------------------------------------------------------------------- 1 | from ..custom_types import * 2 | from abc import ABC 3 | import math 4 | 5 | 6 | def torch_no_grad(func): 7 | def wrapper(*args, **kwargs): 8 | with torch.no_grad(): 9 | result = func(*args, **kwargs) 10 | return result 11 | 12 | return wrapper 13 | 14 | 15 | class Model(nn.Module, ABC): 16 | 17 | def __init__(self): 18 | super(Model, self).__init__() 19 | self.save_model: Union[None, Callable[[nn.Module]]] = None 20 | 21 | def save(self, **kwargs): 22 | self.save_model(self, **kwargs) 23 | 24 | 25 | class Concatenate(nn.Module): 26 | def __init__(self, dim): 27 | super(Concatenate, self).__init__() 28 | self.dim = dim 29 | 30 | def forward(self, x): 31 | return torch.cat(x, dim=self.dim) 32 | 33 | 34 | class View(nn.Module): 35 | 36 | def __init__(self, *shape): 37 | super(View, self).__init__() 38 | self.shape = shape 39 | 40 | def forward(self, x): 41 | return x.view(*self.shape) 42 | 43 | 44 | class Transpose(nn.Module): 45 | 46 | def __init__(self, dim0, dim1): 47 | super(Transpose, self).__init__() 48 | self.dim0, self.dim1 = dim0, dim1 49 | 50 | def forward(self, x): 51 | return x.transpose(self.dim0, self.dim1) 52 | 53 | 54 | class Dummy(nn.Module): 55 | 56 | def __init__(self, *args): 57 | super(Dummy, self).__init__() 58 | 59 | def forward(self, *args): 60 | return args[0] 61 | 62 | 63 | class SineLayer(nn.Module): 64 | """ 65 | From the siren repository 66 | https://colab.research.google.com/github/vsitzmann/siren/blob/master/explore_siren.ipynb 67 | """ 68 | def __init__(self, in_features, out_features, bias=True, 69 | is_first=False, omega_0=30): 70 | super().__init__() 71 | self.omega_0 = omega_0 72 | self.is_first = is_first 73 | 74 | self.in_features = in_features 75 | self.linear = nn.Linear(in_features, out_features, bias=bias) 76 | self.output_channels = out_features 77 | self.init_weights() 78 | 79 | def init_weights(self): 80 | with torch.no_grad(): 81 | if self.is_first: 82 | self.linear.weight.uniform_(-1 / self.in_features, 83 | 1 / self.in_features) 84 | else: 85 | self.linear.weight.uniform_(-np.sqrt(6 / self.in_features) / self.omega_0, 86 | np.sqrt(6 / self.in_features) / self.omega_0) 87 | 88 | def forward(self, input): 89 | return torch.sin(self.omega_0 * self.linear(input)) 90 | 91 | 92 | class MLP(nn.Module): 93 | 94 | def forward(self, x, *_): 95 | return self.net(x) 96 | 97 | def __init__(self, ch: Union[List[int], Tuple[int, ...]], act: nn.Module = nn.ReLU, 98 | weight_norm=False): 99 | super(MLP, self).__init__() 100 | layers = [] 101 | for i in range(len(ch) - 1): 102 | layers.append(nn.Linear(ch[i], ch[i + 1])) 103 | if weight_norm: 104 | layers[-1] = nn.utils.weight_norm(layers[-1]) 105 | if i < len(ch) - 2: 106 | layers.append(act(True)) 107 | self.net = nn.Sequential(*layers) 108 | 109 | 110 | class GMAttend(nn.Module): 111 | 112 | def __init__(self, hidden_dim: int): 113 | super(GMAttend, self).__init__() 114 | self.key_dim = hidden_dim // 8 115 | self.query_w = nn.Linear(hidden_dim, self.key_dim) 116 | self.key_w = nn.Linear(hidden_dim, self.key_dim) 117 | self.value_w = nn.Linear(hidden_dim, hidden_dim) 118 | self.softmax = nn.Softmax(dim=3) 119 | self.gamma = nn.Parameter(torch.zeros(1)) 120 | self.scale = 1 / torch.sqrt(torch.tensor(self.key_dim, dtype=torch.float32)) 121 | 122 | def forward(self, x): 123 | queries = self.query_w(x) 124 | keys = self.key_w(x) 125 | vals = self.value_w(x) 126 | attention = self.softmax(torch.einsum('bgqf,bgkf->bgqk', queries, keys)) 127 | out = torch.einsum('bgvf,bgqv->bgqf', vals, attention) 128 | out = self.gamma * out + x 129 | return out 130 | 131 | 132 | def recursive_to(item, device): 133 | if type(item) is T: 134 | return item.to(device) 135 | elif type(item) is tuple or type(item) is list: 136 | return [recursive_to(item[i], device) for i in range(len(item))] 137 | return item 138 | 139 | -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/spaghetti_github/models/occ_loss.py: -------------------------------------------------------------------------------- 1 | from ..custom_types import * 2 | 3 | 4 | def occupancy_bce(predict: T, winding_gt: T, ignore: Optional[T] = None, *args) -> T: 5 | if winding_gt.dtype is not torch.bool: 6 | winding_gt = winding_gt.gt(0) 7 | labels = winding_gt.flatten().float() 8 | predict = predict.flatten() 9 | if ignore is not None: 10 | ignore = (~ignore).flatten().float() 11 | loss = nnf.binary_cross_entropy_with_logits(predict, labels, weight=ignore) 12 | return loss 13 | 14 | 15 | def reg_z_loss(z: T) -> T: 16 | norms = z.norm(2, 1) 17 | loss = norms.mean() 18 | return loss 19 | -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/spaghetti_github/models/transformer.py: -------------------------------------------------------------------------------- 1 | from ..custom_types import * 2 | 3 | 4 | class FeedForward(nn.Module): 5 | def __init__(self, in_dim, h_dim, out_d: Optional[int] = None, act=nnf.relu, dropout=0.): 6 | super().__init__() 7 | out_d = out_d if out_d is not None else in_dim 8 | self.fc1 = nn.Linear(in_dim, h_dim) 9 | self.act = act 10 | self.fc2 = nn.Linear(h_dim, out_d) 11 | self.dropout = nn.Dropout(dropout) 12 | 13 | def forward(self, x): 14 | x = self.fc1(x) 15 | x = self.act(x) 16 | x = self.dropout(x) 17 | x = self.fc2(x) 18 | x = self.dropout(x) 19 | return x 20 | 21 | 22 | class MultiHeadAttention(nn.Module): 23 | 24 | def __init__(self, dim_self, dim_ref, num_heads, bias=True, dropout=0.): 25 | super().__init__() 26 | self.num_heads = num_heads 27 | head_dim = dim_self // num_heads 28 | self.scale = head_dim ** -0.5 29 | self.to_queries = nn.Linear(dim_self, dim_self, bias=bias) 30 | self.to_keys_values = nn.Linear(dim_ref, dim_self * 2, bias=bias) 31 | self.project = nn.Linear(dim_self, dim_self) 32 | self.dropout = nn.Dropout(dropout) 33 | 34 | def forward_interpolation(self, queries: T, keys: T, values: T, alpha: T, mask: TN = None) -> TNS: 35 | attention = torch.einsum('nhd,bmhd->bnmh', queries[0], keys) * self.scale 36 | if mask is not None: 37 | if mask.dim() == 2: 38 | mask = mask.unsqueeze(1) 39 | attention = attention.masked_fill(mask.unsqueeze(3), float("-inf")) 40 | attention = attention.softmax(dim=2) 41 | attention = attention * alpha[:, None, None, None] 42 | out = torch.einsum('bnmh,bmhd->nhd', attention, values).reshape(1, attention.shape[1], -1) 43 | return out, attention 44 | 45 | def forward(self, x, y: Optional[T] = None, mask: Optional[T] = None, alpha: TN = None): 46 | y = y if y is not None else x 47 | b_a, n, c = x.shape 48 | b, m, d = y.shape 49 | # b n h dh 50 | queries = self.to_queries(x).reshape(b_a, n, self.num_heads, c // self.num_heads) 51 | # b m 2 h dh 52 | keys_values = self.to_keys_values(y).reshape(b, m, 2, self.num_heads, c // self.num_heads) 53 | keys, values = keys_values[:, :, 0], keys_values[:, :, 1] 54 | if alpha is not None: 55 | out, attention = self.forward_interpolation(queries, keys, values, alpha, mask) 56 | else: 57 | attention = torch.einsum('bnhd,bmhd->bnmh', queries, keys) * self.scale 58 | if mask is not None: 59 | if mask.dim() == 2: 60 | mask = mask.unsqueeze(1) 61 | attention = attention.masked_fill(mask.unsqueeze(3), float("-inf")) 62 | attention = attention.softmax(dim=2) 63 | out = torch.einsum('bnmh,bmhd->bnhd', attention, values).reshape(b, n, c) 64 | out = self.project(out) 65 | return out, attention 66 | 67 | 68 | class TransformerLayer(nn.Module): 69 | 70 | def forward_with_attention(self, x, y: Optional[T] = None, mask: Optional[T] = None, alpha: TN = None): 71 | x_, attention = self.attn(self.norm1(x), y, mask, alpha) 72 | x = x + x_ 73 | x = x + self.mlp(self.norm2(x)) 74 | return x, attention 75 | 76 | def forward(self, x, y: Optional[T] = None, mask: Optional[T] = None, alpha: TN = None): 77 | x = x + self.attn(self.norm1(x), y, mask, alpha)[0] 78 | x = x + self.mlp(self.norm2(x)) 79 | return x 80 | 81 | def __init__(self, dim_self, dim_ref, num_heads, mlp_ratio=4., bias=False, dropout=0., act=nnf.relu, 82 | norm_layer: nn.Module = nn.LayerNorm): 83 | super().__init__() 84 | self.norm1 = norm_layer(dim_self) 85 | self.attn = MultiHeadAttention(dim_self, dim_ref, num_heads, bias=bias, dropout=dropout) 86 | self.norm2 = norm_layer(dim_self) 87 | self.mlp = FeedForward(dim_self, int(dim_self * mlp_ratio), act=act, dropout=dropout) 88 | 89 | 90 | class DummyTransformer: 91 | 92 | @staticmethod 93 | def forward_with_attention(x, *_, **__): 94 | return x, [] 95 | 96 | @staticmethod 97 | def forward(x, *_, **__): 98 | return x 99 | 100 | 101 | class Transformer(nn.Module): 102 | 103 | def forward_with_attention(self, x, y: Optional[T] = None, mask: Optional[T] = None, alpha: TN = None): 104 | attentions = [] 105 | for layer in self.layers: 106 | x, att = layer.forward_with_attention(x, y, mask, alpha) 107 | attentions.append(att) 108 | return x, attentions 109 | 110 | def forward(self, x, y: TN = None, mask: TN = None, alpha: TN = None): 111 | for layer in self.layers: 112 | x = layer(x, y, mask, alpha) 113 | return x 114 | 115 | def __init__(self, dim_self: int, num_heads: int, num_layers: int, dim_ref: Optional[int] = None, 116 | mlp_ratio: float = 2., act=nnf.relu, norm_layer: nn.Module = nn.LayerNorm): 117 | super(Transformer, self).__init__() 118 | dim_ref = dim_ref if dim_ref is not None else dim_self 119 | self.layers = nn.ModuleList([TransformerLayer(dim_self, dim_ref, num_heads, mlp_ratio, act=act, 120 | norm_layer=norm_layer) for _ in range(num_layers)]) 121 | -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/spaghetti_github/options.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | import os 3 | import pickle 4 | from . import constants as const 5 | from .custom_types import * 6 | 7 | 8 | class Options: 9 | 10 | def load(self): 11 | device = self.device 12 | if os.path.isfile(self.save_path): 13 | print(f'loading opitons from {self.save_path}') 14 | with open(self.save_path, 'rb') as f: 15 | options = pickle.load(f) 16 | options.device = device 17 | return options 18 | return self 19 | 20 | def save(self): 21 | if os.path.isdir(self.cp_folder): 22 | # self.already_saved = True 23 | with open(self.save_path, 'wb') as f: 24 | pickle.dump(self, f, pickle.HIGHEST_PROTOCOL) 25 | 26 | @property 27 | def info(self) -> str: 28 | return f'{self.model_name}_{self.tag}' 29 | 30 | @property 31 | def cp_folder(self): 32 | return f'{const.CHECKPOINTS_ROOT}{self.info}' 33 | 34 | @property 35 | def save_path(self): 36 | return f'{const.CHECKPOINTS_ROOT}{self.info}/options.pkl' 37 | 38 | def fill_args(self, args): 39 | for arg in args: 40 | if hasattr(self, arg): 41 | setattr(self, arg, args[arg]) 42 | 43 | def __init__(self, **kwargs): 44 | self.device = CUDA(0) 45 | self.tag = 'airplanes' 46 | self.dataset_name = 'shapenet_airplanes_wm_sphere_sym_train' 47 | self.epochs = 2000 48 | self.model_name = 'spaghetti' 49 | self.dim_z = 256 50 | self.pos_dim = 256 - 3 51 | self.dim_h = 512 52 | self.dim_zh = 512 53 | self.num_gaussians = 16 54 | self.min_split = 4 55 | self.max_split = 12 56 | self.gmm_weight = 1 57 | self.decomposition_network = 'transformer' 58 | self.decomposition_num_layers = 4 59 | self.num_layers = 4 60 | self.num_heads = 4 61 | self.num_layers_head = 6 62 | self.num_heads_head = 8 63 | self.head_occ_size = 5 64 | self.head_occ_type = 'skip' 65 | self.batch_size = 18 66 | self.num_samples = 2000 67 | self.dataset_size = -1 68 | self.symmetric = (True, False, False) 69 | self.data_symmetric = (True, False, False) 70 | self.lr_decay = .9 71 | self.lr_decay_every = 500 72 | self.warm_up = 2000 73 | self.reg_weight = 1e-4 74 | self.disentanglement = True 75 | self.use_encoder = True 76 | self.disentanglement_weight = 1 77 | self.augmentation_rotation = 0.3 78 | self.augmentation_scale = .2 79 | self.augmentation_translation = .3 80 | self.fill_args(kwargs) 81 | -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/spaghetti_github/ui/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KAIST-Visual-AI-Group/SALAD/db66b1b2ce57d8a4beb5bd610bc006ccaff43b1e/notebooks/uncleaned_demo/pvd/spaghetti_github/ui/__init__.py -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/spaghetti_github/ui/mock_keyboard.py: -------------------------------------------------------------------------------- 1 | 2 | class Key: 3 | ctrl_l = 'control_l' 4 | 5 | 6 | class Controller: 7 | 8 | @staticmethod 9 | def press(key: str) -> str: 10 | return key 11 | 12 | @staticmethod 13 | def release(key: str) -> str: 14 | return key 15 | -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/spaghetti_github/utils/myparse.py: -------------------------------------------------------------------------------- 1 | from ..custom_types import * 2 | import sys 3 | 4 | 5 | # minimal argparse 6 | def parse(parse_dict: Dict[str, Dict[str, Any]]): 7 | args = sys.argv[1:] 8 | args_dict = {args[i]: args[i + 1] for i in range(0, len(args), 2)} 9 | out_dict = {} 10 | for item in parse_dict: 11 | hint = parse_dict[item]['type'] 12 | if item in args_dict: 13 | out_dict[item[2:]] = hint(args_dict[item]) 14 | else: 15 | out_dict[item[2:]] = parse_dict[item]['default'] 16 | if 'options' in parse_dict[item] and out_dict[item[2:]] not in parse_dict[item]['options']: 17 | raise ValueError(f'Expected {item} to be in {str(parse_dict[item]["options"])}') 18 | return out_dict 19 | -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/spaghetti_github/utils/rotation_utils.py: -------------------------------------------------------------------------------- 1 | from .. import constants 2 | import functools 3 | from scipy.spatial.transform.rotation import Rotation 4 | from ..custom_types import * 5 | 6 | 7 | def quat_to_rot(q): 8 | shape = q.shape 9 | q = q.view(-1, 4) 10 | q_sq = 2 * q[:, :, None] * q[:, None, :] 11 | m00 = 1 - q_sq[:, 1, 1] - q_sq[:, 2, 2] 12 | m01 = q_sq[:, 0, 1] - q_sq[:, 2, 3] 13 | m02 = q_sq[:, 0, 2] + q_sq[:, 1, 3] 14 | 15 | m10 = q_sq[:, 0, 1] + q_sq[:, 2, 3] 16 | m11 = 1 - q_sq[:, 0, 0] - q_sq[:, 2, 2] 17 | m12 = q_sq[:, 1, 2] - q_sq[:, 0, 3] 18 | 19 | m20 = q_sq[:, 0, 2] - q_sq[:, 1, 3] 20 | m21 = q_sq[:, 1, 2] + q_sq[:, 0, 3] 21 | m22 = 1 - q_sq[:, 0, 0] - q_sq[:, 1, 1] 22 | r = torch.stack((m00, m01, m02, m10, m11, m12, m20, m21, m22), dim=1) 23 | r = r.view(*shape[:-1], 3, 3) 24 | return r 25 | 26 | 27 | def rot_to_quat(r): 28 | shape = r.shape 29 | r = r.view(-1, 3, 3) 30 | qw = .5 * (1 + r[:, 0, 0] + r[:, 1, 1] + r[:, 2, 2]).sqrt() 31 | qx = (r[:, 2, 1] - r[:, 1, 2]) / (4 * qw) 32 | qy = (r[:, 0, 2] - r[:, 2, 0]) / (4 * qw) 33 | qz = (r[:, 1, 0] - r[:, 0, 1]) / (4 * qw) 34 | q = torch.stack((qx, qy, qz, qw), -1) 35 | q = q.view(*shape[:-2], 4) 36 | return q 37 | 38 | 39 | @functools.lru_cache(10) 40 | def get_rotation_matrix(theta: float, axis: float, degree: bool = False) -> ARRAY: 41 | if degree: 42 | theta = theta * np.pi / 180 43 | rotate_mat = np.eye(3) 44 | rotate_mat[axis, axis] = 1 45 | cos_theta, sin_theta = np.cos(theta), np.sin(theta) 46 | rotate_mat[(axis + 1) % 3, (axis + 1) % 3] = cos_theta 47 | rotate_mat[(axis + 2) % 3, (axis + 2) % 3] = cos_theta 48 | rotate_mat[(axis + 1) % 3, (axis + 2) % 3] = sin_theta 49 | rotate_mat[(axis + 2) % 3, (axis + 1) % 3] = -sin_theta 50 | return rotate_mat 51 | 52 | 53 | def get_random_rotation(batch_size: int) -> T: 54 | r = Rotation.random(batch_size).as_matrix().astype(np.float32) 55 | Rotation.random() 56 | return torch.from_numpy(r) 57 | 58 | 59 | def rand_bounded_rotation_matrix(cache_size: int, theta_range: float = .1): 60 | 61 | def create_cache(): 62 | # from http://www.realtimerendering.com/resources/GraphicsGems/gemsiii/rand_rotation.c 63 | with torch.no_grad(): 64 | theta, phi, z = torch.rand(cache_size, 3).split((1, 1, 1), dim=1) 65 | theta = (2 * theta - 1) * theta_range + 1 66 | theta = np.pi * theta # Rotation about the pole (Z). 67 | phi = phi * 2 * np.pi # For direction of pole deflection. 68 | z = 2 * z * theta_range # For magnitude of pole deflection. 69 | r = z.sqrt() 70 | v = torch.cat((torch.sin(phi) * r, torch.cos(phi) * r, torch.sqrt(2.0 - z)), dim=1) 71 | st = torch.sin(theta).squeeze(1) 72 | ct = torch.cos(theta).squeeze(1) 73 | rot_ = torch.zeros(cache_size, 3, 3) 74 | rot_[:, 0, 0] = ct 75 | rot_[:, 1, 1] = ct 76 | rot_[:, 0, 1] = st 77 | rot_[:, 1, 0] = -st 78 | rot_[:, 2, 2] = 1 79 | rot = (torch.einsum('ba,bd->bad', v, v) - torch.eye(3)[None, :, :]).bmm(rot_) 80 | det = rot.det() 81 | assert (det.gt(0.99) * det.lt(1.0001)).all().item() 82 | return rot 83 | 84 | def get_batch_rot(batch_size): 85 | nonlocal cache 86 | select = torch.randint(cache_size, size=(batch_size,)) 87 | return cache[select] 88 | 89 | cache = create_cache() 90 | 91 | return get_batch_rot 92 | 93 | 94 | def transform_rotation(points: T, one_axis=False, max_angle=-1): 95 | r = get_random_rotation(one_axis, max_angle) 96 | transformed = torch.einsum('nd,rd->nr', points, r) 97 | return transformed 98 | 99 | 100 | def tb_to_rot(abc: T) -> T: 101 | c, s = torch.cos(abc), torch.sin(abc) 102 | aa = c[:, 0] * c[:, 1] 103 | ab = c[:, 0] * s[:, 1] * s[:, 2] - c[:, 2] * s[:, 0] 104 | ac = s[:, 0] * s[:, 2] + c[:, 0] * c[:, 2] * s[:, 1] 105 | 106 | ba = c[:, 1] * s[:, 0] 107 | bb = c[:, 0] * c[:, 2] + s.prod(-1) 108 | bc = c[:, 2] * s[:, 0] * s[:, 1] - c[:, 0] * s[:, 2] 109 | 110 | ca = -s[:, 1] 111 | cb = c[:, 1] * s[:, 2] 112 | cc = c[:, 1] * c[:, 2] 113 | return torch.stack((aa, ab, ac, ba, bb, bc, ca, cb, cc), 1).view(-1, 3, 3) 114 | 115 | 116 | def rot_to_tb(rot: T) -> T: 117 | sy = torch.sqrt(rot[:, 0, 0] * rot[:, 0, 0] + rot[:, 1, 0] * rot[:, 1, 0]) 118 | out = torch.zeros(rot.shape[0], 3, device = rot.device) 119 | mask = sy.gt(1e-6) 120 | z = torch.atan2(rot[mask, 2, 1], rot[mask, 2, 2]) 121 | y = torch.atan2(-rot[mask, 2, 0], sy[mask]) 122 | x = torch.atan2(rot[mask, 1, 0], rot[mask, 0, 0]) 123 | out[mask] = torch.stack((x, y, z), dim=1) 124 | if not mask.all(): 125 | mask = ~mask 126 | z = torch.atan2(-rot[mask, 1, 2], rot[mask, 1, 1]) 127 | y = torch.atan2(-rot[mask, 2, 0], sy[mask]) 128 | x = torch.zeros(x.shape) 129 | out[mask] = torch.stack((x, y, z), dim=1) 130 | return out 131 | 132 | 133 | def apply_gmm_affine(gmms: TS, affine: T): 134 | mu, p, phi, eigen = gmms 135 | if affine.dim() == 2: 136 | affine = affine.unsqueeze(0).expand(mu.shape[0], *affine.shape) 137 | mu_r = torch.einsum('bad, bpnd->bpna', affine, mu) 138 | p_r = torch.einsum('bad, bpncd->bpnca', affine, p) 139 | return mu_r, p_r, phi, eigen 140 | 141 | 142 | def get_reflection(reflect_axes: Tuple[bool, ...]) -> T: 143 | reflect = torch.eye(constants.DIM) 144 | for i in range(constants.DIM): 145 | if reflect_axes[i]: 146 | reflect[i, i] = -1 147 | return reflect 148 | 149 | 150 | def get_tait_bryan_from_p(p: T) -> T: 151 | # p = p.squeeze(1) 152 | shape = p.shape 153 | rot = p.reshape(-1, 3, 3).permute(0, 2, 1) 154 | angles = rot_to_tb(rot) 155 | angles = angles / np.pi 156 | angles[:, 1] = angles[:, 1] * 2 157 | angles = angles.view(*shape[:2], 3) 158 | return angles 159 | -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/train_clr.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | import pytorch_lightning as pl 3 | import jutils 4 | @hydra.main(config_path="configs", config_name="simclr.yaml") 5 | def main(config): 6 | pl.seed_everything(63) 7 | jutils.sysutil.print_config(config, ("callbacks", "logger", "paths", "experiment", "debug", "data", "trainer", "model")) 8 | model = hydra.utils.instantiate(config.model, _recursive_=True) 9 | 10 | callbacks = [] 11 | if config.get("callbacks"): 12 | for cb_name, cb_conf in config.callbacks.items(): 13 | if config.get("debug") and cb_name == "model_checkpoint": 14 | continue 15 | callbacks.append(hydra.utils.instantiate(cb_conf)) 16 | 17 | logger = [] 18 | if config.get("logger"): 19 | for lg_name, lg_conf in config.logger.items(): 20 | if config.get("debug") and lg_name == "wandb": 21 | continue 22 | logger.append(hydra.utils.instantiate(lg_conf)) 23 | 24 | trainer = hydra.utils.instantiate( 25 | config.trainer, 26 | callbacks=callbacks, 27 | logger=logger, 28 | _convert_="partial", 29 | log_every_n_steps=200, 30 | ) 31 | 32 | trainer.fit(model) 33 | 34 | if __name__ == "__main__": 35 | main() 36 | -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/train_ldm.py: -------------------------------------------------------------------------------- 1 | # < Seungwoo > 2 | # Temporarily re-locate files to avoid errors during import. 3 | #### 4 | import jutils 5 | 6 | import hydra 7 | import pytorch_lightning as pl 8 | # import jutils 9 | #### 10 | 11 | @hydra.main(config_path="configs", config_name="train.yaml") 12 | def main(config): 13 | pl.seed_everything(63) 14 | 15 | #### < Seungwoo > 16 | # To make W&B sweep work, need to explicitly initialize W&B 17 | # Remove this line after finishing hyperparameter search 18 | # import wandb 19 | # wandb.init(settings=wandb.Settings(start_method="fork") ) 20 | #### 21 | 22 | jutils.sysutil.print_config(config, ("callbacks", "logger", "paths", "experiment", "debug", "data", "trainer", "model")) 23 | model = hydra.utils.instantiate(config.model, _recursive_=True) 24 | 25 | callbacks = [] 26 | if config.get("callbacks"): 27 | for cb_name, cb_conf in config.callbacks.items(): 28 | if config.get("debug") and cb_name == "model_checkpoint": 29 | continue 30 | # if cb_name == "lr_monitor": 31 | # continue 32 | callbacks.append(hydra.utils.instantiate(cb_conf)) 33 | 34 | logger = [] 35 | if config.get("logger"): 36 | for lg_name, lg_conf in config.logger.items(): 37 | if config.get("debug") and lg_name == "wandb": 38 | continue 39 | logger.append(hydra.utils.instantiate(lg_conf)) 40 | trainer = hydra.utils.instantiate( 41 | config.trainer, 42 | callbacks=callbacks, 43 | logger=logger if len(logger) != 0 else False, 44 | # logger=False, 45 | _convert_="partial", 46 | log_every_n_steps=200, 47 | resume_from_checkpoint=config.resume_from_checkpoint 48 | ) 49 | 50 | trainer.fit(model) 51 | 52 | if __name__ == "__main__": 53 | main() 54 | -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/train_vae.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | import pytorch_lightning as pl 3 | @hydra.main(config_path="configs", config_name="default.yaml") 4 | def main(config): 5 | pl.seed_everything(63) 6 | 7 | model = hydra.utils.instantiate(config.model) 8 | 9 | callbacks = [] 10 | if config.get("callbacks"): 11 | for cb_name, cb_conf in config.callbacks.items(): 12 | if config.get("debug") and cb_name == "model_checkpoint": 13 | continue 14 | callbacks.append(hydra.utils.instantiate(cb_conf)) 15 | 16 | logger = [] 17 | if config.get("logger"): 18 | for lg_name, lg_conf in config.logger.items(): 19 | if config.get("debug") and lg_name == "wandb": 20 | continue 21 | logger.append(hydra.utils.instantiate(lg_conf)) 22 | 23 | trainer = hydra.utils.instantiate( 24 | config.trainer, 25 | callbacks=callbacks, 26 | logger=logger, 27 | _convert_="partial", 28 | log_every_n_steps=100, 29 | ) 30 | 31 | trainer.fit(model) 32 | model.save_latents() 33 | 34 | if __name__ == "__main__": 35 | main() 36 | -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KAIST-Visual-AI-Group/SALAD/db66b1b2ce57d8a4beb5bd610bc006ccaff43b1e/notebooks/uncleaned_demo/pvd/utils/__init__.py -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/utils/contrastive_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import jutils 4 | 5 | def make_distance_matrix(feat_a, feat_b): 6 | """ 7 | Input: 8 | feat_a: [B,D] 9 | feat_b: [B,D] 10 | """ 11 | pass 12 | -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/utils/ldm_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer 4 | 5 | 6 | class DiagonalGaussianDistribution(object): 7 | def __init__(self, parameters, deterministic=False, fixed_std=-1): 8 | self.parameters = parameters 9 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=-1) 10 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 11 | self.deterministic = deterministic 12 | self.std = torch.exp(0.5 * self.logvar) 13 | self.var = torch.exp(self.logvar) 14 | if self.deterministic: 15 | self.var = self.std = torch.zeros_like(self.mean).to( 16 | device=self.parameters.device 17 | ) 18 | 19 | if fixed_std >= 0: 20 | self.std = torch.ones_like(self.std) * fixed_std 21 | self.var = self.std ** 2 22 | self.logvar = torch.log(self.var) 23 | 24 | def sample(self): 25 | x = self.mean + self.std * torch.randn(self.mean.shape).to( 26 | device=self.parameters.device 27 | ) 28 | return x 29 | 30 | def kl(self, other=None): 31 | num_dim = len(self.mean.shape) 32 | 33 | if self.deterministic: 34 | return torch.Tensor([0.0]) 35 | else: 36 | if other is None: 37 | return 0.5 * torch.sum( 38 | torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, 39 | # dim=[1, 2, 3] 40 | dim=list(np.arange(1, num_dim)), 41 | ) 42 | else: 43 | return 0.5 * torch.sum( 44 | torch.pow(self.mean - other.mean, 2) / other.var 45 | + self.var / other.var 46 | - 1.0 47 | - self.logvar 48 | + other.logvar, 49 | # dim=[1, 2, 3], 50 | dim=list(np.arange(1, num_dim)) 51 | ) 52 | 53 | def nll(self, sample, dims=[1, 2, 3]): 54 | if self.deterministic: 55 | return torch.Tensor([0.0]) 56 | logtwopi = np.log(2.0 * np.pi) 57 | return 0.5 * torch.sum( 58 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 59 | dim=dims, 60 | ) 61 | 62 | def mode(self): 63 | return self.mean 64 | -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/utils/train_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | 4 | 5 | class PolyDecayScheduler(torch.optim.lr_scheduler.LambdaLR): 6 | def __init__(self, optimizer, init_lr, power=0.99, lr_end=1e-7, last_epoch=-1): 7 | def lr_lambda(step): 8 | lr = max(power**step, lr_end / init_lr) 9 | return lr 10 | 11 | super().__init__(optimizer, lr_lambda, last_epoch=last_epoch) 12 | 13 | 14 | def get_dropout_mask(shape, dropout: float, device): 15 | if dropout == 1: 16 | return torch.zeros(shape, device=device, dtype=torch.bool) 17 | elif dropout == 0: 18 | return torch.ones_like(shape, device=device, dtype=torch.bool) 19 | else: 20 | return torch.zeros(shape, device=device).float().uniform_(0, 1) > dropout 21 | -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/utils/transform.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def apply_gmm_affine(gmms, affine): 5 | """ 6 | Applies the given Affine transform to Gaussians in the GMM. 7 | 8 | Args: 9 | gmms: (B, M, 16). A batch of GMMs 10 | affine: (3, 3) or (B, 3, 3). A batch of Affine matrices 11 | """ 12 | mu, p, phi, eigen = torch.split(gmms, [3, 9, 1, 3], dim=2) 13 | if affine.ndim == 2: 14 | affine = affine.unsqueeze(0).expand(mu.size(0), *affine.shape) 15 | 16 | bs, n_part, _ = mu.shape 17 | p = p.reshape(bs, n_part, 3, 3) 18 | 19 | mu_r = torch.einsum("bad, bnd -> bna", affine, mu) 20 | p_r = torch.einsum("bad, bncd -> bnca", affine, p) 21 | p_r = p_r.reshape(bs, n_part, -1) 22 | gmms_t = torch.cat([mu_r, p_r, phi, eigen], dim=2) 23 | assert gmms.shape == gmms_t.shape, "Input and output shapes must be the same" 24 | 25 | return gmms_t -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/vae/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KAIST-Visual-AI-Group/SALAD/db66b1b2ce57d8a4beb5bd610bc006ccaff43b1e/notebooks/uncleaned_demo/pvd/vae/__init__.py -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/pvd/vae/ae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from pointnet2_ops.pointnet2_modules import PointnetSAModule, PointnetSAModuleMSG 4 | 5 | 6 | -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .common import * 2 | from .gmm_edit import * 3 | from .io_utils import * 4 | from .vis import make_img_grid -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/utils/common.py: -------------------------------------------------------------------------------- 1 | """ 2 | common.py 3 | 4 | A collection of commonly used utility functions. 5 | """ 6 | 7 | def detach_cpu_list(list_of_tensor): 8 | result = [item.detach().cpu() for item in list_of_tensor] 9 | return result 10 | 11 | 12 | def compute_batch_sections(n_data: int, batch_size: int): 13 | """ 14 | Computes the sections that divide the given data into batches. 15 | 16 | Args: 17 | n_data: int, number of data 18 | batch_size: int, size of a single batch 19 | """ 20 | assert isinstance(n_data, int) and n_data > 0 21 | assert isinstance(batch_size, int) and batch_size > 0 22 | 23 | batch_indices = list(range(0, n_data, batch_size)) 24 | if batch_indices[-1] != n_data: 25 | batch_indices.append(n_data) 26 | 27 | batch_start = batch_indices[:-1] 28 | batch_end = batch_indices[1:] 29 | assert len(batch_start) == len(batch_end), f"Lengths are different {len(batch_start)} {len(batch_end)}" 30 | 31 | return batch_start, batch_end -------------------------------------------------------------------------------- /notebooks/uncleaned_demo/utils/vis.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | 3 | import jutils 4 | import numpy as np 5 | import torch 6 | 7 | from pvd.utils.spaghetti_utils import batch_gaus_to_gmms, get_occ_func 8 | 9 | 10 | def make_img_grid(img_lists, nrow=5): 11 | """ 12 | Creates an image grid using the provided images. 13 | 14 | Args: 15 | img_lists: Lists of list of images 16 | nrow: Number of samples in a row. 17 | """ 18 | for img_list in img_lists: 19 | assert isinstance(img_list, list) 20 | assert len(img_lists[0]) == len(img_list) 21 | assert len(img_list) > 0 22 | assert isinstance(img_list[0], Image.Image) 23 | 24 | img_pairs = list(zip(*img_lists)) 25 | 26 | indices = list(range(0, len(img_lists[0]), nrow)) 27 | if indices[-1] != len(img_lists[0]): 28 | indices.append(len(img_lists[0])) 29 | starts, ends = indices[0:-1], indices[1:] 30 | 31 | image_arr = [] 32 | for start, end in zip(starts, ends): 33 | row_images = [] 34 | for i in range(start, end): 35 | for img in img_pairs[i]: 36 | row_images.append(img) 37 | image_arr.append(row_images) 38 | 39 | return jutils.imageutil.merge_images(image_arr) 40 | 41 | 42 | def decode_and_render(gmm, zh, spaghetti, mesher, camera_kwargs): 43 | if gmm.ndim == 2: 44 | gmm = gmm[None] 45 | if zh.ndim == 2: 46 | zh = zh[None] 47 | 48 | vert, face = decode_gmm_and_intrinsic( 49 | spaghetti, mesher, gmm, zh 50 | ) 51 | img = render_mesh(vert, face, camera_kwargs) 52 | return img 53 | 54 | 55 | def decode_gmm_and_intrinsic(spaghetti, mesher, gmm, intrinsic, verbose=False): 56 | assert gmm.ndim == 3 and intrinsic.ndim == 3 57 | assert gmm.shape[0] == intrinsic.shape[0] 58 | assert gmm.shape[1] == intrinsic.shape[1] 59 | 60 | zc, _ = spaghetti.merge_zh(intrinsic, batch_gaus_to_gmms(gmm, gmm.device)) 61 | 62 | mesh = mesher.occ_meshing( 63 | decoder=get_occ_func(spaghetti, zc), 64 | res=256, 65 | get_time=False, 66 | verbose=False, 67 | ) 68 | assert mesh is not None, "Marching cube failed" 69 | vert, face = list(map(lambda x: jutils.thutil.th2np(x), mesh)) 70 | assert isinstance(vert, np.ndarray) and isinstance(face, np.ndarray) 71 | 72 | if verbose: 73 | print(f"Vert: {vert.shape} / Face: {face.shape}") 74 | 75 | # Free GPU memory after computing mesh 76 | _ = zc.cpu() 77 | jutils.sysutil.clean_gpu() 78 | if verbose: 79 | print("Freed GPU memory") 80 | 81 | return vert, face 82 | 83 | 84 | def render_mesh(vert, face, camera_kwargs): 85 | img = jutils.fresnelvis.renderMeshCloud( 86 | mesh={"vert": vert / 2, "face": face}, **camera_kwargs 87 | ) 88 | if isinstance(img, np.ndarray): 89 | img = Image.fromarray(img) 90 | assert isinstance(img, Image.Image) 91 | return img 92 | 93 | -------------------------------------------------------------------------------- /salad/model_components/lstm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence 5 | 6 | 7 | class LSTM(nn.Module): 8 | def __init__(self, text_dim, embedding_dim, vocab_size, padding_idx=0): 9 | super().__init__() 10 | self.padding_idx = padding_idx 11 | self.word_embedding = nn.Embedding( 12 | vocab_size, embedding_dim, padding_idx=padding_idx 13 | ) 14 | self.rnn = nn.LSTM(embedding_dim, text_dim, batch_first=True) 15 | self.w_attn = nn.Parameter(torch.Tensor(1, text_dim)) 16 | nn.init.xavier_uniform_(self.w_attn) 17 | 18 | def forward(self, padded_tokens, dropout=0.5): 19 | w_emb = self.word_embedding(padded_tokens) 20 | w_emb = F.dropout(w_emb, dropout, self.training) 21 | len_seq = (padded_tokens != self.padding_idx).sum(dim=1).cpu() 22 | x_packed = pack_padded_sequence( 23 | w_emb, len_seq, enforce_sorted=False, batch_first=True 24 | ) 25 | B = padded_tokens.shape[0] 26 | rnn_out, _ = self.rnn(x_packed) 27 | rnn_out, dummy = pad_packed_sequence(rnn_out, batch_first=True) 28 | h = rnn_out[torch.arange(B), len_seq - 1] 29 | final_feat, attn = self.word_attention(rnn_out, h, len_seq) 30 | return final_feat, attn 31 | 32 | def word_attention(self, R, h, len_seq): 33 | """ 34 | Input: 35 | R: hidden states of the entire words 36 | h: the final hidden state after processing the entire words 37 | len_seq: the length of the sequence 38 | Output: 39 | final_feat: the final feature after the bilinear attention 40 | attn: word attention weights 41 | """ 42 | B, N, D = R.shape 43 | device = R.device 44 | len_seq = len_seq.to(device) 45 | 46 | W_attn = (self.w_attn * torch.eye(D).to(device))[None].repeat(B, 1, 1) 47 | score = torch.bmm(torch.bmm(R, W_attn), h.unsqueeze(-1)) 48 | 49 | mask = torch.arange(N).reshape(1, N, 1).repeat(B, 1, 1).to(device) 50 | mask = mask < len_seq.reshape(B, 1, 1) 51 | 52 | score = score.masked_fill(mask == 0, -1e9) 53 | attn = F.softmax(score, 1) 54 | final_feat = torch.bmm(R.transpose(1, 2), attn).squeeze(-1) 55 | 56 | return final_feat, attn.squeeze(-1) 57 | -------------------------------------------------------------------------------- /salad/model_components/simple_module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | 6 | from salad.model_components.transformer import TimeMLP 7 | 8 | 9 | class TimePointwiseLayer(nn.Module): 10 | def __init__( 11 | self, 12 | dim_in, 13 | dim_ctx, 14 | mlp_ratio=2, 15 | act=F.leaky_relu, 16 | dropout=0.0, 17 | use_time=False, 18 | ): 19 | super().__init__() 20 | self.use_time = use_time 21 | self.act = act 22 | self.mlp1 = TimeMLP( 23 | dim_in, dim_in * mlp_ratio, dim_in, dim_ctx, use_time=use_time 24 | ) 25 | self.norm1 = nn.LayerNorm(dim_in) 26 | 27 | self.mlp2 = TimeMLP( 28 | dim_in, dim_in * mlp_ratio, dim_in, dim_ctx, use_time=use_time 29 | ) 30 | self.norm2 = nn.LayerNorm(dim_in) 31 | self.dropout = nn.Dropout(dropout) 32 | 33 | def forward(self, x, ctx=None): 34 | res = x 35 | x = self.mlp1(x, ctx=ctx) 36 | x = self.norm1(x + res) 37 | 38 | res = x 39 | x = self.mlp2(x, ctx=ctx) 40 | x = self.norm2(x + res) 41 | return x 42 | 43 | 44 | class TimePointWiseEncoder(nn.Module): 45 | def __init__( 46 | self, 47 | dim_in, 48 | dim_ctx=None, 49 | mlp_ratio=2, 50 | act=F.leaky_relu, 51 | dropout=0.0, 52 | use_time=True, 53 | num_layers=6, 54 | last_fc=False, 55 | last_fc_dim_out=None, 56 | ): 57 | super().__init__() 58 | self.last_fc = last_fc 59 | if last_fc: 60 | self.fc = nn.Linear(dim_in, last_fc_dim_out) 61 | self.layers = nn.ModuleList( 62 | [ 63 | TimePointwiseLayer( 64 | dim_in, 65 | dim_ctx=dim_ctx, 66 | mlp_ratio=mlp_ratio, 67 | act=act, 68 | dropout=dropout, 69 | use_time=use_time, 70 | ) 71 | for _ in range(num_layers) 72 | ] 73 | ) 74 | 75 | def forward(self, x, ctx=None): 76 | for i, layer in enumerate(self.layers): 77 | x = layer(x, ctx=ctx) 78 | if self.last_fc: 79 | x = self.fc(x) 80 | return x 81 | 82 | 83 | class TimestepEmbedder(nn.Module): 84 | """ 85 | Embeds scalar timesteps into vector representations. 86 | """ 87 | 88 | def __init__(self, hidden_size, frequency_embedding_size=256): 89 | super().__init__() 90 | self.mlp = nn.Sequential( 91 | nn.Linear(frequency_embedding_size, hidden_size, bias=True), 92 | nn.SiLU(), 93 | nn.Linear(hidden_size, hidden_size, bias=True), 94 | ) 95 | self.frequency_embedding_size = frequency_embedding_size 96 | 97 | @staticmethod 98 | def timestep_embedding(t, dim, max_period=10000): 99 | """ 100 | Create sinusoidal timestep embeddings. 101 | :param t: a 1-D Tensor of N indices, one per batch element. 102 | These may be fractional. 103 | :param dim: the dimension of the output. 104 | :param max_period: controls the minimum frequency of the embeddings. 105 | :return: an (N, D) Tensor of positional embeddings. 106 | """ 107 | # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py 108 | half = dim // 2 109 | freqs = torch.exp( 110 | -math.log(max_period) 111 | * torch.arange(start=0, end=half, dtype=torch.float32) 112 | / half 113 | ).to(device=t.device) 114 | args = t[:, None].float() * freqs[None] 115 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 116 | if dim % 2: 117 | embedding = torch.cat( 118 | [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 119 | ) 120 | return embedding 121 | 122 | def forward(self, t): 123 | t_freq = self.timestep_embedding(t, self.frequency_embedding_size) 124 | t_emb = self.mlp(t_freq) 125 | return t_emb 126 | -------------------------------------------------------------------------------- /salad/model_components/variance_schedule.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch.nn import Linear, Module 4 | 5 | class VarianceSchedule(Module): 6 | def __init__(self, num_steps, beta_1, beta_T, mode="linear"): 7 | super().__init__() 8 | # assert mode in ("linear",) 9 | self.num_steps = num_steps 10 | self.beta_1 = beta_1 11 | self.beta_T = beta_T 12 | self.mode = mode 13 | 14 | if mode == "linear": 15 | betas = torch.linspace(beta_1, beta_T, steps=num_steps) 16 | elif mode == "quad": 17 | betas = torch.linspace(beta_1 ** 0.5, beta_T ** 0.5, num_steps) ** 2 18 | elif mode == "cosine": 19 | cosine_s = 8e-3 20 | timesteps = torch.arange(num_steps + 1) / num_steps + cosine_s 21 | alphas = timesteps / (1 + cosine_s) * np.pi / 2 22 | alphas = torch.cos(alphas).pow(2) 23 | betas = 1 - alphas[1:] / alphas[:-1] 24 | betas = betas.clamp(max=0.999) 25 | 26 | betas = torch.cat([torch.zeros([1]), betas], dim=0) # Padding 27 | 28 | alphas = 1 - betas 29 | log_alphas = torch.log(alphas) 30 | for i in range(1, log_alphas.size(0)): # 1 to T 31 | log_alphas[i] += log_alphas[i - 1] 32 | alpha_bars = log_alphas.exp() 33 | 34 | sigmas_flex = torch.sqrt(betas) 35 | sigmas_inflex = torch.zeros_like(sigmas_flex) 36 | for i in range(1, sigmas_flex.size(0)): 37 | sigmas_inflex[i] = ((1 - alpha_bars[i - 1]) / (1 - alpha_bars[i])) * betas[ 38 | i 39 | ] 40 | sigmas_inflex = torch.sqrt(sigmas_inflex) 41 | 42 | self.register_buffer("betas", betas) 43 | self.register_buffer("alphas", alphas) 44 | self.register_buffer("alpha_bars", alpha_bars) 45 | self.register_buffer("sigmas_flex", sigmas_flex) 46 | self.register_buffer("sigmas_inflex", sigmas_inflex) 47 | 48 | def uniform_sample_t(self, batch_size): 49 | ts = np.random.choice(np.arange(1, self.num_steps + 1), batch_size) 50 | return ts.tolist() 51 | 52 | def get_sigmas(self, t, flexibility): 53 | assert 0 <= flexibility and flexibility <= 1 54 | sigmas = self.sigmas_flex[t] * flexibility + self.sigmas_inflex[t] * ( 55 | 1 - flexibility 56 | ) 57 | return sigmas 58 | -------------------------------------------------------------------------------- /salad/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KAIST-Visual-AI-Group/SALAD/db66b1b2ce57d8a4beb5bd610bc006ccaff43b1e/salad/models/__init__.py -------------------------------------------------------------------------------- /salad/models/base_model.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | import torch 3 | import torch.nn.functional as F 4 | from salad.data.dataset import SALADDataset 5 | from salad.utils.train_util import PolyDecayScheduler 6 | 7 | 8 | class BaseModel(pl.LightningModule): 9 | def __init__( 10 | self, 11 | network, 12 | variance_schedule, 13 | **kwargs, 14 | ): 15 | super().__init__() 16 | self.save_hyperparameters(logger=False) 17 | self.net = network 18 | self.var_sched = variance_schedule 19 | 20 | def forward(self, x): 21 | return self.get_loss(x) 22 | 23 | def step(self, x, stage: str): 24 | loss = self(x) 25 | self.log( 26 | f"{stage}/loss", 27 | loss, 28 | on_step=stage == "train", 29 | prog_bar=True, 30 | ) 31 | return loss 32 | 33 | def training_step(self, batch, batch_idx): 34 | x = batch 35 | return self.step(x, "train") 36 | 37 | def add_noise(self, x, t): 38 | """ 39 | Input: 40 | x: [B,D] or [B,G,D] 41 | t: list of size B 42 | Output: 43 | x_noisy: [B,D] 44 | beta: [B] 45 | e_rand: [B,D] 46 | """ 47 | alpha_bar = self.var_sched.alpha_bars[t] 48 | beta = self.var_sched.betas[t] 49 | 50 | c0 = torch.sqrt(alpha_bar).view(-1, 1) # [B,1] 51 | c1 = torch.sqrt(1 - alpha_bar).view(-1, 1) 52 | 53 | e_rand = torch.randn_like(x) 54 | if e_rand.dim() == 3: 55 | c0 = c0.unsqueeze(1) 56 | c1 = c1.unsqueeze(1) 57 | 58 | x_noisy = c0 * x + c1 * e_rand 59 | 60 | return x_noisy, beta, e_rand 61 | 62 | def get_loss( 63 | self, 64 | x0, 65 | t=None, 66 | noisy_in=False, 67 | beta_in=None, 68 | e_rand_in=None, 69 | ): 70 | if x0.dim() == 2: 71 | B, D = x0.shape 72 | else: 73 | B, G, D = x0.shape 74 | if not noisy_in: 75 | if t is None: 76 | t = self.var_sched.uniform_sample_t(B) 77 | x_noisy, beta, e_rand = self.add_noise(x0, t) 78 | else: 79 | x_noisy = x0 80 | beta = beta_in 81 | e_rand = e_rand_in 82 | 83 | e_theta = self.net(x_noisy, beta=beta) 84 | loss = F.mse_loss(e_theta.flatten(), e_rand.flatten(), reduction="mean") 85 | return loss 86 | 87 | @torch.no_grad() 88 | def sample( 89 | self, 90 | batch_size=0, 91 | return_traj=False, 92 | ): 93 | raise NotImplementedError 94 | 95 | def validation_epoch_end(self, outputs): 96 | if self.hparams.no_run_validation: 97 | return 98 | if not self.trainer.sanity_checking: 99 | if (self.current_epoch) % self.hparams.validation_step == 0: 100 | self.validation() 101 | 102 | def _build_dataset(self, stage): 103 | if hasattr(self, f"data_{stage}"): 104 | return getattr(self, f"data_{stage}") 105 | if stage == "train": 106 | ds = SALADDataset(**self.hparams.dataset_kwargs) 107 | else: 108 | dataset_kwargs = self.hparams.dataset_kwargs.copy() 109 | dataset_kwargs["repeat"] = 1 110 | ds = SALADDataset(**dataset_kwargs) 111 | setattr(self, f"data_{stage}", ds) 112 | return ds 113 | 114 | def _build_dataloader(self, stage): 115 | try: 116 | ds = getattr(self, f"data_{stage}") 117 | except: 118 | ds = self._build_dataset(stage) 119 | 120 | return torch.utils.data.DataLoader( 121 | ds, 122 | batch_size=self.hparams.batch_size, 123 | shuffle=stage == "train", 124 | drop_last=stage == "train", 125 | num_workers=4, 126 | ) 127 | 128 | def train_dataloader(self): 129 | return self._build_dataloader("train") 130 | 131 | def val_dataloader(self): 132 | return self._build_dataloader("val") 133 | 134 | def test_dataloader(self): 135 | return self._build_dataloader("test") 136 | 137 | def configure_optimizers(self): 138 | optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.lr) 139 | scheduler = PolyDecayScheduler(optimizer, self.hparams.lr, power=0.999) 140 | return [optimizer], [scheduler] 141 | 142 | #TODO move get_wandb_logger to logutil.py 143 | def get_wandb_logger(self): 144 | for logger in self.logger: 145 | if isinstance(logger, pl.loggers.wandb.WandbLogger): 146 | return logger 147 | return None 148 | -------------------------------------------------------------------------------- /salad/models/phase1.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from salad.models.base_model import BaseModel 4 | from salad.utils import nputil, thutil 5 | from salad.utils.spaghetti_util import clip_eigenvalues, project_eigenvectors 6 | 7 | class Phase1Model(BaseModel): 8 | def __init__(self, network, variance_schedule, **kwargs): 9 | super().__init__(network, variance_schedule, **kwargs) 10 | 11 | @torch.no_grad() 12 | def sample( 13 | self, 14 | batch_size=0, 15 | return_traj=False, 16 | ): 17 | x_T = torch.randn([batch_size, 16, 16]).to(self.device) 18 | 19 | traj = {self.var_sched.num_steps: x_T} 20 | for t in range(self.var_sched.num_steps, 0, -1): 21 | z = torch.randn_like(x_T) if t > 1 else torch.zeros_like(x_T) 22 | alpha = self.var_sched.alphas[t] 23 | alpha_bar = self.var_sched.alpha_bars[t] 24 | sigma = self.var_sched.get_sigmas(t, flexibility=0) 25 | 26 | c0 = 1.0 / torch.sqrt(alpha) 27 | c1 = (1 - alpha) / torch.sqrt(1 - alpha_bar) 28 | 29 | x_t = traj[t] 30 | 31 | beta = self.var_sched.betas[[t] * batch_size] 32 | e_theta = self.net(x_t, beta=beta) 33 | # print(e_theta.norm(-1).mean()) 34 | 35 | x_next = c0 * (x_t - c1 * e_theta) + sigma * z 36 | traj[t - 1] = x_next.detach() 37 | 38 | traj[t] = traj[t].cpu() 39 | 40 | if not return_traj: 41 | del traj[t] 42 | if return_traj: 43 | return traj 44 | else: 45 | return traj[0] 46 | 47 | def sampling_gaussians(self, num_shapes): 48 | """ 49 | Return: 50 | ldm_gaus: np.ndarray 51 | gt_gaus: np.ndarray 52 | """ 53 | ldm_gaus = self.sample(num_shapes) 54 | 55 | if self.hparams.get("global_normalization"): 56 | if not hasattr(self, "data_val"): 57 | self._build_dataset("val") 58 | if self.hparams.get("global_normalization") == "partial": 59 | ldm_gaus = self.data_val.unnormalize_global_static(ldm_gaus, slice(12,None)) 60 | elif self.hparams.get("global_normalization") == "all": 61 | ldm_gaus = self.data_val.unnormalize_global_static(ldm_gaus, slice(None)) 62 | 63 | ldm_gaus = clip_eigenvalues(ldm_gaus) 64 | ldm_gaus = project_eigenvectors(ldm_gaus) 65 | return ldm_gaus 66 | -------------------------------------------------------------------------------- /salad/spaghetti/.gitignore: -------------------------------------------------------------------------------- 1 | /assets/* 2 | !/assets/readme_resources/ 3 | !/assets/ui_resources/ 4 | !/assets/splits/ 5 | !/assets/mesh/ 6 | *.vtk 7 | .idea/ 8 | __pycache__/ 9 | **_ig_** 10 | -------------------------------------------------------------------------------- /salad/spaghetti/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Amir Hertz 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /salad/spaghetti/README.md: -------------------------------------------------------------------------------- 1 | 2 | ### SPAGHETTI: Editing Implicit Shapes through Part Aware Generation 3 | 4 | ![](./assets/readme_resources/teaser-01.png) 5 | 6 | 7 | ### Installation 8 | 9 | ``` 10 | git clone https://github.com/amirhertz/spaghetti && cd spaghetti 11 | conda env create -f environment.yml 12 | conda activate spaghetti 13 | ``` 14 | 15 | Install [Pytorch](https://pytorch.org/). The installation during development and testing was pytorch==1.9.0 cudatoolkit=11.1 16 | 17 | 18 | ### Demo 19 | - Download pre-trained models 20 | ``` 21 | python download_weights.py 22 | ``` 23 | - Run demo 24 | ``` 25 | python demo.py --model_name chairs_large --shape_dir samples 26 | ``` 27 | or 28 | ``` 29 | python demo.py --model_name airplanes --shape_dir samples 30 | ``` 31 | 32 | - User controls 33 | - Select shapes from the collection on bottom. 34 | - right click to select / deselect parts 35 | - Click the pencil button will toggle between selection / deselection. 36 | - Transform selected parts is similar to Blender short-keys. 37 | Pressing 'G' / 'R', will start translation / rotation mode. Toggle axis by pressing 'X' / 'Y' / 'Z'. Press 'Esc' to cancel transform. 38 | - Click the broom to reset. 39 | 40 | 41 | ### Adding shapes to demo 42 | - From training data 43 | ``` 44 | python shape_inversion.py --model_name --source training --mesh_path --num_samples 45 | ``` 46 | - Random generation 47 | ``` 48 | python shape_inversion.py --model_name --source random --mesh_path --num_samples 49 | ``` 50 | - From existing watertight mesh: 51 | ``` 52 | python shape_inversion.py --model_name --mesh_path 53 | ``` 54 | For example, to add the provided sample chair to the exiting chairs in the demo: 55 | ``` 56 | python shape_inversion.py --model_name chairs_large --mesh_path ./assets/mesh/example.obj 57 | ``` 58 | 59 | ### Training 60 | Coming soon. 61 | 62 | -------------------------------------------------------------------------------- /salad/spaghetti/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KAIST-Visual-AI-Group/SALAD/db66b1b2ce57d8a4beb5bd610bc006ccaff43b1e/salad/spaghetti/__init__.py -------------------------------------------------------------------------------- /salad/spaghetti/assets/checkpoints/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KAIST-Visual-AI-Group/SALAD/db66b1b2ce57d8a4beb5bd610bc006ccaff43b1e/salad/spaghetti/assets/checkpoints/.gitkeep -------------------------------------------------------------------------------- /salad/spaghetti/assets/checkpoints/README.md: -------------------------------------------------------------------------------- 1 | Put SPAGHETTI checkpoint direcotires under this directory. For instance, the full path to `spaghetti_chairs_large` directory would be `salad/salad/spaghetti/assets/checkpoints/spaghetti_chairs_large`. 2 | -------------------------------------------------------------------------------- /salad/spaghetti/assets/readme_resources/teaser-01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KAIST-Visual-AI-Group/SALAD/db66b1b2ce57d8a4beb5bd610bc006ccaff43b1e/salad/spaghetti/assets/readme_resources/teaser-01.png -------------------------------------------------------------------------------- /salad/spaghetti/assets/ui_resources/icons-03.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KAIST-Visual-AI-Group/SALAD/db66b1b2ce57d8a4beb5bd610bc006ccaff43b1e/salad/spaghetti/assets/ui_resources/icons-03.png -------------------------------------------------------------------------------- /salad/spaghetti/assets/ui_resources/icons-04.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KAIST-Visual-AI-Group/SALAD/db66b1b2ce57d8a4beb5bd610bc006ccaff43b1e/salad/spaghetti/assets/ui_resources/icons-04.png -------------------------------------------------------------------------------- /salad/spaghetti/assets/ui_resources/icons-05.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KAIST-Visual-AI-Group/SALAD/db66b1b2ce57d8a4beb5bd610bc006ccaff43b1e/salad/spaghetti/assets/ui_resources/icons-05.png -------------------------------------------------------------------------------- /salad/spaghetti/assets/ui_resources/icons-06.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KAIST-Visual-AI-Group/SALAD/db66b1b2ce57d8a4beb5bd610bc006ccaff43b1e/salad/spaghetti/assets/ui_resources/icons-06.png -------------------------------------------------------------------------------- /salad/spaghetti/assets/ui_resources/icons-10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KAIST-Visual-AI-Group/SALAD/db66b1b2ce57d8a4beb5bd610bc006ccaff43b1e/salad/spaghetti/assets/ui_resources/icons-10.png -------------------------------------------------------------------------------- /salad/spaghetti/assets/ui_resources/icons-11.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KAIST-Visual-AI-Group/SALAD/db66b1b2ce57d8a4beb5bd610bc006ccaff43b1e/salad/spaghetti/assets/ui_resources/icons-11.png -------------------------------------------------------------------------------- /salad/spaghetti/assets/ui_resources/icons-12.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KAIST-Visual-AI-Group/SALAD/db66b1b2ce57d8a4beb5bd610bc006ccaff43b1e/salad/spaghetti/assets/ui_resources/icons-12.png -------------------------------------------------------------------------------- /salad/spaghetti/assets/ui_resources/icons-13.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KAIST-Visual-AI-Group/SALAD/db66b1b2ce57d8a4beb5bd610bc006ccaff43b1e/salad/spaghetti/assets/ui_resources/icons-13.png -------------------------------------------------------------------------------- /salad/spaghetti/assets/ui_resources/icons-14.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KAIST-Visual-AI-Group/SALAD/db66b1b2ce57d8a4beb5bd610bc006ccaff43b1e/salad/spaghetti/assets/ui_resources/icons-14.png -------------------------------------------------------------------------------- /salad/spaghetti/constants.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | import sys 4 | from salad.utils.paths import SPAGHETTI_DIR 5 | 6 | IS_WINDOWS = sys.platform == 'win32' 7 | get_trace = getattr(sys, 'gettrace', None) 8 | DEBUG = get_trace is not None and get_trace() is not None 9 | EPSILON = 1e-4 10 | DIM = 3 11 | # PROJECT_ROOT = str(Path(os.path.realpath(__file__)).parents[0]) 12 | PROJECT_ROOT = str(SPAGHETTI_DIR) 13 | DATA_ROOT = f'{PROJECT_ROOT}/assets/' 14 | CHECKPOINTS_ROOT = f'{DATA_ROOT}checkpoints/' 15 | CACHE_ROOT = f'{DATA_ROOT}cache/' 16 | UI_OUT = f'{DATA_ROOT}ui_export/' 17 | UI_RESOURCES = f'{DATA_ROOT}/ui_resources/' 18 | Shapenet_WT = f'{DATA_ROOT}/ShapeNetCore_wt/' 19 | Shapenet = f'{DATA_ROOT}/ShapeNetCore.v2/' 20 | MAX_VS = 100000 21 | MAX_GAUSIANS = 32 22 | 23 | -------------------------------------------------------------------------------- /salad/spaghetti/custom_types.py: -------------------------------------------------------------------------------- 1 | # import open3d 2 | import enum 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as nnf 7 | # from .constants import DEBUG 8 | from typing import Tuple, List, Union, Callable, Type, Iterator, Dict, Set, Optional, Any, Sized, Iterable 9 | from types import DynamicClassAttribute 10 | from enum import Enum, unique 11 | import torch.optim.optimizer 12 | import torch.utils.data 13 | 14 | # if DEBUG: 15 | # seed = 99 16 | # torch.manual_seed(seed) 17 | # np.random.seed(seed) 18 | 19 | N = type(None) 20 | V = np.array 21 | ARRAY = np.ndarray 22 | ARRAYS = Union[Tuple[ARRAY, ...], List[ARRAY]] 23 | VS = Union[Tuple[V, ...], List[V]] 24 | VN = Union[V, N] 25 | VNS = Union[VS, N] 26 | T = torch.Tensor 27 | TS = Union[Tuple[T, ...], List[T]] 28 | TN = Optional[T] 29 | TNS = Union[Tuple[TN, ...], List[TN]] 30 | TSN = Optional[TS] 31 | TA = Union[T, ARRAY] 32 | 33 | V_Mesh = Tuple[ARRAY, ARRAY] 34 | T_Mesh = Tuple[T, Optional[T]] 35 | T_Mesh_T = Union[T_Mesh, T] 36 | COLORS = Union[T, ARRAY, Tuple[int, int, int]] 37 | 38 | D = torch.device 39 | CPU = torch.device('cpu') 40 | 41 | 42 | def get_device(device_id: int) -> D: 43 | if not torch.cuda.is_available(): 44 | return CPU 45 | device_id = min(torch.cuda.device_count() - 1, device_id) 46 | return torch.device(f'cuda:{device_id}') 47 | 48 | 49 | CUDA = get_device 50 | Optimizer = torch.optim.Adam 51 | Dataset = torch.utils.data.Dataset 52 | DataLoader = torch.utils.data.DataLoader 53 | Subset = torch.utils.data.Subset 54 | -------------------------------------------------------------------------------- /salad/spaghetti/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KAIST-Visual-AI-Group/SALAD/db66b1b2ce57d8a4beb5bd610bc006ccaff43b1e/salad/spaghetti/models/__init__.py -------------------------------------------------------------------------------- /salad/spaghetti/models/gm_utils.py: -------------------------------------------------------------------------------- 1 | from ..custom_types import * 2 | 3 | 4 | def get_gm_support(gm, x): 5 | dim = x.shape[-1] 6 | mu, p, phi, eigen = gm 7 | sigma_det = eigen.prod(-1) 8 | eigen_inv = 1 / eigen 9 | sigma_inverse = torch.matmul(p.transpose(3, 4), p * eigen_inv[:, :, :, :, None]).squeeze(1) 10 | phi = torch.softmax(phi, dim=2) 11 | const_1 = phi / torch.sqrt((2 * np.pi) ** dim * sigma_det) 12 | distance = x[:, :, None, :] - mu 13 | mahalanobis_distance = - .5 * torch.einsum('bngd,bgdc,bngc->bng', distance, sigma_inverse, distance) 14 | const_2, _ = mahalanobis_distance.max(dim=2) # for numeric stability 15 | mahalanobis_distance -= const_2[:, :, None] 16 | support = const_1 * torch.exp(mahalanobis_distance) 17 | return support, const_2 18 | 19 | 20 | def gm_log_likelihood_loss(gms: TS, x: T, get_supports: bool = False, 21 | mask: Optional[T] = None, reduction: str = "mean") -> Union[T, Tuple[T, TS]]: 22 | 23 | batch_size, num_points, dim = x.shape 24 | support, const = get_gm_support(gms, x) 25 | probs = torch.log(support.sum(dim=2)) + const 26 | if mask is not None: 27 | probs = probs.masked_select(mask=mask.flatten()) 28 | if reduction == 'none': 29 | likelihood = probs.sum(-1) 30 | loss = - likelihood / num_points 31 | else: 32 | likelihood = probs.sum() 33 | loss = - likelihood / (probs.shape[0] * probs.shape[1]) 34 | if get_supports: 35 | return loss, support 36 | return loss 37 | 38 | 39 | def split_mesh_by_gmm(mesh: T_Mesh, gmm) -> Dict[int, T]: 40 | faces_split = {} 41 | vs, faces = mesh 42 | vs_mid_faces = vs[faces].mean(1) 43 | _, supports = gm_log_likelihood_loss(gmm, vs_mid_faces.unsqueeze(0), get_supports=True) 44 | supports = supports[0] 45 | label = supports.argmax(1) 46 | for i in range(gmm[1].shape[2]): 47 | select = label.eq(i) 48 | if select.any(): 49 | faces_split[i] = faces[select] 50 | else: 51 | faces_split[i] = None 52 | return faces_split 53 | 54 | 55 | def flatten_gmm(gmm: TS) -> T: 56 | b, gp, g, _ = gmm[0].shape 57 | mu, p, phi, eigen = [item.view(b, gp * g, *item.shape[3:]) for item in gmm] 58 | p = p.reshape(*p.shape[:2], -1) 59 | z_gmm = torch.cat((mu, p, phi.unsqueeze(-1), eigen), dim=2) 60 | return z_gmm 61 | -------------------------------------------------------------------------------- /salad/spaghetti/models/mlp_models.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | -------------------------------------------------------------------------------- /salad/spaghetti/models/models_utils.py: -------------------------------------------------------------------------------- 1 | from ..custom_types import * 2 | from abc import ABC 3 | import math 4 | 5 | 6 | def torch_no_grad(func): 7 | def wrapper(*args, **kwargs): 8 | with torch.no_grad(): 9 | result = func(*args, **kwargs) 10 | return result 11 | 12 | return wrapper 13 | 14 | 15 | class Model(nn.Module, ABC): 16 | 17 | def __init__(self): 18 | super(Model, self).__init__() 19 | self.save_model: Union[None, Callable[[nn.Module]]] = None 20 | 21 | def save(self, **kwargs): 22 | self.save_model(self, **kwargs) 23 | 24 | 25 | class Concatenate(nn.Module): 26 | def __init__(self, dim): 27 | super(Concatenate, self).__init__() 28 | self.dim = dim 29 | 30 | def forward(self, x): 31 | return torch.cat(x, dim=self.dim) 32 | 33 | 34 | class View(nn.Module): 35 | 36 | def __init__(self, *shape): 37 | super(View, self).__init__() 38 | self.shape = shape 39 | 40 | def forward(self, x): 41 | return x.view(*self.shape) 42 | 43 | 44 | class Transpose(nn.Module): 45 | 46 | def __init__(self, dim0, dim1): 47 | super(Transpose, self).__init__() 48 | self.dim0, self.dim1 = dim0, dim1 49 | 50 | def forward(self, x): 51 | return x.transpose(self.dim0, self.dim1) 52 | 53 | 54 | class Dummy(nn.Module): 55 | 56 | def __init__(self, *args): 57 | super(Dummy, self).__init__() 58 | 59 | def forward(self, *args): 60 | return args[0] 61 | 62 | 63 | class SineLayer(nn.Module): 64 | """ 65 | From the siren repository 66 | https://colab.research.google.com/github/vsitzmann/siren/blob/master/explore_siren.ipynb 67 | """ 68 | def __init__(self, in_features, out_features, bias=True, 69 | is_first=False, omega_0=30): 70 | super().__init__() 71 | self.omega_0 = omega_0 72 | self.is_first = is_first 73 | 74 | self.in_features = in_features 75 | self.linear = nn.Linear(in_features, out_features, bias=bias) 76 | self.output_channels = out_features 77 | self.init_weights() 78 | 79 | def init_weights(self): 80 | with torch.no_grad(): 81 | if self.is_first: 82 | self.linear.weight.uniform_(-1 / self.in_features, 83 | 1 / self.in_features) 84 | else: 85 | self.linear.weight.uniform_(-np.sqrt(6 / self.in_features) / self.omega_0, 86 | np.sqrt(6 / self.in_features) / self.omega_0) 87 | 88 | def forward(self, input): 89 | return torch.sin(self.omega_0 * self.linear(input)) 90 | 91 | 92 | class MLP(nn.Module): 93 | 94 | def forward(self, x, *_): 95 | return self.net(x) 96 | 97 | def __init__(self, ch: Union[List[int], Tuple[int, ...]], act: nn.Module = nn.ReLU, 98 | weight_norm=False): 99 | super(MLP, self).__init__() 100 | layers = [] 101 | for i in range(len(ch) - 1): 102 | layers.append(nn.Linear(ch[i], ch[i + 1])) 103 | if weight_norm: 104 | layers[-1] = nn.utils.weight_norm(layers[-1]) 105 | if i < len(ch) - 2: 106 | layers.append(act(True)) 107 | self.net = nn.Sequential(*layers) 108 | 109 | 110 | class GMAttend(nn.Module): 111 | 112 | def __init__(self, hidden_dim: int): 113 | super(GMAttend, self).__init__() 114 | self.key_dim = hidden_dim // 8 115 | self.query_w = nn.Linear(hidden_dim, self.key_dim) 116 | self.key_w = nn.Linear(hidden_dim, self.key_dim) 117 | self.value_w = nn.Linear(hidden_dim, hidden_dim) 118 | self.softmax = nn.Softmax(dim=3) 119 | self.gamma = nn.Parameter(torch.zeros(1)) 120 | self.scale = 1 / torch.sqrt(torch.tensor(self.key_dim, dtype=torch.float32)) 121 | 122 | def forward(self, x): 123 | queries = self.query_w(x) 124 | keys = self.key_w(x) 125 | vals = self.value_w(x) 126 | attention = self.softmax(torch.einsum('bgqf,bgkf->bgqk', queries, keys)) 127 | out = torch.einsum('bgvf,bgqv->bgqf', vals, attention) 128 | out = self.gamma * out + x 129 | return out 130 | 131 | 132 | def recursive_to(item, device): 133 | if type(item) is T: 134 | return item.to(device) 135 | elif type(item) is tuple or type(item) is list: 136 | return [recursive_to(item[i], device) for i in range(len(item))] 137 | return item 138 | 139 | -------------------------------------------------------------------------------- /salad/spaghetti/models/occ_loss.py: -------------------------------------------------------------------------------- 1 | from ..custom_types import * 2 | 3 | 4 | def occupancy_bce(predict: T, winding_gt: T, ignore: Optional[T] = None, *args) -> T: 5 | if winding_gt.dtype is not torch.bool: 6 | winding_gt = winding_gt.gt(0) 7 | labels = winding_gt.flatten().float() 8 | predict = predict.flatten() 9 | if ignore is not None: 10 | ignore = (~ignore).flatten().float() 11 | loss = nnf.binary_cross_entropy_with_logits(predict, labels, weight=ignore) 12 | return loss 13 | 14 | 15 | def reg_z_loss(z: T) -> T: 16 | norms = z.norm(2, 1) 17 | loss = norms.mean() 18 | return loss 19 | -------------------------------------------------------------------------------- /salad/spaghetti/models/transformer.py: -------------------------------------------------------------------------------- 1 | from ..custom_types import * 2 | 3 | 4 | class FeedForward(nn.Module): 5 | def __init__(self, in_dim, h_dim, out_d: Optional[int] = None, act=nnf.relu, dropout=0.): 6 | super().__init__() 7 | out_d = out_d if out_d is not None else in_dim 8 | self.fc1 = nn.Linear(in_dim, h_dim) 9 | self.act = act 10 | self.fc2 = nn.Linear(h_dim, out_d) 11 | self.dropout = nn.Dropout(dropout) 12 | 13 | def forward(self, x): 14 | x = self.fc1(x) 15 | x = self.act(x) 16 | x = self.dropout(x) 17 | x = self.fc2(x) 18 | x = self.dropout(x) 19 | return x 20 | 21 | 22 | class MultiHeadAttention(nn.Module): 23 | 24 | def __init__(self, dim_self, dim_ref, num_heads, bias=True, dropout=0.): 25 | super().__init__() 26 | self.num_heads = num_heads 27 | head_dim = dim_self // num_heads 28 | self.scale = head_dim ** -0.5 29 | self.to_queries = nn.Linear(dim_self, dim_self, bias=bias) 30 | self.to_keys_values = nn.Linear(dim_ref, dim_self * 2, bias=bias) 31 | self.project = nn.Linear(dim_self, dim_self) 32 | self.dropout = nn.Dropout(dropout) 33 | 34 | def forward_interpolation(self, queries: T, keys: T, values: T, alpha: T, mask: TN = None) -> TNS: 35 | attention = torch.einsum('nhd,bmhd->bnmh', queries[0], keys) * self.scale 36 | if mask is not None: 37 | if mask.dim() == 2: 38 | mask = mask.unsqueeze(1) 39 | attention = attention.masked_fill(mask.unsqueeze(3), float("-inf")) 40 | attention = attention.softmax(dim=2) 41 | attention = attention * alpha[:, None, None, None] 42 | out = torch.einsum('bnmh,bmhd->nhd', attention, values).reshape(1, attention.shape[1], -1) 43 | return out, attention 44 | 45 | def forward(self, x, y: Optional[T] = None, mask: Optional[T] = None, alpha: TN = None): 46 | y = y if y is not None else x 47 | b_a, n, c = x.shape 48 | b, m, d = y.shape 49 | # b n h dh 50 | queries = self.to_queries(x).reshape(b_a, n, self.num_heads, c // self.num_heads) 51 | # b m 2 h dh 52 | keys_values = self.to_keys_values(y).reshape(b, m, 2, self.num_heads, c // self.num_heads) 53 | keys, values = keys_values[:, :, 0], keys_values[:, :, 1] 54 | if alpha is not None: 55 | out, attention = self.forward_interpolation(queries, keys, values, alpha, mask) 56 | else: 57 | attention = torch.einsum('bnhd,bmhd->bnmh', queries, keys) * self.scale 58 | if mask is not None: 59 | if mask.dim() == 2: 60 | mask = mask.unsqueeze(1) 61 | attention = attention.masked_fill(mask.unsqueeze(3), float("-inf")) 62 | attention = attention.softmax(dim=2) 63 | out = torch.einsum('bnmh,bmhd->bnhd', attention, values).reshape(b, n, c) 64 | out = self.project(out) 65 | return out, attention 66 | 67 | 68 | class TransformerLayer(nn.Module): 69 | 70 | def forward_with_attention(self, x, y: Optional[T] = None, mask: Optional[T] = None, alpha: TN = None): 71 | x_, attention = self.attn(self.norm1(x), y, mask, alpha) 72 | x = x + x_ 73 | x = x + self.mlp(self.norm2(x)) 74 | return x, attention 75 | 76 | def forward(self, x, y: Optional[T] = None, mask: Optional[T] = None, alpha: TN = None): 77 | x = x + self.attn(self.norm1(x), y, mask, alpha)[0] 78 | x = x + self.mlp(self.norm2(x)) 79 | return x 80 | 81 | def __init__(self, dim_self, dim_ref, num_heads, mlp_ratio=4., bias=False, dropout=0., act=nnf.relu, 82 | norm_layer: nn.Module = nn.LayerNorm): 83 | super().__init__() 84 | self.norm1 = norm_layer(dim_self) 85 | self.attn = MultiHeadAttention(dim_self, dim_ref, num_heads, bias=bias, dropout=dropout) 86 | self.norm2 = norm_layer(dim_self) 87 | self.mlp = FeedForward(dim_self, int(dim_self * mlp_ratio), act=act, dropout=dropout) 88 | 89 | 90 | class DummyTransformer: 91 | 92 | @staticmethod 93 | def forward_with_attention(x, *_, **__): 94 | return x, [] 95 | 96 | @staticmethod 97 | def forward(x, *_, **__): 98 | return x 99 | 100 | 101 | class Transformer(nn.Module): 102 | 103 | def forward_with_attention(self, x, y: Optional[T] = None, mask: Optional[T] = None, alpha: TN = None): 104 | attentions = [] 105 | for layer in self.layers: 106 | x, att = layer.forward_with_attention(x, y, mask, alpha) 107 | attentions.append(att) 108 | return x, attentions 109 | 110 | def forward(self, x, y: TN = None, mask: TN = None, alpha: TN = None): 111 | for layer in self.layers: 112 | x = layer(x, y, mask, alpha) 113 | return x 114 | 115 | def __init__(self, dim_self: int, num_heads: int, num_layers: int, dim_ref: Optional[int] = None, 116 | mlp_ratio: float = 2., act=nnf.relu, norm_layer: nn.Module = nn.LayerNorm): 117 | super(Transformer, self).__init__() 118 | dim_ref = dim_ref if dim_ref is not None else dim_self 119 | self.layers = nn.ModuleList([TransformerLayer(dim_self, dim_ref, num_heads, mlp_ratio, act=act, 120 | norm_layer=norm_layer) for _ in range(num_layers)]) 121 | -------------------------------------------------------------------------------- /salad/spaghetti/options.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | import os 3 | import pickle 4 | import salad.spaghetti.constants as const 5 | from salad.spaghetti.custom_types import * 6 | 7 | 8 | class Options: 9 | 10 | def load(self): 11 | device = self.device 12 | if os.path.isfile(self.save_path): 13 | print(f'loading opitons from {self.save_path}') 14 | with open(self.save_path, 'rb') as f: 15 | options = pickle.load(f) 16 | options.device = device 17 | return options 18 | return self 19 | 20 | def save(self): 21 | if os.path.isdir(self.cp_folder): 22 | # self.already_saved = True 23 | with open(self.save_path, 'wb') as f: 24 | pickle.dump(self, f, pickle.HIGHEST_PROTOCOL) 25 | 26 | @property 27 | def info(self) -> str: 28 | return f'{self.model_name}_{self.tag}' 29 | 30 | @property 31 | def cp_folder(self): 32 | return f'{const.CHECKPOINTS_ROOT}{self.info}' 33 | 34 | @property 35 | def save_path(self): 36 | return f'{const.CHECKPOINTS_ROOT}{self.info}/options.pkl' 37 | 38 | def fill_args(self, args): 39 | for arg in args: 40 | if hasattr(self, arg): 41 | setattr(self, arg, args[arg]) 42 | 43 | def __init__(self, **kwargs): 44 | self.device = CUDA(0) 45 | self.tag = 'airplanes' 46 | self.dataset_name = 'shapenet_airplanes_wm_sphere_sym_train' 47 | self.epochs = 2000 48 | self.model_name = 'spaghetti' 49 | self.dim_z = 256 50 | self.pos_dim = 256 - 3 51 | self.dim_h = 512 52 | self.dim_zh = 512 53 | self.num_gaussians = 16 54 | self.min_split = 4 55 | self.max_split = 12 56 | self.gmm_weight = 1 57 | self.decomposition_network = 'transformer' 58 | self.decomposition_num_layers = 4 59 | self.num_layers = 4 60 | self.num_heads = 4 61 | self.num_layers_head = 6 62 | self.num_heads_head = 8 63 | self.head_occ_size = 5 64 | self.head_occ_type = 'skip' 65 | self.batch_size = 18 66 | self.num_samples = 2000 67 | self.dataset_size = -1 68 | self.symmetric = (True, False, False) 69 | self.data_symmetric = (True, False, False) 70 | self.lr_decay = .9 71 | self.lr_decay_every = 500 72 | self.warm_up = 2000 73 | self.reg_weight = 1e-4 74 | self.disentanglement = True 75 | self.use_encoder = True 76 | self.disentanglement_weight = 1 77 | self.augmentation_rotation = 0.3 78 | self.augmentation_scale = .2 79 | self.augmentation_translation = .3 80 | self.fill_args(kwargs) 81 | -------------------------------------------------------------------------------- /salad/spaghetti/ui/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KAIST-Visual-AI-Group/SALAD/db66b1b2ce57d8a4beb5bd610bc006ccaff43b1e/salad/spaghetti/ui/__init__.py -------------------------------------------------------------------------------- /salad/spaghetti/ui/mock_keyboard.py: -------------------------------------------------------------------------------- 1 | 2 | class Key: 3 | ctrl_l = 'control_l' 4 | 5 | 6 | class Controller: 7 | 8 | @staticmethod 9 | def press(key: str) -> str: 10 | return key 11 | 12 | @staticmethod 13 | def release(key: str) -> str: 14 | return key 15 | -------------------------------------------------------------------------------- /salad/spaghetti/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KAIST-Visual-AI-Group/SALAD/db66b1b2ce57d8a4beb5bd610bc006ccaff43b1e/salad/spaghetti/utils/__init__.py -------------------------------------------------------------------------------- /salad/spaghetti/utils/myparse.py: -------------------------------------------------------------------------------- 1 | from ..custom_types import * 2 | import sys 3 | 4 | 5 | # minimal argparse 6 | def parse(parse_dict: Dict[str, Dict[str, Any]]): 7 | args = sys.argv[1:] 8 | args_dict = {args[i]: args[i + 1] for i in range(0, len(args), 2)} 9 | out_dict = {} 10 | for item in parse_dict: 11 | hint = parse_dict[item]['type'] 12 | if item in args_dict: 13 | out_dict[item[2:]] = hint(args_dict[item]) 14 | else: 15 | out_dict[item[2:]] = parse_dict[item]['default'] 16 | if 'options' in parse_dict[item] and out_dict[item[2:]] not in parse_dict[item]['options']: 17 | raise ValueError(f'Expected {item} to be in {str(parse_dict[item]["options"])}') 18 | return out_dict 19 | -------------------------------------------------------------------------------- /salad/spaghetti/utils/rotation_utils.py: -------------------------------------------------------------------------------- 1 | from .. import constants 2 | import functools 3 | from scipy.spatial.transform.rotation import Rotation 4 | from ..custom_types import * 5 | 6 | 7 | def quat_to_rot(q): 8 | shape = q.shape 9 | q = q.view(-1, 4) 10 | q_sq = 2 * q[:, :, None] * q[:, None, :] 11 | m00 = 1 - q_sq[:, 1, 1] - q_sq[:, 2, 2] 12 | m01 = q_sq[:, 0, 1] - q_sq[:, 2, 3] 13 | m02 = q_sq[:, 0, 2] + q_sq[:, 1, 3] 14 | 15 | m10 = q_sq[:, 0, 1] + q_sq[:, 2, 3] 16 | m11 = 1 - q_sq[:, 0, 0] - q_sq[:, 2, 2] 17 | m12 = q_sq[:, 1, 2] - q_sq[:, 0, 3] 18 | 19 | m20 = q_sq[:, 0, 2] - q_sq[:, 1, 3] 20 | m21 = q_sq[:, 1, 2] + q_sq[:, 0, 3] 21 | m22 = 1 - q_sq[:, 0, 0] - q_sq[:, 1, 1] 22 | r = torch.stack((m00, m01, m02, m10, m11, m12, m20, m21, m22), dim=1) 23 | r = r.view(*shape[:-1], 3, 3) 24 | return r 25 | 26 | 27 | def rot_to_quat(r): 28 | shape = r.shape 29 | r = r.view(-1, 3, 3) 30 | qw = .5 * (1 + r[:, 0, 0] + r[:, 1, 1] + r[:, 2, 2]).sqrt() 31 | qx = (r[:, 2, 1] - r[:, 1, 2]) / (4 * qw) 32 | qy = (r[:, 0, 2] - r[:, 2, 0]) / (4 * qw) 33 | qz = (r[:, 1, 0] - r[:, 0, 1]) / (4 * qw) 34 | q = torch.stack((qx, qy, qz, qw), -1) 35 | q = q.view(*shape[:-2], 4) 36 | return q 37 | 38 | 39 | @functools.lru_cache(10) 40 | def get_rotation_matrix(theta: float, axis: float, degree: bool = False) -> ARRAY: 41 | if degree: 42 | theta = theta * np.pi / 180 43 | rotate_mat = np.eye(3) 44 | rotate_mat[axis, axis] = 1 45 | cos_theta, sin_theta = np.cos(theta), np.sin(theta) 46 | rotate_mat[(axis + 1) % 3, (axis + 1) % 3] = cos_theta 47 | rotate_mat[(axis + 2) % 3, (axis + 2) % 3] = cos_theta 48 | rotate_mat[(axis + 1) % 3, (axis + 2) % 3] = sin_theta 49 | rotate_mat[(axis + 2) % 3, (axis + 1) % 3] = -sin_theta 50 | return rotate_mat 51 | 52 | 53 | def get_random_rotation(batch_size: int) -> T: 54 | r = Rotation.random(batch_size).as_matrix().astype(np.float32) 55 | Rotation.random() 56 | return torch.from_numpy(r) 57 | 58 | 59 | def rand_bounded_rotation_matrix(cache_size: int, theta_range: float = .1): 60 | 61 | def create_cache(): 62 | # from http://www.realtimerendering.com/resources/GraphicsGems/gemsiii/rand_rotation.c 63 | with torch.no_grad(): 64 | theta, phi, z = torch.rand(cache_size, 3).split((1, 1, 1), dim=1) 65 | theta = (2 * theta - 1) * theta_range + 1 66 | theta = np.pi * theta # Rotation about the pole (Z). 67 | phi = phi * 2 * np.pi # For direction of pole deflection. 68 | z = 2 * z * theta_range # For magnitude of pole deflection. 69 | r = z.sqrt() 70 | v = torch.cat((torch.sin(phi) * r, torch.cos(phi) * r, torch.sqrt(2.0 - z)), dim=1) 71 | st = torch.sin(theta).squeeze(1) 72 | ct = torch.cos(theta).squeeze(1) 73 | rot_ = torch.zeros(cache_size, 3, 3) 74 | rot_[:, 0, 0] = ct 75 | rot_[:, 1, 1] = ct 76 | rot_[:, 0, 1] = st 77 | rot_[:, 1, 0] = -st 78 | rot_[:, 2, 2] = 1 79 | rot = (torch.einsum('ba,bd->bad', v, v) - torch.eye(3)[None, :, :]).bmm(rot_) 80 | det = rot.det() 81 | assert (det.gt(0.99) * det.lt(1.0001)).all().item() 82 | return rot 83 | 84 | def get_batch_rot(batch_size): 85 | nonlocal cache 86 | select = torch.randint(cache_size, size=(batch_size,)) 87 | return cache[select] 88 | 89 | cache = create_cache() 90 | 91 | return get_batch_rot 92 | 93 | 94 | def transform_rotation(points: T, one_axis=False, max_angle=-1): 95 | r = get_random_rotation(one_axis, max_angle) 96 | transformed = torch.einsum('nd,rd->nr', points, r) 97 | return transformed 98 | 99 | 100 | def tb_to_rot(abc: T) -> T: 101 | c, s = torch.cos(abc), torch.sin(abc) 102 | aa = c[:, 0] * c[:, 1] 103 | ab = c[:, 0] * s[:, 1] * s[:, 2] - c[:, 2] * s[:, 0] 104 | ac = s[:, 0] * s[:, 2] + c[:, 0] * c[:, 2] * s[:, 1] 105 | 106 | ba = c[:, 1] * s[:, 0] 107 | bb = c[:, 0] * c[:, 2] + s.prod(-1) 108 | bc = c[:, 2] * s[:, 0] * s[:, 1] - c[:, 0] * s[:, 2] 109 | 110 | ca = -s[:, 1] 111 | cb = c[:, 1] * s[:, 2] 112 | cc = c[:, 1] * c[:, 2] 113 | return torch.stack((aa, ab, ac, ba, bb, bc, ca, cb, cc), 1).view(-1, 3, 3) 114 | 115 | 116 | def rot_to_tb(rot: T) -> T: 117 | sy = torch.sqrt(rot[:, 0, 0] * rot[:, 0, 0] + rot[:, 1, 0] * rot[:, 1, 0]) 118 | out = torch.zeros(rot.shape[0], 3, device = rot.device) 119 | mask = sy.gt(1e-6) 120 | z = torch.atan2(rot[mask, 2, 1], rot[mask, 2, 2]) 121 | y = torch.atan2(-rot[mask, 2, 0], sy[mask]) 122 | x = torch.atan2(rot[mask, 1, 0], rot[mask, 0, 0]) 123 | out[mask] = torch.stack((x, y, z), dim=1) 124 | if not mask.all(): 125 | mask = ~mask 126 | z = torch.atan2(-rot[mask, 1, 2], rot[mask, 1, 1]) 127 | y = torch.atan2(-rot[mask, 2, 0], sy[mask]) 128 | x = torch.zeros(x.shape) 129 | out[mask] = torch.stack((x, y, z), dim=1) 130 | return out 131 | 132 | 133 | def apply_gmm_affine(gmms: TS, affine: T): 134 | mu, p, phi, eigen = gmms 135 | if affine.dim() == 2: 136 | affine = affine.unsqueeze(0).expand(mu.shape[0], *affine.shape) 137 | mu_r = torch.einsum('bad, bpnd->bpna', affine, mu) 138 | p_r = torch.einsum('bad, bpncd->bpnca', affine, p) 139 | return mu_r, p_r, phi, eigen 140 | 141 | 142 | def get_reflection(reflect_axes: Tuple[bool, ...]) -> T: 143 | reflect = torch.eye(constants.DIM) 144 | for i in range(constants.DIM): 145 | if reflect_axes[i]: 146 | reflect[i, i] = -1 147 | return reflect 148 | 149 | 150 | def get_tait_bryan_from_p(p: T) -> T: 151 | # p = p.squeeze(1) 152 | shape = p.shape 153 | rot = p.reshape(-1, 3, 3).permute(0, 2, 1) 154 | angles = rot_to_tb(rot) 155 | angles = angles / np.pi 156 | angles[:, 1] = angles[:, 1] * 2 157 | angles = angles.view(*shape[:2], 3) 158 | return angles 159 | -------------------------------------------------------------------------------- /salad/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KAIST-Visual-AI-Group/SALAD/db66b1b2ce57d8a4beb5bd610bc006ccaff43b1e/salad/utils/__init__.py -------------------------------------------------------------------------------- /salad/utils/logutil.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KAIST-Visual-AI-Group/SALAD/db66b1b2ce57d8a4beb5bd610bc006ccaff43b1e/salad/utils/logutil.py -------------------------------------------------------------------------------- /salad/utils/nputil.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | def np2th(ndarray): 5 | if isinstance(ndarray, torch.Tensor): 6 | return ndarray.detach().cpu() 7 | elif isinstance(ndarray, np.ndarray): 8 | return torch.tensor(ndarray).float() 9 | else: 10 | raise ValueError("Input should be either torch.Tensor or np.ndarray") 11 | -------------------------------------------------------------------------------- /salad/utils/paths.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import os 3 | 4 | """ project top dir: salad """ 5 | PROJECT_DIR = Path(os.path.realpath(__file__)).parents[2] 6 | SALAD_DIR = PROJECT_DIR / "salad" 7 | SPAGHETTI_DIR = SALAD_DIR / "spaghetti" 8 | DATA_DIR = PROJECT_DIR / "data" 9 | 10 | if __name__ == "__main__": 11 | print(PROJECT_DIR) 12 | -------------------------------------------------------------------------------- /salad/utils/shapeglot_util.py: -------------------------------------------------------------------------------- 1 | # References: https://github.com/optas/shapeglot 2 | # https://github.com/63days/PartGlot. 3 | 4 | from six.moves import cPickle 5 | 6 | 7 | def unpickle_data(file_name, python2_to_3=False): 8 | """Restore data previously saved with pickle_data(). 9 | :param file_name: file holding the pickled data. 10 | :param python2_to_3: (boolean), if True, pickle happened under python2x, unpickling under python3x. 11 | :return: a generator over the un-pickled items. 12 | Note, about implementing the python2_to_3 see 13 | https://stackoverflow.com/questions/28218466/unpickling-a-python-2-object-with-python-3 14 | """ 15 | 16 | in_file = open(file_name, "rb") 17 | if python2_to_3: 18 | size = cPickle.load(in_file, encoding="latin1") 19 | else: 20 | size = cPickle.load(in_file) 21 | 22 | for _ in range(size): 23 | if python2_to_3: 24 | yield cPickle.load(in_file, encoding="latin1") 25 | else: 26 | yield cPickle.load(in_file) 27 | in_file.close() 28 | 29 | 30 | def get_mask_of_game_data( 31 | game_data: DataFrame, 32 | word2int: Dict, 33 | only_correct: bool, 34 | only_easy_context: bool, 35 | max_seq_len: int, 36 | only_one_part_name: bool, 37 | ): 38 | """ 39 | only_correct (if True): mask will be 1 in location iff human listener predicted correctly. 40 | only_easy (if True): uses only easy context examples (more dissimilar triplet chairs) 41 | max_seq_len: drops examples with len(utterance) > max_seq_len 42 | only_one_part_name (if True): uses only utterances describing only one part in the give set. 43 | """ 44 | mask = np.array(game_data.correct) 45 | if not only_correct: 46 | mask = np.ones_like(mask, dtype=np.bool) 47 | 48 | if only_easy_context: 49 | context_mask = np.array(game_data.context_condition == "easy", dtype=np.bool) 50 | mask = np.logical_and(mask, context_mask) 51 | 52 | short_mask = np.array( 53 | game_data.text.apply(lambda x: len(x)) <= max_seq_len, dtype=np.bool 54 | ) 55 | mask = np.logical_and(mask, short_mask) 56 | 57 | part_indicator, part_mask = get_part_indicator(game_data.text, word2int) 58 | if only_one_part_name: 59 | mask = np.logical_and(mask, part_mask) 60 | 61 | return mask, part_indicator 62 | -------------------------------------------------------------------------------- /salad/utils/sysutil.py: -------------------------------------------------------------------------------- 1 | import gc 2 | 3 | import rich 4 | from rich.tree import Tree 5 | from rich.syntax import Syntax 6 | import torch 7 | from omegaconf import DictConfig, OmegaConf 8 | 9 | 10 | def print_config( 11 | config, 12 | fields=( 13 | "trainer", 14 | "model", 15 | "callbacks", 16 | "logger", 17 | "seed", 18 | "name", 19 | ), 20 | resolve: bool = True, 21 | save_config: bool = False, 22 | ) -> None: 23 | """Prints content of DictConfig using Rich library and its tree structure. 24 | Args: 25 | config (DictConfig): Configuration composed by Hydra. 26 | fields (Sequence[str], optional): Determines which main fields from config will 27 | be printed and in what order. 28 | resolve (bool, optional): Whether to resolve reference fields of DictConfig. 29 | """ 30 | 31 | style = "dim" 32 | tree = Tree("CONFIG", style=style, guide_style=style) 33 | 34 | for field in fields: 35 | branch = tree.add(field, style=style, guide_style=style) 36 | 37 | config_section = config.get(field) 38 | branch_content = str(config_section) 39 | if isinstance(config_section, DictConfig): 40 | branch_content = OmegaConf.to_yaml(config_section, resolve=resolve) 41 | 42 | branch.add(Syntax(branch_content, "yaml")) 43 | 44 | rich.print(tree) 45 | 46 | if save_config: 47 | with open("config_tree.log", "w") as fp: 48 | rich.print(tree, file=fp) 49 | 50 | 51 | def clean_gpu(): 52 | gc.collect() 53 | torch.cuda.empty_cache() 54 | -------------------------------------------------------------------------------- /salad/utils/thutil.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | def th2np(tensor): 5 | if isinstance(tensor, np.ndarray): 6 | return tensor 7 | if isinstance(tensor, torch.Tensor): 8 | return tensor.detach().cpu().numpy() 9 | -------------------------------------------------------------------------------- /salad/utils/train_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class PolyDecayScheduler(torch.optim.lr_scheduler.LambdaLR): 5 | def __init__(self, optimizer, init_lr, power=0.99, lr_end=1e-7, last_epoch=-1): 6 | def lr_lambda(step): 7 | lr = max(power**step, lr_end / init_lr) 8 | return lr 9 | 10 | super().__init__(optimizer, lr_lambda, last_epoch=last_epoch) 11 | 12 | 13 | def get_dropout_mask(shape, dropout: float, device): 14 | if dropout == 1: 15 | return torch.zeros(shape, device=device, dtype=torch.bool) 16 | elif dropout == 0: 17 | return torch.ones_like(shape, device=device, dtype=torch.bool) 18 | else: 19 | return torch.zeros(shape, device=device).float().uniform_(0, 1) > dropout 20 | -------------------------------------------------------------------------------- /salad/utils/visutil.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import trimesh 3 | import matplotlib.pyplot as plt 4 | from salad.utils import nputil, thutil, fresnelvis 5 | from PIL import Image 6 | 7 | 8 | def render_pointcloud( 9 | pointcloud, 10 | camPos=np.array([-2, 2, -2]), 11 | camLookat=np.array([0.0, 0.0, 0.0]), 12 | camUp=np.array([0, 1, 0]), 13 | camHeight=2, 14 | resolution=(512, 512), 15 | samples=16, 16 | cloudR=0.006, 17 | ): 18 | pointcloud = thutil.th2np(pointcloud) 19 | img = fresnelvis.renderMeshCloud( 20 | cloud=pointcloud, 21 | camPos=camPos, 22 | camLookat=camLookat, 23 | camUp=camUp, 24 | camHeight=camHeight, 25 | resolution=resolution, 26 | samples=samples, 27 | cloudR=cloudR, 28 | ) 29 | return Image.fromarray(img) 30 | 31 | 32 | def render_mesh( 33 | vert, 34 | face, 35 | camPos=np.array([-2, 2, -2]), 36 | camLookat=np.array([0, 0, 0.0]), 37 | camUp=np.array([0, 1, 0]), 38 | camHeight=2, 39 | resolution=(512, 512), 40 | samples=16, 41 | ): 42 | vert, face = list(map(lambda x: thutil.th2np(x), [vert, face])) 43 | mesh = {"vert": vert, "face": face} 44 | img = fresnelvis.renderMeshCloud( 45 | mesh=mesh, 46 | camPos=camPos, 47 | camLookat=camLookat, 48 | camUp=camUp, 49 | camHeight=camHeight, 50 | resolution=resolution, 51 | samples=samples, 52 | ) 53 | return Image.fromarray(img) 54 | 55 | 56 | def render_gaussians( 57 | gaussians, 58 | is_bspnet=False, 59 | multiplier=1.0, 60 | gaussians_colors=None, 61 | attn_map=None, 62 | camPos=np.array([-2, 2, -2]), 63 | camLookat=np.array([0.0, 0, 0]), 64 | camUp=np.array([0, 1, 0]), 65 | camHeight=2, 66 | resolution=(512, 512), 67 | samples=16, 68 | ): 69 | gaussians = thutil.th2np(gaussians) 70 | N = gaussians.shape[0] 71 | cmap = plt.get_cmap("jet") 72 | 73 | if attn_map is not None: 74 | assert N == attn_map.shape[0] 75 | vmin, vmax = attn_map.min(), attn_map.max() 76 | if vmin == vmax: 77 | normalized_attn_map = np.zeros_like(attn_map) 78 | else: 79 | normalized_attn_map = (attn_map - vmin) / (vmax - vmin) 80 | 81 | cmap = plt.get_cmap("viridis") 82 | 83 | lights = "rembrandt" 84 | camera_kwargs = dict( 85 | camPos=camPos, 86 | camLookat=camLookat, 87 | camUp=camUp, 88 | camHeight=camHeight, 89 | resolution=resolution, 90 | samples=samples, 91 | ) 92 | renderer = fresnelvis.FresnelRenderer(lights=lights, camera_kwargs=camera_kwargs) 93 | for i, g in enumerate(gaussians): 94 | if is_bspnet: 95 | mu, eival, eivec = g[:3], g[3:6], g[6:15] 96 | else: 97 | mu, eivec, eival = g[:3], g[3:12], g[13:] 98 | R = eivec.reshape(3, 3).T 99 | scale = multiplier * np.sqrt(eival) 100 | scale_transform = np.diag((*scale, 1)) 101 | rigid_transform = np.hstack((R, mu.reshape(3, 1))) 102 | rigid_transform = np.vstack((rigid_transform, [0, 0, 0, 1])) 103 | sphere = trimesh.creation.icosphere() 104 | sphere.apply_transform(scale_transform) 105 | sphere.apply_transform(rigid_transform) 106 | if attn_map is None and gaussians_colors is None: 107 | color = np.array(cmap(i / N)[:3]) 108 | elif attn_map is not None: 109 | color = np.array(cmap(normalized_attn_map[i])[:3]) 110 | else: 111 | color = gaussians_colors[i] 112 | 113 | renderer.add_mesh( 114 | sphere.vertices, sphere.faces, color=color, outline_width=None 115 | ) 116 | image = renderer.render() 117 | return Image.fromarray(image) 118 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup( 4 | name="salad", 5 | version=0.1, 6 | description="SALAD: Part-Level Latent Diffusion for 3D Shape Generation and Manipulation", 7 | url="https://github.com/63days/SALAD", 8 | packages=["salad"], 9 | install_requires=[ 10 | # "pandas", 11 | # "torch", 12 | # "h5py", 13 | # "Pillow", 14 | # "numpy", 15 | # "matplotlib", 16 | # "six", 17 | # "nltk", 18 | # "pytorch-lightning==1.5.5", 19 | # "hydra-core==1.1.0", 20 | # "omegaconf==2.1.1" 21 | ], 22 | zip_safe=False, 23 | ) 24 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | import pytorch_lightning as pl 3 | from salad.utils import sysutil 4 | @hydra.main(config_path="configs", config_name="train.yaml") 5 | def main(config): 6 | pl.seed_everything(63) 7 | sysutil.print_config(config, ("callbacks", "logger", "paths", "debug", "data", "trainer", "model")) 8 | model = hydra.utils.instantiate(config.model, _recursive_=True) 9 | 10 | callbacks = [] 11 | if config.get("callbacks"): 12 | for cb_name, cb_conf in config.callbacks.items(): 13 | if config.get("debug") and cb_name == "model_checkpoint": 14 | continue 15 | callbacks.append(hydra.utils.instantiate(cb_conf)) 16 | 17 | logger = [] 18 | if config.get("logger"): 19 | for lg_name, lg_conf in config.logger.items(): 20 | if config.get("debug") and lg_name == "wandb": 21 | continue 22 | logger.append(hydra.utils.instantiate(lg_conf)) 23 | trainer = hydra.utils.instantiate( 24 | config.trainer, 25 | callbacks=callbacks, 26 | logger=logger if len(logger) != 0 else False, 27 | _convert_="partial", 28 | log_every_n_steps=200, 29 | resume_from_checkpoint=config.resume_from_checkpoint 30 | ) 31 | 32 | trainer.fit(model) 33 | 34 | if __name__ == "__main__": 35 | main() 36 | --------------------------------------------------------------------------------