├── .gitignore
├── README.md
├── configs
├── callbacks
│ └── default.yaml
├── category
│ ├── airplane.yaml
│ ├── chair.yaml
│ └── table.yaml
├── debug
│ └── default.yaml
├── hydra
│ └── default.yaml
├── logger
│ ├── multi_loggers.yaml
│ └── tensorboard.yaml
├── model
│ ├── default.yaml
│ ├── lang_default.yaml
│ ├── lang_phase1.yaml
│ ├── lang_phase2.yaml
│ ├── phase1.yaml
│ └── phase2.yaml
├── paths
│ └── default.yaml
├── train.yaml
└── trainer
│ └── default.yaml
├── docs
└── images
│ └── salad_teaser.png
├── environment.yml
├── notebooks
├── part_completion_demo.ipynb
├── text_generation_demo.ipynb
├── uncleaned_demo
│ ├── part_mixing.py
│ ├── pvd
│ │ ├── __init__.py
│ │ ├── configs
│ │ │ ├── callbacks
│ │ │ │ ├── default.yaml
│ │ │ │ └── fast.yaml
│ │ │ ├── debug
│ │ │ │ ├── default.yaml
│ │ │ │ ├── fast.yaml
│ │ │ │ └── super_fast.yaml
│ │ │ ├── experiment
│ │ │ │ ├── airplane_phase1.yaml
│ │ │ │ ├── airplane_phase1_sym.yaml
│ │ │ │ ├── airplane_phase2.yaml
│ │ │ │ ├── airplane_single_phase.yaml
│ │ │ │ ├── airplane_za.yaml
│ │ │ │ ├── free_guidance.yaml
│ │ │ │ ├── lang_phase1.yaml
│ │ │ │ ├── lang_phase1_zc.yaml
│ │ │ │ ├── lang_phase2.yaml
│ │ │ │ ├── lang_single_phase.yaml
│ │ │ │ ├── phase1-airplane.yaml
│ │ │ │ ├── phase1-pos_enc-airplane.yaml
│ │ │ │ ├── phase1-pos_enc.yaml
│ │ │ │ ├── phase1-scaled-vectors-sym-airplane.yaml
│ │ │ │ ├── phase1-scaled-vectors-sym.yaml
│ │ │ │ ├── phase1-scaled-vectors.yaml
│ │ │ │ ├── phase1-sym-airplane.yaml
│ │ │ │ ├── phase1-sym.yaml
│ │ │ │ ├── phase1-table.yaml
│ │ │ │ ├── phase1.yaml
│ │ │ │ ├── phase1_sym.yaml
│ │ │ │ ├── phase2-airplane.yaml
│ │ │ │ ├── phase2-scaled-vectors.yaml
│ │ │ │ ├── phase2-table.yaml
│ │ │ │ ├── phase2.yaml
│ │ │ │ ├── reverse_phase.yaml
│ │ │ │ ├── simclr-chair.yaml
│ │ │ │ ├── single_phase-airplane.yaml
│ │ │ │ ├── single_phase-pos_enc-airplane.yaml
│ │ │ │ ├── single_phase-pos_enc.yaml
│ │ │ │ ├── single_phase-table.yaml
│ │ │ │ ├── single_phase.yaml
│ │ │ │ ├── single_phase_zb-airplane.yaml
│ │ │ │ ├── single_phase_zb.yaml
│ │ │ │ ├── spaghetti_condition_self_attention.yaml
│ │ │ │ ├── za-airplane.yaml
│ │ │ │ └── za.yaml
│ │ │ ├── hydra
│ │ │ │ └── default.yaml
│ │ │ ├── logger
│ │ │ │ ├── multi_loggers.yaml
│ │ │ │ └── tensorboard.yaml
│ │ │ ├── model
│ │ │ │ ├── ldm_big_phase1.yaml
│ │ │ │ ├── ldm_big_phase1_language.yaml
│ │ │ │ ├── ldm_big_phase1_pos_enc.yaml
│ │ │ │ ├── ldm_big_phase2.yaml
│ │ │ │ ├── ldm_big_phase2_language.yaml
│ │ │ │ ├── ldm_big_single_phase-pos_enc.yaml
│ │ │ │ ├── ldm_big_single_phase.yaml
│ │ │ │ ├── ldm_big_single_phase_language.yaml
│ │ │ │ ├── ldm_big_single_phase_zb.yaml
│ │ │ │ ├── ldm_big_za.yaml
│ │ │ │ ├── ldm_default.yaml
│ │ │ │ ├── ldm_lang_default.yaml
│ │ │ │ ├── simclr.yaml
│ │ │ │ └── vae.yaml
│ │ │ ├── paths
│ │ │ │ └── default.yaml
│ │ │ ├── simclr.yaml
│ │ │ ├── train.yaml
│ │ │ └── trainer
│ │ │ │ └── default.yaml
│ │ ├── dataset.py
│ │ ├── diffusion
│ │ │ ├── __init__.py
│ │ │ ├── baselines
│ │ │ │ └── single_phase.py
│ │ │ ├── common.py
│ │ │ ├── language_cond.py
│ │ │ ├── ldm.py
│ │ │ ├── network.py
│ │ │ ├── phase1.py
│ │ │ ├── phase1_sym.py
│ │ │ ├── reverse_model.py
│ │ │ └── single_phase.py
│ │ ├── embedding
│ │ │ └── model.py
│ │ ├── misc
│ │ │ └── extract_gmms.py
│ │ ├── modules
│ │ │ ├── __init__.py
│ │ │ ├── pointnet2_ops_lib
│ │ │ │ ├── MANIFEST.in
│ │ │ │ ├── pointnet2_ops
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ ├── _ext-src
│ │ │ │ │ │ ├── include
│ │ │ │ │ │ │ ├── ball_query.h
│ │ │ │ │ │ │ ├── cuda_utils.h
│ │ │ │ │ │ │ ├── group_points.h
│ │ │ │ │ │ │ ├── interpolate.h
│ │ │ │ │ │ │ ├── sampling.h
│ │ │ │ │ │ │ └── utils.h
│ │ │ │ │ │ └── src
│ │ │ │ │ │ │ ├── ball_query.cpp
│ │ │ │ │ │ │ ├── ball_query_gpu.cu
│ │ │ │ │ │ │ ├── bindings.cpp
│ │ │ │ │ │ │ ├── group_points.cpp
│ │ │ │ │ │ │ ├── group_points_gpu.cu
│ │ │ │ │ │ │ ├── interpolate.cpp
│ │ │ │ │ │ │ ├── interpolate_gpu.cu
│ │ │ │ │ │ │ ├── sampling.cpp
│ │ │ │ │ │ │ └── sampling_gpu.cu
│ │ │ │ │ ├── _version.py
│ │ │ │ │ ├── pointnet2_modules.py
│ │ │ │ │ └── pointnet2_utils.py
│ │ │ │ └── setup.py
│ │ │ ├── simple_module.py
│ │ │ └── transformer.py
│ │ ├── spaghetti_github
│ │ │ ├── __init__.py
│ │ │ ├── constants.py
│ │ │ ├── custom_types.py
│ │ │ ├── models
│ │ │ │ ├── __init__.py
│ │ │ │ ├── gm_utils.py
│ │ │ │ ├── mlp_models.py
│ │ │ │ ├── models_utils.py
│ │ │ │ ├── occ_gmm.py
│ │ │ │ ├── occ_loss.py
│ │ │ │ └── transformer.py
│ │ │ ├── options.py
│ │ │ ├── ui
│ │ │ │ ├── __init__.py
│ │ │ │ ├── gaussian_status.py
│ │ │ │ ├── inference_processing.py
│ │ │ │ ├── interactors.py
│ │ │ │ ├── mock_keyboard.py
│ │ │ │ ├── occ_inference.py
│ │ │ │ ├── ui_controllers.py
│ │ │ │ ├── ui_main.py
│ │ │ │ └── ui_utils.py
│ │ │ └── utils
│ │ │ │ ├── files_utils.py
│ │ │ │ ├── mcubes_meshing.py
│ │ │ │ ├── mesh_utils.py
│ │ │ │ ├── myparse.py
│ │ │ │ ├── rotation_utils.py
│ │ │ │ └── train_utils.py
│ │ ├── train_clr.py
│ │ ├── train_ldm.py
│ │ ├── train_vae.py
│ │ ├── utils
│ │ │ ├── __init__.py
│ │ │ ├── contrastive_utils.py
│ │ │ ├── ldm_utils.py
│ │ │ ├── sde_utils.py
│ │ │ ├── spaghetti_utils.py
│ │ │ ├── spaghetti_v2_utils.py
│ │ │ ├── train_utils.py
│ │ │ └── transform.py
│ │ └── vae
│ │ │ ├── __init__.py
│ │ │ ├── ae.py
│ │ │ ├── pointnet2_utils.py
│ │ │ └── vae.py
│ ├── sdedit.py
│ ├── shape_editing.py
│ └── utils
│ │ ├── __init__.py
│ │ ├── common.py
│ │ ├── gmm_edit.py
│ │ ├── io_utils.py
│ │ ├── mixing.py
│ │ └── vis.py
└── unconditional_generation_demo.ipynb
├── salad
├── data
│ └── dataset.py
├── model_components
│ ├── lstm.py
│ ├── network.py
│ ├── simple_module.py
│ ├── transformer.py
│ └── variance_schedule.py
├── models
│ ├── __init__.py
│ ├── base_model.py
│ ├── language_phase1.py
│ ├── language_phase2.py
│ ├── phase1.py
│ └── phase2.py
├── spaghetti
│ ├── .gitignore
│ ├── LICENSE
│ ├── README.md
│ ├── __init__.py
│ ├── assets
│ │ ├── checkpoints
│ │ │ ├── .gitkeep
│ │ │ └── README.md
│ │ ├── mesh
│ │ │ └── example.obj
│ │ ├── readme_resources
│ │ │ └── teaser-01.png
│ │ └── ui_resources
│ │ │ ├── icons-03.png
│ │ │ ├── icons-04.png
│ │ │ ├── icons-05.png
│ │ │ ├── icons-06.png
│ │ │ ├── icons-10.png
│ │ │ ├── icons-11.png
│ │ │ ├── icons-12.png
│ │ │ ├── icons-13.png
│ │ │ └── icons-14.png
│ ├── constants.py
│ ├── custom_types.py
│ ├── models
│ │ ├── __init__.py
│ │ ├── gm_utils.py
│ │ ├── mlp_models.py
│ │ ├── models_utils.py
│ │ ├── occ_gmm.py
│ │ ├── occ_loss.py
│ │ └── transformer.py
│ ├── options.py
│ ├── ui
│ │ ├── __init__.py
│ │ ├── gaussian_status.py
│ │ ├── inference_processing.py
│ │ ├── interactors.py
│ │ ├── mock_keyboard.py
│ │ ├── occ_inference.py
│ │ ├── ui_controllers.py
│ │ ├── ui_main.py
│ │ └── ui_utils.py
│ └── utils
│ │ ├── __init__.py
│ │ ├── files_utils.py
│ │ ├── mcubes_meshing.py
│ │ ├── mesh_utils.py
│ │ ├── myparse.py
│ │ ├── rotation_utils.py
│ │ └── train_utils.py
└── utils
│ ├── __init__.py
│ ├── fresnelvis.py
│ ├── imageutil.py
│ ├── logutil.py
│ ├── meshutil.py
│ ├── nputil.py
│ ├── paths.py
│ ├── shapeglot_util.py
│ ├── spaghetti_util.py
│ ├── sysutil.py
│ ├── thutil.py
│ ├── train_util.py
│ └── visutil.py
├── setup.py
└── train.py
/.gitignore:
--------------------------------------------------------------------------------
1 | salad.egg-info/
2 | checkpoints/
3 | results/
4 | data/
5 | **/__pycache__/
6 | .DS_Store
7 | **/*.pyc
8 | **/.ipynb_checkpoints
9 | !salad/data/
10 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # 🥗 SALAD: Part-Level Latent Diffusion for 3D Shape Generation and Manipulation, ICCV 2023
2 |
3 | 
4 |
5 |
6 | [**Arxiv**](https://arxiv.org/abs/2303.12236) | [**Project Page**](https://salad3d.github.io/)
7 |
8 | [Juil Koo\*](https://63days.github.io/), [Seungwoo Yoo\*](https://dvelopery0115.github.io/), [Minh Hieu Nguyen\*](https://min-hieu.github.io/), [Minhyuk Sung](https://mhsung.github.io/)
9 | \* denotes equal contribution.
10 |
11 | ### 🎉 This paper got accepted to ICCV 2023!
12 |
13 | # Introduction
14 | This repository contains the official implementation of 🥗 **SALAD: Part-Level Latent Diffusion for 3D Shape Generation and Manipulation**.
15 | **SALAD** is a 3D shape diffusion model which generates high-quality shapes and hints its zero-shot capability in various shape manipulation applications. More results can be viewed with a 3D viewer on the [project webpage](https://salad3d.github.io).
16 |
17 | [//]: # (### Abstract)
18 | > We present a cascaded diffusion model based on a part-level implicit 3D representation. Our model achieves state-of-the-art generation quality and also enables part-level shape editing and manipulation without any additional training in conditional setup. Diffusion models have demonstrated impressive capabilities in data generation as well as zero-shot completion and editing via a guided reverse process. Recent research on 3D diffusion models has focused on improving their generation capabilities with various data representations, while the absence of structural information has limited their capability in completion and editing tasks. We thus propose our novel diffusion model using a part-level implicit representation. To effectively learn diffusion with high-dimensional embedding vectors of parts, we propose a cascaded framework, learning diffusion first on a low-dimensional subspace encoding extrinsic parameters of parts and then on the other high-dimensional subspace encoding intrinsic attributes. In the experiments, we demonstrate the outperformance of our method compared with the previous ones both in generation and part-level completion and manipulation tasks.
19 |
20 | # UPDATES
21 |
22 | - [x] Training code of SALAD with Chair class. (May 8th 2023)
23 | - [x] Training code of *text-conditioned* SALAD. (May 8th 2023)
24 | - [x] Demo of *text-guided* shape generation. (May 8th 2023)
25 | - [x] Code of generation of more classes. (Sep. 17th 2023)
26 | - [x] Demo of part completion.
27 | - [ ] Demo of part mixing and refinement.
28 | - [ ] Demo of text-guided part completion.
29 |
30 |
31 | # Get Started
32 |
33 | ## Installation
34 |
35 | For main requirements, we have tested the code with Python 3.9, CUDA 11.3 and Pytorch 1.12.1+cu113.
36 |
37 | ```
38 | git clone https://github.com/63days/salad/
39 | cd SALAD
40 | conda env create -f environment.yml
41 | conda activate salad
42 | pip install -e .
43 | ```
44 |
45 | ## Data and Model Checkpoint
46 | We provide ready-to-use data and pre-trained SALAD and SPAGHETTI checkpoints [here](http://onedrive.live.com/?cid=0ce615b143fc4bdc&id=0CE615B143FC4BDC%212507&resid=0CE615B143FC4BDC%212507&ithint=folder&migratedtospo=true&redeem=aHR0cHM6Ly8xZHJ2Lm1zL2YvYy8wY2U2MTViMTQzZmM0YmRjL1F0eExfRU94RmVZZ2dBekxDUUFBQUFBQU42dzI0U1RrV2VpRXd3&v=validatepermission).
47 | Unzip files and put them under the corresponding directories. SPAGHETTI checkpoint directoriess should be under the `SALAD/salad/spaghetti/assets/checkpoints/` directory.
48 |
49 |
50 | # Training
51 | ## Unconditional Generation
52 |
53 | You can train models for unconditional generation with airplanes, chairs or tables.
54 | ```
55 | python train.py model={phase1, phase2} category={airplane, chair, table}
56 | ```
57 |
58 | ## Text-Guided Generation
59 |
60 | For text-guided generation, train a first-phase model and a second-phase model:
61 | ```
62 | python train.py model={lang_phase1, lang_phase2}
63 | ```
64 |
65 | # Demo
66 | ## Unconditional Shape Generation
67 |
68 | We provide the demo code of unconditional generation with airplanes, chairs and tables at `notebooks/unconditional_generation_demo.ipynb`.
69 |
70 | ## Text-Guided Shape Generation
71 |
72 | We also provide text-guided shape generation demo code at `notebooks/text_generation_demo.ipynb`.
73 |
74 |
75 | # Citation
76 | If you find our work useful, please consider citing:
77 |
78 | ```bibtex
79 | @article{koo2023salad,
80 | title={{SALAD}: Part-Level Latent Diffusion for 3D Shape Generation and Manipulation},
81 | author={Koo, Juil and Yoo, Seungwoo and Nguyen, Minh Hieu and Sung, Minhyuk},
82 | year={2023},
83 | journal={arXiv preprint arXiv:2303.12236},
84 | }
85 | ```
86 |
87 | # Reference
88 | We borrow [SPAGHETTI](https://amirhertz.github.io/spaghetti/) code from the [official implementation](https://github.com/amirhertz/spaghetti).
89 |
--------------------------------------------------------------------------------
/configs/callbacks/default.yaml:
--------------------------------------------------------------------------------
1 | model_summary:
2 | _target_: pytorch_lightning.callbacks.RichModelSummary
3 | max_depth: 1
4 |
5 | rich_progress_bar:
6 | _target_: pytorch_lightning.callbacks.RichProgressBar
7 |
8 | model_checkpoint:
9 | _target_: pytorch_lightning.callbacks.ModelCheckpoint
10 | monitor: "train/loss"
11 | mode: "min"
12 | save_top_k: 1
13 | save_last: true
14 | verbose: true
15 | dirpath: "checkpoints/"
16 | filename: "epoch={epoch:02d}-train_loss={train/loss:.4f}"
17 | #every_n_epochs: 200
18 | auto_insert_metric_name: false
19 |
20 | #lr_monitor:
21 | # _target_: pytorch_lightning.callbacks.LearningRateMonitor
22 | # logging_interval: "step"
23 |
24 |
--------------------------------------------------------------------------------
/configs/category/airplane.yaml:
--------------------------------------------------------------------------------
1 | latent_path: spaghetti_airplane_latents.hdf5 # file path from the salad.utils.paths.DATA_DIR directory.
2 | spaghetti_tag: "airplanes"
3 |
--------------------------------------------------------------------------------
/configs/category/chair.yaml:
--------------------------------------------------------------------------------
1 | latent_path: spaghetti_chair_latents.hdf5 # file path from the salad.utils.paths.DATA_DIR directory.
2 | spaghetti_tag: "chairs_large"
3 |
--------------------------------------------------------------------------------
/configs/category/table.yaml:
--------------------------------------------------------------------------------
1 | latent_path: spaghetti_table_latents.hdf5 # file path from the salad.utils.paths.DATA_DIR directory.
2 | spaghetti_tag: "tables"
3 |
--------------------------------------------------------------------------------
/configs/debug/default.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | #
3 | callbacks: null
4 | logger: null
5 |
6 | hydra:
7 | job_logging:
8 | root:
9 | level: DEBUG
10 |
11 | trainer:
12 | max_epochs: 1
13 |
--------------------------------------------------------------------------------
/configs/hydra/default.yaml:
--------------------------------------------------------------------------------
1 | run:
2 | dir: ${paths.save_dir}
3 |
4 | sweep:
5 | dir: ${paths.save_dir}
6 | subdir: ${hydra.job.num}
7 |
--------------------------------------------------------------------------------
/configs/logger/multi_loggers.yaml:
--------------------------------------------------------------------------------
1 | wandb:
2 | _target_: pytorch_lightning.loggers.wandb.WandbLogger
3 | project: "spaghetti-ldm"
4 | name: ${name}
5 | save_dir: "."
6 | entity: "geometry"
7 |
8 | tensorboard:
9 | _target_: pytorch_lightning.loggers.tensorboard.TensorBoardLogger
10 | save_dir: "tensorboard/"
11 | name: null
12 | version: ${name}
13 | log_graph: false
14 | default_hp_metric: true
15 | prefix: ""
16 |
--------------------------------------------------------------------------------
/configs/logger/tensorboard.yaml:
--------------------------------------------------------------------------------
1 | tensorboard:
2 | _target_: pytorch_lightning.loggers.tensorboard.TensorBoardLogger
3 | save_dir: "tensorboard/"
4 | name: null
5 | #version: ${name}
6 | log_graph: true
7 | default_hp_metric: false
8 | prefix: ""
9 |
--------------------------------------------------------------------------------
/configs/model/default.yaml:
--------------------------------------------------------------------------------
1 | network: null
2 |
3 | variance_schedule:
4 | _target_: salad.model_components.variance_schedule.VarianceSchedule
5 | num_steps: &time_steps 1000
6 | beta_1: 1e-4
7 | beta_T: 0.05
8 | mode: linear
9 |
10 | # optimizer
11 | lr: 1e-4
12 | batch_size: ${batch_size}
13 |
14 | # dataset
15 | dataset_kwargs:
16 | data_path: ${category.latent_path}
17 | repeat: 3
18 | data_keys: ["s_j_affine"]
19 |
20 | # model
21 | num_timesteps: *time_steps
22 | faster: true
23 | validation_step: 10
24 | no_run_validation: false
25 | spaghetti_tag: ${category.spaghetti_tag}
26 |
--------------------------------------------------------------------------------
/configs/model/lang_default.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - default.yaml
3 |
4 | # dataset
5 | dataset_kwargs:
6 | data_path: ${category.latent_path}
7 | repeat: 1
8 | only_easy_context: false
9 | global_normalization: ${model.global_normalization}
10 |
11 | # model
12 | global_normalization: false
13 | text_encoder_freeze: false
14 | use_lstm: true
15 | classifier_free_guidance: true
16 | conditioning_dropout_prob: 0.2
17 |
18 |
--------------------------------------------------------------------------------
/configs/model/lang_phase1.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - lang_default.yaml
3 |
4 | _target_: salad.models.language_phase1.LangPhase1Model
5 | network:
6 | _target_: salad.model_components.network.CondDiffNetwork
7 | input_dim: 16
8 | residual: true
9 | context_dim: 768
10 | context_embedding_dim: 1024
11 | embedding_dim: 512
12 | encoder_use_time: false
13 | encoder_type: pointwise
14 | decoder_type: transformer_encoder
15 | enc_num_layers: 2
16 | dec_num_layers: 6
17 | use_timestep_embedder: true
18 | timestep_embedder_dim: 128
19 |
20 | global_normalization: partial
21 |
22 | dataset_kwargs:
23 | data_keys: ["g_js_affine"]
24 |
--------------------------------------------------------------------------------
/configs/model/lang_phase2.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - lang_default.yaml
3 |
4 | _target_: salad.models.language_phase2.LangPhase2Model
5 | network:
6 | _target_: salad.model_components.network.CondDiffNetwork
7 | input_dim: 512
8 | residual: true
9 | context_dim: 784 # concat of 768 lang feat and gaussian.
10 | context_embedding_dim: 1024
11 | embedding_dim: 512
12 | encoder_use_time: false
13 | encoder_type: transformer
14 | decoder_type: transformer_encoder
15 | enc_num_layers: 6
16 | dec_num_layers: 6
17 | use_timestep_embedder: true
18 | timestep_embedder_dim: 128
19 |
20 | dataset_kwargs:
21 | data_keys: ["s_j_affine", "g_js_affine"]
22 |
--------------------------------------------------------------------------------
/configs/model/phase1.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - default.yaml
3 |
4 | _target_: salad.models.phase1.Phase1Model
5 |
6 | network:
7 | _target_: salad.model_components.network.UnCondDiffNetwork
8 | input_dim: 16
9 | embedding_dim: 512
10 | num_heads: 4
11 | use_timestep_embedder: true
12 | timestep_embedder_dim: 128
13 | enc_num_layers: 6
14 | residual: true
15 | encoder_type: transformer
16 | attn_dropout: 0.0
17 |
18 | global_normalization: partial # normalize pi, eigenvalues.
19 |
20 | dataset_kwargs:
21 | data_keys: ["g_js_affine"]
22 | global_normalization: ${model.global_normalization}
23 |
--------------------------------------------------------------------------------
/configs/model/phase2.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - default.yaml
3 |
4 | _target_: salad.models.phase2.Phase2Model
5 | network:
6 | _target_: salad.model_components.network.CondDiffNetwork
7 | input_dim: 512
8 | residual: true
9 | context_dim: 16 # gaussian condition dim.
10 | context_embedding_dim: 512
11 | embedding_dim: 512
12 | encoder_use_time: false
13 | encoder_type: transformer
14 | decoder_type: transformer_encoder # we don't use cross attention.
15 | enc_num_layers: 6
16 | dec_num_layers: 6
17 | use_timestep_embedder: true
18 | timestep_embedder_dim: 128
19 |
20 |
21 | dataset_kwargs:
22 | data_keys: ["s_j_affine", "g_js_affine"]
23 |
--------------------------------------------------------------------------------
/configs/paths/default.yaml:
--------------------------------------------------------------------------------
1 | work_dir: ${hydra:runtime.cwd}
2 | save_dir: ${hydra:runtime.cwd}/results/${name}/${now:%m%d_%H%M%S}
3 |
--------------------------------------------------------------------------------
/configs/train.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | defaults:
4 | - _self_
5 | - category: chair.yaml
6 | - model: phase1.yaml
7 | - callbacks: default.yaml
8 | - logger: tensorboard.yaml # multi_loggers.yaml
9 | - trainer: default.yaml
10 | - paths: default.yaml
11 | - hydra: default.yaml
12 | - debug: null
13 |
14 | name: default
15 | seed: 63
16 |
17 | gpus: [0]
18 | epochs: 5000
19 | batch_size: 32
20 |
21 | resume_from_checkpoint: null
22 |
--------------------------------------------------------------------------------
/configs/trainer/default.yaml:
--------------------------------------------------------------------------------
1 | _target_: pytorch_lightning.Trainer
2 | accelerator: gpu
3 | devices: ${gpus}
4 | max_epochs: ${epochs}
5 |
--------------------------------------------------------------------------------
/docs/images/salad_teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KAIST-Visual-AI-Group/SALAD/db66b1b2ce57d8a4beb5bd610bc006ccaff43b1e/docs/images/salad_teaser.png
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/__init__.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | pvd_chair_path=Path("/home/juil/projects/3D_CRISPR/baselines/PVD/output/test_generation/chair/syn/samples.pth")
4 | pvd_airplane_path=Path("/home/juil/projects/3D_CRISPR/baselines/PVD/output/test_generation/airplane/syn/samples.pth")
5 | dpm_chair_path=Path("/home/hieuristics/docker_home/project/baselines/evaluation_wavelet/metrics/data/chair/dpm/o3d_all_pointclouds.hdf5")
6 | dpm_airplane_path=Path("/home/hieuristics/docker_home/project/baselines/evaluation_wavelet/metrics/data/airplane/dpm/o3d_all_pointclouds.hdf5")
7 | lion_chair_path=Path("/home/juil/projects/3D_CRISPR/baselines/LION/data_pretrained/LION_chair.pt")
8 | lion_airplane_path=Path("/home/juil/projects/3D_CRISPR/baselines/LION/data_pretrained/LION_airplane.pt")
9 | voxel_gan_chair_path=Path("/home/juil/projects/3D_CRISPR/baselines/shapegan/generated_objects/chair")
10 | voxel_gan_airplane_path=Path("/home/juil/projects/3D_CRISPR/baselines/shapegan/generated_objects/airplane")
11 | wavelet_chair_path=Path("/home/hieuristics/docker_home/project/baselines/Wavelet-Generation/debug/2023-02-14_14-47-22_Wavelet-New-chair")
12 | wavelet_airplane_path=Path("/home/hieuristics/docker_home/project/baselines/Wavelet-Generation/debug/2023-02-06_23-49-06_Network-Marching-Cubes-Diffusion-Gen-Airplane")
13 | spaghetti_chair_path=Path("/scratch/juil/iccv2023/data/spaghetti_chair_za_random_sample/meshes")
14 | spaghetti_airplane_path=Path("/home/juil/projects/3D_CRISPR/data/spaghetti_airplane_random_generation/meshes")
15 |
16 | our_chair_path = Path("/home/juil/projects/3D_CRISPR/pointnet-vae-diffusion/pvd/results/phase2/augment_final_0214/0214_202607/eval-02-23-143029/meshes")
17 | our_airplane_path = Path("/home/juil/projects/3D_CRISPR/pointnet-vae-diffusion/pvd/results/phase2-airplane/augment_final_0214/0214_202550/eval-02-27-013138/meshes")
18 |
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/configs/callbacks/default.yaml:
--------------------------------------------------------------------------------
1 | model_summary:
2 | _target_: pytorch_lightning.callbacks.RichModelSummary
3 | max_depth: 1
4 |
5 | rich_progress_bar:
6 | _target_: pytorch_lightning.callbacks.RichProgressBar
7 |
8 | model_checkpoint:
9 | _target_: pytorch_lightning.callbacks.ModelCheckpoint
10 | monitor: "train/loss"
11 | mode: "min"
12 | save_top_k: -1
13 | save_last: true
14 | verbose: true
15 | dirpath: "checkpoints/"
16 | filename: "epoch={epoch:02d}-val_loss={val/loss:.4f}"
17 | every_n_epochs: 200
18 | auto_insert_metric_name: false
19 |
20 | lr_monitor:
21 | _target_: pytorch_lightning.callbacks.LearningRateMonitor
22 | logging_interval: "step"
23 |
24 |
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/configs/callbacks/fast.yaml:
--------------------------------------------------------------------------------
1 | model_summary:
2 | _target_: pytorch_lightning.callbacks.RichModelSummary
3 | max_depth: 2
4 |
5 | rich_progress_bar:
6 | _target_: pytorch_lightning.callbacks.RichProgressBar
7 |
8 | model_checkpoint:
9 | _target_: pytorch_lightning.callbacks.ModelCheckpoint
10 | monitor: "train/loss"
11 | mode: "min"
12 | save_top_k: -1
13 | save_last: true
14 | verbose: true
15 | dirpath: "checkpoints/"
16 | filename: "epoch={epoch:02d}-val_loss={val/loss:.4f}"
17 | every_n_epochs: 200
18 | auto_insert_metric_name: false
19 |
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/configs/debug/default.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | #
3 | callbacks: null
4 | logger: null
5 |
6 | hydra:
7 | job_logging:
8 | root:
9 | level: DEBUG
10 |
11 | trainer:
12 | max_epochs: 1
13 |
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/configs/debug/fast.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | #
3 | model:
4 | no_run_validation: true
5 |
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/configs/debug/super_fast.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | #
3 | defaults:
4 | - override /logger: tensorboard.yaml
5 | - override /callbacks: fast.yaml
6 | logger: null
7 | model:
8 | no_run_validation: true
9 |
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/configs/experiment/airplane_phase1.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | defaults:
3 | - override /model: ldm_big_phase1.yaml
4 | - override /paths: default.yaml
5 |
6 | paths:
7 | ####
8 | # save_dir: ${paths.work_dir}/results/phase1/${name}/${now:%m%d_%H%M%S}
9 | save_dir: ${paths.work_dir}/results/airplane-phase1/${name}/${now:%m%d_%H%M%S}
10 | ####
11 |
12 | ####
13 | data:
14 | latent_path: "/home/dreamy1534/Projects/iccv2023-spaghetti/part_ldm/pvd/data/airplane_1775.hdf5"
15 | spaghetti_tag: "airplanes"
16 | ####
17 |
18 | logger:
19 | wandb:
20 | ####
21 | # < Seungwoo >
22 | # Modified project name to prevent overwrite.
23 | # project: "spaghetti-gaus"
24 | # < Seungwoo >
25 | # Modified project name for airplane data
26 | # project: "sw_spaghetti-gaus"
27 | project: "sw_spaghetti-airplane-gaus"
28 | ####
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/configs/experiment/airplane_phase1_sym.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | defaults:
3 | - override /model: ldm_big_phase1.yaml
4 | - override /paths: default.yaml
5 |
6 |
7 | #### < Seungwoo >
8 | # Change log directory
9 | paths:
10 | # save_dir: ${paths.work_dir}/results/phase1/${name}/${now:%m%d_%H%M%S}
11 | # save_dir: ${paths.work_dir}/results/phase1_sym/${name}/${now:%m%d_%H%M%S}
12 | save_dir: ${paths.work_dir}/results/airplane-phase1_sym/${name}/${now:%m%d_%H%M%S}
13 | ####
14 |
15 |
16 | #### < Seungwoo >
17 | # Override config in 'data/spaghetti_author' to load the proper data
18 | data:
19 | latent_path: "/home/dreamy1534/Projects/iccv2023-spaghetti/part_ldm/pvd/data/airplane_1775_half.hdf5"
20 | spaghetti_tag: "airplanes"
21 | ####
22 |
23 |
24 | #### < Seungwoo >
25 | # Override configs in 'ldm_big_phase1' to use the symmetric model.
26 | model:
27 | _target_: pvd.diffusion.phase1_sym.GaussianSALDM
28 | eigen_value_clipping: true
29 | ####
30 |
31 | logger:
32 | wandb:
33 | # < Seungwoo >
34 | # Modified project name to prevent overwrite.
35 | # project: "spaghetti-gaus"
36 | project: "sw_spaghetti-gaus-sym"
37 |
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/configs/experiment/airplane_phase2.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | defaults:
3 | - override /model: ldm_big_phase2.yaml
4 | - override /paths: default.yaml
5 |
6 | paths:
7 | ####
8 | # save_dir: ${paths.work_dir}/results/phase2/${name}/${now:%m%d_%H%M%S}
9 | save_dir: ${paths.work_dir}/results/airplane-phase2/${name}/${now:%m%d_%H%M%S}
10 | ####
11 |
12 | ####
13 | data:
14 | latent_path: "/home/dreamy1534/Projects/iccv2023-spaghetti/part_ldm/pvd/data/airplane_1775.hdf5"
15 | spaghetti_tag: "airplanes"
16 | ####
17 |
18 | logger:
19 | wandb:
20 | # < Seungwoo >
21 | # Modified project name to prevent overwrite.
22 | # project: "spaghetti-phase2"
23 | # < Seungwoo >
24 | # Modified project name for airplane data
25 | # project: "sw_spaghetti-phase2"
26 | project: "sw_spaghetti-airplane-phase2"
27 |
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/configs/experiment/airplane_single_phase.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | defaults:
3 | - override /model: ldm_big_phase1.yaml
4 | - override /paths: default.yaml
5 |
6 | paths:
7 | save_dir: ${paths.work_dir}/results/airplane-single_phase/${name}/${now:%m%d_%H%M%S}
8 |
9 | model:
10 | _target_: pvd.diffusion.baselines.single_phase.SinglePhaseSALDM
11 | network:
12 | _target_: pvd.diffusion.baselines.single_phase.LatentSelfAttentionNetwork
13 | input_dim: 528
14 |
15 | #### < Seungwoo >
16 | # Adjust the number of validation steps
17 | validation_step: 5
18 | ####
19 |
20 | dataset_kwargs:
21 | data_keys: ["concat"]
22 |
23 | data:
24 | ####
25 | # latent_path: "/home/dreamy1534/Projects/iccv2023-spaghetti/part_ldm/pvd/data/chair_6755_concat.hdf5"
26 | latent_path: "/home/dreamy1534/Projects/iccv2023-spaghetti/part_ldm/pvd/data/airplane_1775_concat.hdf5"
27 | spaghetti_tag: "airplanes"
28 | ####
29 |
30 | logger:
31 | wandb:
32 | ####
33 | # project: "spaghetti-single_phase"
34 | project: "spaghetti-airplane-single_phase"
35 | ####
36 |
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/configs/experiment/airplane_za.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | defaults:
3 | - override /model: ldm_big_za.yaml
4 | - override /paths: default.yaml
5 |
6 | paths:
7 | ####
8 | # save_dir: ${paths.work_dir}/results/za/${name}/${now:%m%d_%H%M%S}
9 | save_dir: ${paths.work_dir}/results/airplane-za/${name}/${now:%m%d_%H%M%S}
10 | ####
11 |
12 | ####
13 | data:
14 | latent_path: "/home/dreamy1534/Projects/iccv2023-spaghetti/part_ldm/pvd/data/airplane_1775.hdf5"
15 | spaghetti_tag: "airplanes"
16 | ####
17 |
18 | logger:
19 | wandb:
20 | ####
21 | # project: "spaghetti-za"
22 | project: "sw_spaghetti-airplane-za"
23 | ####
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/configs/experiment/free_guidance.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | defaults:
3 | - override /model: ldm_big_phase2.yaml
4 | - override /paths: default.yaml
5 |
6 | paths:
7 | save_dir: ${paths.work_dir}/results/free_guidance/${name}/${now:%m%d_%H%M%S}
8 |
9 | model:
10 | classifier_free_guidance: true
11 | draw_tsne: false
12 |
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/configs/experiment/lang_phase1.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | defaults:
3 | - override /model: ldm_big_phase1_language.yaml
4 | - override /paths: default.yaml
5 | #- override /data: spaghetti_half_chair.yaml
6 | - override /data: spaghetti_author.yaml
7 | - override /callbacks: default.yaml
8 |
9 | paths:
10 | save_dir: ${paths.work_dir}/results/phase1-lang/${name}/${now:%m%d_%H%M%S}
11 |
12 | logger:
13 | wandb:
14 | project: "spaghetti-phase1-lang"
15 |
16 |
17 | callbacks:
18 | model_checkpoint:
19 | every_n_epochs: 5
20 |
21 |
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/configs/experiment/lang_phase1_zc.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | defaults:
3 | - override /model: ldm_big_phase1_language.yaml
4 | - override /paths: default.yaml
5 | #- override /data: spaghetti_half_chair.yaml
6 | - override /data: spaghetti_author.yaml
7 | - override /callbacks: default.yaml
8 |
9 | paths:
10 | save_dir: ${paths.work_dir}/results/phase1-lang/${name}/${now:%m%d_%H%M%S}
11 |
12 | logger:
13 | wandb:
14 | project: "spaghetti-phase1-lang"
15 |
16 | model:
17 | dataset_kwargs:
18 | data_keys: ["Zc"]
19 | use_zc: true
20 | network:
21 | input_dim: 512
22 |
23 |
24 | callbacks:
25 | model_checkpoint:
26 | every_n_epochs: 5
27 |
28 |
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/configs/experiment/lang_phase2.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | defaults:
3 | - override /model: ldm_big_phase2_language.yaml
4 | - override /paths: default.yaml
5 | - override /data: spaghetti_author.yaml
6 | - override /callbacks: default.yaml
7 |
8 | paths:
9 | save_dir: ${paths.work_dir}/results/phase2-lang/${name}/${now:%m%d_%H%M%S}
10 |
11 | logger:
12 | wandb:
13 | project: "spaghetti-phase2-lang"
14 |
15 | callbacks:
16 | model_checkpoint:
17 | every_n_epochs: 5
18 |
19 |
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/configs/experiment/lang_single_phase.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | defaults:
3 | - override /model: ldm_big_single_phase_language.yaml
4 | - override /paths: default.yaml
5 | - override /callbacks: default.yaml
6 |
7 | paths:
8 | save_dir: ${paths.work_dir}/results/single_phase-lang/${name}/${now:%m%d_%H%M%S}
9 |
10 | logger:
11 | wandb:
12 | project: "spaghetti-single_phase-lang"
13 |
14 | callbacks:
15 | model_checkpoint:
16 | every_n_epochs: 5
17 |
18 |
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/configs/experiment/phase1-airplane.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | defaults:
3 | - override /model: ldm_big_phase1.yaml
4 | - override /paths: default.yaml
5 | - override /data: spaghetti_author_airplane.yaml
6 |
7 | paths:
8 | save_dir: ${paths.work_dir}/results/phase1-airplane/${name}/${now:%m%d_%H%M%S}
9 |
10 | logger:
11 | wandb:
12 | project: "spaghetti-gaus-airplane"
13 |
14 | model:
15 | spaghetti_tag: "airplanes"
16 |
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/configs/experiment/phase1-pos_enc-airplane.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | defaults:
3 | - override /model: ldm_big_phase1_pos_enc.yaml
4 | - override /paths: default.yaml
5 | - override /data: spaghetti_author_airplane.yaml
6 |
7 | paths:
8 | save_dir: ${paths.work_dir}/results/phase1-pos_enc-airplane/${name}/${now:%m%d_%H%M%S}
9 |
10 | logger:
11 | wandb:
12 | project: "spaghetti-gaus-pos_enc-airplane"
13 | model:
14 | spaghetti_tag: "airplanes"
15 |
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/configs/experiment/phase1-pos_enc.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | defaults:
3 | - override /model: ldm_big_phase1_pos_enc.yaml
4 | - override /paths: default.yaml
5 |
6 | paths:
7 | save_dir: ${paths.work_dir}/results/phase1-pos_enc/${name}/${now:%m%d_%H%M%S}
8 |
9 | logger:
10 | wandb:
11 | project: "spaghetti-gaus"
12 |
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/configs/experiment/phase1-scaled-vectors-sym-airplane.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | defaults:
3 | - override /model: ldm_big_phase1.yaml
4 | - override /paths: default.yaml
5 | - override /data: spaghetti_half_airplane.yaml
6 |
7 | paths:
8 | save_dir: ${paths.work_dir}/results/phase1-sym-scaled-vectors-airplane/${name}/${now:%m%d_%H%M%S}
9 |
10 | logger:
11 | wandb:
12 | project: "spaghetti-gaus-sym-scaled-vectors-airplane"
13 |
14 | model:
15 | _target_: pvd.diffusion.phase1_sym.GaussianSymSALDM
16 | use_symmetric: &sym true
17 | network:
18 | input_dim: 13
19 | scale_eigenvectors: &scale_vectors true
20 | dataset_kwargs:
21 | scale_eigenvectors: *scale_vectors
22 | use_scaled_eigenvectors: *scale_vectors
23 | spaghetti_tag: "airplanes"
24 |
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/configs/experiment/phase1-scaled-vectors-sym.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | defaults:
3 | - override /model: ldm_big_phase1.yaml
4 | - override /paths: default.yaml
5 | - override /data: spaghetti_half_chair.yaml
6 |
7 | paths:
8 | save_dir: ${paths.work_dir}/results/phase1-sym-scaled-vectors/${name}/${now:%m%d_%H%M%S}
9 |
10 | logger:
11 | wandb:
12 | project: "spaghetti-gaus-sym-scaled-vectors"
13 |
14 | model:
15 | _target_: pvd.diffusion.phase1_sym.GaussianSymSALDM
16 | use_symmetric: &sym true
17 | network:
18 | input_dim: 13
19 | scale_eigenvectors: &scale_vectors true
20 | dataset_kwargs:
21 | scale_eigenvectors: *scale_vectors
22 | use_scaled_eigenvectors: *scale_vectors
23 |
24 |
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/configs/experiment/phase1-scaled-vectors.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | defaults:
3 | - override /model: ldm_big_phase1.yaml
4 | - override /paths: default.yaml
5 |
6 | paths:
7 | save_dir: ${paths.work_dir}/results/phase1-scaled-vectors/${name}/${now:%m%d_%H%M%S}
8 |
9 | logger:
10 | wandb:
11 | project: "spaghetti-phase1-scaled-vectors"
12 |
13 | model:
14 | #variance_schedule:
15 | #num_steps: 2000
16 | #beta_1: 1e-6
17 | #beta_T: 0.02
18 | #mode: linear
19 |
20 | network:
21 | input_dim: 13
22 | scale_eigenvectors: &scale_vectors true
23 | dataset_kwargs:
24 | scale_eigenvectors: *scale_vectors
25 |
26 | use_scaled_eigenvectors: *scale_vectors
27 |
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/configs/experiment/phase1-sym-airplane.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | defaults:
3 | - override /model: ldm_big_phase1.yaml
4 | - override /paths: default.yaml
5 | - override /data: spaghetti_half_airplane.yaml
6 |
7 | paths:
8 | save_dir: ${paths.work_dir}/results/phase1-sym-airplane/${name}/${now:%m%d_%H%M%S}
9 |
10 | logger:
11 | wandb:
12 | project: "spaghetti-gaus-sym-airplane"
13 |
14 | model:
15 | _target_: pvd.diffusion.phase1_sym.GaussianSymSALDM
16 | use_symmetric: &sym true
17 | spaghetti_tag: "airplanes"
18 |
19 |
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/configs/experiment/phase1-sym.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | defaults:
3 | - override /model: ldm_big_phase1.yaml
4 | - override /paths: default.yaml
5 | - override /data: spaghetti_half_chair.yaml
6 |
7 | paths:
8 | save_dir: ${paths.work_dir}/results/phase1-sym/${name}/${now:%m%d_%H%M%S}
9 |
10 | logger:
11 | wandb:
12 | project: "spaghetti-gaus-sym"
13 |
14 | model:
15 | _target_: pvd.diffusion.phase1_sym.GaussianSymSALDM
16 | use_symmetric: &sym true
17 |
18 |
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/configs/experiment/phase1-table.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | defaults:
3 | - override /model: ldm_big_phase1.yaml
4 | - override /paths: default.yaml
5 | - override /data: spaghetti_table.yaml
6 |
7 | paths:
8 | save_dir: ${paths.work_dir}/results/phase1-airplane/${name}/${now:%m%d_%H%M%S}
9 |
10 | logger:
11 | wandb:
12 | project: "spaghetti-gaus-table"
13 |
14 | model:
15 | spaghetti_tag: "tables"
16 |
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/configs/experiment/phase1.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | defaults:
3 | - override /model: ldm_big_phase1.yaml
4 | - override /paths: default.yaml
5 |
6 | paths:
7 | save_dir: ${paths.work_dir}/results/phase1/${name}/${now:%m%d_%H%M%S}
8 |
9 | logger:
10 | wandb:
11 | # < Seungwoo >
12 | # Modified project name to prevent overwrite.
13 | # project: "spaghetti-gaus"
14 | project: "sw_spaghetti-gaus"
15 |
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/configs/experiment/phase1_sym.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | defaults:
3 | - override /model: ldm_big_phase1.yaml
4 | - override /paths: default.yaml
5 |
6 | #### < Seungwoo >
7 | paths:
8 | # save_dir: ${paths.work_dir}/results/phase1/${name}/${now:%m%d_%H%M%S}
9 | save_dir: ${paths.work_dir}/results/phase1_sym/${name}/${now:%m%d_%H%M%S}
10 | ####
11 |
12 |
13 | #### < Seungwoo >
14 | # Override config in 'data/spaghetti_author' to load the proper data
15 | data:
16 | latent_path: "/home/dreamy1534/Projects/iccv2023-spaghetti/part_ldm/pvd/data/6755_2023-01-19-15-02-36_half.hdf5"
17 | spaghetti_tag: "chairs_large"
18 | ####
19 |
20 |
21 | #### < Seungwoo >
22 | # Override configs in 'ldm_big_phase1' to use the symmetric model.
23 | model:
24 | _target_: pvd.diffusion.phase1_sym.GaussianSALDM
25 | eigen_value_clipping: true
26 | ####
27 |
28 | logger:
29 | wandb:
30 | # < Seungwoo >
31 | # Modified project name to prevent overwrite.
32 | # project: "spaghetti-gaus"
33 | project: "sw_spaghetti-gaus-sym"
34 |
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/configs/experiment/phase2-airplane.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | defaults:
3 | - override /model: ldm_big_phase2.yaml
4 | - override /paths: default.yaml
5 | - override /data: spaghetti_author_airplane.yaml
6 |
7 | paths:
8 | save_dir: ${paths.work_dir}/results/phase2-airplane/${name}/${now:%m%d_%H%M%S}
9 |
10 | logger:
11 | wandb:
12 | project: "spaghetti-phase2-airplane"
13 |
14 | model:
15 | spaghetti_tag: "airplanes"
16 |
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/configs/experiment/phase2-scaled-vectors.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | defaults:
3 | - override /model: ldm_big_phase2.yaml
4 | - override /paths: default.yaml
5 |
6 | paths:
7 | save_dir: ${paths.work_dir}/results/phase2-scaled-vectors/${name}/${now:%m%d_%H%M%S}
8 |
9 | logger:
10 | wandb:
11 | project: "spaghetti-phase2-scaled-vectors"
12 |
13 | model:
14 | condition_var_sched:
15 | num_steps: 2000
16 | beta_1: 1e-6
17 | beta_T: 0.02
18 | max_cond_t: 4
19 |
20 | network:
21 | context_dim: 13
22 | scale_eigenvectors: &scale_vectors true
23 | dataset_kwargs:
24 | scale_eigenvectors: *scale_vectors
25 |
26 | use_scaled_eigenvectors: *scale_vectors
27 | augment_condition: true
28 |
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/configs/experiment/phase2-table.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | defaults:
3 | - override /model: ldm_big_phase2.yaml
4 | - override /paths: default.yaml
5 | - override /data: spaghetti_table.yaml
6 |
7 | paths:
8 | save_dir: ${paths.work_dir}/results/phase2-airplane/${name}/${now:%m%d_%H%M%S}
9 |
10 | logger:
11 | wandb:
12 | project: "spaghetti-phase2-table"
13 |
14 | model:
15 | spaghetti_tag: "tables"
16 |
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/configs/experiment/phase2.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | defaults:
3 | - override /model: ldm_big_phase2.yaml
4 | - override /paths: default.yaml
5 |
6 | paths:
7 | save_dir: ${paths.work_dir}/results/phase2/${name}/${now:%m%d_%H%M%S}
8 |
9 | logger:
10 | wandb:
11 | ####
12 | # < Seungwoo >
13 | # Modified project name to prevent overwrite.
14 | # project: "spaghetti-phase2"
15 | project: "sw_spaghetti-phase2"
16 | ####
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/configs/experiment/reverse_phase.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | defaults:
3 | - override /model: ldm_big_phase2.yaml
4 | - override /paths: default.yaml
5 |
6 | paths:
7 | save_dir: ${paths.work_dir}/results/reverse_phase/${name}/${now:%m%d_%H%M%S}
8 |
9 | logger:
10 | wandb:
11 | project: "spaghetti-reverse_phase"
12 |
13 |
14 | model:
15 | _target_: pvd.diffusion.reverse_model.ReverseConditionSALDM
16 | network:
17 | input_dim: 16
18 | residual: true
19 | context_dim: 512
20 |
21 | dataset_kwargs:
22 | data_keys: ["g_js_affine", "s_j_affine"]
23 |
24 |
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/configs/experiment/simclr-chair.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | defaults:
3 | - override /model: simclr.yaml
4 | - override /paths: default.yaml
5 | - override /data: spaghetti_author.yaml
6 |
7 | paths:
8 | save_dir: ${paths.work_dir}/results/simclr-chair/${name}/${now:%m%d_%H%M%S}
9 |
10 | logger:
11 | wandb:
12 | project: "simclr-chair"
13 |
14 | callbacks:
15 | model_checkpoint:
16 | every_n_epochs: null
17 | save_top_k: 1
18 | monitor: "val/loss"
19 |
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/configs/experiment/single_phase-airplane.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | defaults:
3 | - override /model: ldm_big_single_phase.yaml
4 | - override /paths: default.yaml
5 | - override /data: spaghetti_author_airplane.yaml
6 |
7 | paths:
8 | save_dir: ${paths.work_dir}/results/single_phase-airplane/${name}/${now:%m%d_%H%M%S}
9 |
10 | logger:
11 | wandb:
12 | project: "spaghetti-single-phase-airplane"
13 |
14 | model:
15 | spaghetti_tag: "airplanes"
16 |
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/configs/experiment/single_phase-pos_enc-airplane.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | defaults:
3 | - override /model: ldm_big_single_phase-pos_enc.yaml
4 | - override /paths: default.yaml
5 | - override /data: spaghetti_author_airplane.yaml
6 |
7 | paths:
8 | save_dir: ${paths.work_dir}/results/single_phase-airplane-pos_enc/${name}/${now:%m%d_%H%M%S}
9 |
10 | logger:
11 | wandb:
12 | project: "spaghetti-single-phase-airplane-pos_enc"
13 |
14 | model:
15 | spaghetti_tag: "airplanes"
16 |
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/configs/experiment/single_phase-pos_enc.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | defaults:
3 | - override /model: ldm_big_single_phase-pos_enc.yaml
4 | - override /paths: default.yaml
5 |
6 | paths:
7 | save_dir: ${paths.work_dir}/results/single_phase-pos_enc/${name}/${now:%m%d_%H%M%S}
8 |
9 | logger:
10 | wandb:
11 | project: "spaghetti-single-phase-pos_enc"
12 |
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/configs/experiment/single_phase-table.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | defaults:
3 | - override /model: ldm_big_single_phase.yaml
4 | - override /paths: default.yaml
5 | - override /data: spaghetti_table.yaml
6 |
7 | paths:
8 | save_dir: ${paths.work_dir}/results/single_phase-table/${name}/${now:%m%d_%H%M%S}
9 |
10 | logger:
11 | wandb:
12 | project: "spaghetti-single-phase-table"
13 |
14 | model:
15 | spaghetti_tag: "tables"
16 |
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/configs/experiment/single_phase.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | defaults:
3 | - override /model: ldm_big_single_phase.yaml
4 | - override /paths: default.yaml
5 |
6 | paths:
7 | save_dir: ${paths.work_dir}/results/single_phase/${name}/${now:%m%d_%H%M%S}
8 |
9 | logger:
10 | wandb:
11 | project: "spaghetti-single-phase"
12 |
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/configs/experiment/single_phase_zb-airplane.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | defaults:
3 | - override /model: ldm_big_single_phase_zb.yaml
4 | - override /paths: default.yaml
5 | - override /data: spaghetti_author_airplane.yaml
6 |
7 | paths:
8 | save_dir: ${paths.work_dir}/results/single_phase-zb-airplane/${name}/${now:%m%d_%H%M%S}
9 |
10 | logger:
11 | wandb:
12 | project: "spaghetti-single-phase-zb-airplane"
13 |
14 | model:
15 | spaghetti_tag: "airplanes"
16 |
17 | data:
18 | latent_path: /home/juil/projects/3D_CRISPR/data/spaghetti_airplane_inversion/3236-03-02-112852-inversion_with_replacement_with_zb.hdf5
19 |
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/configs/experiment/single_phase_zb.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | defaults:
3 | - override /model: ldm_big_single_phase_zb.yaml
4 | - override /paths: default.yaml
5 |
6 | paths:
7 | save_dir: ${paths.work_dir}/results/single_phase-zb/${name}/${now:%m%d_%H%M%S}
8 |
9 | logger:
10 | wandb:
11 | project: "spaghetti-single-phase-zb"
12 |
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/configs/experiment/spaghetti_condition_self_attention.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | defaults:
4 | - override /model: ldm_big_phase2.yaml
5 | - override /paths: default.yaml
6 |
7 | paths:
8 | save_dir: ${paths.work_dir}/results/spaghetti_condition_sa/${name}/${now:%m%d_%H%M%S}
9 |
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/configs/experiment/za-airplane.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | defaults:
3 | - override /model: ldm_big_za.yaml
4 | - override /paths: default.yaml
5 | - override /data: spaghetti_author_airplane.yaml
6 |
7 | paths:
8 | save_dir: ${paths.work_dir}/results/za-airplane/${name}/${now:%m%d_%H%M%S}
9 |
10 | logger:
11 | wandb:
12 | project: "spaghetti-za-airplane"
13 |
14 | model:
15 | spaghetti_tag: "airplanes"
16 |
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/configs/experiment/za.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | defaults:
3 | - override /model: ldm_big_za.yaml
4 | - override /paths: default.yaml
5 |
6 | paths:
7 | save_dir: ${paths.work_dir}/results/za/${name}/${now:%m%d_%H%M%S}
8 |
9 | logger:
10 | wandb:
11 | project: "spaghetti-za"
12 |
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/configs/hydra/default.yaml:
--------------------------------------------------------------------------------
1 | #defaults:
2 | #- override hydra_logging: colorlog
3 | #- override job_logging: colorlog
4 |
5 | run:
6 | dir: ${paths.save_dir}
7 |
8 | sweep:
9 | dir: ${paths.save_dir}
10 | subdir: ${hydra.job.num}
11 |
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/configs/logger/multi_loggers.yaml:
--------------------------------------------------------------------------------
1 | wandb:
2 | _target_: pytorch_lightning.loggers.wandb.WandbLogger
3 | project: "spaghetti-ldm"
4 | name: ${name}
5 | save_dir: "."
6 | entity: "geometry"
7 |
8 | tensorboard:
9 | _target_: pytorch_lightning.loggers.tensorboard.TensorBoardLogger
10 | save_dir: "tensorboard/"
11 | name: null
12 | version: ${name}
13 | log_graph: false
14 | default_hp_metric: true
15 | prefix: ""
16 |
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/configs/logger/tensorboard.yaml:
--------------------------------------------------------------------------------
1 | tensorboard:
2 | _target_: pytorch_lightning.loggers.tensorboard.TensorBoardLogger
3 | save_dir: "tensorboard/"
4 | name: null
5 | #version: ${name}
6 | log_graph: true
7 | default_hp_metric: false
8 | prefix: ""
9 |
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/configs/model/ldm_big_phase1.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - ldm_default.yaml
3 |
4 | _target_: pvd.diffusion.phase1.GaussianSALDM
5 | use_scaled_eigenvectors: false
6 | network:
7 | _target_: pvd.diffusion.network.LatentSelfAttentionNetwork
8 | input_dim: 16
9 | embedding_dim: 512
10 | num_heads: 4
11 | use_timestep_embedder: true
12 | timestep_embedder_dim: 128
13 | enc_num_layers: 6
14 | residual: true
15 | scale_eigenvectors: ${model.use_scaled_eigenvectors} #*scale_vectors
16 | encoder_type: transformer
17 |
18 | validation_step: 1
19 | eigen_value_clipping: false
20 | eigen_value_clip_val: 1e-4
21 |
22 | global_normalization: false
23 |
24 | dataset_kwargs:
25 | data_keys: ["g_js_affine"]
26 | scale_eigenvectors: ${model.use_scaled_eigenvectors} #*scale_vectors
27 | global_normalization: ${model.global_normalization}
28 |
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/configs/model/ldm_big_phase1_language.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - ldm_lang_default.yaml
3 |
4 | #_target_: pvd.diffusion.language_cond.LanguageGaussianSymSALDM
5 | _target_: pvd.diffusion.language_cond.ShapeglotPhase1
6 | network:
7 | _target_: pvd.diffusion.network.LatentConditionSelfAttentionNetwork
8 | input_dim: 16
9 | residual: true
10 | context_dim: 768
11 | context_embedding_dim: 1024
12 | embedding_dim: 512
13 | encoder_use_time: false
14 | encoder_type: pointwise
15 | decoder_type: transformer_encoder
16 | enc_num_layers: 2
17 | dec_num_layers: 6
18 | use_timestep_embedder: true
19 | timestep_embedder_dim: 128
20 | scale_eigenvectors: &scale_vectors false
21 | use_pos_encoding: false
22 |
23 | classifier_free_guidance: false
24 | conditioning_dropout_prob: 0.2
25 |
26 | dataset_kwargs:
27 | data_keys: ["g_js_affine"]
28 | #lang_data_path: /home/juil/projects/3D_CRISPR/data/text2shape/T5embed2spaghetti_aug_ids.h5
29 | #lang_data_path: /home/juil/projects/3D_CRISPR/data/shapeglot/spaghetti_game_data.csv
30 |
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/configs/model/ldm_big_phase1_pos_enc.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - ldm_default.yaml
3 |
4 | _target_: pvd.diffusion.phase1.GaussianSALDM
5 | use_scaled_eigenvectors: false
6 | network:
7 | _target_: pvd.diffusion.network.PosEncSelfAttentionNetwork
8 | input_dim: 16
9 | embedding_dim: 512
10 | num_heads: 4
11 | use_timestep_embedder: true
12 | timestep_embedder_dim: 128
13 | enc_num_layers: 6
14 | residual: true
15 | scale_eigenvectors: ${model.use_scaled_eigenvectors} #*scale_vectors
16 |
17 | validation_step: 5
18 | eigen_value_clipping: false
19 | eigen_value_clip_val: 1e-4
20 |
21 | global_normalization: false
22 |
23 | dataset_kwargs:
24 | data_keys: ["g_js_affine"]
25 | scale_eigenvectors: ${model.use_scaled_eigenvectors} #*scale_vectors
26 | global_normalization: ${model.global_normalization}
27 |
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/configs/model/ldm_big_phase2.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - ldm_default.yaml
3 |
4 | _target_: pvd.diffusion.ldm.SpaghettiConditionSALDM
5 | network:
6 | _target_: pvd.diffusion.network.LatentConditionSelfAttentionNetwork
7 | input_dim: 512
8 | residual: true
9 | context_dim: 16
10 | context_embedding_dim: 512
11 | embedding_dim: 512
12 | encoder_use_time: false
13 | encoder_type: transformer
14 | decoder_type: transformer_encoder
15 | enc_num_layers: 6
16 | dec_num_layers: 6
17 | use_timestep_embedder: true
18 | timestep_embedder_dim: 128
19 | scale_eigenvectors: &scale_vectors false
20 |
21 | classifier_free_guidance: false
22 | conditioning_dropout_prob: 0.0
23 | conditioning_dropout_level: shape
24 |
25 | augment_condition: false
26 | phase1_ckpt_path: null
27 | augment_timestep: 100
28 | mu_noise_scale: 1
29 | eigenvectors_noise_scale: 1
30 | pi_noise_scale: 1
31 | eigenvalues_noise_scale: 1
32 |
33 | use_scaled_eigenvectors: *scale_vectors
34 | sj_global_normalization: false
35 |
36 | dataset_kwargs:
37 | data_keys: ["s_j_affine", "g_js_affine"]
38 | scale_eigenvectors: *scale_vectors
39 | sj_global_normalization: ${model.sj_global_normalization}
40 |
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/configs/model/ldm_big_phase2_language.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - ldm_lang_default.yaml
3 |
4 | #_target_: pvd.diffusion.language_cond.LanguageConditionIntrinsicLDM
5 | _target_: pvd.diffusion.language_cond.ShapeglotPhase2
6 | global_normalization: false
7 | network:
8 | _target_: pvd.diffusion.network.LatentConditionSelfAttentionNetwork
9 | input_dim: 512
10 | residual: true
11 | context_dim: 784
12 | context_embedding_dim: 1024
13 | embedding_dim: 512
14 | encoder_use_time: false
15 | encoder_type: transformer
16 | decoder_type: transformer_encoder
17 | enc_num_layers: 6
18 | dec_num_layers: 6
19 | use_timestep_embedder: true
20 | timestep_embedder_dim: 128
21 | scale_eigenvectors: &scale_vectors false
22 |
23 | classifier_free_guidance: false
24 | conditioning_dropout_prob: 0.2
25 |
26 | dataset_kwargs:
27 | data_keys: ["s_j_affine", "g_js_affine"]
28 | #lang_data_path: /home/juil/projects/3D_CRISPR/data/text2shape/T5embed2spaghetti_aug_ids.h5
29 | #lang_data_path: /home/juil/projects/3D_CRISPR/data/shapeglot/spaghetti_game_data.csv
30 |
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/configs/model/ldm_big_single_phase-pos_enc.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - ldm_default.yaml
3 |
4 | _target_: pvd.diffusion.single_phase.SingleSALDM
5 | network:
6 | _target_: pvd.diffusion.network.PosEncSelfAttentionNetwork
7 | input_dim: 528
8 | embedding_dim: 512
9 | num_heads: 4
10 | use_timestep_embedder: true
11 | timestep_embedder_dim: 128
12 | enc_num_layers: 6
13 | residual: true
14 | scale_eigenvectors: &scale_vectors false
15 |
16 | validation_step: 10
17 | eigen_value_clipping: false
18 | eigen_value_clip_val: 1e-4
19 | use_scaled_eigenvectors: *scale_vectors
20 | global_normalization: false
21 |
22 | dataset_kwargs:
23 | data_keys: ["s_j_affine", "g_js_affine"]
24 | scale_eigenvectors: *scale_vectors
25 | concat_data: true
26 | global_normalization: ${model.global_normalization}
27 |
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/configs/model/ldm_big_single_phase.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - ldm_default.yaml
3 |
4 | _target_: pvd.diffusion.single_phase.SingleSALDM
5 | network:
6 | _target_: pvd.diffusion.network.LatentSelfAttentionNetwork
7 | input_dim: 528
8 | embedding_dim: 512
9 | num_heads: 4
10 | use_timestep_embedder: true
11 | timestep_embedder_dim: 128
12 | enc_num_layers: 6
13 | residual: true
14 | scale_eigenvectors: &scale_vectors false
15 |
16 | validation_step: 10
17 | eigen_value_clipping: false
18 | eigen_value_clip_val: 1e-4
19 | use_scaled_eigenvectors: *scale_vectors
20 | global_normalization: false
21 |
22 | dataset_kwargs:
23 | data_keys: ["s_j_affine", "g_js_affine"]
24 | scale_eigenvectors: *scale_vectors
25 | concat_data: true
26 | global_normalization: ${model.global_normalization}
27 |
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/configs/model/ldm_big_single_phase_language.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - ldm_lang_default.yaml
3 |
4 | _target_: pvd.diffusion.language_cond.ShapeglotSinglePhase
5 | network:
6 | _target_: pvd.diffusion.network.LatentConditionSelfAttentionNetwork
7 | input_dim: 528
8 | residual: true
9 | context_dim: 768
10 | context_embedding_dim: 1024
11 | embedding_dim: 512
12 | encoder_use_time: false
13 | encoder_type: pointwise
14 | decoder_type: transformer_encoder
15 | enc_num_layers: 2
16 | dec_num_layers: 6
17 | use_timestep_embedder: true
18 | timestep_embedder_dim: 128
19 | use_pos_encoding: false
20 |
21 | classifier_free_guidance: false
22 | conditioning_dropout_prob: 0.2
23 | global_normalization: false
24 |
25 | dataset_kwargs:
26 | data_keys: ["s_j_affine", "g_js_affine"]
27 | #lang_data_path: /home/juil/projects/3D_CRISPR/data/shapeglot/spaghetti_game_data.csv
28 | concat_data: true
29 | global_normalization: ${model.global_normalization}
30 |
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/configs/model/ldm_big_single_phase_zb.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - ldm_default.yaml
3 |
4 | _target_: pvd.diffusion.single_phase.SingleZbSALDM
5 | network:
6 | _target_: pvd.diffusion.network.LatentSelfAttentionNetwork
7 | input_dim: 512
8 | embedding_dim: 512
9 | num_heads: 4
10 | use_timestep_embedder: true
11 | timestep_embedder_dim: 128
12 | enc_num_layers: 6
13 | residual: true
14 | scale_eigenvectors: &scale_vectors false
15 | use_pos_encoding: false
16 |
17 | validation_step: 10
18 | eigen_value_clipping: false
19 | eigen_value_clip_val: 1e-4
20 | use_scaled_eigenvectors: *scale_vectors
21 | global_normalization: false
22 |
23 | dataset_kwargs:
24 | data_keys: ["Zb"]
25 | scale_eigenvectors: *scale_vectors
26 | global_normalization: ${model.global_normalization}
27 |
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/configs/model/ldm_big_za.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - ldm_default.yaml
3 |
4 | _target_: pvd.diffusion.phase1.ZaSALDM
5 | network:
6 | _target_: pvd.diffusion.network.LatentLinearNetwork
7 | input_dim: 256
8 | embedding_dim: 512
9 | use_timestep_embedder: true
10 | timestep_embedder_dim: 128
11 | enc_num_layers: 6
12 | residual: true
13 |
14 | validation_step: 1
15 |
16 | dataset_kwargs:
17 | data_keys: ["za"]
18 |
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/configs/model/ldm_default.yaml:
--------------------------------------------------------------------------------
1 | _target_: pvd.diffusion.ldm.SpaghettiSALDM
2 | network: null
3 |
4 | variance_schedule:
5 | _target_: pvd.diffusion.common.VarianceSchedule
6 | num_steps: &time_steps 1000
7 | beta_1: 1e-4
8 | beta_T: 0.05
9 | mode: linear
10 |
11 | # optimizer
12 | lr: 1e-4
13 | batch_size: ${batch_size}
14 |
15 | # dataset
16 | dataset_kwargs:
17 | data_path: ${data.latent_path}
18 | repeat: 3
19 | data_keys: ["s_j_affine"]
20 |
21 | # model
22 | num_timesteps: *time_steps
23 | faster: true
24 | validation_step: 10
25 | no_run_validation: false
26 | spaghetti_tag: "chairs_large" # or airplanes
27 |
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/configs/model/ldm_lang_default.yaml:
--------------------------------------------------------------------------------
1 | _target_: pvd.diffusion.ldm.SpaghettiSALDM
2 | network: null
3 |
4 | variance_schedule:
5 | _target_: pvd.diffusion.common.VarianceSchedule
6 | num_steps: &time_steps 1000
7 | beta_1: 1e-4
8 | beta_T: 0.05
9 | mode: linear
10 |
11 | # optimizer
12 | lr: 1e-4
13 | batch_size: ${batch_size}
14 |
15 | global_normalization: partial
16 | # dataset
17 | dataset_kwargs:
18 | data_path: ${data.latent_path}
19 | repeat: 1
20 | data_keys: ["s_j_affine"]
21 | lang_data_path: /home/juil/projects/3D_CRISPR/data/shapeglot/autosdf_spaghetti_game_data_v2.csv
22 | only_correct: false
23 | only_easy_context: false
24 | max_dataset_size: 10000000
25 | global_normalization: ${model.global_normalization}
26 |
27 | # model
28 | num_timesteps: *time_steps
29 | faster: true
30 | validation_step: 5
31 | no_run_validation: false
32 | spaghetti_tag: "chairs_large" # or airplanes
33 | text_encoder_freeze: true
34 | text_encoder_return_seq: false
35 | use_lstm: false
36 | use_partglot_data: false
37 |
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/configs/model/simclr.yaml:
--------------------------------------------------------------------------------
1 | _target_: pvd.embedding.model.CLIPEmbedder
2 | ext_embedder:
3 | _target_: pvd.embedding.model.TFEmbedder
4 | dim_in: 16
5 | dim_h: 512
6 | dim_out: 512
7 | num_layers: 6
8 | num_heads: 4
9 |
10 | int_embedder:
11 | _target_: pvd.embedding.model.TFEmbedder
12 | dim_in: 512
13 | dim_h: 512
14 | dim_out: 512
15 | num_layers: 6
16 | num_heads: 4
17 |
18 | temperature: 3
19 | lr: 1e-4
20 | batch_size: ${batch_size}
21 |
22 | dataset_kwargs:
23 | data_path: ${data.latent_path}
24 | repeat: 3
25 | data_keys: ["s_j_affine", "g_js_affine"]
26 |
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/configs/model/vae.yaml:
--------------------------------------------------------------------------------
1 | hydra:
2 | run:
3 | dir: ${save_dir}
4 |
5 | defaults:
6 | - _self_
7 |
8 | trainer:
9 | _target_: pytorch_lightning.Trainer
10 | accelerator: "gpu"
11 | devices: ${gpus}
12 | max_epochs: ${epochs}
13 |
14 | work_dir: ${hydra:runtime.cwd}
15 | save_dir: ${work_dir}/results/${name}/${now:%m%d_%H%M%S}
16 |
17 | epochs: 200
18 | batch_size: 16
19 | gpus: [1]
20 | seed: 63
21 | name: default
22 | debug: false
23 |
24 | model:
25 | _target_: pvd.models.vae.PointNetVAE
26 | encoder: pointnet
27 | batch_size: ${batch_size}
28 | save_dir: ${save_dir}
29 | kl_weight: 1e-4
30 |
31 | callbacks:
32 | model_summary:
33 | _target_: pytorch_lightning.callbacks.RichModelSummary
34 | max_depth: 1
35 |
36 | rich_progress_bar:
37 | _target_: pytorch_lightning.callbacks.RichProgressBar
38 |
39 | model_checkpoint:
40 | _target_: pytorch_lightning.callbacks.ModelCheckpoint
41 | monitor: "train/loss"
42 | mode: "min"
43 | save_top_k: 1
44 | save_last: true
45 | verbose: true
46 | dirpath: "checkpoints/"
47 | filename: "epoch={epoch:02d}-val_loss={val/loss:.4f}"
48 | auto_insert_metric_name: false
49 |
50 | lr_monitor:
51 | _target_: pytorch_lightning.callbacks.LearningRateMonitor
52 | logging_interval: "step"
53 |
54 | logger:
55 | wandb:
56 | _target_: pytorch_lightning.loggers.wandb.WandbLogger
57 | project: "point-vae"
58 | name: ${name}
59 | save_dir: "."
60 | entity: "geometry"
61 |
62 | tensorboard:
63 | _target_: pytorch_lightning.loggers.tensorboard.TensorBoardLogger
64 | save_dir: "tensorboard/"
65 | name: null
66 | version: ${name}
67 | log_graph: false
68 | default_hp_metric: true
69 | prefix: ""
70 |
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/configs/paths/default.yaml:
--------------------------------------------------------------------------------
1 | work_dir: ${hydra:runtime.cwd}
2 | save_dir: ${work_dir}/results/${name}/${now:%m%d_%H%M%S}
3 |
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/configs/simclr.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | defaults:
4 | - _self_
5 | - data: spaghetti_author.yaml
6 | - model: simclr.yaml
7 | - callbacks: default.yaml
8 | - logger: multi_loggers.yaml
9 | - trainer: default.yaml
10 | - paths: default.yaml
11 | - hydra: default.yaml
12 |
13 | - experiment: simclr-chair.yaml
14 | - debug: null
15 |
16 | name: default
17 | seed: 63
18 |
19 | gpus: [0]
20 | epochs: 100
21 | batch_size: 1024
22 |
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/configs/train.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | defaults:
4 | - _self_
5 | - data: spaghetti_author.yaml
6 | - model: ldm_default.yaml
7 | - callbacks: default.yaml
8 | - logger: multi_loggers.yaml
9 | - trainer: default.yaml
10 | - paths: default.yaml
11 | - hydra: default.yaml
12 |
13 | - experiment: phase2.yaml
14 | - debug: null
15 |
16 | name: default
17 | seed: 63
18 |
19 | gpus: [0]
20 | epochs: 5000
21 | batch_size: 32
22 |
23 | resume_from_checkpoint: null
24 |
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/configs/trainer/default.yaml:
--------------------------------------------------------------------------------
1 | _target_: pytorch_lightning.Trainer
2 | accelerator: gpu
3 | devices: ${gpus}
4 | max_epochs: ${epochs}
5 |
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/diffusion/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KAIST-Visual-AI-Group/SALAD/db66b1b2ce57d8a4beb5bd610bc006ccaff43b1e/notebooks/uncleaned_demo/pvd/diffusion/__init__.py
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/diffusion/phase1_sym.py:
--------------------------------------------------------------------------------
1 | import io
2 | from pathlib import Path
3 | from PIL import Image
4 | from pvd.dataset import batch_split_eigens
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | import numpy as np
9 | from pvd.diffusion.ldm import SpaghettiSALDM
10 | import jutils
11 | from pvd.utils.spaghetti_utils import *
12 |
13 |
14 | class GaussianSymSALDM(SpaghettiSALDM):
15 | def __init__(self, network, variance_schedule, **kwargs):
16 | super().__init__(network, variance_schedule, **kwargs)
17 |
18 | def forward(self, x):
19 | assert x.shape[1] == 8, f"x.shape: {x.shape}"
20 | return self.get_loss(x)
21 |
22 | @torch.no_grad()
23 | def sample(
24 | self,
25 | batch_size=0,
26 | return_traj=False,
27 | ):
28 | if self.hparams.get("use_scaled_eigenvectors"):
29 | x_T = torch.randn([batch_size, 8, 13]).to(self.device)
30 | else:
31 | x_T = torch.randn([batch_size, 8, 16]).to(self.device)
32 |
33 | traj = {self.var_sched.num_steps: x_T}
34 | for t in range(self.var_sched.num_steps, 0, -1):
35 | z = torch.randn_like(x_T) if t > 1 else torch.zeros_like(x_T)
36 | alpha = self.var_sched.alphas[t]
37 | alpha_bar = self.var_sched.alpha_bars[t]
38 | sigma = self.var_sched.get_sigmas(t, flexibility=0)
39 |
40 | c0 = 1.0 / torch.sqrt(alpha)
41 | c1 = (1 - alpha) / torch.sqrt(1 - alpha_bar)
42 |
43 | x_t = traj[t]
44 |
45 | beta = self.var_sched.betas[[t] * batch_size]
46 | e_theta = self.net(x_t, beta=beta)
47 |
48 | x_next = c0 * (x_t - c1 * e_theta) + sigma * z
49 | traj[t - 1] = x_next.detach()
50 |
51 | traj[t] = traj[t].cpu()
52 |
53 | if not return_traj:
54 | del traj[t]
55 |
56 | if self.hparams.get("eigen_value_clipping"):
57 | traj[0][:, :, 13:16] = torch.clamp_min(
58 | traj[0][:, :, 13:16], min=self.hparams["eigen_value_clip_val"]
59 | )
60 |
61 | ####
62 | if return_traj:
63 | return traj
64 | else:
65 | return traj[0]
66 |
67 | def sampling_gaussians(self, num_shapes):
68 | """
69 | Return:
70 | ldm_gaus: np.ndarray
71 | """
72 | ldm_gaus = self.sample(num_shapes)
73 |
74 | if self.hparams.get("use_scaled_eigenvectors"):
75 | ldm_gaus = jutils.thutil.th2np(batch_split_eigens(jutils.nputil.np2th(ldm_gaus)))
76 |
77 | half = ldm_gaus
78 | full = reflect_and_concat_gmms(half)
79 | ldm_gaus = full
80 | if self.hparams.get("global_normalization"):
81 | if not hasattr(self, "data_val"):
82 | self._build_dataset("val")
83 | print(f"[!] Unnormalize samples as {self.hparams.get('global_normalization')}")
84 | if self.hparams.get("global_normalization") == "partial":
85 | ldm_gaus = self.data_val.unnormalize_global_static(ldm_gaus, slice(12,None))
86 | elif self.hparams.get("global_normalization") == "all":
87 | ldm_gaus = self.data_val.unnormalize_global_static(ldm_gaus, slice(None))
88 |
89 | return ldm_gaus
90 |
91 | def validation(self):
92 | vis_num_shapes = 16
93 | ldm_gaus = self.sampling_gaussians(vis_num_shapes)
94 |
95 | pred_images = []
96 | for i in range(vis_num_shapes):
97 |
98 | def draw_gaus(gaus):
99 | gaus = clip_eigenvalues(gaus)
100 | img = jutils.visutil.render_gaussians(gaus, resolution=(256, 256))
101 | return img
102 |
103 | try:
104 | pred_img = draw_gaus(ldm_gaus[i])
105 | pred_images.append(pred_img)
106 | except:
107 | pass
108 |
109 | wandb_logger = self.get_wandb_logger()
110 |
111 | try:
112 | pred_images = jutils.imageutil.merge_images(pred_images)
113 | wandb_logger.log_image("Pred", [pred_images])
114 | except:
115 | pass
116 |
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/diffusion/reverse_model.py:
--------------------------------------------------------------------------------
1 | import io
2 | from typing import Optional
3 | from pathlib import Path
4 | from PIL import Image
5 | from pvd.utils.sde_utils import add_noise, gaus_denoising_tweedie, tweedie_approach
6 | import pytorch_lightning as pl
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 | from torch.nn import Module, Parameter, ModuleList
11 | import numpy as np
12 | from torch.utils.data.dataset import Subset
13 | from sklearn.manifold import TSNE
14 | import matplotlib.pyplot as plt
15 |
16 |
17 | from pvd.dataset import LatentDataset, SpaghettiLatentDataset, split_eigens
18 | from pvd.utils.train_utils import PolyDecayScheduler, get_dropout_mask
19 | from pvd.utils.spaghetti_utils import *
20 | from pvd.vae.vae import PointNetVAE
21 | from pvd.diffusion.network import *
22 | from eval3d import Evaluator
23 | from dotmap import DotMap
24 | from pvd.diffusion.common import *
25 | from pvd.diffusion.ldm import *
26 | import jutils
27 |
28 |
29 | class ReverseConditionSALDM(SpaghettiConditionSALDM):
30 | def __init__(self, network, variance_schedule, **kwargs):
31 | super().__init__(network, variance_schedule, **kwargs)
32 |
33 | def forward(self, x, cond):
34 | return self.get_loss(x, cond)
35 |
36 | def sample(
37 | self,
38 | num_samples_or_sjs,
39 | return_traj=False,
40 | return_cond=False
41 | ):
42 | if isinstance(num_samples_or_sjs, int):
43 | batch_size = num_samples_or_sjs
44 | ds = self._build_dataset("val")
45 | cond = torch.stack([ds[i][1] for i in range(batch_size)], 0)
46 | elif isinstance(num_samples_or_sjs, np.ndarray) or isinstance(num_samples_or_sjs, torch.Tensor):
47 | cond = jutils.nputil.np2th(num_samples_or_sjs)
48 | if cond.dim() == 2:
49 | cond = cond[None]
50 | batch_size = len(cond)
51 | x_T = torch.randn([batch_size, 16, 16]).to(self.device)
52 | cond = cond.to(self.device)
53 |
54 | traj = {self.var_sched.num_steps: x_T}
55 | for t in range(self.var_sched.num_steps, 0, -1):
56 | z = torch.randn_like(x_T) if t > 1 else torch.zeros_like(x_T)
57 | alpha = self.var_sched.alphas[t]
58 | alpha_bar = self.var_sched.alpha_bars[t]
59 | sigma = self.var_sched.get_sigmas(t, flexibility=0)
60 |
61 | c0 = 1.0 / torch.sqrt(alpha)
62 | c1 = (1 - alpha) / torch.sqrt(1 - alpha_bar)
63 |
64 | x_t = traj[t]
65 |
66 | beta = self.var_sched.betas[[t] * batch_size]
67 | e_theta = self.net(x_t, beta=beta, context=cond)
68 |
69 | x_next = c0 * (x_t - c1 * e_theta) + sigma * z
70 | traj[t - 1] = x_next.detach()
71 |
72 | traj[t] = traj[t].cpu()
73 |
74 | if not return_traj:
75 | del traj[t]
76 |
77 | if return_traj:
78 | if return_cond:
79 | return traj, cond
80 | return traj
81 | else:
82 | if return_cond:
83 | return traj[0], cond
84 | return traj[0]
85 |
86 | def validation(self):
87 | latent_ds = self._build_dataset("val")
88 | vis_num_shapes = 3
89 | num_variations = 3
90 | jutils.sysutil.clean_gpu()
91 |
92 | if not hasattr(self, "spaghetti"):
93 | spaghetti = load_spaghetti(self.device, self.hparams.spaghetti_tag if self.hparams.get("spaghetti_tag") else "chairs_large")
94 | self.spaghetti = spaghetti
95 | else:
96 | spaghetti = self.spaghetti
97 |
98 | if not hasattr(self, "mesher"):
99 | mesher = load_mesher(self.device)
100 | self.mesher = mesher
101 | else:
102 | mesher = self.mesher
103 |
104 |
105 | gt_gaus, gt_zs = zip(*[latent_ds[i+3] for i in range(vis_num_shapes)])
106 | gt_gaus, gt_zs = list(map(lambda x : torch.stack(x), [gt_gaus, gt_zs]))
107 |
108 | gt_zs_repeated = gt_zs.repeat_interleave(num_variations, 0)
109 | ldm_gaus, ldm_zs = self.sample(gt_zs_repeated, return_cond=True)
110 | ldm_gaus = clip_eigenvalues(ldm_gaus)
111 | ldm_zcs = generate_zc_from_sj_gaus(spaghetti, ldm_zs, ldm_gaus)
112 | gt_zcs = generate_zc_from_sj_gaus(spaghetti, gt_zs, gt_gaus)
113 | jutils.sysutil.clean_gpu()
114 |
115 | wandb_logger = self.get_wandb_logger()
116 | resolution = (256,256)
117 | for i in range(vis_num_shapes):
118 | img_per_shape = []
119 | gaus_img = jutils.visutil.render_gaussians(gt_gaus[i], resolution=resolution)
120 | vert, face = get_mesh_from_spaghetti(spaghetti, mesher, gt_zcs[i], res=128)
121 | gt_mesh_img = jutils.visutil.render_mesh(vert, face, resolution=resolution)
122 | gt_img = jutils.imageutil.merge_images([gaus_img, gt_mesh_img])
123 | gt_img = jutils.imageutil.draw_text(gt_img, "GT", font_size=24)
124 | img_per_shape.append(gt_img)
125 | for j in range(num_variations):
126 | try:
127 | gaus_img = jutils.visutil.render_gaussians(ldm_gaus[i * num_variations + j], resolution=resolution)
128 | vert, face = get_mesh_from_spaghetti(spaghetti, mesher, ldm_zcs[i * num_variations + j], res=128)
129 | mesh_img = jutils.visutil.render_mesh(vert, face, resolution=resolution)
130 | pred_img = jutils.imageutil.merge_images([gaus_img, mesh_img])
131 | pred_img = jutils.imageutil.draw_text(pred_img, f"{j}-th predicted gaus", font_size=24)
132 | img_per_shape.append(pred_img)
133 | except Exception as e:
134 | print(e)
135 |
136 | try:
137 | image = jutils.imageutil.merge_images(img_per_shape)
138 | wandb_logger.log_image("visualization", [image])
139 | except Exception as e:
140 | print(e)
141 |
142 |
143 |
144 |
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/diffusion/single_phase.py:
--------------------------------------------------------------------------------
1 | import io
2 | from pathlib import Path
3 | from PIL import Image
4 | from pvd.dataset import batch_split_eigens
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | import numpy as np
9 | from pvd.diffusion.ldm import *
10 |
11 |
12 | class SingleSALDM(SpaghettiSALDM):
13 | def __init__(self, network, variance_schedule, **kwargs):
14 | super().__init__(network, variance_schedule, **kwargs)
15 |
16 | @torch.no_grad()
17 | def sample(
18 | self,
19 | batch_size=0,
20 | return_traj=False,
21 | ):
22 | x_T = torch.randn([batch_size, 16, 528]).to(self.device)
23 |
24 | traj = {self.var_sched.num_steps: x_T}
25 | for t in range(self.var_sched.num_steps, 0, -1):
26 | z = torch.randn_like(x_T) if t > 1 else torch.zeros_like(x_T)
27 | alpha = self.var_sched.alphas[t]
28 | alpha_bar = self.var_sched.alpha_bars[t]
29 | sigma = self.var_sched.get_sigmas(t, flexibility=0)
30 |
31 | c0 = 1.0 / torch.sqrt(alpha)
32 | c1 = (1 - alpha) / torch.sqrt(1 - alpha_bar)
33 |
34 | x_t = traj[t]
35 |
36 | beta = self.var_sched.betas[[t] * batch_size]
37 | e_theta = self.net(x_t, beta=beta)
38 |
39 | x_next = c0 * (x_t - c1 * e_theta) + sigma * z
40 | traj[t - 1] = x_next.detach()
41 |
42 | traj[t] = traj[t].cpu()
43 |
44 | if not return_traj:
45 | del traj[t]
46 |
47 | if self.hparams.get("eigen_value_clipping"):
48 | raise NotImplementedError("Not implemented eig clip for single phase.")
49 | # traj[0][:, :, 13:16] = torch.clamp_min(
50 | # traj[0][:, :, 13:16], min=self.hparams["eigen_value_clip_val"]
51 | # )
52 | if return_traj:
53 | return traj
54 | else:
55 | return traj[0]
56 |
57 | def validation(self):
58 | vis_num_shapes = 4
59 | # ldm_gaus, gt_gaus = self.sampling_gaussians(vis_num_shapes)
60 | ldm_feats = self.sample(vis_num_shapes)
61 | ldm_sj, ldm_gaus = ldm_feats.split(split_size=[512,16], dim=-1)
62 | if self.hparams.get("global_normalization") == "partial":
63 | ldm_gaus = self._build_dataset("val").unnormalize_global_static(ldm_gaus, slice(12,None))
64 | elif self.hparams.get("global_normalization") == "all":
65 | ldm_gaus = self._build_dataset("val").unnormalize_global_static(ldm_gaus, slice(None))
66 |
67 | jutils.sysutil.clean_gpu()
68 | """ Spaghetti Decoding """
69 | if not hasattr(self, "spaghetti"):
70 | spaghetti = load_spaghetti(self.device, self.hparams.spaghetti_tag if self.hparams.get("spaghetti_tag") else "chairs_large")
71 | self.spaghetti = spaghetti
72 | else:
73 | spaghetti = self.spaghetti
74 |
75 | if not hasattr(self, "mesher"):
76 | mesher = load_mesher(self.device)
77 | self.mesher = mesher
78 | else:
79 | mesher = self.mesher
80 |
81 | ldm_zcs = generate_zc_from_sj_gaus(spaghetti, ldm_sj, ldm_gaus)
82 |
83 | camera_kwargs = dict(
84 | camPos = np.array([-2,2,-2]),
85 | camLookat=np.array([0,0,0]),
86 | camUp=np.array([0,1,0]),
87 | resolution=(256,256),
88 | samples=16,
89 | )
90 |
91 | wandb_logger = self.get_wandb_logger()
92 | for i in range(vis_num_shapes):
93 | try:
94 | vert, face = get_mesh_from_spaghetti(spaghetti, mesher, ldm_zcs[i])
95 | mesh_img = jutils.visutil.render_mesh(vert, face, **camera_kwargs)
96 | gaus_img = jutils.visutil.render_gaussians(clip_eigenvalues(ldm_gaus[i]), resolution=camera_kwargs["resolution"])
97 | img = jutils.imageutil.merge_images([gaus_img, mesh_img])
98 | wandb_logger.log_image("Pred", [img])
99 | except Exception as e:
100 | print(f"Error occured at visualizing, {e}")
101 |
102 | class SingleZbSALDM(SpaghettiSALDM):
103 | def validation(self):
104 | vis_num_shapes = 4
105 | ldm_zbs = self.sample(vis_num_shapes)
106 |
107 | if not hasattr(self, "spaghetti"):
108 | spaghetti = load_spaghetti(self.device, self.hparams.spaghetti_tag if self.hparams.get("spaghetti_tag") else "chairs_large")
109 | self.spaghetti = spaghetti
110 | else:
111 | spaghetti = self.spaghetti
112 |
113 | if not hasattr(self, "mesher"):
114 | mesher = load_mesher(self.device)
115 | self.mesher = mesher
116 | else:
117 | mesher = self.mesher
118 |
119 | ldm_sjs, ldm_gmms = spaghetti.decomposition_control.forward_mid(ldm_zbs)
120 | ldm_gaus = batch_gmms_to_gaus(ldm_gmms)
121 | ldm_zb_hat = spaghetti.merge_zh_step_a(ldm_sjs, ldm_gmms)
122 | ldm_zcs, _ = spaghetti.mixing_network.forward_with_attention(ldm_zb_hat)
123 |
124 | camera_kwargs = dict(
125 | camPos = np.array([-2,2,-2]),
126 | camLookat=np.array([0,0,0]),
127 | camUp=np.array([0,1,0]),
128 | resolution=(256,256),
129 | samples=16,
130 | )
131 |
132 | wandb_logger = self.get_wandb_logger()
133 | for i in range(vis_num_shapes):
134 | try:
135 | vert, face = get_mesh_from_spaghetti(spaghetti, mesher, ldm_zcs[i])
136 | mesh_img = jutils.visutil.render_mesh(vert, face, **camera_kwargs)
137 | gaus_img = jutils.visutil.render_gaussians(clip_eigenvalues(ldm_gaus[i]), resolution=camera_kwargs["resolution"])
138 | img = jutils.imageutil.merge_images([gaus_img, mesh_img])
139 | wandb_logger.log_image("Pred", [img])
140 | except Exception as e:
141 | print(f"Error occured at visualizing, {e}")
142 |
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/embedding/model.py:
--------------------------------------------------------------------------------
1 | from pvd.utils.train_utils import PolyDecayScheduler
2 | from tqdm import tqdm
3 | from rich.progress import track
4 | from pvd.vae.vae import conv_block
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | import pytorch_lightning as pl
9 | import numpy as np
10 | from pathlib import Path
11 | from dotmap import DotMap
12 |
13 | import jutils
14 |
15 | from pvd.dataset import *
16 | from pvd.diffusion.network import *
17 |
18 | class ContrastiveLearningDistWrapper:
19 | def __init__(self,**kwargs):
20 | self.__dict__.update(kwargs)
21 | self.hparams = DotMap(self.__dict__)
22 |
23 | def __call__(self, feat_a, feat_b):
24 | """
25 | Input:
26 | feat_a: [B,D_e]
27 | feat_b: [B,D_e]
28 | Output:
29 | dist_mat: [B,B]
30 | """
31 | assert feat_a.dim() == feat_b.dim()
32 | assert feat_a.shape[0] == feat_b.shape[0]
33 | if feat_a.dim() == 1:
34 | feat_a = feat_a.unsqueeze(0)
35 | feat_b = feat_b.unsqueeze(0)
36 |
37 |
38 | class TFEmbedder(nn.Module):
39 | def __init__(self, **kwargs):
40 | super().__init__()
41 | self.__dict__.update(kwargs)
42 | self.hparams = DotMap(self.__dict__)
43 |
44 | self.embedding = nn.Linear(self.hparams.dim_in, self.hparams.dim_h)
45 | dim_h = self.hparams.dim_h
46 | self.encoder = TimeTransformerEncoder(dim_h, dim_ctx=None, num_heads=self.hparams.num_heads, use_time=False, num_layers=self.hparams.num_layers, last_fc=True, last_fc_dim_out=self.hparams.dim_out)
47 |
48 | def forward(self, x):
49 | """
50 | Input:
51 | x: [B,G,D_in]
52 | Output:
53 | out: [B,D_out]
54 | """
55 | B, G = x.shape[:2]
56 | # res = x
57 |
58 | x = self.embedding(x)
59 | out = self.encoder(x)
60 | out = out.max(1)[0]
61 | return out
62 |
63 |
64 |
65 | class CLIPEmbedder(pl.LightningModule):
66 | def __init__(self, ext_embedder, int_embedder, **kwargs):
67 | super().__init__()
68 | self.save_hyperparameters(logger=False)
69 | self.ext_embedder = ext_embedder
70 | self.int_embedder = int_embedder
71 |
72 | def forward(self, intrinsics, extrinsics):
73 | """
74 | Input:
75 | intrinsics: [Bi,16,512]
76 | extrinsics: [Be,16,16]
77 | Output:
78 | cossim_mat: [Bi,Be]
79 | """
80 |
81 | int_f = self.int_embedder(intrinsics) #[B,D_emb]
82 | ext_f = self.ext_embedder(extrinsics) #[B,D_emb]
83 |
84 | int_e = int_f / int_f.norm(dim=-1, keepdim=True)
85 | ext_e = ext_f / ext_f.norm(dim=-1, keepdim=True)
86 |
87 | sim_mat = torch.einsum("nd,md->nm", int_e, ext_e)
88 |
89 | return sim_mat, int_e, ext_e
90 |
91 | def info_nce_loss(self, sim_mat):
92 | """
93 | Input:
94 | sim_mat: [N,N]
95 | """
96 | assert sim_mat.dim() == 2 and sim_mat.shape[0] == sim_mat.shape[1]
97 | N = sim_mat.shape[0]
98 | logits = sim_mat * np.exp(self.hparams.temperature)
99 | labels = torch.arange(N).to(logits).long()
100 | loss = F.cross_entropy(logits, labels)
101 | return loss
102 |
103 | def step(self, batch, batch_idx, stage: str):
104 | intrinsics, extrinsics = batch
105 |
106 | sim_mat, int_e, ext_e = self(intrinsics, extrinsics)
107 | loss = self.info_nce_loss(sim_mat)
108 | self.log(f"{stage}/loss", loss, on_step=stage == "train", prog_bar=True)
109 | return loss
110 |
111 | def training_step(self, batch, batch_idx):
112 | return self.step(batch, batch_idx, "train")
113 |
114 | def validation_step(self, batch, batch_idx):
115 | return self.step(batch, batch_idx, "val")
116 |
117 | def get_intrinsic_extrinsic_distance_matrix(self):
118 | """
119 | [num_i, num_e]
120 | """
121 | ds = self._build_dataset("val")
122 | intr = ds.data["s_j_affine"]
123 | extr = ds.data["g_js_affine"]
124 | intr, extr = list(map(lambda x: jutils.nputil.np2th(x).to(self.device), [intr, extr]))
125 | dist_mat = self(intr, extr)[0].cpu().numpy()
126 |
127 | return dist_mat
128 |
129 |
130 |
131 | def _build_dataset(self, stage):
132 | if stage == "train":
133 | ds = SpaghettiLatentDataset(**self.hparams.dataset_kwargs)
134 | else:
135 | dataset_kwargs = self.hparams.dataset_kwargs.copy()
136 | dataset_kwargs["repeat"] = 1
137 | ds = SpaghettiLatentDataset(**dataset_kwargs)
138 | setattr(self, f"data_{stage}", ds)
139 | return ds
140 |
141 | def _build_dataloader(self, stage, batch_size=None):
142 | try:
143 | ds = getattr(self, f"data_{stage}")
144 | except:
145 | ds = self._build_dataset(stage)
146 |
147 | return torch.utils.data.DataLoader(
148 | ds,
149 | batch_size=batch_size if batch_size is not None else self.hparams.batch_size,
150 | shuffle=stage == "train",
151 | drop_last=stage == "train",
152 | num_workers=4,
153 | )
154 | def train_dataloader(self):
155 | return self._build_dataloader("train")
156 |
157 | def val_dataloader(self):
158 | return self._build_dataloader("val")
159 |
160 | def configure_optimizers(self):
161 | optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
162 | scheduler = PolyDecayScheduler(optimizer, self.hparams.lr, power=0.999)
163 | return [optimizer], [scheduler]
164 |
165 |
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/misc/extract_gmms.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from pvd.dataset import *
3 | from pvd.diffusion.ldm import *
4 | from pvd.diffusion.phase1 import *
5 | import torch
6 | import numpy as np
7 | import jutils
8 | import matplotlib.pyplot as plt
9 | from sklearn.manifold import TSNE
10 | from PIL import Image
11 | from eval3d import Evaluator
12 | import io
13 | import os
14 | from PIL import Image
15 | from datetime import datetime
16 |
17 | def load_ldm(path, strict=True):
18 | path = Path(path)
19 | assert path.exists()
20 |
21 | ldm = GaussianSALDM.load_from_checkpoint(path, strict=strict).cuda(0)
22 | ldm.eval()
23 | for p in ldm.parameters(): p.requires_grad_(False)
24 | dataset = ldm._build_dataset("val")
25 | return ldm, dataset
26 |
27 | def sample_gmms(ldm, num_shapes):
28 | ldm_gaus = ldm.sample(num_shapes).to("cpu")
29 | return ldm_gaus
30 |
31 | def main(args):
32 | assert Path(args.save_dir).exists()
33 | now = datetime.now().strftime("%m-%d-%H%M%S")
34 | ldm, dataset = load_ldm(args.ldm_path, strict=False)
35 |
36 | ldm_gaus = sample_gmms(ldm, args.num_shapes)
37 | ldm_gaus = jutils.thutil.th2np(ldm_gaus)
38 |
39 | with h5py.File(f"{args.save_dir}/{len(ldm_gaus)}-{now}.hdf5", "w") as f:
40 | f["data"] = ldm_gaus.astype(np.float32)
41 |
42 | if __name__ =="__main__":
43 | parser = argparse.ArgumentParser()
44 | parser.add_argument("--ldm_path", type=str)
45 | parser.add_argument("--num_shapes", type=int)
46 | parser.add_argument("--save_dir", type=str)
47 |
48 | args = parser.parse_args()
49 | main(args)
50 |
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/modules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KAIST-Visual-AI-Group/SALAD/db66b1b2ce57d8a4beb5bd610bc006ccaff43b1e/notebooks/uncleaned_demo/pvd/modules/__init__.py
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/modules/pointnet2_ops_lib/MANIFEST.in:
--------------------------------------------------------------------------------
1 | graft pointnet2_ops/_ext-src
2 |
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/modules/pointnet2_ops_lib/pointnet2_ops/__init__.py:
--------------------------------------------------------------------------------
1 | import pointnet2_ops.pointnet2_modules
2 | import pointnet2_ops.pointnet2_utils
3 | from pointnet2_ops._version import __version__
4 |
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/modules/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/ball_query.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include
3 |
4 | at::Tensor ball_query(at::Tensor new_xyz, at::Tensor xyz, const float radius,
5 | const int nsample);
6 |
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/modules/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/cuda_utils.h:
--------------------------------------------------------------------------------
1 | #ifndef _CUDA_UTILS_H
2 | #define _CUDA_UTILS_H
3 |
4 | #include
5 | #include
6 | #include
7 |
8 | #include
9 | #include
10 |
11 | #include
12 |
13 | #define TOTAL_THREADS 512
14 |
15 | inline int opt_n_threads(int work_size) {
16 | const int pow_2 = std::log(static_cast(work_size)) / std::log(2.0);
17 |
18 | return max(min(1 << pow_2, TOTAL_THREADS), 1);
19 | }
20 |
21 | inline dim3 opt_block_config(int x, int y) {
22 | const int x_threads = opt_n_threads(x);
23 | const int y_threads =
24 | max(min(opt_n_threads(y), TOTAL_THREADS / x_threads), 1);
25 | dim3 block_config(x_threads, y_threads, 1);
26 |
27 | return block_config;
28 | }
29 |
30 | #define CUDA_CHECK_ERRORS() \
31 | do { \
32 | cudaError_t err = cudaGetLastError(); \
33 | if (cudaSuccess != err) { \
34 | fprintf(stderr, "CUDA kernel failed : %s\n%s at L:%d in %s\n", \
35 | cudaGetErrorString(err), __PRETTY_FUNCTION__, __LINE__, \
36 | __FILE__); \
37 | exit(-1); \
38 | } \
39 | } while (0)
40 |
41 | #endif
42 |
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/modules/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/group_points.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include
3 |
4 | at::Tensor group_points(at::Tensor points, at::Tensor idx);
5 | at::Tensor group_points_grad(at::Tensor grad_out, at::Tensor idx, const int n);
6 |
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/modules/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/interpolate.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include
4 | #include
5 |
6 | std::vector three_nn(at::Tensor unknowns, at::Tensor knows);
7 | at::Tensor three_interpolate(at::Tensor points, at::Tensor idx,
8 | at::Tensor weight);
9 | at::Tensor three_interpolate_grad(at::Tensor grad_out, at::Tensor idx,
10 | at::Tensor weight, const int m);
11 |
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/modules/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/sampling.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include
3 |
4 | at::Tensor gather_points(at::Tensor points, at::Tensor idx);
5 | at::Tensor gather_points_grad(at::Tensor grad_out, at::Tensor idx, const int n);
6 | at::Tensor furthest_point_sampling(at::Tensor points, const int nsamples);
7 |
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/modules/pointnet2_ops_lib/pointnet2_ops/_ext-src/include/utils.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include
3 | #include
4 |
5 | #define CHECK_CUDA(x) \
6 | do { \
7 | AT_ASSERT(x.is_cuda(), #x " must be a CUDA tensor"); \
8 | } while (0)
9 |
10 | #define CHECK_CONTIGUOUS(x) \
11 | do { \
12 | AT_ASSERT(x.is_contiguous(), #x " must be a contiguous tensor"); \
13 | } while (0)
14 |
15 | #define CHECK_IS_INT(x) \
16 | do { \
17 | AT_ASSERT(x.scalar_type() == at::ScalarType::Int, \
18 | #x " must be an int tensor"); \
19 | } while (0)
20 |
21 | #define CHECK_IS_FLOAT(x) \
22 | do { \
23 | AT_ASSERT(x.scalar_type() == at::ScalarType::Float, \
24 | #x " must be a float tensor"); \
25 | } while (0)
26 |
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/modules/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/ball_query.cpp:
--------------------------------------------------------------------------------
1 | #include "ball_query.h"
2 | #include "utils.h"
3 |
4 | void query_ball_point_kernel_wrapper(int b, int n, int m, float radius,
5 | int nsample, const float *new_xyz,
6 | const float *xyz, int *idx);
7 |
8 | at::Tensor ball_query(at::Tensor new_xyz, at::Tensor xyz, const float radius,
9 | const int nsample) {
10 | CHECK_CONTIGUOUS(new_xyz);
11 | CHECK_CONTIGUOUS(xyz);
12 | CHECK_IS_FLOAT(new_xyz);
13 | CHECK_IS_FLOAT(xyz);
14 |
15 | if (new_xyz.is_cuda()) {
16 | CHECK_CUDA(xyz);
17 | }
18 |
19 | at::Tensor idx =
20 | torch::zeros({new_xyz.size(0), new_xyz.size(1), nsample},
21 | at::device(new_xyz.device()).dtype(at::ScalarType::Int));
22 |
23 | if (new_xyz.is_cuda()) {
24 | query_ball_point_kernel_wrapper(xyz.size(0), xyz.size(1), new_xyz.size(1),
25 | radius, nsample, new_xyz.data_ptr(),
26 | xyz.data_ptr(), idx.data_ptr());
27 | } else {
28 | AT_ASSERT(false, "CPU not supported");
29 | }
30 |
31 | return idx;
32 | }
33 |
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/modules/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/ball_query_gpu.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 |
5 | #include "cuda_utils.h"
6 |
7 | // input: new_xyz(b, m, 3) xyz(b, n, 3)
8 | // output: idx(b, m, nsample)
9 | __global__ void query_ball_point_kernel(int b, int n, int m, float radius,
10 | int nsample,
11 | const float *__restrict__ new_xyz,
12 | const float *__restrict__ xyz,
13 | int *__restrict__ idx) {
14 | int batch_index = blockIdx.x;
15 | xyz += batch_index * n * 3;
16 | new_xyz += batch_index * m * 3;
17 | idx += m * nsample * batch_index;
18 |
19 | int index = threadIdx.x;
20 | int stride = blockDim.x;
21 |
22 | float radius2 = radius * radius;
23 | for (int j = index; j < m; j += stride) {
24 | float new_x = new_xyz[j * 3 + 0];
25 | float new_y = new_xyz[j * 3 + 1];
26 | float new_z = new_xyz[j * 3 + 2];
27 | for (int k = 0, cnt = 0; k < n && cnt < nsample; ++k) {
28 | float x = xyz[k * 3 + 0];
29 | float y = xyz[k * 3 + 1];
30 | float z = xyz[k * 3 + 2];
31 | float d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) +
32 | (new_z - z) * (new_z - z);
33 | if (d2 < radius2) {
34 | if (cnt == 0) {
35 | for (int l = 0; l < nsample; ++l) {
36 | idx[j * nsample + l] = k;
37 | }
38 | }
39 | idx[j * nsample + cnt] = k;
40 | ++cnt;
41 | }
42 | }
43 | }
44 | }
45 |
46 | void query_ball_point_kernel_wrapper(int b, int n, int m, float radius,
47 | int nsample, const float *new_xyz,
48 | const float *xyz, int *idx) {
49 | cudaStream_t stream = at::cuda::getCurrentCUDAStream();
50 | query_ball_point_kernel<<>>(
51 | b, n, m, radius, nsample, new_xyz, xyz, idx);
52 |
53 | CUDA_CHECK_ERRORS();
54 | }
55 |
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/modules/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/bindings.cpp:
--------------------------------------------------------------------------------
1 | #include "ball_query.h"
2 | #include "group_points.h"
3 | #include "interpolate.h"
4 | #include "sampling.h"
5 |
6 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
7 | m.def("gather_points", &gather_points);
8 | m.def("gather_points_grad", &gather_points_grad);
9 | m.def("furthest_point_sampling", &furthest_point_sampling);
10 |
11 | m.def("three_nn", &three_nn);
12 | m.def("three_interpolate", &three_interpolate);
13 | m.def("three_interpolate_grad", &three_interpolate_grad);
14 |
15 | m.def("ball_query", &ball_query);
16 |
17 | m.def("group_points", &group_points);
18 | m.def("group_points_grad", &group_points_grad);
19 | }
20 |
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/modules/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/group_points.cpp:
--------------------------------------------------------------------------------
1 | #include "group_points.h"
2 | #include "utils.h"
3 |
4 | void group_points_kernel_wrapper(int b, int c, int n, int npoints, int nsample,
5 | const float *points, const int *idx,
6 | float *out);
7 |
8 | void group_points_grad_kernel_wrapper(int b, int c, int n, int npoints,
9 | int nsample, const float *grad_out,
10 | const int *idx, float *grad_points);
11 |
12 | at::Tensor group_points(at::Tensor points, at::Tensor idx) {
13 | CHECK_CONTIGUOUS(points);
14 | CHECK_CONTIGUOUS(idx);
15 | CHECK_IS_FLOAT(points);
16 | CHECK_IS_INT(idx);
17 |
18 | if (points.is_cuda()) {
19 | CHECK_CUDA(idx);
20 | }
21 |
22 | at::Tensor output =
23 | torch::zeros({points.size(0), points.size(1), idx.size(1), idx.size(2)},
24 | at::device(points.device()).dtype(at::ScalarType::Float));
25 |
26 | if (points.is_cuda()) {
27 | group_points_kernel_wrapper(points.size(0), points.size(1), points.size(2),
28 | idx.size(1), idx.size(2),
29 | points.data_ptr(), idx.data_ptr(),
30 | output.data_ptr());
31 | } else {
32 | AT_ASSERT(false, "CPU not supported");
33 | }
34 |
35 | return output;
36 | }
37 |
38 | at::Tensor group_points_grad(at::Tensor grad_out, at::Tensor idx, const int n) {
39 | CHECK_CONTIGUOUS(grad_out);
40 | CHECK_CONTIGUOUS(idx);
41 | CHECK_IS_FLOAT(grad_out);
42 | CHECK_IS_INT(idx);
43 |
44 | if (grad_out.is_cuda()) {
45 | CHECK_CUDA(idx);
46 | }
47 |
48 | at::Tensor output =
49 | torch::zeros({grad_out.size(0), grad_out.size(1), n},
50 | at::device(grad_out.device()).dtype(at::ScalarType::Float));
51 |
52 | if (grad_out.is_cuda()) {
53 | group_points_grad_kernel_wrapper(
54 | grad_out.size(0), grad_out.size(1), n, idx.size(1), idx.size(2),
55 | grad_out.data_ptr(), idx.data_ptr(),
56 | output.data_ptr());
57 | } else {
58 | AT_ASSERT(false, "CPU not supported");
59 | }
60 |
61 | return output;
62 | }
63 |
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/modules/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/group_points_gpu.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 |
4 | #include "cuda_utils.h"
5 |
6 | // input: points(b, c, n) idx(b, npoints, nsample)
7 | // output: out(b, c, npoints, nsample)
8 | __global__ void group_points_kernel(int b, int c, int n, int npoints,
9 | int nsample,
10 | const float *__restrict__ points,
11 | const int *__restrict__ idx,
12 | float *__restrict__ out) {
13 | int batch_index = blockIdx.x;
14 | points += batch_index * n * c;
15 | idx += batch_index * npoints * nsample;
16 | out += batch_index * npoints * nsample * c;
17 |
18 | const int index = threadIdx.y * blockDim.x + threadIdx.x;
19 | const int stride = blockDim.y * blockDim.x;
20 | for (int i = index; i < c * npoints; i += stride) {
21 | const int l = i / npoints;
22 | const int j = i % npoints;
23 | for (int k = 0; k < nsample; ++k) {
24 | int ii = idx[j * nsample + k];
25 | out[(l * npoints + j) * nsample + k] = points[l * n + ii];
26 | }
27 | }
28 | }
29 |
30 | void group_points_kernel_wrapper(int b, int c, int n, int npoints, int nsample,
31 | const float *points, const int *idx,
32 | float *out) {
33 | cudaStream_t stream = at::cuda::getCurrentCUDAStream();
34 |
35 | group_points_kernel<<>>(
36 | b, c, n, npoints, nsample, points, idx, out);
37 |
38 | CUDA_CHECK_ERRORS();
39 | }
40 |
41 | // input: grad_out(b, c, npoints, nsample), idx(b, npoints, nsample)
42 | // output: grad_points(b, c, n)
43 | __global__ void group_points_grad_kernel(int b, int c, int n, int npoints,
44 | int nsample,
45 | const float *__restrict__ grad_out,
46 | const int *__restrict__ idx,
47 | float *__restrict__ grad_points) {
48 | int batch_index = blockIdx.x;
49 | grad_out += batch_index * npoints * nsample * c;
50 | idx += batch_index * npoints * nsample;
51 | grad_points += batch_index * n * c;
52 |
53 | const int index = threadIdx.y * blockDim.x + threadIdx.x;
54 | const int stride = blockDim.y * blockDim.x;
55 | for (int i = index; i < c * npoints; i += stride) {
56 | const int l = i / npoints;
57 | const int j = i % npoints;
58 | for (int k = 0; k < nsample; ++k) {
59 | int ii = idx[j * nsample + k];
60 | atomicAdd(grad_points + l * n + ii,
61 | grad_out[(l * npoints + j) * nsample + k]);
62 | }
63 | }
64 | }
65 |
66 | void group_points_grad_kernel_wrapper(int b, int c, int n, int npoints,
67 | int nsample, const float *grad_out,
68 | const int *idx, float *grad_points) {
69 | cudaStream_t stream = at::cuda::getCurrentCUDAStream();
70 |
71 | group_points_grad_kernel<<>>(
72 | b, c, n, npoints, nsample, grad_out, idx, grad_points);
73 |
74 | CUDA_CHECK_ERRORS();
75 | }
76 |
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/modules/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/interpolate.cpp:
--------------------------------------------------------------------------------
1 | #include "interpolate.h"
2 | #include "utils.h"
3 |
4 | void three_nn_kernel_wrapper(int b, int n, int m, const float *unknown,
5 | const float *known, float *dist2, int *idx);
6 | void three_interpolate_kernel_wrapper(int b, int c, int m, int n,
7 | const float *points, const int *idx,
8 | const float *weight, float *out);
9 | void three_interpolate_grad_kernel_wrapper(int b, int c, int n, int m,
10 | const float *grad_out,
11 | const int *idx, const float *weight,
12 | float *grad_points);
13 |
14 | std::vector three_nn(at::Tensor unknowns, at::Tensor knows) {
15 | CHECK_CONTIGUOUS(unknowns);
16 | CHECK_CONTIGUOUS(knows);
17 | CHECK_IS_FLOAT(unknowns);
18 | CHECK_IS_FLOAT(knows);
19 |
20 | if (unknowns.is_cuda()) {
21 | CHECK_CUDA(knows);
22 | }
23 |
24 | at::Tensor idx =
25 | torch::zeros({unknowns.size(0), unknowns.size(1), 3},
26 | at::device(unknowns.device()).dtype(at::ScalarType::Int));
27 | at::Tensor dist2 =
28 | torch::zeros({unknowns.size(0), unknowns.size(1), 3},
29 | at::device(unknowns.device()).dtype(at::ScalarType::Float));
30 |
31 | if (unknowns.is_cuda()) {
32 | three_nn_kernel_wrapper(unknowns.size(0), unknowns.size(1), knows.size(1),
33 | unknowns.data_ptr(), knows.data_ptr(),
34 | dist2.data_ptr(), idx.data_ptr());
35 | } else {
36 | AT_ASSERT(false, "CPU not supported");
37 | }
38 |
39 | return {dist2, idx};
40 | }
41 |
42 | at::Tensor three_interpolate(at::Tensor points, at::Tensor idx,
43 | at::Tensor weight) {
44 | CHECK_CONTIGUOUS(points);
45 | CHECK_CONTIGUOUS(idx);
46 | CHECK_CONTIGUOUS(weight);
47 | CHECK_IS_FLOAT(points);
48 | CHECK_IS_INT(idx);
49 | CHECK_IS_FLOAT(weight);
50 |
51 | if (points.is_cuda()) {
52 | CHECK_CUDA(idx);
53 | CHECK_CUDA(weight);
54 | }
55 |
56 | at::Tensor output =
57 | torch::zeros({points.size(0), points.size(1), idx.size(1)},
58 | at::device(points.device()).dtype(at::ScalarType::Float));
59 |
60 | if (points.is_cuda()) {
61 | three_interpolate_kernel_wrapper(
62 | points.size(0), points.size(1), points.size(2), idx.size(1),
63 | points.data_ptr(), idx.data_ptr(), weight.data_ptr(),
64 | output.data_ptr());
65 | } else {
66 | AT_ASSERT(false, "CPU not supported");
67 | }
68 |
69 | return output;
70 | }
71 | at::Tensor three_interpolate_grad(at::Tensor grad_out, at::Tensor idx,
72 | at::Tensor weight, const int m) {
73 | CHECK_CONTIGUOUS(grad_out);
74 | CHECK_CONTIGUOUS(idx);
75 | CHECK_CONTIGUOUS(weight);
76 | CHECK_IS_FLOAT(grad_out);
77 | CHECK_IS_INT(idx);
78 | CHECK_IS_FLOAT(weight);
79 |
80 | if (grad_out.is_cuda()) {
81 | CHECK_CUDA(idx);
82 | CHECK_CUDA(weight);
83 | }
84 |
85 | at::Tensor output =
86 | torch::zeros({grad_out.size(0), grad_out.size(1), m},
87 | at::device(grad_out.device()).dtype(at::ScalarType::Float));
88 |
89 | if (grad_out.is_cuda()) {
90 | three_interpolate_grad_kernel_wrapper(
91 | grad_out.size(0), grad_out.size(1), grad_out.size(2), m,
92 | grad_out.data_ptr(), idx.data_ptr(),
93 | weight.data_ptr(), output.data_ptr());
94 | } else {
95 | AT_ASSERT(false, "CPU not supported");
96 | }
97 |
98 | return output;
99 | }
100 |
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/modules/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/interpolate_gpu.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 |
5 | #include "cuda_utils.h"
6 |
7 | // input: unknown(b, n, 3) known(b, m, 3)
8 | // output: dist2(b, n, 3), idx(b, n, 3)
9 | __global__ void three_nn_kernel(int b, int n, int m,
10 | const float *__restrict__ unknown,
11 | const float *__restrict__ known,
12 | float *__restrict__ dist2,
13 | int *__restrict__ idx) {
14 | int batch_index = blockIdx.x;
15 | unknown += batch_index * n * 3;
16 | known += batch_index * m * 3;
17 | dist2 += batch_index * n * 3;
18 | idx += batch_index * n * 3;
19 |
20 | int index = threadIdx.x;
21 | int stride = blockDim.x;
22 | for (int j = index; j < n; j += stride) {
23 | float ux = unknown[j * 3 + 0];
24 | float uy = unknown[j * 3 + 1];
25 | float uz = unknown[j * 3 + 2];
26 |
27 | double best1 = 1e40, best2 = 1e40, best3 = 1e40;
28 | int besti1 = 0, besti2 = 0, besti3 = 0;
29 | for (int k = 0; k < m; ++k) {
30 | float x = known[k * 3 + 0];
31 | float y = known[k * 3 + 1];
32 | float z = known[k * 3 + 2];
33 | float d = (ux - x) * (ux - x) + (uy - y) * (uy - y) + (uz - z) * (uz - z);
34 | if (d < best1) {
35 | best3 = best2;
36 | besti3 = besti2;
37 | best2 = best1;
38 | besti2 = besti1;
39 | best1 = d;
40 | besti1 = k;
41 | } else if (d < best2) {
42 | best3 = best2;
43 | besti3 = besti2;
44 | best2 = d;
45 | besti2 = k;
46 | } else if (d < best3) {
47 | best3 = d;
48 | besti3 = k;
49 | }
50 | }
51 | dist2[j * 3 + 0] = best1;
52 | dist2[j * 3 + 1] = best2;
53 | dist2[j * 3 + 2] = best3;
54 |
55 | idx[j * 3 + 0] = besti1;
56 | idx[j * 3 + 1] = besti2;
57 | idx[j * 3 + 2] = besti3;
58 | }
59 | }
60 |
61 | void three_nn_kernel_wrapper(int b, int n, int m, const float *unknown,
62 | const float *known, float *dist2, int *idx) {
63 | cudaStream_t stream = at::cuda::getCurrentCUDAStream();
64 | three_nn_kernel<<>>(b, n, m, unknown, known,
65 | dist2, idx);
66 |
67 | CUDA_CHECK_ERRORS();
68 | }
69 |
70 | // input: points(b, c, m), idx(b, n, 3), weight(b, n, 3)
71 | // output: out(b, c, n)
72 | __global__ void three_interpolate_kernel(int b, int c, int m, int n,
73 | const float *__restrict__ points,
74 | const int *__restrict__ idx,
75 | const float *__restrict__ weight,
76 | float *__restrict__ out) {
77 | int batch_index = blockIdx.x;
78 | points += batch_index * m * c;
79 |
80 | idx += batch_index * n * 3;
81 | weight += batch_index * n * 3;
82 |
83 | out += batch_index * n * c;
84 |
85 | const int index = threadIdx.y * blockDim.x + threadIdx.x;
86 | const int stride = blockDim.y * blockDim.x;
87 | for (int i = index; i < c * n; i += stride) {
88 | const int l = i / n;
89 | const int j = i % n;
90 | float w1 = weight[j * 3 + 0];
91 | float w2 = weight[j * 3 + 1];
92 | float w3 = weight[j * 3 + 2];
93 |
94 | int i1 = idx[j * 3 + 0];
95 | int i2 = idx[j * 3 + 1];
96 | int i3 = idx[j * 3 + 2];
97 |
98 | out[i] = points[l * m + i1] * w1 + points[l * m + i2] * w2 +
99 | points[l * m + i3] * w3;
100 | }
101 | }
102 |
103 | void three_interpolate_kernel_wrapper(int b, int c, int m, int n,
104 | const float *points, const int *idx,
105 | const float *weight, float *out) {
106 | cudaStream_t stream = at::cuda::getCurrentCUDAStream();
107 | three_interpolate_kernel<<>>(
108 | b, c, m, n, points, idx, weight, out);
109 |
110 | CUDA_CHECK_ERRORS();
111 | }
112 |
113 | // input: grad_out(b, c, n), idx(b, n, 3), weight(b, n, 3)
114 | // output: grad_points(b, c, m)
115 |
116 | __global__ void three_interpolate_grad_kernel(
117 | int b, int c, int n, int m, const float *__restrict__ grad_out,
118 | const int *__restrict__ idx, const float *__restrict__ weight,
119 | float *__restrict__ grad_points) {
120 | int batch_index = blockIdx.x;
121 | grad_out += batch_index * n * c;
122 | idx += batch_index * n * 3;
123 | weight += batch_index * n * 3;
124 | grad_points += batch_index * m * c;
125 |
126 | const int index = threadIdx.y * blockDim.x + threadIdx.x;
127 | const int stride = blockDim.y * blockDim.x;
128 | for (int i = index; i < c * n; i += stride) {
129 | const int l = i / n;
130 | const int j = i % n;
131 | float w1 = weight[j * 3 + 0];
132 | float w2 = weight[j * 3 + 1];
133 | float w3 = weight[j * 3 + 2];
134 |
135 | int i1 = idx[j * 3 + 0];
136 | int i2 = idx[j * 3 + 1];
137 | int i3 = idx[j * 3 + 2];
138 |
139 | atomicAdd(grad_points + l * m + i1, grad_out[i] * w1);
140 | atomicAdd(grad_points + l * m + i2, grad_out[i] * w2);
141 | atomicAdd(grad_points + l * m + i3, grad_out[i] * w3);
142 | }
143 | }
144 |
145 | void three_interpolate_grad_kernel_wrapper(int b, int c, int n, int m,
146 | const float *grad_out,
147 | const int *idx, const float *weight,
148 | float *grad_points) {
149 | cudaStream_t stream = at::cuda::getCurrentCUDAStream();
150 | three_interpolate_grad_kernel<<>>(
151 | b, c, n, m, grad_out, idx, weight, grad_points);
152 |
153 | CUDA_CHECK_ERRORS();
154 | }
155 |
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/modules/pointnet2_ops_lib/pointnet2_ops/_ext-src/src/sampling.cpp:
--------------------------------------------------------------------------------
1 | #include "sampling.h"
2 | #include "utils.h"
3 |
4 | void gather_points_kernel_wrapper(int b, int c, int n, int npoints,
5 | const float *points, const int *idx,
6 | float *out);
7 | void gather_points_grad_kernel_wrapper(int b, int c, int n, int npoints,
8 | const float *grad_out, const int *idx,
9 | float *grad_points);
10 |
11 | void furthest_point_sampling_kernel_wrapper(int b, int n, int m,
12 | const float *dataset, float *temp,
13 | int *idxs);
14 |
15 | at::Tensor gather_points(at::Tensor points, at::Tensor idx) {
16 | CHECK_CONTIGUOUS(points);
17 | CHECK_CONTIGUOUS(idx);
18 | CHECK_IS_FLOAT(points);
19 | CHECK_IS_INT(idx);
20 |
21 | if (points.is_cuda()) {
22 | CHECK_CUDA(idx);
23 | }
24 |
25 | at::Tensor output =
26 | torch::zeros({points.size(0), points.size(1), idx.size(1)},
27 | at::device(points.device()).dtype(at::ScalarType::Float));
28 |
29 | if (points.is_cuda()) {
30 | gather_points_kernel_wrapper(points.size(0), points.size(1), points.size(2),
31 | idx.size(1), points.data_ptr(),
32 | idx.data_ptr(), output.data_ptr());
33 | } else {
34 | AT_ASSERT(false, "CPU not supported");
35 | }
36 |
37 | return output;
38 | }
39 |
40 | at::Tensor gather_points_grad(at::Tensor grad_out, at::Tensor idx,
41 | const int n) {
42 | CHECK_CONTIGUOUS(grad_out);
43 | CHECK_CONTIGUOUS(idx);
44 | CHECK_IS_FLOAT(grad_out);
45 | CHECK_IS_INT(idx);
46 |
47 | if (grad_out.is_cuda()) {
48 | CHECK_CUDA(idx);
49 | }
50 |
51 | at::Tensor output =
52 | torch::zeros({grad_out.size(0), grad_out.size(1), n},
53 | at::device(grad_out.device()).dtype(at::ScalarType::Float));
54 |
55 | if (grad_out.is_cuda()) {
56 | gather_points_grad_kernel_wrapper(grad_out.size(0), grad_out.size(1), n,
57 | idx.size(1), grad_out.data_ptr(),
58 | idx.data_ptr(),
59 | output.data_ptr());
60 | } else {
61 | AT_ASSERT(false, "CPU not supported");
62 | }
63 |
64 | return output;
65 | }
66 | at::Tensor furthest_point_sampling(at::Tensor points, const int nsamples) {
67 | CHECK_CONTIGUOUS(points);
68 | CHECK_IS_FLOAT(points);
69 |
70 | at::Tensor output =
71 | torch::zeros({points.size(0), nsamples},
72 | at::device(points.device()).dtype(at::ScalarType::Int));
73 |
74 | at::Tensor tmp =
75 | torch::full({points.size(0), points.size(1)}, 1e10,
76 | at::device(points.device()).dtype(at::ScalarType::Float));
77 |
78 | if (points.is_cuda()) {
79 | furthest_point_sampling_kernel_wrapper(
80 | points.size(0), points.size(1), nsamples, points.data_ptr(),
81 | tmp.data_ptr(), output.data_ptr());
82 | } else {
83 | AT_ASSERT(false, "CPU not supported");
84 | }
85 |
86 | return output;
87 | }
88 |
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/modules/pointnet2_ops_lib/pointnet2_ops/_version.py:
--------------------------------------------------------------------------------
1 | __version__ = "3.0.0"
2 |
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/modules/pointnet2_ops_lib/setup.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import os
3 | import os.path as osp
4 |
5 | from setuptools import find_packages, setup
6 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension
7 |
8 | this_dir = osp.dirname(osp.abspath(__file__))
9 | _ext_src_root = osp.join("pointnet2_ops", "_ext-src")
10 | _ext_sources = glob.glob(osp.join(_ext_src_root, "src", "*.cpp")) + glob.glob(
11 | osp.join(_ext_src_root, "src", "*.cu")
12 | )
13 | _ext_headers = glob.glob(osp.join(_ext_src_root, "include", "*"))
14 |
15 | requirements = ["torch>=1.4"]
16 |
17 | exec(open(osp.join("pointnet2_ops", "_version.py")).read())
18 |
19 | os.environ["TORCH_CUDA_ARCH_LIST"] = "3.7+PTX;5.0;6.0;6.1;6.2;7.0;7.5"
20 | setup(
21 | name="pointnet2_ops",
22 | version=__version__,
23 | author="Erik Wijmans",
24 | packages=find_packages(),
25 | install_requires=requirements,
26 | ext_modules=[
27 | CUDAExtension(
28 | name="pointnet2_ops._ext",
29 | sources=_ext_sources,
30 | extra_compile_args={
31 | "cxx": ["-O3"],
32 | "nvcc": ["-O3", "-Xfatbin", "-compress-all"],
33 | },
34 | include_dirs=[osp.join(this_dir, _ext_src_root, "include")],
35 | )
36 | ],
37 | cmdclass={"build_ext": BuildExtension},
38 | include_package_data=True,
39 | )
40 |
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/modules/simple_module.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from pvd.modules.transformer import TimeMLP
5 | import math
6 |
7 |
8 | class ConcatDecoder(nn.Module):
9 | def __init__(self, dim_self, dim_ref, dim_ctx, num_layers=3):
10 | super().__init__()
11 | dim_total = dim_self + dim_ref
12 |
13 | self.layers = nn.ModuleList(
14 | [
15 | TimeMLP(dim_total, dim_total, dim_total, dim_ctx, use_time=True)
16 | for _ in range(num_layers)
17 | ]
18 | )
19 | self.fc = nn.Linear(dim_total, dim_self)
20 |
21 | def forward(self, x, y, ctx=None):
22 | out = torch.cat([x, y], -1)
23 | for i, layer in enumerate(self.layers):
24 | out = layer(out, ctx=ctx)
25 | out = self.fc(out)
26 | return out
27 |
28 |
29 | class TimestepEmbedder(nn.Module):
30 | """
31 | Embeds scalar timesteps into vector representations.
32 | """
33 |
34 | def __init__(self, hidden_size, frequency_embedding_size=256):
35 | super().__init__()
36 | self.mlp = nn.Sequential(
37 | nn.Linear(frequency_embedding_size, hidden_size, bias=True),
38 | nn.SiLU(),
39 | nn.Linear(hidden_size, hidden_size, bias=True),
40 | )
41 | self.frequency_embedding_size = frequency_embedding_size
42 |
43 | @staticmethod
44 | def timestep_embedding(t, dim, max_period=10000):
45 | """
46 | Create sinusoidal timestep embeddings.
47 | :param t: a 1-D Tensor of N indices, one per batch element.
48 | These may be fractional.
49 | :param dim: the dimension of the output.
50 | :param max_period: controls the minimum frequency of the embeddings.
51 | :return: an (N, D) Tensor of positional embeddings.
52 | """
53 | # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
54 | half = dim // 2
55 | freqs = torch.exp(
56 | -math.log(max_period)
57 | * torch.arange(start=0, end=half, dtype=torch.float32)
58 | / half
59 | ).to(device=t.device)
60 | args = t[:, None].float() * freqs[None]
61 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
62 | if dim % 2:
63 | embedding = torch.cat(
64 | [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
65 | )
66 | return embedding
67 |
68 | def forward(self, t):
69 | t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
70 | t_emb = self.mlp(t_freq)
71 | return t_emb
72 |
73 |
74 | class TimePointwiseLayer(nn.Module):
75 | def __init__(
76 | self,
77 | dim_in,
78 | dim_ctx,
79 | mlp_ratio=2,
80 | act=F.leaky_relu,
81 | dropout=0.0,
82 | use_time=False,
83 | ):
84 | super().__init__()
85 | self.use_time = use_time
86 | self.act = act
87 | self.mlp1 = TimeMLP(
88 | dim_in, dim_in * mlp_ratio, dim_in, dim_ctx, use_time=use_time
89 | )
90 | self.norm1 = nn.LayerNorm(dim_in)
91 |
92 | self.mlp2 = TimeMLP(
93 | dim_in, dim_in * mlp_ratio, dim_in, dim_ctx, use_time=use_time
94 | )
95 | self.norm2 = nn.LayerNorm(dim_in)
96 | self.dropout = nn.Dropout(dropout)
97 |
98 | def forward(self, x, ctx=None):
99 | res = x
100 | x = self.mlp1(x, ctx=ctx)
101 | x = self.norm1(x + res)
102 |
103 | res = x
104 | x = self.mlp2(x, ctx=ctx)
105 | x = self.norm2(x + res)
106 | return x
107 |
108 |
109 | class TimePointWiseEncoder(nn.Module):
110 | def __init__(
111 | self,
112 | dim_in,
113 | dim_ctx=None,
114 | mlp_ratio=2,
115 | act=F.leaky_relu,
116 | dropout=0.0,
117 | use_time=True,
118 | num_layers=6,
119 | last_fc=False,
120 | last_fc_dim_out=None,
121 | ):
122 | super().__init__()
123 | self.last_fc = last_fc
124 | if last_fc:
125 | self.fc = nn.Linear(dim_in, last_fc_dim_out)
126 | self.layers = nn.ModuleList(
127 | [
128 | TimePointwiseLayer(
129 | dim_in,
130 | dim_ctx=dim_ctx,
131 | mlp_ratio=mlp_ratio,
132 | act=act,
133 | dropout=dropout,
134 | use_time=use_time,
135 | )
136 | for _ in range(num_layers)
137 | ]
138 | )
139 |
140 | def forward(self, x, ctx=None):
141 | for i, layer in enumerate(self.layers):
142 | x = layer(x, ctx=ctx)
143 | if self.last_fc:
144 | x = self.fc(x)
145 | return x
146 |
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/spaghetti_github/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KAIST-Visual-AI-Group/SALAD/db66b1b2ce57d8a4beb5bd610bc006ccaff43b1e/notebooks/uncleaned_demo/pvd/spaghetti_github/__init__.py
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/spaghetti_github/constants.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | IS_WINDOWS = sys.platform == 'win32'
5 | get_trace = getattr(sys, 'gettrace', None)
6 | DEBUG = get_trace is not None and get_trace() is not None
7 | EPSILON = 1e-4
8 | DIM = 3
9 |
10 | #### < Seungwoo >
11 | # Added 'docker_home' directory in the middle.
12 | PROJECT_ROOT = "/home/juil/docker_home/projects/3D_CRISPR/spaghetti_github"
13 | # PROJECT_ROOT = "/home/juil/projects/3D_CRISPR/spaghetti_github"
14 | ####
15 |
16 | DATA_ROOT = f'{PROJECT_ROOT}/assets/'
17 | CHECKPOINTS_ROOT = f'{DATA_ROOT}checkpoints/'
18 | CACHE_ROOT = f'{DATA_ROOT}cache/'
19 | UI_OUT = f'{DATA_ROOT}ui_export/'
20 | UI_RESOURCES = f'{DATA_ROOT}/ui_resources/'
21 | Shapenet_WT = f'{DATA_ROOT}/ShapeNetCore_wt/'
22 | Shapenet = f'{DATA_ROOT}/ShapeNetCore.v2/'
23 | MAX_VS = 100000
24 | MAX_GAUSIANS = 32
25 |
26 |
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/spaghetti_github/custom_types.py:
--------------------------------------------------------------------------------
1 | # import open3d
2 | import enum
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as nnf
7 | from .constants import DEBUG
8 | from typing import Tuple, List, Union, Callable, Type, Iterator, Dict, Set, Optional, Any, Sized, Iterable
9 | from types import DynamicClassAttribute
10 | from enum import Enum, unique
11 | import torch.optim.optimizer
12 | import torch.utils.data
13 |
14 | if DEBUG:
15 | seed = 99
16 | torch.manual_seed(seed)
17 | np.random.seed(seed)
18 |
19 | N = type(None)
20 | V = np.array
21 | ARRAY = np.ndarray
22 | ARRAYS = Union[Tuple[ARRAY, ...], List[ARRAY]]
23 | VS = Union[Tuple[V, ...], List[V]]
24 | VN = Union[V, N]
25 | VNS = Union[VS, N]
26 | T = torch.Tensor
27 | TS = Union[Tuple[T, ...], List[T]]
28 | TN = Optional[T]
29 | TNS = Union[Tuple[TN, ...], List[TN]]
30 | TSN = Optional[TS]
31 | TA = Union[T, ARRAY]
32 |
33 | V_Mesh = Tuple[ARRAY, ARRAY]
34 | T_Mesh = Tuple[T, Optional[T]]
35 | T_Mesh_T = Union[T_Mesh, T]
36 | COLORS = Union[T, ARRAY, Tuple[int, int, int]]
37 |
38 | D = torch.device
39 | CPU = torch.device('cpu')
40 |
41 |
42 | def get_device(device_id: int) -> D:
43 | if not torch.cuda.is_available():
44 | return CPU
45 | device_id = min(torch.cuda.device_count() - 1, device_id)
46 | return torch.device(f'cuda:{device_id}')
47 |
48 |
49 | CUDA = get_device
50 | Optimizer = torch.optim.Adam
51 | Dataset = torch.utils.data.Dataset
52 | DataLoader = torch.utils.data.DataLoader
53 | Subset = torch.utils.data.Subset
54 |
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/spaghetti_github/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KAIST-Visual-AI-Group/SALAD/db66b1b2ce57d8a4beb5bd610bc006ccaff43b1e/notebooks/uncleaned_demo/pvd/spaghetti_github/models/__init__.py
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/spaghetti_github/models/gm_utils.py:
--------------------------------------------------------------------------------
1 | from ..custom_types import *
2 |
3 |
4 | def get_gm_support(gm, x):
5 | dim = x.shape[-1]
6 | mu, p, phi, eigen = gm
7 | sigma_det = eigen.prod(-1)
8 | eigen_inv = 1 / eigen
9 | sigma_inverse = torch.matmul(p.transpose(3, 4), p * eigen_inv[:, :, :, :, None]).squeeze(1)
10 | phi = torch.softmax(phi, dim=2)
11 | const_1 = phi / torch.sqrt((2 * np.pi) ** dim * sigma_det)
12 | distance = x[:, :, None, :] - mu
13 | mahalanobis_distance = - .5 * torch.einsum('bngd,bgdc,bngc->bng', distance, sigma_inverse, distance)
14 | const_2, _ = mahalanobis_distance.max(dim=2) # for numeric stability
15 | mahalanobis_distance -= const_2[:, :, None]
16 | support = const_1 * torch.exp(mahalanobis_distance)
17 | return support, const_2
18 |
19 |
20 | def gm_log_likelihood_loss(gms: TS, x: T, get_supports: bool = False,
21 | mask: Optional[T] = None, reduction: str = "mean") -> Union[T, Tuple[T, TS]]:
22 |
23 | batch_size, num_points, dim = x.shape
24 | support, const = get_gm_support(gms, x)
25 | probs = torch.log(support.sum(dim=2)) + const
26 | if mask is not None:
27 | probs = probs.masked_select(mask=mask.flatten())
28 | if reduction == 'none':
29 | likelihood = probs.sum(-1)
30 | loss = - likelihood / num_points
31 | else:
32 | likelihood = probs.sum()
33 | loss = - likelihood / (probs.shape[0] * probs.shape[1])
34 | if get_supports:
35 | return loss, support
36 | return loss
37 |
38 |
39 | def split_mesh_by_gmm(mesh: T_Mesh, gmm) -> Dict[int, T]:
40 | faces_split = {}
41 | vs, faces = mesh
42 | vs_mid_faces = vs[faces].mean(1)
43 | _, supports = gm_log_likelihood_loss(gmm, vs_mid_faces.unsqueeze(0), get_supports=True)
44 | supports = supports[0]
45 | label = supports.argmax(1)
46 | for i in range(gmm[1].shape[2]):
47 | select = label.eq(i)
48 | if select.any():
49 | faces_split[i] = faces[select]
50 | else:
51 | faces_split[i] = None
52 | return faces_split
53 |
54 |
55 | def flatten_gmm(gmm: TS) -> T:
56 | b, gp, g, _ = gmm[0].shape
57 | mu, p, phi, eigen = [item.view(b, gp * g, *item.shape[3:]) for item in gmm]
58 | p = p.reshape(*p.shape[:2], -1)
59 | z_gmm = torch.cat((mu, p, phi.unsqueeze(-1), eigen), dim=2)
60 | return z_gmm
61 |
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/spaghetti_github/models/mlp_models.py:
--------------------------------------------------------------------------------
1 |
2 |
3 |
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/spaghetti_github/models/models_utils.py:
--------------------------------------------------------------------------------
1 | from ..custom_types import *
2 | from abc import ABC
3 | import math
4 |
5 |
6 | def torch_no_grad(func):
7 | def wrapper(*args, **kwargs):
8 | with torch.no_grad():
9 | result = func(*args, **kwargs)
10 | return result
11 |
12 | return wrapper
13 |
14 |
15 | class Model(nn.Module, ABC):
16 |
17 | def __init__(self):
18 | super(Model, self).__init__()
19 | self.save_model: Union[None, Callable[[nn.Module]]] = None
20 |
21 | def save(self, **kwargs):
22 | self.save_model(self, **kwargs)
23 |
24 |
25 | class Concatenate(nn.Module):
26 | def __init__(self, dim):
27 | super(Concatenate, self).__init__()
28 | self.dim = dim
29 |
30 | def forward(self, x):
31 | return torch.cat(x, dim=self.dim)
32 |
33 |
34 | class View(nn.Module):
35 |
36 | def __init__(self, *shape):
37 | super(View, self).__init__()
38 | self.shape = shape
39 |
40 | def forward(self, x):
41 | return x.view(*self.shape)
42 |
43 |
44 | class Transpose(nn.Module):
45 |
46 | def __init__(self, dim0, dim1):
47 | super(Transpose, self).__init__()
48 | self.dim0, self.dim1 = dim0, dim1
49 |
50 | def forward(self, x):
51 | return x.transpose(self.dim0, self.dim1)
52 |
53 |
54 | class Dummy(nn.Module):
55 |
56 | def __init__(self, *args):
57 | super(Dummy, self).__init__()
58 |
59 | def forward(self, *args):
60 | return args[0]
61 |
62 |
63 | class SineLayer(nn.Module):
64 | """
65 | From the siren repository
66 | https://colab.research.google.com/github/vsitzmann/siren/blob/master/explore_siren.ipynb
67 | """
68 | def __init__(self, in_features, out_features, bias=True,
69 | is_first=False, omega_0=30):
70 | super().__init__()
71 | self.omega_0 = omega_0
72 | self.is_first = is_first
73 |
74 | self.in_features = in_features
75 | self.linear = nn.Linear(in_features, out_features, bias=bias)
76 | self.output_channels = out_features
77 | self.init_weights()
78 |
79 | def init_weights(self):
80 | with torch.no_grad():
81 | if self.is_first:
82 | self.linear.weight.uniform_(-1 / self.in_features,
83 | 1 / self.in_features)
84 | else:
85 | self.linear.weight.uniform_(-np.sqrt(6 / self.in_features) / self.omega_0,
86 | np.sqrt(6 / self.in_features) / self.omega_0)
87 |
88 | def forward(self, input):
89 | return torch.sin(self.omega_0 * self.linear(input))
90 |
91 |
92 | class MLP(nn.Module):
93 |
94 | def forward(self, x, *_):
95 | return self.net(x)
96 |
97 | def __init__(self, ch: Union[List[int], Tuple[int, ...]], act: nn.Module = nn.ReLU,
98 | weight_norm=False):
99 | super(MLP, self).__init__()
100 | layers = []
101 | for i in range(len(ch) - 1):
102 | layers.append(nn.Linear(ch[i], ch[i + 1]))
103 | if weight_norm:
104 | layers[-1] = nn.utils.weight_norm(layers[-1])
105 | if i < len(ch) - 2:
106 | layers.append(act(True))
107 | self.net = nn.Sequential(*layers)
108 |
109 |
110 | class GMAttend(nn.Module):
111 |
112 | def __init__(self, hidden_dim: int):
113 | super(GMAttend, self).__init__()
114 | self.key_dim = hidden_dim // 8
115 | self.query_w = nn.Linear(hidden_dim, self.key_dim)
116 | self.key_w = nn.Linear(hidden_dim, self.key_dim)
117 | self.value_w = nn.Linear(hidden_dim, hidden_dim)
118 | self.softmax = nn.Softmax(dim=3)
119 | self.gamma = nn.Parameter(torch.zeros(1))
120 | self.scale = 1 / torch.sqrt(torch.tensor(self.key_dim, dtype=torch.float32))
121 |
122 | def forward(self, x):
123 | queries = self.query_w(x)
124 | keys = self.key_w(x)
125 | vals = self.value_w(x)
126 | attention = self.softmax(torch.einsum('bgqf,bgkf->bgqk', queries, keys))
127 | out = torch.einsum('bgvf,bgqv->bgqf', vals, attention)
128 | out = self.gamma * out + x
129 | return out
130 |
131 |
132 | def recursive_to(item, device):
133 | if type(item) is T:
134 | return item.to(device)
135 | elif type(item) is tuple or type(item) is list:
136 | return [recursive_to(item[i], device) for i in range(len(item))]
137 | return item
138 |
139 |
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/spaghetti_github/models/occ_loss.py:
--------------------------------------------------------------------------------
1 | from ..custom_types import *
2 |
3 |
4 | def occupancy_bce(predict: T, winding_gt: T, ignore: Optional[T] = None, *args) -> T:
5 | if winding_gt.dtype is not torch.bool:
6 | winding_gt = winding_gt.gt(0)
7 | labels = winding_gt.flatten().float()
8 | predict = predict.flatten()
9 | if ignore is not None:
10 | ignore = (~ignore).flatten().float()
11 | loss = nnf.binary_cross_entropy_with_logits(predict, labels, weight=ignore)
12 | return loss
13 |
14 |
15 | def reg_z_loss(z: T) -> T:
16 | norms = z.norm(2, 1)
17 | loss = norms.mean()
18 | return loss
19 |
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/spaghetti_github/models/transformer.py:
--------------------------------------------------------------------------------
1 | from ..custom_types import *
2 |
3 |
4 | class FeedForward(nn.Module):
5 | def __init__(self, in_dim, h_dim, out_d: Optional[int] = None, act=nnf.relu, dropout=0.):
6 | super().__init__()
7 | out_d = out_d if out_d is not None else in_dim
8 | self.fc1 = nn.Linear(in_dim, h_dim)
9 | self.act = act
10 | self.fc2 = nn.Linear(h_dim, out_d)
11 | self.dropout = nn.Dropout(dropout)
12 |
13 | def forward(self, x):
14 | x = self.fc1(x)
15 | x = self.act(x)
16 | x = self.dropout(x)
17 | x = self.fc2(x)
18 | x = self.dropout(x)
19 | return x
20 |
21 |
22 | class MultiHeadAttention(nn.Module):
23 |
24 | def __init__(self, dim_self, dim_ref, num_heads, bias=True, dropout=0.):
25 | super().__init__()
26 | self.num_heads = num_heads
27 | head_dim = dim_self // num_heads
28 | self.scale = head_dim ** -0.5
29 | self.to_queries = nn.Linear(dim_self, dim_self, bias=bias)
30 | self.to_keys_values = nn.Linear(dim_ref, dim_self * 2, bias=bias)
31 | self.project = nn.Linear(dim_self, dim_self)
32 | self.dropout = nn.Dropout(dropout)
33 |
34 | def forward_interpolation(self, queries: T, keys: T, values: T, alpha: T, mask: TN = None) -> TNS:
35 | attention = torch.einsum('nhd,bmhd->bnmh', queries[0], keys) * self.scale
36 | if mask is not None:
37 | if mask.dim() == 2:
38 | mask = mask.unsqueeze(1)
39 | attention = attention.masked_fill(mask.unsqueeze(3), float("-inf"))
40 | attention = attention.softmax(dim=2)
41 | attention = attention * alpha[:, None, None, None]
42 | out = torch.einsum('bnmh,bmhd->nhd', attention, values).reshape(1, attention.shape[1], -1)
43 | return out, attention
44 |
45 | def forward(self, x, y: Optional[T] = None, mask: Optional[T] = None, alpha: TN = None):
46 | y = y if y is not None else x
47 | b_a, n, c = x.shape
48 | b, m, d = y.shape
49 | # b n h dh
50 | queries = self.to_queries(x).reshape(b_a, n, self.num_heads, c // self.num_heads)
51 | # b m 2 h dh
52 | keys_values = self.to_keys_values(y).reshape(b, m, 2, self.num_heads, c // self.num_heads)
53 | keys, values = keys_values[:, :, 0], keys_values[:, :, 1]
54 | if alpha is not None:
55 | out, attention = self.forward_interpolation(queries, keys, values, alpha, mask)
56 | else:
57 | attention = torch.einsum('bnhd,bmhd->bnmh', queries, keys) * self.scale
58 | if mask is not None:
59 | if mask.dim() == 2:
60 | mask = mask.unsqueeze(1)
61 | attention = attention.masked_fill(mask.unsqueeze(3), float("-inf"))
62 | attention = attention.softmax(dim=2)
63 | out = torch.einsum('bnmh,bmhd->bnhd', attention, values).reshape(b, n, c)
64 | out = self.project(out)
65 | return out, attention
66 |
67 |
68 | class TransformerLayer(nn.Module):
69 |
70 | def forward_with_attention(self, x, y: Optional[T] = None, mask: Optional[T] = None, alpha: TN = None):
71 | x_, attention = self.attn(self.norm1(x), y, mask, alpha)
72 | x = x + x_
73 | x = x + self.mlp(self.norm2(x))
74 | return x, attention
75 |
76 | def forward(self, x, y: Optional[T] = None, mask: Optional[T] = None, alpha: TN = None):
77 | x = x + self.attn(self.norm1(x), y, mask, alpha)[0]
78 | x = x + self.mlp(self.norm2(x))
79 | return x
80 |
81 | def __init__(self, dim_self, dim_ref, num_heads, mlp_ratio=4., bias=False, dropout=0., act=nnf.relu,
82 | norm_layer: nn.Module = nn.LayerNorm):
83 | super().__init__()
84 | self.norm1 = norm_layer(dim_self)
85 | self.attn = MultiHeadAttention(dim_self, dim_ref, num_heads, bias=bias, dropout=dropout)
86 | self.norm2 = norm_layer(dim_self)
87 | self.mlp = FeedForward(dim_self, int(dim_self * mlp_ratio), act=act, dropout=dropout)
88 |
89 |
90 | class DummyTransformer:
91 |
92 | @staticmethod
93 | def forward_with_attention(x, *_, **__):
94 | return x, []
95 |
96 | @staticmethod
97 | def forward(x, *_, **__):
98 | return x
99 |
100 |
101 | class Transformer(nn.Module):
102 |
103 | def forward_with_attention(self, x, y: Optional[T] = None, mask: Optional[T] = None, alpha: TN = None):
104 | attentions = []
105 | for layer in self.layers:
106 | x, att = layer.forward_with_attention(x, y, mask, alpha)
107 | attentions.append(att)
108 | return x, attentions
109 |
110 | def forward(self, x, y: TN = None, mask: TN = None, alpha: TN = None):
111 | for layer in self.layers:
112 | x = layer(x, y, mask, alpha)
113 | return x
114 |
115 | def __init__(self, dim_self: int, num_heads: int, num_layers: int, dim_ref: Optional[int] = None,
116 | mlp_ratio: float = 2., act=nnf.relu, norm_layer: nn.Module = nn.LayerNorm):
117 | super(Transformer, self).__init__()
118 | dim_ref = dim_ref if dim_ref is not None else dim_self
119 | self.layers = nn.ModuleList([TransformerLayer(dim_self, dim_ref, num_heads, mlp_ratio, act=act,
120 | norm_layer=norm_layer) for _ in range(num_layers)])
121 |
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/spaghetti_github/options.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | import os
3 | import pickle
4 | from . import constants as const
5 | from .custom_types import *
6 |
7 |
8 | class Options:
9 |
10 | def load(self):
11 | device = self.device
12 | if os.path.isfile(self.save_path):
13 | print(f'loading opitons from {self.save_path}')
14 | with open(self.save_path, 'rb') as f:
15 | options = pickle.load(f)
16 | options.device = device
17 | return options
18 | return self
19 |
20 | def save(self):
21 | if os.path.isdir(self.cp_folder):
22 | # self.already_saved = True
23 | with open(self.save_path, 'wb') as f:
24 | pickle.dump(self, f, pickle.HIGHEST_PROTOCOL)
25 |
26 | @property
27 | def info(self) -> str:
28 | return f'{self.model_name}_{self.tag}'
29 |
30 | @property
31 | def cp_folder(self):
32 | return f'{const.CHECKPOINTS_ROOT}{self.info}'
33 |
34 | @property
35 | def save_path(self):
36 | return f'{const.CHECKPOINTS_ROOT}{self.info}/options.pkl'
37 |
38 | def fill_args(self, args):
39 | for arg in args:
40 | if hasattr(self, arg):
41 | setattr(self, arg, args[arg])
42 |
43 | def __init__(self, **kwargs):
44 | self.device = CUDA(0)
45 | self.tag = 'airplanes'
46 | self.dataset_name = 'shapenet_airplanes_wm_sphere_sym_train'
47 | self.epochs = 2000
48 | self.model_name = 'spaghetti'
49 | self.dim_z = 256
50 | self.pos_dim = 256 - 3
51 | self.dim_h = 512
52 | self.dim_zh = 512
53 | self.num_gaussians = 16
54 | self.min_split = 4
55 | self.max_split = 12
56 | self.gmm_weight = 1
57 | self.decomposition_network = 'transformer'
58 | self.decomposition_num_layers = 4
59 | self.num_layers = 4
60 | self.num_heads = 4
61 | self.num_layers_head = 6
62 | self.num_heads_head = 8
63 | self.head_occ_size = 5
64 | self.head_occ_type = 'skip'
65 | self.batch_size = 18
66 | self.num_samples = 2000
67 | self.dataset_size = -1
68 | self.symmetric = (True, False, False)
69 | self.data_symmetric = (True, False, False)
70 | self.lr_decay = .9
71 | self.lr_decay_every = 500
72 | self.warm_up = 2000
73 | self.reg_weight = 1e-4
74 | self.disentanglement = True
75 | self.use_encoder = True
76 | self.disentanglement_weight = 1
77 | self.augmentation_rotation = 0.3
78 | self.augmentation_scale = .2
79 | self.augmentation_translation = .3
80 | self.fill_args(kwargs)
81 |
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/spaghetti_github/ui/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KAIST-Visual-AI-Group/SALAD/db66b1b2ce57d8a4beb5bd610bc006ccaff43b1e/notebooks/uncleaned_demo/pvd/spaghetti_github/ui/__init__.py
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/spaghetti_github/ui/mock_keyboard.py:
--------------------------------------------------------------------------------
1 |
2 | class Key:
3 | ctrl_l = 'control_l'
4 |
5 |
6 | class Controller:
7 |
8 | @staticmethod
9 | def press(key: str) -> str:
10 | return key
11 |
12 | @staticmethod
13 | def release(key: str) -> str:
14 | return key
15 |
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/spaghetti_github/utils/myparse.py:
--------------------------------------------------------------------------------
1 | from ..custom_types import *
2 | import sys
3 |
4 |
5 | # minimal argparse
6 | def parse(parse_dict: Dict[str, Dict[str, Any]]):
7 | args = sys.argv[1:]
8 | args_dict = {args[i]: args[i + 1] for i in range(0, len(args), 2)}
9 | out_dict = {}
10 | for item in parse_dict:
11 | hint = parse_dict[item]['type']
12 | if item in args_dict:
13 | out_dict[item[2:]] = hint(args_dict[item])
14 | else:
15 | out_dict[item[2:]] = parse_dict[item]['default']
16 | if 'options' in parse_dict[item] and out_dict[item[2:]] not in parse_dict[item]['options']:
17 | raise ValueError(f'Expected {item} to be in {str(parse_dict[item]["options"])}')
18 | return out_dict
19 |
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/spaghetti_github/utils/rotation_utils.py:
--------------------------------------------------------------------------------
1 | from .. import constants
2 | import functools
3 | from scipy.spatial.transform.rotation import Rotation
4 | from ..custom_types import *
5 |
6 |
7 | def quat_to_rot(q):
8 | shape = q.shape
9 | q = q.view(-1, 4)
10 | q_sq = 2 * q[:, :, None] * q[:, None, :]
11 | m00 = 1 - q_sq[:, 1, 1] - q_sq[:, 2, 2]
12 | m01 = q_sq[:, 0, 1] - q_sq[:, 2, 3]
13 | m02 = q_sq[:, 0, 2] + q_sq[:, 1, 3]
14 |
15 | m10 = q_sq[:, 0, 1] + q_sq[:, 2, 3]
16 | m11 = 1 - q_sq[:, 0, 0] - q_sq[:, 2, 2]
17 | m12 = q_sq[:, 1, 2] - q_sq[:, 0, 3]
18 |
19 | m20 = q_sq[:, 0, 2] - q_sq[:, 1, 3]
20 | m21 = q_sq[:, 1, 2] + q_sq[:, 0, 3]
21 | m22 = 1 - q_sq[:, 0, 0] - q_sq[:, 1, 1]
22 | r = torch.stack((m00, m01, m02, m10, m11, m12, m20, m21, m22), dim=1)
23 | r = r.view(*shape[:-1], 3, 3)
24 | return r
25 |
26 |
27 | def rot_to_quat(r):
28 | shape = r.shape
29 | r = r.view(-1, 3, 3)
30 | qw = .5 * (1 + r[:, 0, 0] + r[:, 1, 1] + r[:, 2, 2]).sqrt()
31 | qx = (r[:, 2, 1] - r[:, 1, 2]) / (4 * qw)
32 | qy = (r[:, 0, 2] - r[:, 2, 0]) / (4 * qw)
33 | qz = (r[:, 1, 0] - r[:, 0, 1]) / (4 * qw)
34 | q = torch.stack((qx, qy, qz, qw), -1)
35 | q = q.view(*shape[:-2], 4)
36 | return q
37 |
38 |
39 | @functools.lru_cache(10)
40 | def get_rotation_matrix(theta: float, axis: float, degree: bool = False) -> ARRAY:
41 | if degree:
42 | theta = theta * np.pi / 180
43 | rotate_mat = np.eye(3)
44 | rotate_mat[axis, axis] = 1
45 | cos_theta, sin_theta = np.cos(theta), np.sin(theta)
46 | rotate_mat[(axis + 1) % 3, (axis + 1) % 3] = cos_theta
47 | rotate_mat[(axis + 2) % 3, (axis + 2) % 3] = cos_theta
48 | rotate_mat[(axis + 1) % 3, (axis + 2) % 3] = sin_theta
49 | rotate_mat[(axis + 2) % 3, (axis + 1) % 3] = -sin_theta
50 | return rotate_mat
51 |
52 |
53 | def get_random_rotation(batch_size: int) -> T:
54 | r = Rotation.random(batch_size).as_matrix().astype(np.float32)
55 | Rotation.random()
56 | return torch.from_numpy(r)
57 |
58 |
59 | def rand_bounded_rotation_matrix(cache_size: int, theta_range: float = .1):
60 |
61 | def create_cache():
62 | # from http://www.realtimerendering.com/resources/GraphicsGems/gemsiii/rand_rotation.c
63 | with torch.no_grad():
64 | theta, phi, z = torch.rand(cache_size, 3).split((1, 1, 1), dim=1)
65 | theta = (2 * theta - 1) * theta_range + 1
66 | theta = np.pi * theta # Rotation about the pole (Z).
67 | phi = phi * 2 * np.pi # For direction of pole deflection.
68 | z = 2 * z * theta_range # For magnitude of pole deflection.
69 | r = z.sqrt()
70 | v = torch.cat((torch.sin(phi) * r, torch.cos(phi) * r, torch.sqrt(2.0 - z)), dim=1)
71 | st = torch.sin(theta).squeeze(1)
72 | ct = torch.cos(theta).squeeze(1)
73 | rot_ = torch.zeros(cache_size, 3, 3)
74 | rot_[:, 0, 0] = ct
75 | rot_[:, 1, 1] = ct
76 | rot_[:, 0, 1] = st
77 | rot_[:, 1, 0] = -st
78 | rot_[:, 2, 2] = 1
79 | rot = (torch.einsum('ba,bd->bad', v, v) - torch.eye(3)[None, :, :]).bmm(rot_)
80 | det = rot.det()
81 | assert (det.gt(0.99) * det.lt(1.0001)).all().item()
82 | return rot
83 |
84 | def get_batch_rot(batch_size):
85 | nonlocal cache
86 | select = torch.randint(cache_size, size=(batch_size,))
87 | return cache[select]
88 |
89 | cache = create_cache()
90 |
91 | return get_batch_rot
92 |
93 |
94 | def transform_rotation(points: T, one_axis=False, max_angle=-1):
95 | r = get_random_rotation(one_axis, max_angle)
96 | transformed = torch.einsum('nd,rd->nr', points, r)
97 | return transformed
98 |
99 |
100 | def tb_to_rot(abc: T) -> T:
101 | c, s = torch.cos(abc), torch.sin(abc)
102 | aa = c[:, 0] * c[:, 1]
103 | ab = c[:, 0] * s[:, 1] * s[:, 2] - c[:, 2] * s[:, 0]
104 | ac = s[:, 0] * s[:, 2] + c[:, 0] * c[:, 2] * s[:, 1]
105 |
106 | ba = c[:, 1] * s[:, 0]
107 | bb = c[:, 0] * c[:, 2] + s.prod(-1)
108 | bc = c[:, 2] * s[:, 0] * s[:, 1] - c[:, 0] * s[:, 2]
109 |
110 | ca = -s[:, 1]
111 | cb = c[:, 1] * s[:, 2]
112 | cc = c[:, 1] * c[:, 2]
113 | return torch.stack((aa, ab, ac, ba, bb, bc, ca, cb, cc), 1).view(-1, 3, 3)
114 |
115 |
116 | def rot_to_tb(rot: T) -> T:
117 | sy = torch.sqrt(rot[:, 0, 0] * rot[:, 0, 0] + rot[:, 1, 0] * rot[:, 1, 0])
118 | out = torch.zeros(rot.shape[0], 3, device = rot.device)
119 | mask = sy.gt(1e-6)
120 | z = torch.atan2(rot[mask, 2, 1], rot[mask, 2, 2])
121 | y = torch.atan2(-rot[mask, 2, 0], sy[mask])
122 | x = torch.atan2(rot[mask, 1, 0], rot[mask, 0, 0])
123 | out[mask] = torch.stack((x, y, z), dim=1)
124 | if not mask.all():
125 | mask = ~mask
126 | z = torch.atan2(-rot[mask, 1, 2], rot[mask, 1, 1])
127 | y = torch.atan2(-rot[mask, 2, 0], sy[mask])
128 | x = torch.zeros(x.shape)
129 | out[mask] = torch.stack((x, y, z), dim=1)
130 | return out
131 |
132 |
133 | def apply_gmm_affine(gmms: TS, affine: T):
134 | mu, p, phi, eigen = gmms
135 | if affine.dim() == 2:
136 | affine = affine.unsqueeze(0).expand(mu.shape[0], *affine.shape)
137 | mu_r = torch.einsum('bad, bpnd->bpna', affine, mu)
138 | p_r = torch.einsum('bad, bpncd->bpnca', affine, p)
139 | return mu_r, p_r, phi, eigen
140 |
141 |
142 | def get_reflection(reflect_axes: Tuple[bool, ...]) -> T:
143 | reflect = torch.eye(constants.DIM)
144 | for i in range(constants.DIM):
145 | if reflect_axes[i]:
146 | reflect[i, i] = -1
147 | return reflect
148 |
149 |
150 | def get_tait_bryan_from_p(p: T) -> T:
151 | # p = p.squeeze(1)
152 | shape = p.shape
153 | rot = p.reshape(-1, 3, 3).permute(0, 2, 1)
154 | angles = rot_to_tb(rot)
155 | angles = angles / np.pi
156 | angles[:, 1] = angles[:, 1] * 2
157 | angles = angles.view(*shape[:2], 3)
158 | return angles
159 |
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/train_clr.py:
--------------------------------------------------------------------------------
1 | import hydra
2 | import pytorch_lightning as pl
3 | import jutils
4 | @hydra.main(config_path="configs", config_name="simclr.yaml")
5 | def main(config):
6 | pl.seed_everything(63)
7 | jutils.sysutil.print_config(config, ("callbacks", "logger", "paths", "experiment", "debug", "data", "trainer", "model"))
8 | model = hydra.utils.instantiate(config.model, _recursive_=True)
9 |
10 | callbacks = []
11 | if config.get("callbacks"):
12 | for cb_name, cb_conf in config.callbacks.items():
13 | if config.get("debug") and cb_name == "model_checkpoint":
14 | continue
15 | callbacks.append(hydra.utils.instantiate(cb_conf))
16 |
17 | logger = []
18 | if config.get("logger"):
19 | for lg_name, lg_conf in config.logger.items():
20 | if config.get("debug") and lg_name == "wandb":
21 | continue
22 | logger.append(hydra.utils.instantiate(lg_conf))
23 |
24 | trainer = hydra.utils.instantiate(
25 | config.trainer,
26 | callbacks=callbacks,
27 | logger=logger,
28 | _convert_="partial",
29 | log_every_n_steps=200,
30 | )
31 |
32 | trainer.fit(model)
33 |
34 | if __name__ == "__main__":
35 | main()
36 |
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/train_ldm.py:
--------------------------------------------------------------------------------
1 | # < Seungwoo >
2 | # Temporarily re-locate files to avoid errors during import.
3 | ####
4 | import jutils
5 |
6 | import hydra
7 | import pytorch_lightning as pl
8 | # import jutils
9 | ####
10 |
11 | @hydra.main(config_path="configs", config_name="train.yaml")
12 | def main(config):
13 | pl.seed_everything(63)
14 |
15 | #### < Seungwoo >
16 | # To make W&B sweep work, need to explicitly initialize W&B
17 | # Remove this line after finishing hyperparameter search
18 | # import wandb
19 | # wandb.init(settings=wandb.Settings(start_method="fork") )
20 | ####
21 |
22 | jutils.sysutil.print_config(config, ("callbacks", "logger", "paths", "experiment", "debug", "data", "trainer", "model"))
23 | model = hydra.utils.instantiate(config.model, _recursive_=True)
24 |
25 | callbacks = []
26 | if config.get("callbacks"):
27 | for cb_name, cb_conf in config.callbacks.items():
28 | if config.get("debug") and cb_name == "model_checkpoint":
29 | continue
30 | # if cb_name == "lr_monitor":
31 | # continue
32 | callbacks.append(hydra.utils.instantiate(cb_conf))
33 |
34 | logger = []
35 | if config.get("logger"):
36 | for lg_name, lg_conf in config.logger.items():
37 | if config.get("debug") and lg_name == "wandb":
38 | continue
39 | logger.append(hydra.utils.instantiate(lg_conf))
40 | trainer = hydra.utils.instantiate(
41 | config.trainer,
42 | callbacks=callbacks,
43 | logger=logger if len(logger) != 0 else False,
44 | # logger=False,
45 | _convert_="partial",
46 | log_every_n_steps=200,
47 | resume_from_checkpoint=config.resume_from_checkpoint
48 | )
49 |
50 | trainer.fit(model)
51 |
52 | if __name__ == "__main__":
53 | main()
54 |
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/train_vae.py:
--------------------------------------------------------------------------------
1 | import hydra
2 | import pytorch_lightning as pl
3 | @hydra.main(config_path="configs", config_name="default.yaml")
4 | def main(config):
5 | pl.seed_everything(63)
6 |
7 | model = hydra.utils.instantiate(config.model)
8 |
9 | callbacks = []
10 | if config.get("callbacks"):
11 | for cb_name, cb_conf in config.callbacks.items():
12 | if config.get("debug") and cb_name == "model_checkpoint":
13 | continue
14 | callbacks.append(hydra.utils.instantiate(cb_conf))
15 |
16 | logger = []
17 | if config.get("logger"):
18 | for lg_name, lg_conf in config.logger.items():
19 | if config.get("debug") and lg_name == "wandb":
20 | continue
21 | logger.append(hydra.utils.instantiate(lg_conf))
22 |
23 | trainer = hydra.utils.instantiate(
24 | config.trainer,
25 | callbacks=callbacks,
26 | logger=logger,
27 | _convert_="partial",
28 | log_every_n_steps=100,
29 | )
30 |
31 | trainer.fit(model)
32 | model.save_latents()
33 |
34 | if __name__ == "__main__":
35 | main()
36 |
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KAIST-Visual-AI-Group/SALAD/db66b1b2ce57d8a4beb5bd610bc006ccaff43b1e/notebooks/uncleaned_demo/pvd/utils/__init__.py
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/utils/contrastive_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import jutils
4 |
5 | def make_distance_matrix(feat_a, feat_b):
6 | """
7 | Input:
8 | feat_a: [B,D]
9 | feat_b: [B,D]
10 | """
11 | pass
12 |
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/utils/ldm_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
4 |
5 |
6 | class DiagonalGaussianDistribution(object):
7 | def __init__(self, parameters, deterministic=False, fixed_std=-1):
8 | self.parameters = parameters
9 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=-1)
10 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
11 | self.deterministic = deterministic
12 | self.std = torch.exp(0.5 * self.logvar)
13 | self.var = torch.exp(self.logvar)
14 | if self.deterministic:
15 | self.var = self.std = torch.zeros_like(self.mean).to(
16 | device=self.parameters.device
17 | )
18 |
19 | if fixed_std >= 0:
20 | self.std = torch.ones_like(self.std) * fixed_std
21 | self.var = self.std ** 2
22 | self.logvar = torch.log(self.var)
23 |
24 | def sample(self):
25 | x = self.mean + self.std * torch.randn(self.mean.shape).to(
26 | device=self.parameters.device
27 | )
28 | return x
29 |
30 | def kl(self, other=None):
31 | num_dim = len(self.mean.shape)
32 |
33 | if self.deterministic:
34 | return torch.Tensor([0.0])
35 | else:
36 | if other is None:
37 | return 0.5 * torch.sum(
38 | torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
39 | # dim=[1, 2, 3]
40 | dim=list(np.arange(1, num_dim)),
41 | )
42 | else:
43 | return 0.5 * torch.sum(
44 | torch.pow(self.mean - other.mean, 2) / other.var
45 | + self.var / other.var
46 | - 1.0
47 | - self.logvar
48 | + other.logvar,
49 | # dim=[1, 2, 3],
50 | dim=list(np.arange(1, num_dim))
51 | )
52 |
53 | def nll(self, sample, dims=[1, 2, 3]):
54 | if self.deterministic:
55 | return torch.Tensor([0.0])
56 | logtwopi = np.log(2.0 * np.pi)
57 | return 0.5 * torch.sum(
58 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
59 | dim=dims,
60 | )
61 |
62 | def mode(self):
63 | return self.mean
64 |
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/utils/train_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 |
4 |
5 | class PolyDecayScheduler(torch.optim.lr_scheduler.LambdaLR):
6 | def __init__(self, optimizer, init_lr, power=0.99, lr_end=1e-7, last_epoch=-1):
7 | def lr_lambda(step):
8 | lr = max(power**step, lr_end / init_lr)
9 | return lr
10 |
11 | super().__init__(optimizer, lr_lambda, last_epoch=last_epoch)
12 |
13 |
14 | def get_dropout_mask(shape, dropout: float, device):
15 | if dropout == 1:
16 | return torch.zeros(shape, device=device, dtype=torch.bool)
17 | elif dropout == 0:
18 | return torch.ones_like(shape, device=device, dtype=torch.bool)
19 | else:
20 | return torch.zeros(shape, device=device).float().uniform_(0, 1) > dropout
21 |
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/utils/transform.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def apply_gmm_affine(gmms, affine):
5 | """
6 | Applies the given Affine transform to Gaussians in the GMM.
7 |
8 | Args:
9 | gmms: (B, M, 16). A batch of GMMs
10 | affine: (3, 3) or (B, 3, 3). A batch of Affine matrices
11 | """
12 | mu, p, phi, eigen = torch.split(gmms, [3, 9, 1, 3], dim=2)
13 | if affine.ndim == 2:
14 | affine = affine.unsqueeze(0).expand(mu.size(0), *affine.shape)
15 |
16 | bs, n_part, _ = mu.shape
17 | p = p.reshape(bs, n_part, 3, 3)
18 |
19 | mu_r = torch.einsum("bad, bnd -> bna", affine, mu)
20 | p_r = torch.einsum("bad, bncd -> bnca", affine, p)
21 | p_r = p_r.reshape(bs, n_part, -1)
22 | gmms_t = torch.cat([mu_r, p_r, phi, eigen], dim=2)
23 | assert gmms.shape == gmms_t.shape, "Input and output shapes must be the same"
24 |
25 | return gmms_t
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/vae/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KAIST-Visual-AI-Group/SALAD/db66b1b2ce57d8a4beb5bd610bc006ccaff43b1e/notebooks/uncleaned_demo/pvd/vae/__init__.py
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/pvd/vae/ae.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from pointnet2_ops.pointnet2_modules import PointnetSAModule, PointnetSAModuleMSG
4 |
5 |
6 |
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .common import *
2 | from .gmm_edit import *
3 | from .io_utils import *
4 | from .vis import make_img_grid
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/utils/common.py:
--------------------------------------------------------------------------------
1 | """
2 | common.py
3 |
4 | A collection of commonly used utility functions.
5 | """
6 |
7 | def detach_cpu_list(list_of_tensor):
8 | result = [item.detach().cpu() for item in list_of_tensor]
9 | return result
10 |
11 |
12 | def compute_batch_sections(n_data: int, batch_size: int):
13 | """
14 | Computes the sections that divide the given data into batches.
15 |
16 | Args:
17 | n_data: int, number of data
18 | batch_size: int, size of a single batch
19 | """
20 | assert isinstance(n_data, int) and n_data > 0
21 | assert isinstance(batch_size, int) and batch_size > 0
22 |
23 | batch_indices = list(range(0, n_data, batch_size))
24 | if batch_indices[-1] != n_data:
25 | batch_indices.append(n_data)
26 |
27 | batch_start = batch_indices[:-1]
28 | batch_end = batch_indices[1:]
29 | assert len(batch_start) == len(batch_end), f"Lengths are different {len(batch_start)} {len(batch_end)}"
30 |
31 | return batch_start, batch_end
--------------------------------------------------------------------------------
/notebooks/uncleaned_demo/utils/vis.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 |
3 | import jutils
4 | import numpy as np
5 | import torch
6 |
7 | from pvd.utils.spaghetti_utils import batch_gaus_to_gmms, get_occ_func
8 |
9 |
10 | def make_img_grid(img_lists, nrow=5):
11 | """
12 | Creates an image grid using the provided images.
13 |
14 | Args:
15 | img_lists: Lists of list of images
16 | nrow: Number of samples in a row.
17 | """
18 | for img_list in img_lists:
19 | assert isinstance(img_list, list)
20 | assert len(img_lists[0]) == len(img_list)
21 | assert len(img_list) > 0
22 | assert isinstance(img_list[0], Image.Image)
23 |
24 | img_pairs = list(zip(*img_lists))
25 |
26 | indices = list(range(0, len(img_lists[0]), nrow))
27 | if indices[-1] != len(img_lists[0]):
28 | indices.append(len(img_lists[0]))
29 | starts, ends = indices[0:-1], indices[1:]
30 |
31 | image_arr = []
32 | for start, end in zip(starts, ends):
33 | row_images = []
34 | for i in range(start, end):
35 | for img in img_pairs[i]:
36 | row_images.append(img)
37 | image_arr.append(row_images)
38 |
39 | return jutils.imageutil.merge_images(image_arr)
40 |
41 |
42 | def decode_and_render(gmm, zh, spaghetti, mesher, camera_kwargs):
43 | if gmm.ndim == 2:
44 | gmm = gmm[None]
45 | if zh.ndim == 2:
46 | zh = zh[None]
47 |
48 | vert, face = decode_gmm_and_intrinsic(
49 | spaghetti, mesher, gmm, zh
50 | )
51 | img = render_mesh(vert, face, camera_kwargs)
52 | return img
53 |
54 |
55 | def decode_gmm_and_intrinsic(spaghetti, mesher, gmm, intrinsic, verbose=False):
56 | assert gmm.ndim == 3 and intrinsic.ndim == 3
57 | assert gmm.shape[0] == intrinsic.shape[0]
58 | assert gmm.shape[1] == intrinsic.shape[1]
59 |
60 | zc, _ = spaghetti.merge_zh(intrinsic, batch_gaus_to_gmms(gmm, gmm.device))
61 |
62 | mesh = mesher.occ_meshing(
63 | decoder=get_occ_func(spaghetti, zc),
64 | res=256,
65 | get_time=False,
66 | verbose=False,
67 | )
68 | assert mesh is not None, "Marching cube failed"
69 | vert, face = list(map(lambda x: jutils.thutil.th2np(x), mesh))
70 | assert isinstance(vert, np.ndarray) and isinstance(face, np.ndarray)
71 |
72 | if verbose:
73 | print(f"Vert: {vert.shape} / Face: {face.shape}")
74 |
75 | # Free GPU memory after computing mesh
76 | _ = zc.cpu()
77 | jutils.sysutil.clean_gpu()
78 | if verbose:
79 | print("Freed GPU memory")
80 |
81 | return vert, face
82 |
83 |
84 | def render_mesh(vert, face, camera_kwargs):
85 | img = jutils.fresnelvis.renderMeshCloud(
86 | mesh={"vert": vert / 2, "face": face}, **camera_kwargs
87 | )
88 | if isinstance(img, np.ndarray):
89 | img = Image.fromarray(img)
90 | assert isinstance(img, Image.Image)
91 | return img
92 |
93 |
--------------------------------------------------------------------------------
/salad/model_components/lstm.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
5 |
6 |
7 | class LSTM(nn.Module):
8 | def __init__(self, text_dim, embedding_dim, vocab_size, padding_idx=0):
9 | super().__init__()
10 | self.padding_idx = padding_idx
11 | self.word_embedding = nn.Embedding(
12 | vocab_size, embedding_dim, padding_idx=padding_idx
13 | )
14 | self.rnn = nn.LSTM(embedding_dim, text_dim, batch_first=True)
15 | self.w_attn = nn.Parameter(torch.Tensor(1, text_dim))
16 | nn.init.xavier_uniform_(self.w_attn)
17 |
18 | def forward(self, padded_tokens, dropout=0.5):
19 | w_emb = self.word_embedding(padded_tokens)
20 | w_emb = F.dropout(w_emb, dropout, self.training)
21 | len_seq = (padded_tokens != self.padding_idx).sum(dim=1).cpu()
22 | x_packed = pack_padded_sequence(
23 | w_emb, len_seq, enforce_sorted=False, batch_first=True
24 | )
25 | B = padded_tokens.shape[0]
26 | rnn_out, _ = self.rnn(x_packed)
27 | rnn_out, dummy = pad_packed_sequence(rnn_out, batch_first=True)
28 | h = rnn_out[torch.arange(B), len_seq - 1]
29 | final_feat, attn = self.word_attention(rnn_out, h, len_seq)
30 | return final_feat, attn
31 |
32 | def word_attention(self, R, h, len_seq):
33 | """
34 | Input:
35 | R: hidden states of the entire words
36 | h: the final hidden state after processing the entire words
37 | len_seq: the length of the sequence
38 | Output:
39 | final_feat: the final feature after the bilinear attention
40 | attn: word attention weights
41 | """
42 | B, N, D = R.shape
43 | device = R.device
44 | len_seq = len_seq.to(device)
45 |
46 | W_attn = (self.w_attn * torch.eye(D).to(device))[None].repeat(B, 1, 1)
47 | score = torch.bmm(torch.bmm(R, W_attn), h.unsqueeze(-1))
48 |
49 | mask = torch.arange(N).reshape(1, N, 1).repeat(B, 1, 1).to(device)
50 | mask = mask < len_seq.reshape(B, 1, 1)
51 |
52 | score = score.masked_fill(mask == 0, -1e9)
53 | attn = F.softmax(score, 1)
54 | final_feat = torch.bmm(R.transpose(1, 2), attn).squeeze(-1)
55 |
56 | return final_feat, attn.squeeze(-1)
57 |
--------------------------------------------------------------------------------
/salad/model_components/simple_module.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import math
5 |
6 | from salad.model_components.transformer import TimeMLP
7 |
8 |
9 | class TimePointwiseLayer(nn.Module):
10 | def __init__(
11 | self,
12 | dim_in,
13 | dim_ctx,
14 | mlp_ratio=2,
15 | act=F.leaky_relu,
16 | dropout=0.0,
17 | use_time=False,
18 | ):
19 | super().__init__()
20 | self.use_time = use_time
21 | self.act = act
22 | self.mlp1 = TimeMLP(
23 | dim_in, dim_in * mlp_ratio, dim_in, dim_ctx, use_time=use_time
24 | )
25 | self.norm1 = nn.LayerNorm(dim_in)
26 |
27 | self.mlp2 = TimeMLP(
28 | dim_in, dim_in * mlp_ratio, dim_in, dim_ctx, use_time=use_time
29 | )
30 | self.norm2 = nn.LayerNorm(dim_in)
31 | self.dropout = nn.Dropout(dropout)
32 |
33 | def forward(self, x, ctx=None):
34 | res = x
35 | x = self.mlp1(x, ctx=ctx)
36 | x = self.norm1(x + res)
37 |
38 | res = x
39 | x = self.mlp2(x, ctx=ctx)
40 | x = self.norm2(x + res)
41 | return x
42 |
43 |
44 | class TimePointWiseEncoder(nn.Module):
45 | def __init__(
46 | self,
47 | dim_in,
48 | dim_ctx=None,
49 | mlp_ratio=2,
50 | act=F.leaky_relu,
51 | dropout=0.0,
52 | use_time=True,
53 | num_layers=6,
54 | last_fc=False,
55 | last_fc_dim_out=None,
56 | ):
57 | super().__init__()
58 | self.last_fc = last_fc
59 | if last_fc:
60 | self.fc = nn.Linear(dim_in, last_fc_dim_out)
61 | self.layers = nn.ModuleList(
62 | [
63 | TimePointwiseLayer(
64 | dim_in,
65 | dim_ctx=dim_ctx,
66 | mlp_ratio=mlp_ratio,
67 | act=act,
68 | dropout=dropout,
69 | use_time=use_time,
70 | )
71 | for _ in range(num_layers)
72 | ]
73 | )
74 |
75 | def forward(self, x, ctx=None):
76 | for i, layer in enumerate(self.layers):
77 | x = layer(x, ctx=ctx)
78 | if self.last_fc:
79 | x = self.fc(x)
80 | return x
81 |
82 |
83 | class TimestepEmbedder(nn.Module):
84 | """
85 | Embeds scalar timesteps into vector representations.
86 | """
87 |
88 | def __init__(self, hidden_size, frequency_embedding_size=256):
89 | super().__init__()
90 | self.mlp = nn.Sequential(
91 | nn.Linear(frequency_embedding_size, hidden_size, bias=True),
92 | nn.SiLU(),
93 | nn.Linear(hidden_size, hidden_size, bias=True),
94 | )
95 | self.frequency_embedding_size = frequency_embedding_size
96 |
97 | @staticmethod
98 | def timestep_embedding(t, dim, max_period=10000):
99 | """
100 | Create sinusoidal timestep embeddings.
101 | :param t: a 1-D Tensor of N indices, one per batch element.
102 | These may be fractional.
103 | :param dim: the dimension of the output.
104 | :param max_period: controls the minimum frequency of the embeddings.
105 | :return: an (N, D) Tensor of positional embeddings.
106 | """
107 | # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
108 | half = dim // 2
109 | freqs = torch.exp(
110 | -math.log(max_period)
111 | * torch.arange(start=0, end=half, dtype=torch.float32)
112 | / half
113 | ).to(device=t.device)
114 | args = t[:, None].float() * freqs[None]
115 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
116 | if dim % 2:
117 | embedding = torch.cat(
118 | [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
119 | )
120 | return embedding
121 |
122 | def forward(self, t):
123 | t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
124 | t_emb = self.mlp(t_freq)
125 | return t_emb
126 |
--------------------------------------------------------------------------------
/salad/model_components/variance_schedule.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from torch.nn import Linear, Module
4 |
5 | class VarianceSchedule(Module):
6 | def __init__(self, num_steps, beta_1, beta_T, mode="linear"):
7 | super().__init__()
8 | # assert mode in ("linear",)
9 | self.num_steps = num_steps
10 | self.beta_1 = beta_1
11 | self.beta_T = beta_T
12 | self.mode = mode
13 |
14 | if mode == "linear":
15 | betas = torch.linspace(beta_1, beta_T, steps=num_steps)
16 | elif mode == "quad":
17 | betas = torch.linspace(beta_1 ** 0.5, beta_T ** 0.5, num_steps) ** 2
18 | elif mode == "cosine":
19 | cosine_s = 8e-3
20 | timesteps = torch.arange(num_steps + 1) / num_steps + cosine_s
21 | alphas = timesteps / (1 + cosine_s) * np.pi / 2
22 | alphas = torch.cos(alphas).pow(2)
23 | betas = 1 - alphas[1:] / alphas[:-1]
24 | betas = betas.clamp(max=0.999)
25 |
26 | betas = torch.cat([torch.zeros([1]), betas], dim=0) # Padding
27 |
28 | alphas = 1 - betas
29 | log_alphas = torch.log(alphas)
30 | for i in range(1, log_alphas.size(0)): # 1 to T
31 | log_alphas[i] += log_alphas[i - 1]
32 | alpha_bars = log_alphas.exp()
33 |
34 | sigmas_flex = torch.sqrt(betas)
35 | sigmas_inflex = torch.zeros_like(sigmas_flex)
36 | for i in range(1, sigmas_flex.size(0)):
37 | sigmas_inflex[i] = ((1 - alpha_bars[i - 1]) / (1 - alpha_bars[i])) * betas[
38 | i
39 | ]
40 | sigmas_inflex = torch.sqrt(sigmas_inflex)
41 |
42 | self.register_buffer("betas", betas)
43 | self.register_buffer("alphas", alphas)
44 | self.register_buffer("alpha_bars", alpha_bars)
45 | self.register_buffer("sigmas_flex", sigmas_flex)
46 | self.register_buffer("sigmas_inflex", sigmas_inflex)
47 |
48 | def uniform_sample_t(self, batch_size):
49 | ts = np.random.choice(np.arange(1, self.num_steps + 1), batch_size)
50 | return ts.tolist()
51 |
52 | def get_sigmas(self, t, flexibility):
53 | assert 0 <= flexibility and flexibility <= 1
54 | sigmas = self.sigmas_flex[t] * flexibility + self.sigmas_inflex[t] * (
55 | 1 - flexibility
56 | )
57 | return sigmas
58 |
--------------------------------------------------------------------------------
/salad/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KAIST-Visual-AI-Group/SALAD/db66b1b2ce57d8a4beb5bd610bc006ccaff43b1e/salad/models/__init__.py
--------------------------------------------------------------------------------
/salad/models/base_model.py:
--------------------------------------------------------------------------------
1 | import pytorch_lightning as pl
2 | import torch
3 | import torch.nn.functional as F
4 | from salad.data.dataset import SALADDataset
5 | from salad.utils.train_util import PolyDecayScheduler
6 |
7 |
8 | class BaseModel(pl.LightningModule):
9 | def __init__(
10 | self,
11 | network,
12 | variance_schedule,
13 | **kwargs,
14 | ):
15 | super().__init__()
16 | self.save_hyperparameters(logger=False)
17 | self.net = network
18 | self.var_sched = variance_schedule
19 |
20 | def forward(self, x):
21 | return self.get_loss(x)
22 |
23 | def step(self, x, stage: str):
24 | loss = self(x)
25 | self.log(
26 | f"{stage}/loss",
27 | loss,
28 | on_step=stage == "train",
29 | prog_bar=True,
30 | )
31 | return loss
32 |
33 | def training_step(self, batch, batch_idx):
34 | x = batch
35 | return self.step(x, "train")
36 |
37 | def add_noise(self, x, t):
38 | """
39 | Input:
40 | x: [B,D] or [B,G,D]
41 | t: list of size B
42 | Output:
43 | x_noisy: [B,D]
44 | beta: [B]
45 | e_rand: [B,D]
46 | """
47 | alpha_bar = self.var_sched.alpha_bars[t]
48 | beta = self.var_sched.betas[t]
49 |
50 | c0 = torch.sqrt(alpha_bar).view(-1, 1) # [B,1]
51 | c1 = torch.sqrt(1 - alpha_bar).view(-1, 1)
52 |
53 | e_rand = torch.randn_like(x)
54 | if e_rand.dim() == 3:
55 | c0 = c0.unsqueeze(1)
56 | c1 = c1.unsqueeze(1)
57 |
58 | x_noisy = c0 * x + c1 * e_rand
59 |
60 | return x_noisy, beta, e_rand
61 |
62 | def get_loss(
63 | self,
64 | x0,
65 | t=None,
66 | noisy_in=False,
67 | beta_in=None,
68 | e_rand_in=None,
69 | ):
70 | if x0.dim() == 2:
71 | B, D = x0.shape
72 | else:
73 | B, G, D = x0.shape
74 | if not noisy_in:
75 | if t is None:
76 | t = self.var_sched.uniform_sample_t(B)
77 | x_noisy, beta, e_rand = self.add_noise(x0, t)
78 | else:
79 | x_noisy = x0
80 | beta = beta_in
81 | e_rand = e_rand_in
82 |
83 | e_theta = self.net(x_noisy, beta=beta)
84 | loss = F.mse_loss(e_theta.flatten(), e_rand.flatten(), reduction="mean")
85 | return loss
86 |
87 | @torch.no_grad()
88 | def sample(
89 | self,
90 | batch_size=0,
91 | return_traj=False,
92 | ):
93 | raise NotImplementedError
94 |
95 | def validation_epoch_end(self, outputs):
96 | if self.hparams.no_run_validation:
97 | return
98 | if not self.trainer.sanity_checking:
99 | if (self.current_epoch) % self.hparams.validation_step == 0:
100 | self.validation()
101 |
102 | def _build_dataset(self, stage):
103 | if hasattr(self, f"data_{stage}"):
104 | return getattr(self, f"data_{stage}")
105 | if stage == "train":
106 | ds = SALADDataset(**self.hparams.dataset_kwargs)
107 | else:
108 | dataset_kwargs = self.hparams.dataset_kwargs.copy()
109 | dataset_kwargs["repeat"] = 1
110 | ds = SALADDataset(**dataset_kwargs)
111 | setattr(self, f"data_{stage}", ds)
112 | return ds
113 |
114 | def _build_dataloader(self, stage):
115 | try:
116 | ds = getattr(self, f"data_{stage}")
117 | except:
118 | ds = self._build_dataset(stage)
119 |
120 | return torch.utils.data.DataLoader(
121 | ds,
122 | batch_size=self.hparams.batch_size,
123 | shuffle=stage == "train",
124 | drop_last=stage == "train",
125 | num_workers=4,
126 | )
127 |
128 | def train_dataloader(self):
129 | return self._build_dataloader("train")
130 |
131 | def val_dataloader(self):
132 | return self._build_dataloader("val")
133 |
134 | def test_dataloader(self):
135 | return self._build_dataloader("test")
136 |
137 | def configure_optimizers(self):
138 | optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
139 | scheduler = PolyDecayScheduler(optimizer, self.hparams.lr, power=0.999)
140 | return [optimizer], [scheduler]
141 |
142 | #TODO move get_wandb_logger to logutil.py
143 | def get_wandb_logger(self):
144 | for logger in self.logger:
145 | if isinstance(logger, pl.loggers.wandb.WandbLogger):
146 | return logger
147 | return None
148 |
--------------------------------------------------------------------------------
/salad/models/phase1.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from salad.models.base_model import BaseModel
4 | from salad.utils import nputil, thutil
5 | from salad.utils.spaghetti_util import clip_eigenvalues, project_eigenvectors
6 |
7 | class Phase1Model(BaseModel):
8 | def __init__(self, network, variance_schedule, **kwargs):
9 | super().__init__(network, variance_schedule, **kwargs)
10 |
11 | @torch.no_grad()
12 | def sample(
13 | self,
14 | batch_size=0,
15 | return_traj=False,
16 | ):
17 | x_T = torch.randn([batch_size, 16, 16]).to(self.device)
18 |
19 | traj = {self.var_sched.num_steps: x_T}
20 | for t in range(self.var_sched.num_steps, 0, -1):
21 | z = torch.randn_like(x_T) if t > 1 else torch.zeros_like(x_T)
22 | alpha = self.var_sched.alphas[t]
23 | alpha_bar = self.var_sched.alpha_bars[t]
24 | sigma = self.var_sched.get_sigmas(t, flexibility=0)
25 |
26 | c0 = 1.0 / torch.sqrt(alpha)
27 | c1 = (1 - alpha) / torch.sqrt(1 - alpha_bar)
28 |
29 | x_t = traj[t]
30 |
31 | beta = self.var_sched.betas[[t] * batch_size]
32 | e_theta = self.net(x_t, beta=beta)
33 | # print(e_theta.norm(-1).mean())
34 |
35 | x_next = c0 * (x_t - c1 * e_theta) + sigma * z
36 | traj[t - 1] = x_next.detach()
37 |
38 | traj[t] = traj[t].cpu()
39 |
40 | if not return_traj:
41 | del traj[t]
42 | if return_traj:
43 | return traj
44 | else:
45 | return traj[0]
46 |
47 | def sampling_gaussians(self, num_shapes):
48 | """
49 | Return:
50 | ldm_gaus: np.ndarray
51 | gt_gaus: np.ndarray
52 | """
53 | ldm_gaus = self.sample(num_shapes)
54 |
55 | if self.hparams.get("global_normalization"):
56 | if not hasattr(self, "data_val"):
57 | self._build_dataset("val")
58 | if self.hparams.get("global_normalization") == "partial":
59 | ldm_gaus = self.data_val.unnormalize_global_static(ldm_gaus, slice(12,None))
60 | elif self.hparams.get("global_normalization") == "all":
61 | ldm_gaus = self.data_val.unnormalize_global_static(ldm_gaus, slice(None))
62 |
63 | ldm_gaus = clip_eigenvalues(ldm_gaus)
64 | ldm_gaus = project_eigenvectors(ldm_gaus)
65 | return ldm_gaus
66 |
--------------------------------------------------------------------------------
/salad/spaghetti/.gitignore:
--------------------------------------------------------------------------------
1 | /assets/*
2 | !/assets/readme_resources/
3 | !/assets/ui_resources/
4 | !/assets/splits/
5 | !/assets/mesh/
6 | *.vtk
7 | .idea/
8 | __pycache__/
9 | **_ig_**
10 |
--------------------------------------------------------------------------------
/salad/spaghetti/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Amir Hertz
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/salad/spaghetti/README.md:
--------------------------------------------------------------------------------
1 |
2 | ### SPAGHETTI: Editing Implicit Shapes through Part Aware Generation
3 |
4 | 
5 |
6 |
7 | ### Installation
8 |
9 | ```
10 | git clone https://github.com/amirhertz/spaghetti && cd spaghetti
11 | conda env create -f environment.yml
12 | conda activate spaghetti
13 | ```
14 |
15 | Install [Pytorch](https://pytorch.org/). The installation during development and testing was pytorch==1.9.0 cudatoolkit=11.1
16 |
17 |
18 | ### Demo
19 | - Download pre-trained models
20 | ```
21 | python download_weights.py
22 | ```
23 | - Run demo
24 | ```
25 | python demo.py --model_name chairs_large --shape_dir samples
26 | ```
27 | or
28 | ```
29 | python demo.py --model_name airplanes --shape_dir samples
30 | ```
31 |
32 | - User controls
33 | - Select shapes from the collection on bottom.
34 | - right click to select / deselect parts
35 | - Click the pencil button will toggle between selection / deselection.
36 | - Transform selected parts is similar to Blender short-keys.
37 | Pressing 'G' / 'R', will start translation / rotation mode. Toggle axis by pressing 'X' / 'Y' / 'Z'. Press 'Esc' to cancel transform.
38 | - Click the broom to reset.
39 |
40 |
41 | ### Adding shapes to demo
42 | - From training data
43 | ```
44 | python shape_inversion.py --model_name --source training --mesh_path --num_samples
45 | ```
46 | - Random generation
47 | ```
48 | python shape_inversion.py --model_name --source random --mesh_path --num_samples
49 | ```
50 | - From existing watertight mesh:
51 | ```
52 | python shape_inversion.py --model_name --mesh_path
53 | ```
54 | For example, to add the provided sample chair to the exiting chairs in the demo:
55 | ```
56 | python shape_inversion.py --model_name chairs_large --mesh_path ./assets/mesh/example.obj
57 | ```
58 |
59 | ### Training
60 | Coming soon.
61 |
62 |
--------------------------------------------------------------------------------
/salad/spaghetti/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KAIST-Visual-AI-Group/SALAD/db66b1b2ce57d8a4beb5bd610bc006ccaff43b1e/salad/spaghetti/__init__.py
--------------------------------------------------------------------------------
/salad/spaghetti/assets/checkpoints/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KAIST-Visual-AI-Group/SALAD/db66b1b2ce57d8a4beb5bd610bc006ccaff43b1e/salad/spaghetti/assets/checkpoints/.gitkeep
--------------------------------------------------------------------------------
/salad/spaghetti/assets/checkpoints/README.md:
--------------------------------------------------------------------------------
1 | Put SPAGHETTI checkpoint direcotires under this directory. For instance, the full path to `spaghetti_chairs_large` directory would be `salad/salad/spaghetti/assets/checkpoints/spaghetti_chairs_large`.
2 |
--------------------------------------------------------------------------------
/salad/spaghetti/assets/readme_resources/teaser-01.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KAIST-Visual-AI-Group/SALAD/db66b1b2ce57d8a4beb5bd610bc006ccaff43b1e/salad/spaghetti/assets/readme_resources/teaser-01.png
--------------------------------------------------------------------------------
/salad/spaghetti/assets/ui_resources/icons-03.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KAIST-Visual-AI-Group/SALAD/db66b1b2ce57d8a4beb5bd610bc006ccaff43b1e/salad/spaghetti/assets/ui_resources/icons-03.png
--------------------------------------------------------------------------------
/salad/spaghetti/assets/ui_resources/icons-04.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KAIST-Visual-AI-Group/SALAD/db66b1b2ce57d8a4beb5bd610bc006ccaff43b1e/salad/spaghetti/assets/ui_resources/icons-04.png
--------------------------------------------------------------------------------
/salad/spaghetti/assets/ui_resources/icons-05.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KAIST-Visual-AI-Group/SALAD/db66b1b2ce57d8a4beb5bd610bc006ccaff43b1e/salad/spaghetti/assets/ui_resources/icons-05.png
--------------------------------------------------------------------------------
/salad/spaghetti/assets/ui_resources/icons-06.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KAIST-Visual-AI-Group/SALAD/db66b1b2ce57d8a4beb5bd610bc006ccaff43b1e/salad/spaghetti/assets/ui_resources/icons-06.png
--------------------------------------------------------------------------------
/salad/spaghetti/assets/ui_resources/icons-10.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KAIST-Visual-AI-Group/SALAD/db66b1b2ce57d8a4beb5bd610bc006ccaff43b1e/salad/spaghetti/assets/ui_resources/icons-10.png
--------------------------------------------------------------------------------
/salad/spaghetti/assets/ui_resources/icons-11.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KAIST-Visual-AI-Group/SALAD/db66b1b2ce57d8a4beb5bd610bc006ccaff43b1e/salad/spaghetti/assets/ui_resources/icons-11.png
--------------------------------------------------------------------------------
/salad/spaghetti/assets/ui_resources/icons-12.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KAIST-Visual-AI-Group/SALAD/db66b1b2ce57d8a4beb5bd610bc006ccaff43b1e/salad/spaghetti/assets/ui_resources/icons-12.png
--------------------------------------------------------------------------------
/salad/spaghetti/assets/ui_resources/icons-13.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KAIST-Visual-AI-Group/SALAD/db66b1b2ce57d8a4beb5bd610bc006ccaff43b1e/salad/spaghetti/assets/ui_resources/icons-13.png
--------------------------------------------------------------------------------
/salad/spaghetti/assets/ui_resources/icons-14.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KAIST-Visual-AI-Group/SALAD/db66b1b2ce57d8a4beb5bd610bc006ccaff43b1e/salad/spaghetti/assets/ui_resources/icons-14.png
--------------------------------------------------------------------------------
/salad/spaghetti/constants.py:
--------------------------------------------------------------------------------
1 | import os
2 | from pathlib import Path
3 | import sys
4 | from salad.utils.paths import SPAGHETTI_DIR
5 |
6 | IS_WINDOWS = sys.platform == 'win32'
7 | get_trace = getattr(sys, 'gettrace', None)
8 | DEBUG = get_trace is not None and get_trace() is not None
9 | EPSILON = 1e-4
10 | DIM = 3
11 | # PROJECT_ROOT = str(Path(os.path.realpath(__file__)).parents[0])
12 | PROJECT_ROOT = str(SPAGHETTI_DIR)
13 | DATA_ROOT = f'{PROJECT_ROOT}/assets/'
14 | CHECKPOINTS_ROOT = f'{DATA_ROOT}checkpoints/'
15 | CACHE_ROOT = f'{DATA_ROOT}cache/'
16 | UI_OUT = f'{DATA_ROOT}ui_export/'
17 | UI_RESOURCES = f'{DATA_ROOT}/ui_resources/'
18 | Shapenet_WT = f'{DATA_ROOT}/ShapeNetCore_wt/'
19 | Shapenet = f'{DATA_ROOT}/ShapeNetCore.v2/'
20 | MAX_VS = 100000
21 | MAX_GAUSIANS = 32
22 |
23 |
--------------------------------------------------------------------------------
/salad/spaghetti/custom_types.py:
--------------------------------------------------------------------------------
1 | # import open3d
2 | import enum
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as nnf
7 | # from .constants import DEBUG
8 | from typing import Tuple, List, Union, Callable, Type, Iterator, Dict, Set, Optional, Any, Sized, Iterable
9 | from types import DynamicClassAttribute
10 | from enum import Enum, unique
11 | import torch.optim.optimizer
12 | import torch.utils.data
13 |
14 | # if DEBUG:
15 | # seed = 99
16 | # torch.manual_seed(seed)
17 | # np.random.seed(seed)
18 |
19 | N = type(None)
20 | V = np.array
21 | ARRAY = np.ndarray
22 | ARRAYS = Union[Tuple[ARRAY, ...], List[ARRAY]]
23 | VS = Union[Tuple[V, ...], List[V]]
24 | VN = Union[V, N]
25 | VNS = Union[VS, N]
26 | T = torch.Tensor
27 | TS = Union[Tuple[T, ...], List[T]]
28 | TN = Optional[T]
29 | TNS = Union[Tuple[TN, ...], List[TN]]
30 | TSN = Optional[TS]
31 | TA = Union[T, ARRAY]
32 |
33 | V_Mesh = Tuple[ARRAY, ARRAY]
34 | T_Mesh = Tuple[T, Optional[T]]
35 | T_Mesh_T = Union[T_Mesh, T]
36 | COLORS = Union[T, ARRAY, Tuple[int, int, int]]
37 |
38 | D = torch.device
39 | CPU = torch.device('cpu')
40 |
41 |
42 | def get_device(device_id: int) -> D:
43 | if not torch.cuda.is_available():
44 | return CPU
45 | device_id = min(torch.cuda.device_count() - 1, device_id)
46 | return torch.device(f'cuda:{device_id}')
47 |
48 |
49 | CUDA = get_device
50 | Optimizer = torch.optim.Adam
51 | Dataset = torch.utils.data.Dataset
52 | DataLoader = torch.utils.data.DataLoader
53 | Subset = torch.utils.data.Subset
54 |
--------------------------------------------------------------------------------
/salad/spaghetti/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KAIST-Visual-AI-Group/SALAD/db66b1b2ce57d8a4beb5bd610bc006ccaff43b1e/salad/spaghetti/models/__init__.py
--------------------------------------------------------------------------------
/salad/spaghetti/models/gm_utils.py:
--------------------------------------------------------------------------------
1 | from ..custom_types import *
2 |
3 |
4 | def get_gm_support(gm, x):
5 | dim = x.shape[-1]
6 | mu, p, phi, eigen = gm
7 | sigma_det = eigen.prod(-1)
8 | eigen_inv = 1 / eigen
9 | sigma_inverse = torch.matmul(p.transpose(3, 4), p * eigen_inv[:, :, :, :, None]).squeeze(1)
10 | phi = torch.softmax(phi, dim=2)
11 | const_1 = phi / torch.sqrt((2 * np.pi) ** dim * sigma_det)
12 | distance = x[:, :, None, :] - mu
13 | mahalanobis_distance = - .5 * torch.einsum('bngd,bgdc,bngc->bng', distance, sigma_inverse, distance)
14 | const_2, _ = mahalanobis_distance.max(dim=2) # for numeric stability
15 | mahalanobis_distance -= const_2[:, :, None]
16 | support = const_1 * torch.exp(mahalanobis_distance)
17 | return support, const_2
18 |
19 |
20 | def gm_log_likelihood_loss(gms: TS, x: T, get_supports: bool = False,
21 | mask: Optional[T] = None, reduction: str = "mean") -> Union[T, Tuple[T, TS]]:
22 |
23 | batch_size, num_points, dim = x.shape
24 | support, const = get_gm_support(gms, x)
25 | probs = torch.log(support.sum(dim=2)) + const
26 | if mask is not None:
27 | probs = probs.masked_select(mask=mask.flatten())
28 | if reduction == 'none':
29 | likelihood = probs.sum(-1)
30 | loss = - likelihood / num_points
31 | else:
32 | likelihood = probs.sum()
33 | loss = - likelihood / (probs.shape[0] * probs.shape[1])
34 | if get_supports:
35 | return loss, support
36 | return loss
37 |
38 |
39 | def split_mesh_by_gmm(mesh: T_Mesh, gmm) -> Dict[int, T]:
40 | faces_split = {}
41 | vs, faces = mesh
42 | vs_mid_faces = vs[faces].mean(1)
43 | _, supports = gm_log_likelihood_loss(gmm, vs_mid_faces.unsqueeze(0), get_supports=True)
44 | supports = supports[0]
45 | label = supports.argmax(1)
46 | for i in range(gmm[1].shape[2]):
47 | select = label.eq(i)
48 | if select.any():
49 | faces_split[i] = faces[select]
50 | else:
51 | faces_split[i] = None
52 | return faces_split
53 |
54 |
55 | def flatten_gmm(gmm: TS) -> T:
56 | b, gp, g, _ = gmm[0].shape
57 | mu, p, phi, eigen = [item.view(b, gp * g, *item.shape[3:]) for item in gmm]
58 | p = p.reshape(*p.shape[:2], -1)
59 | z_gmm = torch.cat((mu, p, phi.unsqueeze(-1), eigen), dim=2)
60 | return z_gmm
61 |
--------------------------------------------------------------------------------
/salad/spaghetti/models/mlp_models.py:
--------------------------------------------------------------------------------
1 |
2 |
3 |
--------------------------------------------------------------------------------
/salad/spaghetti/models/models_utils.py:
--------------------------------------------------------------------------------
1 | from ..custom_types import *
2 | from abc import ABC
3 | import math
4 |
5 |
6 | def torch_no_grad(func):
7 | def wrapper(*args, **kwargs):
8 | with torch.no_grad():
9 | result = func(*args, **kwargs)
10 | return result
11 |
12 | return wrapper
13 |
14 |
15 | class Model(nn.Module, ABC):
16 |
17 | def __init__(self):
18 | super(Model, self).__init__()
19 | self.save_model: Union[None, Callable[[nn.Module]]] = None
20 |
21 | def save(self, **kwargs):
22 | self.save_model(self, **kwargs)
23 |
24 |
25 | class Concatenate(nn.Module):
26 | def __init__(self, dim):
27 | super(Concatenate, self).__init__()
28 | self.dim = dim
29 |
30 | def forward(self, x):
31 | return torch.cat(x, dim=self.dim)
32 |
33 |
34 | class View(nn.Module):
35 |
36 | def __init__(self, *shape):
37 | super(View, self).__init__()
38 | self.shape = shape
39 |
40 | def forward(self, x):
41 | return x.view(*self.shape)
42 |
43 |
44 | class Transpose(nn.Module):
45 |
46 | def __init__(self, dim0, dim1):
47 | super(Transpose, self).__init__()
48 | self.dim0, self.dim1 = dim0, dim1
49 |
50 | def forward(self, x):
51 | return x.transpose(self.dim0, self.dim1)
52 |
53 |
54 | class Dummy(nn.Module):
55 |
56 | def __init__(self, *args):
57 | super(Dummy, self).__init__()
58 |
59 | def forward(self, *args):
60 | return args[0]
61 |
62 |
63 | class SineLayer(nn.Module):
64 | """
65 | From the siren repository
66 | https://colab.research.google.com/github/vsitzmann/siren/blob/master/explore_siren.ipynb
67 | """
68 | def __init__(self, in_features, out_features, bias=True,
69 | is_first=False, omega_0=30):
70 | super().__init__()
71 | self.omega_0 = omega_0
72 | self.is_first = is_first
73 |
74 | self.in_features = in_features
75 | self.linear = nn.Linear(in_features, out_features, bias=bias)
76 | self.output_channels = out_features
77 | self.init_weights()
78 |
79 | def init_weights(self):
80 | with torch.no_grad():
81 | if self.is_first:
82 | self.linear.weight.uniform_(-1 / self.in_features,
83 | 1 / self.in_features)
84 | else:
85 | self.linear.weight.uniform_(-np.sqrt(6 / self.in_features) / self.omega_0,
86 | np.sqrt(6 / self.in_features) / self.omega_0)
87 |
88 | def forward(self, input):
89 | return torch.sin(self.omega_0 * self.linear(input))
90 |
91 |
92 | class MLP(nn.Module):
93 |
94 | def forward(self, x, *_):
95 | return self.net(x)
96 |
97 | def __init__(self, ch: Union[List[int], Tuple[int, ...]], act: nn.Module = nn.ReLU,
98 | weight_norm=False):
99 | super(MLP, self).__init__()
100 | layers = []
101 | for i in range(len(ch) - 1):
102 | layers.append(nn.Linear(ch[i], ch[i + 1]))
103 | if weight_norm:
104 | layers[-1] = nn.utils.weight_norm(layers[-1])
105 | if i < len(ch) - 2:
106 | layers.append(act(True))
107 | self.net = nn.Sequential(*layers)
108 |
109 |
110 | class GMAttend(nn.Module):
111 |
112 | def __init__(self, hidden_dim: int):
113 | super(GMAttend, self).__init__()
114 | self.key_dim = hidden_dim // 8
115 | self.query_w = nn.Linear(hidden_dim, self.key_dim)
116 | self.key_w = nn.Linear(hidden_dim, self.key_dim)
117 | self.value_w = nn.Linear(hidden_dim, hidden_dim)
118 | self.softmax = nn.Softmax(dim=3)
119 | self.gamma = nn.Parameter(torch.zeros(1))
120 | self.scale = 1 / torch.sqrt(torch.tensor(self.key_dim, dtype=torch.float32))
121 |
122 | def forward(self, x):
123 | queries = self.query_w(x)
124 | keys = self.key_w(x)
125 | vals = self.value_w(x)
126 | attention = self.softmax(torch.einsum('bgqf,bgkf->bgqk', queries, keys))
127 | out = torch.einsum('bgvf,bgqv->bgqf', vals, attention)
128 | out = self.gamma * out + x
129 | return out
130 |
131 |
132 | def recursive_to(item, device):
133 | if type(item) is T:
134 | return item.to(device)
135 | elif type(item) is tuple or type(item) is list:
136 | return [recursive_to(item[i], device) for i in range(len(item))]
137 | return item
138 |
139 |
--------------------------------------------------------------------------------
/salad/spaghetti/models/occ_loss.py:
--------------------------------------------------------------------------------
1 | from ..custom_types import *
2 |
3 |
4 | def occupancy_bce(predict: T, winding_gt: T, ignore: Optional[T] = None, *args) -> T:
5 | if winding_gt.dtype is not torch.bool:
6 | winding_gt = winding_gt.gt(0)
7 | labels = winding_gt.flatten().float()
8 | predict = predict.flatten()
9 | if ignore is not None:
10 | ignore = (~ignore).flatten().float()
11 | loss = nnf.binary_cross_entropy_with_logits(predict, labels, weight=ignore)
12 | return loss
13 |
14 |
15 | def reg_z_loss(z: T) -> T:
16 | norms = z.norm(2, 1)
17 | loss = norms.mean()
18 | return loss
19 |
--------------------------------------------------------------------------------
/salad/spaghetti/models/transformer.py:
--------------------------------------------------------------------------------
1 | from ..custom_types import *
2 |
3 |
4 | class FeedForward(nn.Module):
5 | def __init__(self, in_dim, h_dim, out_d: Optional[int] = None, act=nnf.relu, dropout=0.):
6 | super().__init__()
7 | out_d = out_d if out_d is not None else in_dim
8 | self.fc1 = nn.Linear(in_dim, h_dim)
9 | self.act = act
10 | self.fc2 = nn.Linear(h_dim, out_d)
11 | self.dropout = nn.Dropout(dropout)
12 |
13 | def forward(self, x):
14 | x = self.fc1(x)
15 | x = self.act(x)
16 | x = self.dropout(x)
17 | x = self.fc2(x)
18 | x = self.dropout(x)
19 | return x
20 |
21 |
22 | class MultiHeadAttention(nn.Module):
23 |
24 | def __init__(self, dim_self, dim_ref, num_heads, bias=True, dropout=0.):
25 | super().__init__()
26 | self.num_heads = num_heads
27 | head_dim = dim_self // num_heads
28 | self.scale = head_dim ** -0.5
29 | self.to_queries = nn.Linear(dim_self, dim_self, bias=bias)
30 | self.to_keys_values = nn.Linear(dim_ref, dim_self * 2, bias=bias)
31 | self.project = nn.Linear(dim_self, dim_self)
32 | self.dropout = nn.Dropout(dropout)
33 |
34 | def forward_interpolation(self, queries: T, keys: T, values: T, alpha: T, mask: TN = None) -> TNS:
35 | attention = torch.einsum('nhd,bmhd->bnmh', queries[0], keys) * self.scale
36 | if mask is not None:
37 | if mask.dim() == 2:
38 | mask = mask.unsqueeze(1)
39 | attention = attention.masked_fill(mask.unsqueeze(3), float("-inf"))
40 | attention = attention.softmax(dim=2)
41 | attention = attention * alpha[:, None, None, None]
42 | out = torch.einsum('bnmh,bmhd->nhd', attention, values).reshape(1, attention.shape[1], -1)
43 | return out, attention
44 |
45 | def forward(self, x, y: Optional[T] = None, mask: Optional[T] = None, alpha: TN = None):
46 | y = y if y is not None else x
47 | b_a, n, c = x.shape
48 | b, m, d = y.shape
49 | # b n h dh
50 | queries = self.to_queries(x).reshape(b_a, n, self.num_heads, c // self.num_heads)
51 | # b m 2 h dh
52 | keys_values = self.to_keys_values(y).reshape(b, m, 2, self.num_heads, c // self.num_heads)
53 | keys, values = keys_values[:, :, 0], keys_values[:, :, 1]
54 | if alpha is not None:
55 | out, attention = self.forward_interpolation(queries, keys, values, alpha, mask)
56 | else:
57 | attention = torch.einsum('bnhd,bmhd->bnmh', queries, keys) * self.scale
58 | if mask is not None:
59 | if mask.dim() == 2:
60 | mask = mask.unsqueeze(1)
61 | attention = attention.masked_fill(mask.unsqueeze(3), float("-inf"))
62 | attention = attention.softmax(dim=2)
63 | out = torch.einsum('bnmh,bmhd->bnhd', attention, values).reshape(b, n, c)
64 | out = self.project(out)
65 | return out, attention
66 |
67 |
68 | class TransformerLayer(nn.Module):
69 |
70 | def forward_with_attention(self, x, y: Optional[T] = None, mask: Optional[T] = None, alpha: TN = None):
71 | x_, attention = self.attn(self.norm1(x), y, mask, alpha)
72 | x = x + x_
73 | x = x + self.mlp(self.norm2(x))
74 | return x, attention
75 |
76 | def forward(self, x, y: Optional[T] = None, mask: Optional[T] = None, alpha: TN = None):
77 | x = x + self.attn(self.norm1(x), y, mask, alpha)[0]
78 | x = x + self.mlp(self.norm2(x))
79 | return x
80 |
81 | def __init__(self, dim_self, dim_ref, num_heads, mlp_ratio=4., bias=False, dropout=0., act=nnf.relu,
82 | norm_layer: nn.Module = nn.LayerNorm):
83 | super().__init__()
84 | self.norm1 = norm_layer(dim_self)
85 | self.attn = MultiHeadAttention(dim_self, dim_ref, num_heads, bias=bias, dropout=dropout)
86 | self.norm2 = norm_layer(dim_self)
87 | self.mlp = FeedForward(dim_self, int(dim_self * mlp_ratio), act=act, dropout=dropout)
88 |
89 |
90 | class DummyTransformer:
91 |
92 | @staticmethod
93 | def forward_with_attention(x, *_, **__):
94 | return x, []
95 |
96 | @staticmethod
97 | def forward(x, *_, **__):
98 | return x
99 |
100 |
101 | class Transformer(nn.Module):
102 |
103 | def forward_with_attention(self, x, y: Optional[T] = None, mask: Optional[T] = None, alpha: TN = None):
104 | attentions = []
105 | for layer in self.layers:
106 | x, att = layer.forward_with_attention(x, y, mask, alpha)
107 | attentions.append(att)
108 | return x, attentions
109 |
110 | def forward(self, x, y: TN = None, mask: TN = None, alpha: TN = None):
111 | for layer in self.layers:
112 | x = layer(x, y, mask, alpha)
113 | return x
114 |
115 | def __init__(self, dim_self: int, num_heads: int, num_layers: int, dim_ref: Optional[int] = None,
116 | mlp_ratio: float = 2., act=nnf.relu, norm_layer: nn.Module = nn.LayerNorm):
117 | super(Transformer, self).__init__()
118 | dim_ref = dim_ref if dim_ref is not None else dim_self
119 | self.layers = nn.ModuleList([TransformerLayer(dim_self, dim_ref, num_heads, mlp_ratio, act=act,
120 | norm_layer=norm_layer) for _ in range(num_layers)])
121 |
--------------------------------------------------------------------------------
/salad/spaghetti/options.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | import os
3 | import pickle
4 | import salad.spaghetti.constants as const
5 | from salad.spaghetti.custom_types import *
6 |
7 |
8 | class Options:
9 |
10 | def load(self):
11 | device = self.device
12 | if os.path.isfile(self.save_path):
13 | print(f'loading opitons from {self.save_path}')
14 | with open(self.save_path, 'rb') as f:
15 | options = pickle.load(f)
16 | options.device = device
17 | return options
18 | return self
19 |
20 | def save(self):
21 | if os.path.isdir(self.cp_folder):
22 | # self.already_saved = True
23 | with open(self.save_path, 'wb') as f:
24 | pickle.dump(self, f, pickle.HIGHEST_PROTOCOL)
25 |
26 | @property
27 | def info(self) -> str:
28 | return f'{self.model_name}_{self.tag}'
29 |
30 | @property
31 | def cp_folder(self):
32 | return f'{const.CHECKPOINTS_ROOT}{self.info}'
33 |
34 | @property
35 | def save_path(self):
36 | return f'{const.CHECKPOINTS_ROOT}{self.info}/options.pkl'
37 |
38 | def fill_args(self, args):
39 | for arg in args:
40 | if hasattr(self, arg):
41 | setattr(self, arg, args[arg])
42 |
43 | def __init__(self, **kwargs):
44 | self.device = CUDA(0)
45 | self.tag = 'airplanes'
46 | self.dataset_name = 'shapenet_airplanes_wm_sphere_sym_train'
47 | self.epochs = 2000
48 | self.model_name = 'spaghetti'
49 | self.dim_z = 256
50 | self.pos_dim = 256 - 3
51 | self.dim_h = 512
52 | self.dim_zh = 512
53 | self.num_gaussians = 16
54 | self.min_split = 4
55 | self.max_split = 12
56 | self.gmm_weight = 1
57 | self.decomposition_network = 'transformer'
58 | self.decomposition_num_layers = 4
59 | self.num_layers = 4
60 | self.num_heads = 4
61 | self.num_layers_head = 6
62 | self.num_heads_head = 8
63 | self.head_occ_size = 5
64 | self.head_occ_type = 'skip'
65 | self.batch_size = 18
66 | self.num_samples = 2000
67 | self.dataset_size = -1
68 | self.symmetric = (True, False, False)
69 | self.data_symmetric = (True, False, False)
70 | self.lr_decay = .9
71 | self.lr_decay_every = 500
72 | self.warm_up = 2000
73 | self.reg_weight = 1e-4
74 | self.disentanglement = True
75 | self.use_encoder = True
76 | self.disentanglement_weight = 1
77 | self.augmentation_rotation = 0.3
78 | self.augmentation_scale = .2
79 | self.augmentation_translation = .3
80 | self.fill_args(kwargs)
81 |
--------------------------------------------------------------------------------
/salad/spaghetti/ui/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KAIST-Visual-AI-Group/SALAD/db66b1b2ce57d8a4beb5bd610bc006ccaff43b1e/salad/spaghetti/ui/__init__.py
--------------------------------------------------------------------------------
/salad/spaghetti/ui/mock_keyboard.py:
--------------------------------------------------------------------------------
1 |
2 | class Key:
3 | ctrl_l = 'control_l'
4 |
5 |
6 | class Controller:
7 |
8 | @staticmethod
9 | def press(key: str) -> str:
10 | return key
11 |
12 | @staticmethod
13 | def release(key: str) -> str:
14 | return key
15 |
--------------------------------------------------------------------------------
/salad/spaghetti/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KAIST-Visual-AI-Group/SALAD/db66b1b2ce57d8a4beb5bd610bc006ccaff43b1e/salad/spaghetti/utils/__init__.py
--------------------------------------------------------------------------------
/salad/spaghetti/utils/myparse.py:
--------------------------------------------------------------------------------
1 | from ..custom_types import *
2 | import sys
3 |
4 |
5 | # minimal argparse
6 | def parse(parse_dict: Dict[str, Dict[str, Any]]):
7 | args = sys.argv[1:]
8 | args_dict = {args[i]: args[i + 1] for i in range(0, len(args), 2)}
9 | out_dict = {}
10 | for item in parse_dict:
11 | hint = parse_dict[item]['type']
12 | if item in args_dict:
13 | out_dict[item[2:]] = hint(args_dict[item])
14 | else:
15 | out_dict[item[2:]] = parse_dict[item]['default']
16 | if 'options' in parse_dict[item] and out_dict[item[2:]] not in parse_dict[item]['options']:
17 | raise ValueError(f'Expected {item} to be in {str(parse_dict[item]["options"])}')
18 | return out_dict
19 |
--------------------------------------------------------------------------------
/salad/spaghetti/utils/rotation_utils.py:
--------------------------------------------------------------------------------
1 | from .. import constants
2 | import functools
3 | from scipy.spatial.transform.rotation import Rotation
4 | from ..custom_types import *
5 |
6 |
7 | def quat_to_rot(q):
8 | shape = q.shape
9 | q = q.view(-1, 4)
10 | q_sq = 2 * q[:, :, None] * q[:, None, :]
11 | m00 = 1 - q_sq[:, 1, 1] - q_sq[:, 2, 2]
12 | m01 = q_sq[:, 0, 1] - q_sq[:, 2, 3]
13 | m02 = q_sq[:, 0, 2] + q_sq[:, 1, 3]
14 |
15 | m10 = q_sq[:, 0, 1] + q_sq[:, 2, 3]
16 | m11 = 1 - q_sq[:, 0, 0] - q_sq[:, 2, 2]
17 | m12 = q_sq[:, 1, 2] - q_sq[:, 0, 3]
18 |
19 | m20 = q_sq[:, 0, 2] - q_sq[:, 1, 3]
20 | m21 = q_sq[:, 1, 2] + q_sq[:, 0, 3]
21 | m22 = 1 - q_sq[:, 0, 0] - q_sq[:, 1, 1]
22 | r = torch.stack((m00, m01, m02, m10, m11, m12, m20, m21, m22), dim=1)
23 | r = r.view(*shape[:-1], 3, 3)
24 | return r
25 |
26 |
27 | def rot_to_quat(r):
28 | shape = r.shape
29 | r = r.view(-1, 3, 3)
30 | qw = .5 * (1 + r[:, 0, 0] + r[:, 1, 1] + r[:, 2, 2]).sqrt()
31 | qx = (r[:, 2, 1] - r[:, 1, 2]) / (4 * qw)
32 | qy = (r[:, 0, 2] - r[:, 2, 0]) / (4 * qw)
33 | qz = (r[:, 1, 0] - r[:, 0, 1]) / (4 * qw)
34 | q = torch.stack((qx, qy, qz, qw), -1)
35 | q = q.view(*shape[:-2], 4)
36 | return q
37 |
38 |
39 | @functools.lru_cache(10)
40 | def get_rotation_matrix(theta: float, axis: float, degree: bool = False) -> ARRAY:
41 | if degree:
42 | theta = theta * np.pi / 180
43 | rotate_mat = np.eye(3)
44 | rotate_mat[axis, axis] = 1
45 | cos_theta, sin_theta = np.cos(theta), np.sin(theta)
46 | rotate_mat[(axis + 1) % 3, (axis + 1) % 3] = cos_theta
47 | rotate_mat[(axis + 2) % 3, (axis + 2) % 3] = cos_theta
48 | rotate_mat[(axis + 1) % 3, (axis + 2) % 3] = sin_theta
49 | rotate_mat[(axis + 2) % 3, (axis + 1) % 3] = -sin_theta
50 | return rotate_mat
51 |
52 |
53 | def get_random_rotation(batch_size: int) -> T:
54 | r = Rotation.random(batch_size).as_matrix().astype(np.float32)
55 | Rotation.random()
56 | return torch.from_numpy(r)
57 |
58 |
59 | def rand_bounded_rotation_matrix(cache_size: int, theta_range: float = .1):
60 |
61 | def create_cache():
62 | # from http://www.realtimerendering.com/resources/GraphicsGems/gemsiii/rand_rotation.c
63 | with torch.no_grad():
64 | theta, phi, z = torch.rand(cache_size, 3).split((1, 1, 1), dim=1)
65 | theta = (2 * theta - 1) * theta_range + 1
66 | theta = np.pi * theta # Rotation about the pole (Z).
67 | phi = phi * 2 * np.pi # For direction of pole deflection.
68 | z = 2 * z * theta_range # For magnitude of pole deflection.
69 | r = z.sqrt()
70 | v = torch.cat((torch.sin(phi) * r, torch.cos(phi) * r, torch.sqrt(2.0 - z)), dim=1)
71 | st = torch.sin(theta).squeeze(1)
72 | ct = torch.cos(theta).squeeze(1)
73 | rot_ = torch.zeros(cache_size, 3, 3)
74 | rot_[:, 0, 0] = ct
75 | rot_[:, 1, 1] = ct
76 | rot_[:, 0, 1] = st
77 | rot_[:, 1, 0] = -st
78 | rot_[:, 2, 2] = 1
79 | rot = (torch.einsum('ba,bd->bad', v, v) - torch.eye(3)[None, :, :]).bmm(rot_)
80 | det = rot.det()
81 | assert (det.gt(0.99) * det.lt(1.0001)).all().item()
82 | return rot
83 |
84 | def get_batch_rot(batch_size):
85 | nonlocal cache
86 | select = torch.randint(cache_size, size=(batch_size,))
87 | return cache[select]
88 |
89 | cache = create_cache()
90 |
91 | return get_batch_rot
92 |
93 |
94 | def transform_rotation(points: T, one_axis=False, max_angle=-1):
95 | r = get_random_rotation(one_axis, max_angle)
96 | transformed = torch.einsum('nd,rd->nr', points, r)
97 | return transformed
98 |
99 |
100 | def tb_to_rot(abc: T) -> T:
101 | c, s = torch.cos(abc), torch.sin(abc)
102 | aa = c[:, 0] * c[:, 1]
103 | ab = c[:, 0] * s[:, 1] * s[:, 2] - c[:, 2] * s[:, 0]
104 | ac = s[:, 0] * s[:, 2] + c[:, 0] * c[:, 2] * s[:, 1]
105 |
106 | ba = c[:, 1] * s[:, 0]
107 | bb = c[:, 0] * c[:, 2] + s.prod(-1)
108 | bc = c[:, 2] * s[:, 0] * s[:, 1] - c[:, 0] * s[:, 2]
109 |
110 | ca = -s[:, 1]
111 | cb = c[:, 1] * s[:, 2]
112 | cc = c[:, 1] * c[:, 2]
113 | return torch.stack((aa, ab, ac, ba, bb, bc, ca, cb, cc), 1).view(-1, 3, 3)
114 |
115 |
116 | def rot_to_tb(rot: T) -> T:
117 | sy = torch.sqrt(rot[:, 0, 0] * rot[:, 0, 0] + rot[:, 1, 0] * rot[:, 1, 0])
118 | out = torch.zeros(rot.shape[0], 3, device = rot.device)
119 | mask = sy.gt(1e-6)
120 | z = torch.atan2(rot[mask, 2, 1], rot[mask, 2, 2])
121 | y = torch.atan2(-rot[mask, 2, 0], sy[mask])
122 | x = torch.atan2(rot[mask, 1, 0], rot[mask, 0, 0])
123 | out[mask] = torch.stack((x, y, z), dim=1)
124 | if not mask.all():
125 | mask = ~mask
126 | z = torch.atan2(-rot[mask, 1, 2], rot[mask, 1, 1])
127 | y = torch.atan2(-rot[mask, 2, 0], sy[mask])
128 | x = torch.zeros(x.shape)
129 | out[mask] = torch.stack((x, y, z), dim=1)
130 | return out
131 |
132 |
133 | def apply_gmm_affine(gmms: TS, affine: T):
134 | mu, p, phi, eigen = gmms
135 | if affine.dim() == 2:
136 | affine = affine.unsqueeze(0).expand(mu.shape[0], *affine.shape)
137 | mu_r = torch.einsum('bad, bpnd->bpna', affine, mu)
138 | p_r = torch.einsum('bad, bpncd->bpnca', affine, p)
139 | return mu_r, p_r, phi, eigen
140 |
141 |
142 | def get_reflection(reflect_axes: Tuple[bool, ...]) -> T:
143 | reflect = torch.eye(constants.DIM)
144 | for i in range(constants.DIM):
145 | if reflect_axes[i]:
146 | reflect[i, i] = -1
147 | return reflect
148 |
149 |
150 | def get_tait_bryan_from_p(p: T) -> T:
151 | # p = p.squeeze(1)
152 | shape = p.shape
153 | rot = p.reshape(-1, 3, 3).permute(0, 2, 1)
154 | angles = rot_to_tb(rot)
155 | angles = angles / np.pi
156 | angles[:, 1] = angles[:, 1] * 2
157 | angles = angles.view(*shape[:2], 3)
158 | return angles
159 |
--------------------------------------------------------------------------------
/salad/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KAIST-Visual-AI-Group/SALAD/db66b1b2ce57d8a4beb5bd610bc006ccaff43b1e/salad/utils/__init__.py
--------------------------------------------------------------------------------
/salad/utils/logutil.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KAIST-Visual-AI-Group/SALAD/db66b1b2ce57d8a4beb5bd610bc006ccaff43b1e/salad/utils/logutil.py
--------------------------------------------------------------------------------
/salad/utils/nputil.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 | def np2th(ndarray):
5 | if isinstance(ndarray, torch.Tensor):
6 | return ndarray.detach().cpu()
7 | elif isinstance(ndarray, np.ndarray):
8 | return torch.tensor(ndarray).float()
9 | else:
10 | raise ValueError("Input should be either torch.Tensor or np.ndarray")
11 |
--------------------------------------------------------------------------------
/salad/utils/paths.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | import os
3 |
4 | """ project top dir: salad """
5 | PROJECT_DIR = Path(os.path.realpath(__file__)).parents[2]
6 | SALAD_DIR = PROJECT_DIR / "salad"
7 | SPAGHETTI_DIR = SALAD_DIR / "spaghetti"
8 | DATA_DIR = PROJECT_DIR / "data"
9 |
10 | if __name__ == "__main__":
11 | print(PROJECT_DIR)
12 |
--------------------------------------------------------------------------------
/salad/utils/shapeglot_util.py:
--------------------------------------------------------------------------------
1 | # References: https://github.com/optas/shapeglot
2 | # https://github.com/63days/PartGlot.
3 |
4 | from six.moves import cPickle
5 |
6 |
7 | def unpickle_data(file_name, python2_to_3=False):
8 | """Restore data previously saved with pickle_data().
9 | :param file_name: file holding the pickled data.
10 | :param python2_to_3: (boolean), if True, pickle happened under python2x, unpickling under python3x.
11 | :return: a generator over the un-pickled items.
12 | Note, about implementing the python2_to_3 see
13 | https://stackoverflow.com/questions/28218466/unpickling-a-python-2-object-with-python-3
14 | """
15 |
16 | in_file = open(file_name, "rb")
17 | if python2_to_3:
18 | size = cPickle.load(in_file, encoding="latin1")
19 | else:
20 | size = cPickle.load(in_file)
21 |
22 | for _ in range(size):
23 | if python2_to_3:
24 | yield cPickle.load(in_file, encoding="latin1")
25 | else:
26 | yield cPickle.load(in_file)
27 | in_file.close()
28 |
29 |
30 | def get_mask_of_game_data(
31 | game_data: DataFrame,
32 | word2int: Dict,
33 | only_correct: bool,
34 | only_easy_context: bool,
35 | max_seq_len: int,
36 | only_one_part_name: bool,
37 | ):
38 | """
39 | only_correct (if True): mask will be 1 in location iff human listener predicted correctly.
40 | only_easy (if True): uses only easy context examples (more dissimilar triplet chairs)
41 | max_seq_len: drops examples with len(utterance) > max_seq_len
42 | only_one_part_name (if True): uses only utterances describing only one part in the give set.
43 | """
44 | mask = np.array(game_data.correct)
45 | if not only_correct:
46 | mask = np.ones_like(mask, dtype=np.bool)
47 |
48 | if only_easy_context:
49 | context_mask = np.array(game_data.context_condition == "easy", dtype=np.bool)
50 | mask = np.logical_and(mask, context_mask)
51 |
52 | short_mask = np.array(
53 | game_data.text.apply(lambda x: len(x)) <= max_seq_len, dtype=np.bool
54 | )
55 | mask = np.logical_and(mask, short_mask)
56 |
57 | part_indicator, part_mask = get_part_indicator(game_data.text, word2int)
58 | if only_one_part_name:
59 | mask = np.logical_and(mask, part_mask)
60 |
61 | return mask, part_indicator
62 |
--------------------------------------------------------------------------------
/salad/utils/sysutil.py:
--------------------------------------------------------------------------------
1 | import gc
2 |
3 | import rich
4 | from rich.tree import Tree
5 | from rich.syntax import Syntax
6 | import torch
7 | from omegaconf import DictConfig, OmegaConf
8 |
9 |
10 | def print_config(
11 | config,
12 | fields=(
13 | "trainer",
14 | "model",
15 | "callbacks",
16 | "logger",
17 | "seed",
18 | "name",
19 | ),
20 | resolve: bool = True,
21 | save_config: bool = False,
22 | ) -> None:
23 | """Prints content of DictConfig using Rich library and its tree structure.
24 | Args:
25 | config (DictConfig): Configuration composed by Hydra.
26 | fields (Sequence[str], optional): Determines which main fields from config will
27 | be printed and in what order.
28 | resolve (bool, optional): Whether to resolve reference fields of DictConfig.
29 | """
30 |
31 | style = "dim"
32 | tree = Tree("CONFIG", style=style, guide_style=style)
33 |
34 | for field in fields:
35 | branch = tree.add(field, style=style, guide_style=style)
36 |
37 | config_section = config.get(field)
38 | branch_content = str(config_section)
39 | if isinstance(config_section, DictConfig):
40 | branch_content = OmegaConf.to_yaml(config_section, resolve=resolve)
41 |
42 | branch.add(Syntax(branch_content, "yaml"))
43 |
44 | rich.print(tree)
45 |
46 | if save_config:
47 | with open("config_tree.log", "w") as fp:
48 | rich.print(tree, file=fp)
49 |
50 |
51 | def clean_gpu():
52 | gc.collect()
53 | torch.cuda.empty_cache()
54 |
--------------------------------------------------------------------------------
/salad/utils/thutil.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 | def th2np(tensor):
5 | if isinstance(tensor, np.ndarray):
6 | return tensor
7 | if isinstance(tensor, torch.Tensor):
8 | return tensor.detach().cpu().numpy()
9 |
--------------------------------------------------------------------------------
/salad/utils/train_util.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class PolyDecayScheduler(torch.optim.lr_scheduler.LambdaLR):
5 | def __init__(self, optimizer, init_lr, power=0.99, lr_end=1e-7, last_epoch=-1):
6 | def lr_lambda(step):
7 | lr = max(power**step, lr_end / init_lr)
8 | return lr
9 |
10 | super().__init__(optimizer, lr_lambda, last_epoch=last_epoch)
11 |
12 |
13 | def get_dropout_mask(shape, dropout: float, device):
14 | if dropout == 1:
15 | return torch.zeros(shape, device=device, dtype=torch.bool)
16 | elif dropout == 0:
17 | return torch.ones_like(shape, device=device, dtype=torch.bool)
18 | else:
19 | return torch.zeros(shape, device=device).float().uniform_(0, 1) > dropout
20 |
--------------------------------------------------------------------------------
/salad/utils/visutil.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import trimesh
3 | import matplotlib.pyplot as plt
4 | from salad.utils import nputil, thutil, fresnelvis
5 | from PIL import Image
6 |
7 |
8 | def render_pointcloud(
9 | pointcloud,
10 | camPos=np.array([-2, 2, -2]),
11 | camLookat=np.array([0.0, 0.0, 0.0]),
12 | camUp=np.array([0, 1, 0]),
13 | camHeight=2,
14 | resolution=(512, 512),
15 | samples=16,
16 | cloudR=0.006,
17 | ):
18 | pointcloud = thutil.th2np(pointcloud)
19 | img = fresnelvis.renderMeshCloud(
20 | cloud=pointcloud,
21 | camPos=camPos,
22 | camLookat=camLookat,
23 | camUp=camUp,
24 | camHeight=camHeight,
25 | resolution=resolution,
26 | samples=samples,
27 | cloudR=cloudR,
28 | )
29 | return Image.fromarray(img)
30 |
31 |
32 | def render_mesh(
33 | vert,
34 | face,
35 | camPos=np.array([-2, 2, -2]),
36 | camLookat=np.array([0, 0, 0.0]),
37 | camUp=np.array([0, 1, 0]),
38 | camHeight=2,
39 | resolution=(512, 512),
40 | samples=16,
41 | ):
42 | vert, face = list(map(lambda x: thutil.th2np(x), [vert, face]))
43 | mesh = {"vert": vert, "face": face}
44 | img = fresnelvis.renderMeshCloud(
45 | mesh=mesh,
46 | camPos=camPos,
47 | camLookat=camLookat,
48 | camUp=camUp,
49 | camHeight=camHeight,
50 | resolution=resolution,
51 | samples=samples,
52 | )
53 | return Image.fromarray(img)
54 |
55 |
56 | def render_gaussians(
57 | gaussians,
58 | is_bspnet=False,
59 | multiplier=1.0,
60 | gaussians_colors=None,
61 | attn_map=None,
62 | camPos=np.array([-2, 2, -2]),
63 | camLookat=np.array([0.0, 0, 0]),
64 | camUp=np.array([0, 1, 0]),
65 | camHeight=2,
66 | resolution=(512, 512),
67 | samples=16,
68 | ):
69 | gaussians = thutil.th2np(gaussians)
70 | N = gaussians.shape[0]
71 | cmap = plt.get_cmap("jet")
72 |
73 | if attn_map is not None:
74 | assert N == attn_map.shape[0]
75 | vmin, vmax = attn_map.min(), attn_map.max()
76 | if vmin == vmax:
77 | normalized_attn_map = np.zeros_like(attn_map)
78 | else:
79 | normalized_attn_map = (attn_map - vmin) / (vmax - vmin)
80 |
81 | cmap = plt.get_cmap("viridis")
82 |
83 | lights = "rembrandt"
84 | camera_kwargs = dict(
85 | camPos=camPos,
86 | camLookat=camLookat,
87 | camUp=camUp,
88 | camHeight=camHeight,
89 | resolution=resolution,
90 | samples=samples,
91 | )
92 | renderer = fresnelvis.FresnelRenderer(lights=lights, camera_kwargs=camera_kwargs)
93 | for i, g in enumerate(gaussians):
94 | if is_bspnet:
95 | mu, eival, eivec = g[:3], g[3:6], g[6:15]
96 | else:
97 | mu, eivec, eival = g[:3], g[3:12], g[13:]
98 | R = eivec.reshape(3, 3).T
99 | scale = multiplier * np.sqrt(eival)
100 | scale_transform = np.diag((*scale, 1))
101 | rigid_transform = np.hstack((R, mu.reshape(3, 1)))
102 | rigid_transform = np.vstack((rigid_transform, [0, 0, 0, 1]))
103 | sphere = trimesh.creation.icosphere()
104 | sphere.apply_transform(scale_transform)
105 | sphere.apply_transform(rigid_transform)
106 | if attn_map is None and gaussians_colors is None:
107 | color = np.array(cmap(i / N)[:3])
108 | elif attn_map is not None:
109 | color = np.array(cmap(normalized_attn_map[i])[:3])
110 | else:
111 | color = gaussians_colors[i]
112 |
113 | renderer.add_mesh(
114 | sphere.vertices, sphere.faces, color=color, outline_width=None
115 | )
116 | image = renderer.render()
117 | return Image.fromarray(image)
118 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup
2 |
3 | setup(
4 | name="salad",
5 | version=0.1,
6 | description="SALAD: Part-Level Latent Diffusion for 3D Shape Generation and Manipulation",
7 | url="https://github.com/63days/SALAD",
8 | packages=["salad"],
9 | install_requires=[
10 | # "pandas",
11 | # "torch",
12 | # "h5py",
13 | # "Pillow",
14 | # "numpy",
15 | # "matplotlib",
16 | # "six",
17 | # "nltk",
18 | # "pytorch-lightning==1.5.5",
19 | # "hydra-core==1.1.0",
20 | # "omegaconf==2.1.1"
21 | ],
22 | zip_safe=False,
23 | )
24 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import hydra
2 | import pytorch_lightning as pl
3 | from salad.utils import sysutil
4 | @hydra.main(config_path="configs", config_name="train.yaml")
5 | def main(config):
6 | pl.seed_everything(63)
7 | sysutil.print_config(config, ("callbacks", "logger", "paths", "debug", "data", "trainer", "model"))
8 | model = hydra.utils.instantiate(config.model, _recursive_=True)
9 |
10 | callbacks = []
11 | if config.get("callbacks"):
12 | for cb_name, cb_conf in config.callbacks.items():
13 | if config.get("debug") and cb_name == "model_checkpoint":
14 | continue
15 | callbacks.append(hydra.utils.instantiate(cb_conf))
16 |
17 | logger = []
18 | if config.get("logger"):
19 | for lg_name, lg_conf in config.logger.items():
20 | if config.get("debug") and lg_name == "wandb":
21 | continue
22 | logger.append(hydra.utils.instantiate(lg_conf))
23 | trainer = hydra.utils.instantiate(
24 | config.trainer,
25 | callbacks=callbacks,
26 | logger=logger if len(logger) != 0 else False,
27 | _convert_="partial",
28 | log_every_n_steps=200,
29 | resume_from_checkpoint=config.resume_from_checkpoint
30 | )
31 |
32 | trainer.fit(model)
33 |
34 | if __name__ == "__main__":
35 | main()
36 |
--------------------------------------------------------------------------------