├── .gitignore ├── LICENSE ├── README.md ├── configs ├── binary_model.yaml ├── data │ ├── circle_dirs.yaml │ ├── dinner_dirs.yaml │ ├── line_dirs.yaml │ ├── test_tower_dirs.yaml │ └── tower_dirs.yaml ├── object_selection_network.yaml ├── structformer.yaml ├── structformer_no_encoder.yaml └── structformer_no_structure.yaml ├── doc └── rearrange_mugs.gif ├── pyproject.toml ├── requirements.txt ├── scripts └── run_full_pipeline.py ├── setup.cfg ├── setup.py └── src └── structformer ├── __init__.py ├── data ├── __init__.py ├── binary_dataset.py ├── object_set_refer_dataset.py ├── sequence_dataset.py └── tokenizer.py ├── evaluation ├── __init__.py ├── inference.py ├── test_binary_model.py ├── test_object_selection_network.py ├── test_structformer.py ├── test_structformer_no_encoder.py └── test_structformer_no_structure.py ├── models ├── __init__.py ├── object_selection_network.py ├── point_transformer.py └── pose_generation_network.py ├── training ├── __init__.py ├── train_binary_model.py ├── train_object_selection_network.py ├── train_structformer.py ├── train_structformer_no_encoder.py └── train_structformer_no_structure.py └── utils ├── __init__.py ├── brain2 ├── __init__.py ├── camera.py ├── image.py └── pose.py ├── pointnet.py ├── rearrangement.py ├── rotation_continuity.py └── transformations.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Vim temporary files 114 | *.swp 115 | *.swo 116 | 117 | # Spyder project settings 118 | .spyderproject 119 | .spyproject 120 | 121 | # Rope project settings 122 | .ropeproject 123 | 124 | # mkdocs documentation 125 | /site 126 | 127 | # mypy 128 | .mypy_cache/ 129 | .dmypy.json 130 | dmypy.json 131 | 132 | # Pyre type checker 133 | .pyre/ 134 | 135 | # Custom 136 | /experiments 137 | /models 138 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | NVIDIA Source Code License for StructFormer 2 | 1. Definitions 3 | “Licensor” means any person or entity that distributes its Work. 4 | “Software” means the original work of authorship made available under this License. 5 | “Work” means the Software and any additions to or derivative works of the Software that are made available under this License. 6 | The terms “reproduce,” “reproduction,” “derivative works,” and “distribution” have the meaning as provided under U.S. copyright law; provided, however, that for the purposes of this License, derivative works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work. 7 | Works, including the Software, are “made available” under this License by including in or with the Work either (a) a copyright notice referencing the applicability of this License to the Work, or (b) a copy of this License. 8 | 2. License Grant 9 | 2.1 Copyright Grant. Subject to the terms and conditions of this License, each Licensor grants to you a perpetual, worldwide, non-exclusive, royalty-free, copyright license to reproduce, prepare derivative works of, publicly display, publicly perform, sublicense and distribute its Work and any resulting derivative works in any form. 10 | 3. Limitations 11 | 3.1 Redistribution. You may reproduce or distribute the Work only if (a) you do so under this License, (b) you include a complete copy of this License with your distribution, and (c) you retain without modification any copyright, patent, trademark, or attribution notices that are present in the Work. 12 | 3.2 Derivative Works. You may specify that additional or different terms apply to the use, reproduction, and distribution of your derivative works of the Work (“Your Terms”) only if (a) Your Terms provide that the use limitation in Section 3.3 applies to your derivative works, and (b) you identify the specific derivative works that are subject to Your Terms. Notwithstanding Your Terms, this License (including the redistribution requirements in Section 3.1) will continue to apply to the Work itself. 13 | 3.3 Use Limitation. The Work and any derivative works thereof only may be used or intended for use non-commercially. Notwithstanding the foregoing, NVIDIA and its affiliates may use the Work and any derivative works commercially. As used herein, “non-commercially” means for research or evaluation purposes only. 14 | 3.4 Patent Claims. If you bring or threaten to bring a patent claim against any Licensor (including any claim, cross-claim or counterclaim in a lawsuit) to enforce any patents that you allege are infringed by any Work, then your rights under this License from such Licensor (including the grant in Section 2.1) will terminate immediately. 15 | 3.5 Trademarks. This License does not grant any rights to use any Licensor’s or its affiliates’ names, logos, or trademarks, except as necessary to reproduce the notices described in this License. 16 | 3.6 Termination. If you violate any term of this License, then your rights under this License (including the grant in Section 2.1) will terminate immediately. 17 | 4. Disclaimer of Warranty. 18 | THE WORK IS PROVIDED “AS IS” WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF 19 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER THIS LICENSE. 20 | 5. Limitation of Liability. 21 | EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER COMM ERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # StructFormer 2 | 3 | Pytorch implementation for ICRA 2022 paper _StructFormer: Learning Spatial Structure for Language-Guided Semantic Rearrangement of Novel Objects_. [[PDF]](https://arxiv.org/abs/2110.10189) [[Video]](https://youtu.be/6NPdpAtMawM) [[Website]](https://sites.google.com/view/structformer) 4 | 5 | StructFormer rearranges unknown objects into semantically meaningful spatial structures based on high-level language instructions and partial-view 6 | point cloud observations of the scene. The model use multi-modal transformers to predict both which objects to manipulate and where to place them. 7 | 8 |

9 | drawing 10 |

11 | 12 | ## License 13 | The source code is released under the [NVIDIA Source Code License](LICENSE). The dataset is released under [CC BY-NC 4.0](https://creativecommons.org/licenses/by-nc/4.0/). 14 | 15 | 16 | ## Installation 17 | 18 | ``` 19 | pip install -r requirements.txt 20 | pip install -e . 21 | ``` 22 | 23 | ### Notes on Dependencies 24 | - `h5py==2.10`: this specific version is needed. 25 | - `omegaconfg==2.1`: some functions used in this repo are from newer versions 26 | 27 | ### Environments 28 | The code has been tested on ubuntu 18.04 with nvidia driver 460.91, cuda 11.0, python 3.6, and pytorch 1.7. 29 | 30 | ## Organization 31 | Source code in the StructFormer package is mainly organized as: 32 | - data loaders `data` 33 | - models `models` 34 | - training scripts `training` 35 | - inference scripts `evaluation` 36 | 37 | Parameters for data loaders and models are defined in `OmegaConf` yaml files stored in `configs`. 38 | 39 | Trained models are stored in `/experiments` 40 | 41 | ## Quick Start with Pretrained Models 42 | - Set the package root dir: `export STRUCTFORMER=/path/to/StructFormer` 43 | - Download pretrained models from [this link](https://drive.google.com/file/d/1EsptihJv_lPND902P6CYbe00QW-y-rA4/view?usp=sharing) and unzip to the `$STRUCTFORMER/models` folder 44 | - Download the test split of the dataset from [this link](https://drive.google.com/file/d/1e76qJbBJ2bKYq0JzDSRWZjswySX1ftq_/view?usp=sharing) and unzip to the `$STRUCTFORMER/data_new_objects_test_split` 45 | 46 | ### Run StructFormer 47 | ```bash 48 | cd $STRUCTFORMER/scripts/ 49 | python run_full_pipeline.py \ 50 | --dataset_base_dir $STRUCTFORMER/data_new_objects_test_split \ 51 | --object_selection_model_dir $STRUCTFORMER/models/object_selection_network/best_model \ 52 | --pose_generation_model_dir $STRUCTFORMER/models/structformer_circle/best_model \ 53 | --dirs_config $STRUCTFORMER/configs/data/circle_dirs.yaml 54 | ``` 55 | 56 | ### Evaluate Pose Generation Networks 57 | 58 | Where `{model_name}` is one of `structformer_no_encoder`, `structformer_no_structure`, `object_selection_network`, `structformer`, and `{structure}` is one of `circle`, `line`, `tower`, or `dinner`: 59 | 60 | ```bash 61 | cd $STRUCTFORMER/src/structformer/evaluation/ 62 | python test_{model_name}.py \ 63 | --dataset_base_dir $STRUCTFORMER/data_new_objects_test_split \ 64 | --model_dir $STRUCTFORMER/models/{model_name}_{structure}/best_model \ 65 | --dirs_config $STRUCTFORMER/configs/data/{structure}_dirs.yaml 66 | ``` 67 | 68 | ### Evaluate Object Selection Network 69 | 70 | Where `{structure}` is as above: 71 | 72 | ```bash 73 | cd $STRUCTFORMER/src/structformer/evaluation/ 74 | python test_object_selection_network.py \ 75 | --dataset_base_dir $STRUCTFORMER/data_new_objects_test_split \ 76 | --model_dir $STRUCTFORMER/models/object_selection_network/best_model \ 77 | --dirs_config $STRUCTFORMER/configs/data/{structure}_dirs.yaml 78 | ``` 79 | 80 | ## Training 81 | 82 | - Download vocabulary list `type_vocabs_coarse.json` from [this link](https://drive.google.com/file/d/1topawwqMSvwE8Ac-8OiwMApEqwYeR5rc/view?usp=sharing) and unzip to the `$STRUCTFORMER/data_new_objects`. 83 | - Download all data for [circle](https://drive.google.com/file/d/1PTGFcAWBrQmlglygNiJz6p7s0rqe2rtP/view?usp=sharing) and unzip to the `$STRUCTFORMER/data_new_objects`. 84 | 85 | ### Pose Generation Networks 86 | 87 | Where `{model_name}` is one of `structformer_no_encoder`, `structformer_no_structure`, `object_selection_network`, `structformer`, and `{structure}` is one of `circle`, `line`, `tower`, or `dinner`: 88 | 89 | ```bash 90 | cd $STRUCTFORMER/src/structformer/training/ 91 | python train_{model_name}.py \ 92 | --dataset_base_dir $STRUCTFORMER/data_new_objects \ 93 | --main_config $STRUCTFORMER/configs/{model_name}.yaml \ 94 | --dirs_config STRUCTFORMER/configs/data/{structure}_dirs.yaml 95 | ``` 96 | 97 | ### Object Selection Network 98 | ```bash 99 | cd $STRUCTFORMER/src/structformer/training/ 100 | python train_object_selection_network.py \ 101 | --dataset_base_dir $STRUCTFORMER/data_new_objects \ 102 | --main_config $STRUCTFORMER/configs/object_selection_network.yaml \ 103 | --dirs_config $STRUCTFORMER/configs/data/circle_dirs.yaml 104 | ``` 105 | 106 | ## Citation 107 | If you find our work useful in your research, please cite: 108 | ``` 109 | @inproceedings{structformer2022, 110 | title = {StructFormer: Learning Spatial Structure for Language-Guided Semantic Rearrangement of Novel Objects}, 111 | author = {Liu, Weiyu and Paxton, Chris and Hermans, Tucker and Fox, Dieter}, 112 | year = {2022}, 113 | booktitle = {ICRA 2022} 114 | } 115 | ``` -------------------------------------------------------------------------------- /configs/binary_model.yaml: -------------------------------------------------------------------------------- 1 | random_seed: 1 2 | device: 0 3 | 4 | obj_xytheta_relative: False 5 | save_model: True 6 | save_best_model: True 7 | 8 | dataset: 9 | batch_size: 32 10 | max_num_shape_parameters: 5 11 | max_num_objects: 7 12 | max_num_other_objects: 5 13 | max_num_rearrange_features: 0 14 | max_num_anchor_features: 0 15 | num_pts: 1024 16 | num_workers: 4 17 | pin_memory: True 18 | 19 | model: 20 | name: binary_model 21 | num_attention_heads: 8 22 | encoder_hidden_dim: 512 23 | encoder_dropout: 0.0 24 | encoder_activation: relu 25 | encoder_num_layers: 8 26 | object_dropout: 0.1 27 | theta_loss_divide: 3 28 | ignore_rgb: True 29 | 30 | training: 31 | learning_rate: 0.0001 32 | max_epochs: 200 33 | l2: 0.0001 34 | lr_restart: 3000 35 | warmup: 10 -------------------------------------------------------------------------------- /configs/data/circle_dirs.yaml: -------------------------------------------------------------------------------- 1 | experiment_dir: ../../experiments/${env:DATETIME} 2 | dataset_base_dir: ??? 3 | dataset: 4 | dirs: 5 | - ${dataset_base_dir}/examples_circle_new_objects/result 6 | index_dirs: 7 | - index_34k 8 | vocab_dir: ${dataset_base_dir}/type_vocabs_coarse.json -------------------------------------------------------------------------------- /configs/data/dinner_dirs.yaml: -------------------------------------------------------------------------------- 1 | experiment_dir: ../../experiments/${env:DATETIME} 2 | dataset_base_dir: ??? 3 | dataset: 4 | dirs: 5 | - ${dataset_base_dir}/examples_dinner_new_objects/result 6 | index_dirs: 7 | - index_24k 8 | vocab_dir: ${dataset_base_dir}/type_vocabs_coarse.json -------------------------------------------------------------------------------- /configs/data/line_dirs.yaml: -------------------------------------------------------------------------------- 1 | experiment_dir: ../../experiments/${env:DATETIME} 2 | dataset_base_dir: ??? 3 | dataset: 4 | dirs: 5 | - ${dataset_base_dir}/examples_line_new_objects/result 6 | index_dirs: 7 | - index_42k 8 | vocab_dir: ${dataset_base_dir}/type_vocabs_coarse.json -------------------------------------------------------------------------------- /configs/data/test_tower_dirs.yaml: -------------------------------------------------------------------------------- 1 | experiment_dir: ../../experiments/${env:DATETIME} 2 | dataset_base_dir: ??? 3 | dataset: 4 | dirs: 5 | - ${dataset_base_dir}/examples_tower_new_objects/result 6 | index_dirs: 7 | - index_13k 8 | vocab_dir: ${dataset_base_dir}/type_vocabs_coarse.json -------------------------------------------------------------------------------- /configs/data/tower_dirs.yaml: -------------------------------------------------------------------------------- 1 | experiment_dir: ../../experiments/${env:DATETIME} 2 | dataset_base_dir: ??? 3 | dataset: 4 | dirs: 5 | - ${dataset_base_dir}/examples_tower_new_objects/result 6 | index_dirs: 7 | - index_25k 8 | vocab_dir: ${dataset_base_dir}/type_vocabs_coarse.json -------------------------------------------------------------------------------- /configs/object_selection_network.yaml: -------------------------------------------------------------------------------- 1 | random_seed: 1 2 | device: 0 3 | 4 | save_model: True 5 | save_best_model: True 6 | 7 | dataset: 8 | batch_size: 64 9 | max_num_all_objects: 11 10 | max_num_shape_parameters: 5 11 | max_num_rearrange_features: 1 12 | max_num_anchor_features: 3 13 | num_pts: 1024 14 | num_workers: 4 15 | pin_memory: True 16 | 17 | model: 18 | name: object_selection_network 19 | num_attention_heads: 8 20 | encoder_hidden_dim: 32 21 | encoder_dropout: 0.2 22 | encoder_activation: relu 23 | encoder_num_layers: 8 24 | use_focal_loss: True 25 | focal_loss_gamma: 2 26 | 27 | training: 28 | learning_rate: 0.0001 29 | max_epochs: 200 30 | l2: 0.0001 31 | lr_restart: 3000 32 | warmup: 10 -------------------------------------------------------------------------------- /configs/structformer.yaml: -------------------------------------------------------------------------------- 1 | random_seed: 1 2 | device: 0 3 | 4 | obj_xytheta_relative: False 5 | save_model: True 6 | save_best_model: True 7 | 8 | dataset: 9 | batch_size: 32 10 | max_num_shape_parameters: 5 11 | max_num_objects: 7 12 | max_num_other_objects: 5 13 | max_num_rearrange_features: 0 14 | max_num_anchor_features: 0 15 | num_pts: 1024 16 | num_workers: 4 17 | pin_memory: True 18 | use_structure_frame: True 19 | 20 | model: 21 | name: structformer 22 | num_attention_heads: 8 23 | encoder_hidden_dim: 512 24 | encoder_dropout: 0.0 25 | encoder_activation: relu 26 | encoder_num_layers: 8 27 | structure_dropout: 0.5 28 | object_dropout: 0.1 29 | theta_loss_divide: 3 30 | ignore_rgb: True 31 | 32 | training: 33 | learning_rate: 0.0001 34 | max_epochs: 200 35 | l2: 0.0001 36 | lr_restart: 3000 37 | warmup: 10 -------------------------------------------------------------------------------- /configs/structformer_no_encoder.yaml: -------------------------------------------------------------------------------- 1 | random_seed: 1 2 | device: 0 3 | 4 | obj_xytheta_relative: False 5 | save_model: True 6 | save_best_model: True 7 | 8 | dataset: 9 | batch_size: 32 10 | max_num_shape_parameters: 5 11 | max_num_objects: 7 12 | max_num_other_objects: 5 13 | max_num_rearrange_features: 0 14 | max_num_anchor_features: 0 15 | num_pts: 1024 16 | num_workers: 4 17 | pin_memory: True 18 | use_structure_frame: True 19 | 20 | model: 21 | name: structformer_no_encoder 22 | num_attention_heads: 8 23 | encoder_hidden_dim: 512 24 | encoder_dropout: 0.0 25 | encoder_activation: relu 26 | encoder_num_layers: 8 27 | structure_dropout: 0.5 28 | object_dropout: 0.1 29 | theta_loss_divide: 3 30 | ignore_rgb: True 31 | 32 | training: 33 | learning_rate: 0.0001 34 | max_epochs: 200 35 | l2: 0.0001 36 | lr_restart: 3000 37 | warmup: 10 -------------------------------------------------------------------------------- /configs/structformer_no_structure.yaml: -------------------------------------------------------------------------------- 1 | random_seed: 1 2 | device: 0 3 | 4 | obj_xytheta_relative: False 5 | save_model: True 6 | save_best_model: True 7 | 8 | dataset: 9 | batch_size: 32 10 | max_num_shape_parameters: 5 11 | max_num_objects: 7 12 | max_num_other_objects: 5 13 | max_num_rearrange_features: 0 14 | max_num_anchor_features: 0 15 | num_pts: 1024 16 | num_workers: 4 17 | pin_memory: True 18 | use_structure_frame: False 19 | 20 | model: 21 | name: structformer_no_structure 22 | num_attention_heads: 8 23 | encoder_hidden_dim: 512 24 | encoder_dropout: 0.0 25 | encoder_activation: relu 26 | encoder_num_layers: 8 27 | object_dropout: 0.1 28 | theta_loss_divide: 3 29 | ignore_rgb: True 30 | 31 | training: 32 | learning_rate: 0.0001 33 | max_epochs: 200 34 | l2: 0.0001 35 | lr_restart: 3000 36 | warmup: 10 -------------------------------------------------------------------------------- /doc/rearrange_mugs.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wliu88/StructFormer/a81d8c83313f1f24a75ca83aa9dd9f8cfcf27419/doc/rearrange_mugs.gif -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = [ 3 | "setuptools>=42", 4 | "wheel" 5 | ] 6 | build-backend = "setuptools.build_meta" 7 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | warmup-scheduler==0.3 2 | h5py==2.10.0 3 | omegaconf==2.1.1 4 | tqdm==4.63.0 5 | opencv-python==4.5.5.62 6 | open3d==0.15.2 7 | trimesh==3.10.2 8 | torch # cuda 11 requires a different installation, see https://pytorch.org/get-started/locally/ -------------------------------------------------------------------------------- /scripts/run_full_pipeline.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import time 3 | import torch 4 | import numpy as np 5 | import os 6 | import argparse 7 | from omegaconf import OmegaConf 8 | from torch.utils.data import DataLoader 9 | 10 | from structformer.data.tokenizer import Tokenizer 11 | from structformer.evaluation.test_object_selection_network import ObjectSelectionInference 12 | from structformer.evaluation.test_structformer import PriorInference 13 | from structformer.utils.rearrangement import show_pcs_with_predictions, get_initial_scene_idxs, evaluate_target_object_predictions, save_img, show_pcs_with_labels, test_new_vis, show_pcs 14 | from structformer.evaluation.inference import PointCloudRearrangement 15 | 16 | 17 | def run_demo(object_selection_model_dir, pose_generation_model_dir, dirs_config, beam_size=3): 18 | """ 19 | Run a simple demo. Creates the object selection inference model, pose generation model, 20 | and so on. Requires paths to the model + config directories. 21 | """ 22 | object_selection_inference = ObjectSelectionInference(object_selection_model_dir, dirs_cfg) 23 | pose_generation_inference = PriorInference(pose_generation_model_dir, dirs_cfg) 24 | 25 | test_dataset = object_selection_inference.dataset 26 | initial_scene_idxs = get_initial_scene_idxs(test_dataset) 27 | 28 | for idx in range(len(test_dataset)): 29 | if idx not in initial_scene_idxs: 30 | continue 31 | 32 | if idx == 4: 33 | continue 34 | 35 | filename, _ = test_dataset.get_data_index(idx) 36 | scene_id = os.path.split(filename)[1][4:-3] 37 | print("-"*50) 38 | print("Scene No.{}".format(scene_id)) 39 | 40 | # retrieve data 41 | init_datum = test_dataset.get_raw_data(idx) 42 | goal_specification = init_datum["goal_specification"] 43 | object_selection_structured_sentence = init_datum["sentence"][5:] 44 | structure_specification_structured_sentence = init_datum["sentence"][:5] 45 | object_selection_natural_sentence = object_selection_inference.tokenizer.convert_to_natural_sentence( 46 | object_selection_structured_sentence) 47 | structure_specification_natural_sentence = object_selection_inference.tokenizer.convert_structure_params_to_natural_language(structure_specification_structured_sentence) 48 | 49 | # object selection 50 | predictions, gts = object_selection_inference.predict_target_objects(init_datum) 51 | 52 | all_obj_xyzs = init_datum["xyzs"][:len(predictions)] 53 | all_obj_rgbs = init_datum["rgbs"][:len(predictions)] 54 | obj_idxs = [i for i, l in enumerate(predictions) if l == 1.0] 55 | if len(obj_idxs) == 0: 56 | continue 57 | other_obj_idxs = [i for i, l in enumerate(predictions) if l == 0.0] 58 | obj_xyzs = [all_obj_xyzs[i] for i in obj_idxs] 59 | obj_rgbs = [all_obj_rgbs[i] for i in obj_idxs] 60 | other_obj_xyzs = [all_obj_xyzs[i] for i in other_obj_idxs] 61 | other_obj_rgbs = [all_obj_rgbs[i] for i in other_obj_idxs] 62 | 63 | print("\nSelect objects to rearrange...") 64 | print("Instruction:", object_selection_natural_sentence) 65 | print("Visualize groundtruth (dot color) and prediction (ring color)") 66 | show_pcs_with_predictions(init_datum["xyzs"][:len(predictions)], init_datum["rgbs"][:len(predictions)], 67 | gts, predictions, add_table=True, side_view=True) 68 | print("Visualize object to rearrange") 69 | show_pcs(obj_xyzs, obj_rgbs, side_view=True, add_table=True) 70 | 71 | # pose generation 72 | max_num_objects = pose_generation_inference.cfg.dataset.max_num_objects 73 | max_num_other_objects = pose_generation_inference.cfg.dataset.max_num_other_objects 74 | if len(obj_xyzs) > max_num_objects: 75 | print("WARNING: reducing the number of \"query\" objects because this model is trained with a maximum of {} \"query\" objects. Train a new model if a larger number is needed.".format(max_num_objects)) 76 | obj_xyzs = obj_xyzs[:max_num_objects] 77 | obj_rgbs = obj_rgbs[:max_num_objects] 78 | if len(other_obj_xyzs) > max_num_other_objects: 79 | print("WARNING: reducing the number of \"distractor\" objects because this model is trained with a maximum of {} \"distractor\" objects. Train a new model if a larger number is needed.".format(max_num_other_objects)) 80 | other_obj_xyzs = other_obj_xyzs[:max_num_other_objects] 81 | other_obj_rgbs = other_obj_rgbs[:max_num_other_objects] 82 | 83 | pose_generation_datum = pose_generation_inference.dataset.prepare_test_data(obj_xyzs, obj_rgbs, 84 | other_obj_xyzs, other_obj_rgbs, 85 | goal_specification["shape"]) 86 | beam_data = [] 87 | beam_pc_rearrangements = [] 88 | for b in range(beam_size): 89 | datum_copy = copy.deepcopy(pose_generation_datum) 90 | beam_data.append(datum_copy) 91 | beam_pc_rearrangements.append(PointCloudRearrangement(datum_copy)) 92 | 93 | # autoregressive decoding 94 | num_target_objects = beam_pc_rearrangements[0].num_target_objects 95 | 96 | # first predict structure pose 97 | beam_goal_struct_pose, target_object_preds = pose_generation_inference.limited_batch_inference(beam_data) 98 | for b in range(beam_size): 99 | datum = beam_data[b] 100 | datum["struct_x_inputs"] = [beam_goal_struct_pose[b][0]] 101 | datum["struct_y_inputs"] = [beam_goal_struct_pose[b][1]] 102 | datum["struct_z_inputs"] = [beam_goal_struct_pose[b][2]] 103 | datum["struct_theta_inputs"] = [beam_goal_struct_pose[b][3:]] 104 | 105 | # then iteratively predict pose of each object 106 | beam_goal_obj_poses = [] 107 | for obj_idx in range(num_target_objects): 108 | struct_preds, target_object_preds = pose_generation_inference.limited_batch_inference(beam_data) 109 | beam_goal_obj_poses.append(target_object_preds[:, obj_idx]) 110 | for b in range(beam_size): 111 | datum = beam_data[b] 112 | datum["obj_x_inputs"][obj_idx] = target_object_preds[b][obj_idx][0] 113 | datum["obj_y_inputs"][obj_idx] = target_object_preds[b][obj_idx][1] 114 | datum["obj_z_inputs"][obj_idx] = target_object_preds[b][obj_idx][2] 115 | datum["obj_theta_inputs"][obj_idx] = target_object_preds[b][obj_idx][3:] 116 | # concat in the object dim 117 | beam_goal_obj_poses = np.stack(beam_goal_obj_poses, axis=0) 118 | # swap axis 119 | beam_goal_obj_poses = np.swapaxes(beam_goal_obj_poses, 1, 0) # batch size, number of target objects, pose dim 120 | 121 | # move pc 122 | for bi in range(beam_size): 123 | beam_pc_rearrangements[bi].set_goal_poses(beam_goal_struct_pose[bi], beam_goal_obj_poses[bi]) 124 | beam_pc_rearrangements[bi].rearrange() 125 | 126 | print("\nRearrange \"query\" objects...") 127 | print("Instruction:", structure_specification_natural_sentence) 128 | for pi, pc_rearrangement in enumerate(beam_pc_rearrangements): 129 | print("Visualize rearranged scene sample {}".format(pi)) 130 | pc_rearrangement.visualize("goal", add_other_objects=True, add_table=True, side_view=True) 131 | 132 | 133 | if __name__ == "__main__": 134 | parser = argparse.ArgumentParser(description="Run a simple model") 135 | parser.add_argument("--dataset_base_dir", help='location of the dataset', type=str) 136 | parser.add_argument("--object_selection_model_dir", help='location for the saved object selection model', type=str) 137 | parser.add_argument("--pose_generation_model_dir", help='location for the saved pose generation model', type=str) 138 | parser.add_argument("--dirs_config", help='config yaml file for directories', type=str) 139 | args = parser.parse_args() 140 | 141 | os.environ["DATETIME"] = time.strftime("%Y%m%d-%H%M%S") 142 | 143 | # # debug only 144 | # args.dataset_base_dir = "/home/weiyu/data_drive/data_new_objects_test_split" 145 | # args.object_selection_model_dir = "/home/weiyu/Research/intern/StructFormer/models/object_selection_network/best_model" 146 | # args.pose_generation_model_dir = "/home/weiyu/Research/intern/StructFormer/models/structformer_circle/best_model" 147 | # args.dirs_config = "/home/weiyu/Research/intern/StructFormer/configs/data/circle_dirs.yaml" 148 | 149 | if args.dirs_config: 150 | assert os.path.exists(args.dirs_config), "Cannot find config yaml file at {}".format(args.dirs_config) 151 | dirs_cfg = OmegaConf.load(args.dirs_config) 152 | dirs_cfg.dataset_base_dir = args.dataset_base_dir 153 | OmegaConf.resolve(dirs_cfg) 154 | else: 155 | dirs_cfg = None 156 | 157 | run_demo(args.object_selection_model_dir, args.pose_generation_model_dir, dirs_cfg) 158 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | name = StructFormer-wliu88 3 | version = 0.0.1 4 | author = Weiyu Liu 5 | author_email = wliu88@gatech.edu 6 | description = Source code for StructFormer 7 | long_description = file: README.md 8 | long_description_content_type = text/markdown 9 | url = https://github.com/wliu88/StructFormer 10 | project_urls = 11 | Bug Tracker = https://github.com/wliu88/StructFormer/issues 12 | classifiers = 13 | Programming Language :: Python :: 3 14 | License :: OSI Approved :: NVIDIA License 15 | Operating System :: OS Independent 16 | 17 | [options] 18 | package_dir = 19 | = src 20 | packages = find: 21 | python_requires = >=3.6 22 | 23 | [options.packages.find] 24 | where = src 25 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | with open("README.md", "r", encoding="utf-8") as fh: 4 | long_description = fh.read() 5 | 6 | setuptools.setup( 7 | name="Structformer-wliu88", 8 | version="0.0.1", 9 | author="Weiyu Liu", 10 | author_email="wliu88@gatech.edu", 11 | description="Source code for StructFormer", 12 | long_description=long_description, 13 | long_description_content_type="text/markdown", 14 | url="https://github.com/wliu88/StructFormer", 15 | project_urls={ 16 | "Bug Tracker": "https://github.com/wliu88/StructFormer/issues", 17 | }, 18 | classifiers=[ 19 | "Programming Language :: Python :: 3", 20 | "License :: OSI Approved :: NVIDIA License", 21 | "Operating System :: OS Independent", 22 | ], 23 | package_dir={"": "src"}, 24 | packages=setuptools.find_packages(where="src"), 25 | python_requires=">=3.6", 26 | ) 27 | 28 | -------------------------------------------------------------------------------- /src/structformer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wliu88/StructFormer/a81d8c83313f1f24a75ca83aa9dd9f8cfcf27419/src/structformer/__init__.py -------------------------------------------------------------------------------- /src/structformer/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wliu88/StructFormer/a81d8c83313f1f24a75ca83aa9dd9f8cfcf27419/src/structformer/data/__init__.py -------------------------------------------------------------------------------- /src/structformer/evaluation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wliu88/StructFormer/a81d8c83313f1f24a75ca83aa9dd9f8cfcf27419/src/structformer/evaluation/__init__.py -------------------------------------------------------------------------------- /src/structformer/evaluation/inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import copy 5 | 6 | import structformer.utils.transformations as tra 7 | from structformer.utils.rearrangement import show_pcs, save_pcs, move_one_object_pc, make_gifs, modify_language, sample_gaussians, fit_gaussians, show_pcs_color_order 8 | 9 | 10 | class PointCloudRearrangement: 11 | 12 | """ 13 | helps to keep track of point clouds and predicted object poses for inference 14 | 15 | ToDo: make the whole thing live on pytorch tensor 16 | ToDo: support binary format 17 | """ 18 | 19 | def __init__(self, initial_datum, pose_format="xyz+3x3", use_structure_frame=True): 20 | 21 | assert pose_format == "xyz+3x3", "{} pose format not supported".format(pose_format) 22 | 23 | self.use_structure_frame = use_structure_frame 24 | 25 | self.num_target_objects = None 26 | self.num_other_objects = None 27 | 28 | # important: we do not store any padding pcs and poses 29 | self.initial_xyzs = {"xyzs": [], "rgbs": [], "other_xyzs": [], "other_rgbs": []} 30 | self.goal_xyzs = {"xyzs": [], "rgbs": []} 31 | self.goal_poses = {"obj_poses": []} 32 | if self.use_structure_frame: 33 | self.goal_poses["struct_pose"] = [] 34 | 35 | self.set_initial_pc(initial_datum) 36 | 37 | def set_initial_pc(self, datum): 38 | 39 | self.num_target_objects = np.sum(np.array(datum["object_pad_mask"]) == 0) 40 | self.num_other_objects = np.sum(np.array(datum["other_object_pad_mask"]) == 0) 41 | 42 | self.initial_xyzs["xyzs"] = datum["xyzs"][:self.num_target_objects] 43 | self.initial_xyzs["rgbs"] = datum["rgbs"][:self.num_target_objects] 44 | self.initial_xyzs["other_xyzs"] = datum["other_xyzs"][:self.num_other_objects] 45 | self.initial_xyzs["other_rgbs"] = datum["other_rgbs"][:self.num_other_objects] 46 | 47 | def set_goal_poses(self, goal_struct_pose, goal_obj_poses, input_pose_format="xyz+3x3", 48 | n_obj_idxs=None, skip_update_struct=False): 49 | """ 50 | 51 | :param goal_struct_pose: 52 | :param goal_obj_poses: 53 | :param input_pose_format: 54 | :param n_obj_idxs: only set the goal poses for the indexed objects 55 | :param skip_update_struct: if set to true, do not update struct pose 56 | :return: 57 | """ 58 | 59 | # in these cases, we need to ensure the goal poses have already been set so that we can only update some of it 60 | if n_obj_idxs is not None or skip_update_struct: 61 | assert len(self.goal_poses["obj_poses"]) != 0 62 | if self.use_structure_frame: 63 | assert len(self.goal_poses["struct_pose"]) != 0 64 | 65 | if input_pose_format == "xyz+3x3": 66 | # check input 67 | if not skip_update_struct and self.use_structure_frame: 68 | assert len(goal_struct_pose) == 12 69 | if n_obj_idxs is None: 70 | # in case the input contains padding poses 71 | if len(goal_obj_poses) != self.num_target_objects: 72 | goal_obj_poses = goal_obj_poses[:self.num_target_objects] 73 | else: 74 | assert len(goal_obj_poses) == len(n_obj_idxs) 75 | assert all(len(gop) == 12 for gop in goal_obj_poses) 76 | 77 | # convert to standard form 78 | if not skip_update_struct and self.use_structure_frame: 79 | if type(goal_struct_pose) != list: 80 | goal_struct_pose = goal_struct_pose.tolist() 81 | if type(goal_obj_poses) != list: 82 | goal_obj_poses = goal_obj_poses.tolist() 83 | for i in range(len(goal_obj_poses)): 84 | if type(goal_obj_poses[i]) != list: 85 | goal_obj_poses[i] = goal_obj_poses.tolist() 86 | 87 | elif input_pose_format == "flat:xyz+3x3": 88 | # check input 89 | if not skip_update_struct and self.use_structure_frame: 90 | assert len(goal_struct_pose) == 12 91 | if n_obj_idxs is None: 92 | # flat means that object poses are in one list instead of a list of lists 93 | assert len(goal_obj_poses) == self.num_target_objects * 12 94 | else: 95 | assert len(goal_obj_poses) == len(n_obj_idxs) * 12 96 | 97 | # convert to standard form 98 | if not skip_update_struct and self.use_structure_frame: 99 | if type(goal_struct_pose) != list: 100 | goal_struct_pose = goal_struct_pose.tolist() 101 | if type(goal_obj_poses) != list: 102 | goal_obj_poses = goal_obj_poses.tolist() 103 | 104 | goal_obj_poses = np.array(goal_obj_poses).reshape(-1, 12).tolist() 105 | 106 | elif input_pose_format == "flat:xyz+rpy": 107 | # check input 108 | if not skip_update_struct and self.use_structure_frame: 109 | assert len(goal_struct_pose) == 6 110 | if n_obj_idxs is None: 111 | assert len(goal_obj_poses) == self.num_target_objects * 6 112 | else: 113 | assert len(goal_obj_poses) == len(n_obj_idxs) * 6 114 | 115 | # convert to standard form 116 | if not skip_update_struct and self.use_structure_frame: 117 | if type(goal_struct_pose) != list: 118 | goal_struct_pose = goal_struct_pose.tolist() 119 | if type(goal_obj_poses) != list: 120 | goal_obj_poses = np.array(goal_obj_poses).reshape(-1, 6).tolist() 121 | 122 | if not skip_update_struct and self.use_structure_frame: 123 | goal_struct_pose = goal_struct_pose[:3] + tra.euler_matrix(goal_struct_pose[3], goal_struct_pose[4], goal_struct_pose[5])[:3, :3].flatten().tolist() 124 | converted_goal_obj_poses = [] 125 | for gop in goal_obj_poses: 126 | converted_goal_obj_poses.append( 127 | gop[:3] + tra.euler_matrix(gop[3], gop[4], gop[5])[:3, :3].flatten().tolist()) 128 | goal_obj_poses = converted_goal_obj_poses 129 | 130 | else: 131 | raise KeyError 132 | 133 | # update 134 | if not skip_update_struct and self.use_structure_frame: 135 | self.goal_poses["struct_pose"] = goal_struct_pose 136 | if n_obj_idxs is None: 137 | self.goal_poses["obj_poses"] = goal_obj_poses 138 | else: 139 | for count, oi in enumerate(n_obj_idxs): 140 | self.goal_poses["obj_poses"][oi] = goal_obj_poses[count] 141 | 142 | def get_goal_poses(self, output_pose_format="xyz+3x3", 143 | n_obj_idxs=None, skip_update_struct=False, combine_struct_objs=False): 144 | """ 145 | 146 | :param output_pose_format: 147 | :param n_obj_idxs: only retrieve the goal poses for the indexed objects 148 | :param skip_update_struct: if set to true, do not retrieve struct pose 149 | :param combine_struct_objs: one output, return a list of lists, where the first list if for the structure pose 150 | and remainings are for object poses 151 | :return: 152 | """ 153 | if output_pose_format == "xyz+3x3": 154 | if self.use_structure_frame: 155 | goal_struct_pose = self.goal_poses["struct_pose"] 156 | goal_obj_poses = self.goal_poses["obj_poses"] 157 | 158 | if n_obj_idxs is not None: 159 | goal_obj_poses = [goal_obj_poses[i] for i in n_obj_idxs] 160 | 161 | elif output_pose_format == "flat:xyz+3x3": 162 | if self.use_structure_frame: 163 | goal_struct_pose = self.goal_poses["struct_pose"] 164 | 165 | if n_obj_idxs is None: 166 | goal_obj_poses = np.array(self.goal_poses["obj_poses"]).flatten().tolist() 167 | else: 168 | goal_obj_poses = np.array([self.goal_poses["obj_poses"][i] for i in n_obj_idxs]).flatten().tolist() 169 | 170 | elif output_pose_format == "flat:xyz+rpy": 171 | if self.use_structure_frame: 172 | ax, ay, az = tra.euler_from_matrix(np.asarray(self.goal_poses["struct_pose"][3:]).reshape(3, 3)) 173 | goal_struct_pose = self.goal_poses["struct_pose"][:3] + [ax, ay, az] 174 | 175 | goal_obj_poses = [] 176 | for gop in self.goal_poses["obj_poses"]: 177 | ax, ay, az = tra.euler_from_matrix(np.asarray(gop[3:]).reshape(3, 3)) 178 | goal_obj_poses.append(gop[:3] + [ax, ay, az]) 179 | 180 | if n_obj_idxs is None: 181 | goal_obj_poses = np.array(goal_obj_poses).flatten().tolist() 182 | else: 183 | goal_obj_poses = np.array([goal_obj_poses[i] for i in n_obj_idxs]).flatten().tolist() 184 | 185 | else: 186 | raise KeyError 187 | 188 | if not skip_update_struct and self.use_structure_frame: 189 | if not combine_struct_objs: 190 | return goal_struct_pose, goal_obj_poses 191 | else: 192 | return [goal_struct_pose] + goal_obj_poses 193 | else: 194 | return None, goal_obj_poses 195 | 196 | def rearrange(self, n_obj_idxs=None): 197 | """ 198 | use stored object point clouds of the initial scene and goal poses to 199 | compute object point clouds of the goal scene. 200 | 201 | :param n_obj_idxs: only update the goal point clouds of indexed objects 202 | :return: 203 | """ 204 | 205 | # initial scene and goal poses have to be set first 206 | assert all(len(self.initial_xyzs[k]) != 0 for k in ["xyzs", "rgbs"]) 207 | assert all(len(self.goal_poses[k]) != 0 for k in self.goal_poses) 208 | 209 | # whether we are initializing or updating 210 | no_goal_xyzs_yet = True 211 | if len(self.goal_xyzs["xyzs"]): 212 | no_goal_xyzs_yet = False 213 | 214 | if n_obj_idxs is not None: 215 | assert no_goal_xyzs_yet is False 216 | 217 | if n_obj_idxs is not None: 218 | update_obj_idxs = n_obj_idxs 219 | else: 220 | update_obj_idxs = list(range(self.num_target_objects)) 221 | 222 | if self.use_structure_frame: 223 | goal_struct_pose = self.goal_poses["struct_pose"] 224 | else: 225 | goal_struct_pose = None 226 | for obj_idx in update_obj_idxs: 227 | imagined_obj_xyz, imagined_obj_rgb = move_one_object_pc(self.initial_xyzs["xyzs"][obj_idx], 228 | self.initial_xyzs["rgbs"][obj_idx], 229 | self.goal_poses["obj_poses"][obj_idx], 230 | goal_struct_pose) 231 | 232 | if no_goal_xyzs_yet: 233 | self.goal_xyzs["xyzs"].append(imagined_obj_xyz) 234 | self.goal_xyzs["rgbs"].append(imagined_obj_rgb) 235 | else: 236 | self.goal_xyzs["xyzs"][obj_idx] = imagined_obj_xyz 237 | self.goal_xyzs["rgbs"][obj_idx] = imagined_obj_rgb 238 | 239 | def visualize(self, time_step, add_other_objects=False, 240 | add_coordinate_frame=False, side_view=False, add_table=False, 241 | show_vis=True, save_vis=False, save_filename=None, order_color=False): 242 | 243 | if time_step == "initial": 244 | xyzs = self.initial_xyzs["xyzs"] 245 | rgbs = self.initial_xyzs["rgbs"] 246 | elif time_step == "goal": 247 | xyzs = self.goal_xyzs["xyzs"] 248 | rgbs = self.goal_xyzs["rgbs"] 249 | else: 250 | raise KeyError() 251 | 252 | if add_other_objects: 253 | xyzs += self.initial_xyzs["other_xyzs"] 254 | rgbs += self.initial_xyzs["other_rgbs"] 255 | 256 | if show_vis: 257 | if not order_color: 258 | show_pcs(xyzs, rgbs, add_coordinate_frame=add_coordinate_frame, side_view=side_view, add_table=add_table) 259 | else: 260 | show_pcs_color_order(xyzs, rgbs, add_coordinate_frame=add_coordinate_frame, side_view=side_view, add_table=add_table) 261 | 262 | if save_vis and save_filename is not None: 263 | save_pcs(xyzs, rgbs, save_path=save_filename, add_coordinate_frame=add_coordinate_frame, side_view=side_view, add_table=add_table) -------------------------------------------------------------------------------- /src/structformer/evaluation/test_binary_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import os 4 | import copy 5 | import tqdm 6 | import argparse 7 | from omegaconf import OmegaConf 8 | import trimesh 9 | import time 10 | from torch.utils.data import DataLoader 11 | 12 | from structformer.data.tokenizer import Tokenizer 13 | import structformer.data.binary_dataset as prior_dataset 14 | import structformer.training.train_binary_model as prior_model 15 | from structformer.utils.rearrangement import move_one_object_pc, make_gifs, \ 16 | modify_language, sample_gaussians, fit_gaussians, get_initial_scene_idxs, show_pcs, save_pcs 17 | 18 | 19 | def test_model(model_dir, dirs_cfg): 20 | prior_inference = PriorInference(model_dir, dirs_cfg, data_split="test") 21 | prior_inference.validate() 22 | 23 | 24 | class PriorInference: 25 | 26 | def __init__(self, model_dir, dirs_cfg, data_split="test"): 27 | # load prior 28 | cfg, tokenizer, model, optimizer, scheduler, epoch = prior_model.load_model(model_dir, dirs_cfg) 29 | 30 | data_cfg = cfg.dataset 31 | 32 | dataset = prior_dataset.BinaryDataset(data_cfg.dirs, data_cfg.index_dirs, data_split, tokenizer, 33 | data_cfg.max_num_objects, 34 | data_cfg.max_num_other_objects, 35 | data_cfg.max_num_shape_parameters, 36 | data_cfg.max_num_rearrange_features, 37 | data_cfg.max_num_anchor_features, 38 | data_cfg.num_pts) 39 | self.cfg = cfg 40 | self.tokenizer = tokenizer 41 | self.model = model 42 | self.cfg = cfg 43 | self.dataset = dataset 44 | self.epoch = epoch 45 | 46 | def validate(self): 47 | data_cfg = self.cfg.dataset 48 | data_iter = DataLoader(self.dataset, batch_size=data_cfg.batch_size, shuffle=False, 49 | collate_fn=prior_dataset.BinaryDataset.collate_fn, 50 | pin_memory=data_cfg.pin_memory, num_workers=data_cfg.num_workers) 51 | 52 | prior_model.validate(self.cfg, self.model, data_iter, self.epoch, self.cfg.device) 53 | 54 | def limited_batch_inference(self, data, t, verbose=False): 55 | """ 56 | This function makes the assumption that scenes in the batch have the same number of objects that need to be 57 | rearranged 58 | 59 | :param data: 60 | :param model: 61 | :param test_dataset: 62 | :param tokenizer: 63 | :param cfg: 64 | :param num_samples: 65 | :param verbose: 66 | :return: 67 | """ 68 | 69 | data_size = len(data) 70 | batch_size = self.cfg.dataset.batch_size 71 | if verbose: 72 | print("data size:", data_size) 73 | print("batch size:", batch_size) 74 | 75 | num_batches = int(data_size / batch_size) 76 | if data_size % batch_size != 0: 77 | num_batches += 1 78 | 79 | all_obj_preds = [] 80 | all_binary_data = [] 81 | for b in range(num_batches): 82 | if b + 1 == num_batches: 83 | # last batch 84 | batch = data[b * batch_size:] 85 | else: 86 | batch = data[b * batch_size: (b+1) * batch_size] 87 | 88 | binary_data = [self.dataset.convert_sequence_to_binary(d, t) for d in batch] 89 | data_tensors = [self.dataset.convert_to_tensors(d, self.tokenizer) for d in binary_data] 90 | data_tensors = self.dataset.collate_fn(data_tensors) 91 | predictions = prior_model.infer_once(self.cfg, self.model, data_tensors, self.cfg.device) 92 | 93 | obj_x_preds = torch.cat(predictions["obj_x_outputs"], dim=0) 94 | obj_y_preds = torch.cat(predictions["obj_y_outputs"], dim=0) 95 | obj_z_preds = torch.cat(predictions["obj_z_outputs"], dim=0) 96 | obj_theta_preds = torch.cat(predictions["obj_theta_outputs"], dim=0) 97 | obj_preds = torch.cat([obj_x_preds, obj_y_preds, obj_z_preds, obj_theta_preds], dim=1) # batch_size * max num objects, output_dim 98 | 99 | all_obj_preds.append(obj_preds) 100 | all_binary_data.extend(binary_data) 101 | 102 | obj_preds = torch.cat(all_obj_preds, dim=0) # data_size * max num objects, output_dim 103 | obj_preds = obj_preds.detach().cpu().numpy() 104 | obj_preds = obj_preds.reshape(data_size, obj_preds.shape[-1]) # batch_size, max num objects, output_dim 105 | 106 | return obj_preds, all_binary_data 107 | 108 | 109 | def inference_beam_decoding(model_dir, dirs_cfg, beam_size=100, max_scene_decodes=30000, 110 | visualize=True, visualize_action_sequence=False, 111 | inference_visualization_dir=None): 112 | """ 113 | This function decodes a scene with a single forward pass 114 | 115 | :param model_dir: 116 | :param discriminator_model_dir: 117 | :param inference_visualization_dir: 118 | :param visualize: 119 | :param num_samples: number of MDN samples drawn, in this case it's also the number of rearrangements 120 | :param keep_steps: 121 | :param initial_scenes_only: 122 | :param verbose: 123 | :return: 124 | """ 125 | 126 | if inference_visualization_dir and not os.path.exists(inference_visualization_dir): 127 | os.makedirs(inference_visualization_dir) 128 | 129 | prior_inference = PriorInference(model_dir, dirs_cfg) 130 | test_dataset = prior_inference.dataset 131 | 132 | initial_scene_idxs = get_initial_scene_idxs(test_dataset) 133 | 134 | decoded_scene_count = 0 135 | with tqdm.tqdm(total=len(initial_scene_idxs)) as pbar: 136 | # for idx in np.random.choice(range(len(test_dataset)), len(test_dataset), replace=False): 137 | for idx in initial_scene_idxs: 138 | 139 | if decoded_scene_count == max_scene_decodes: 140 | break 141 | 142 | filename, step_t = test_dataset.get_data_index(idx) 143 | scene_id = os.path.split(filename)[1][4:-3] 144 | 145 | decoded_scene_count += 1 146 | 147 | ############################################ 148 | # retrieve data 149 | beam_data = [] 150 | num_target_objects = None 151 | for b in range(beam_size): 152 | datum = test_dataset.get_raw_sequence_data(idx) 153 | beam_data.append(datum) 154 | 155 | if num_target_objects is None: 156 | num_target_objects = len(datum["xyzs"]) 157 | 158 | # We can play with different language here 159 | # datum["sentence"] = modify_language(datum["sentence"], radius=0.5) 160 | 161 | if visualize: 162 | datum = beam_data[0] 163 | print("#"*50) 164 | print("sentence", datum["sentence"]) 165 | print("num target objects", num_target_objects) 166 | show_pcs(datum["xyzs"] + datum["other_bg_xyzs"], 167 | datum["rgbs"] + datum["other_bg_rgbs"], 168 | add_coordinate_frame=False, side_view=True, add_table=True) 169 | 170 | ############################################ 171 | 172 | beam_predicted_parameters = [[]] * beam_size 173 | for time_index in range(num_target_objects): 174 | 175 | # iteratively decoding 176 | target_object_preds, binary_data = prior_inference.limited_batch_inference(beam_data, time_index) 177 | 178 | for b in range(beam_size): 179 | # a list of list, where each inside list contains xyz, 3x3 rotation 180 | 181 | datum = beam_data[b] 182 | binary_datum = binary_data[b] 183 | obj_pred = target_object_preds[b] 184 | 185 | #------------ 186 | goal_query_pc_translation_offset = obj_pred[:3] 187 | goal_query_pc_rotation = np.eye(4) 188 | goal_query_pc_rotation[:3, :3] = np.array(obj_pred[3:]).reshape(3, 3) 189 | 190 | query_obj_xyz = binary_datum["query_xyz"] 191 | anchor_obj_xyz = binary_datum["anchor_xyz"] 192 | 193 | 194 | current_query_pc_center = torch.mean(query_obj_xyz, dim=0).numpy()[:3] 195 | current_anchor_pc_center = torch.mean(anchor_obj_xyz, dim=0).numpy()[:3] 196 | 197 | t = np.eye(4) 198 | t[:3, 3] = current_anchor_pc_center + goal_query_pc_translation_offset - current_query_pc_center 199 | new_query_obj_xyz = trimesh.transform_points(query_obj_xyz, t) 200 | 201 | # rotating in place 202 | # R = tra.euler_matrix(0, 0, obj_pc_rotations[i]) 203 | query_obj_center = np.mean(new_query_obj_xyz, axis=0) 204 | centered_query_obj_xyz = new_query_obj_xyz - query_obj_center 205 | new_centered_query_obj_xyz = trimesh.transform_points(centered_query_obj_xyz, goal_query_pc_rotation, 206 | translate=True) 207 | new_query_obj_xyz = new_centered_query_obj_xyz + query_obj_center 208 | new_query_obj_xyz = torch.tensor(new_query_obj_xyz, dtype=query_obj_xyz.dtype) 209 | 210 | # vis_query_obj_rgb = np.tile(np.array([0, 1, 0], dtype=np.float), (query_obj_xyz.shape[0], 1)) 211 | # vis_anchor_obj_rgb = np.tile(np.array([1, 0, 0], dtype=np.float), (anchor_obj_xyz.shape[0], 1)) 212 | # vis_new_query_obj_rgb = np.tile(np.array([0, 0, 1], dtype=np.float), (new_query_obj_xyz.shape[0], 1)) 213 | # show_pcs([new_query_obj_xyz, query_obj_xyz, anchor_obj_xyz], 214 | # [vis_new_query_obj_rgb, vis_query_obj_rgb, vis_anchor_obj_rgb], 215 | # add_coordinate_frame=True) 216 | 217 | datum["xyzs"][time_index] = new_query_obj_xyz 218 | 219 | current_object_param = t[:3, 3].tolist() + goal_query_pc_rotation[:3, :3].flatten().tolist() 220 | beam_predicted_parameters[b].append(current_object_param) 221 | 222 | for b in range(beam_size): 223 | datum = beam_data[b] 224 | 225 | pc_sizes = [xyz.shape[0] for xyz in datum["other_bg_xyzs"]] 226 | table_idx = np.argmax(pc_sizes) 227 | show_pcs(datum["xyzs"] + [xyz for i, xyz in enumerate(datum["other_bg_xyzs"]) if i != table_idx], 228 | datum["rgbs"] + [rgb for i, rgb in enumerate(datum["other_bg_rgbs"]) if i != table_idx], 229 | add_coordinate_frame=False, side_view=True, add_table=True) 230 | 231 | pbar.update(1) 232 | 233 | 234 | if __name__ == "__main__": 235 | parser = argparse.ArgumentParser(description="Run a simple model") 236 | parser.add_argument("--dataset_base_dir", help='location of the dataset', type=str) 237 | parser.add_argument("--model_dir", help='location for the saved model', type=str) 238 | parser.add_argument("--dirs_config", help='config yaml file for directories', default="", type=str) 239 | args = parser.parse_args() 240 | 241 | os.environ["DATETIME"] = time.strftime("%Y%m%d-%H%M%S") 242 | 243 | # # debug only 244 | # args.dataset_base_dir = "/home/weiyu/data_drive/data_new_objects_test_split" 245 | # args.model_dir = "/home/weiyu/Research/intern/StructFormer/models/binary_model_tower/best_model" 246 | # args.dirs_config = "/home/weiyu/Research/intern/StructFormer/structformer/configs/data/tower_dirs.yaml" 247 | 248 | if args.dirs_config: 249 | assert os.path.exists(args.dirs_config), "Cannot find config yaml file at {}".format(args.dirs_config) 250 | dirs_cfg = OmegaConf.load(args.dirs_config) 251 | dirs_cfg.dataset_base_dir = args.dataset_base_dir 252 | OmegaConf.resolve(dirs_cfg) 253 | else: 254 | dirs_cfg = None 255 | 256 | inference_beam_decoding(args.model_dir, dirs_cfg, beam_size=3, max_scene_decodes=30000, 257 | visualize=True, visualize_action_sequence=False, 258 | inference_visualization_dir=None) 259 | 260 | -------------------------------------------------------------------------------- /src/structformer/evaluation/test_object_selection_network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import os 4 | import tqdm 5 | import time 6 | import argparse 7 | from omegaconf import OmegaConf 8 | 9 | from torch.utils.data import DataLoader 10 | from structformer.data.object_set_refer_dataset import ObjectSetReferDataset 11 | from structformer.training.train_object_selection_network import load_model, validate, infer_once 12 | from structformer.utils.rearrangement import show_pcs_with_predictions, get_initial_scene_idxs, evaluate_target_object_predictions, save_img, show_pcs_with_labels, test_new_vis 13 | 14 | 15 | class ObjectSelectionInference: 16 | 17 | def __init__(self, model_dir, dirs_cfg, data_split="test"): 18 | # load prior 19 | cfg, tokenizer, model, optimizer, scheduler, epoch = load_model(model_dir, dirs_cfg) 20 | 21 | data_cfg = cfg.dataset 22 | test_dataset = ObjectSetReferDataset(data_cfg.dirs, data_cfg.index_dirs, data_split, tokenizer, 23 | data_cfg.max_num_all_objects, 24 | data_cfg.max_num_shape_parameters, 25 | data_cfg.max_num_rearrange_features, 26 | data_cfg.max_num_anchor_features, 27 | data_cfg.num_pts) 28 | 29 | self.cfg = cfg 30 | self.tokenizer = tokenizer 31 | self.model = model 32 | self.cfg = cfg 33 | self.dataset = test_dataset 34 | self.epoch = epoch 35 | 36 | def prepare_datum(self, obj_xyzs, obj_rgbs, goal_specification, structure_parameters, gt_num_rearrange_objects): 37 | datum = self.dataset.prepare_test_data(obj_xyzs, obj_rgbs, goal_specification, structure_parameters, gt_num_rearrange_objects) 38 | return datum 39 | 40 | def predict_target_objects(self, datum): 41 | 42 | batch = self.dataset.collate_fn([self.dataset.convert_to_tensors(datum, self.tokenizer)]) 43 | 44 | gts, predictions = infer_once(self.model, batch, self.cfg.device) 45 | gts = gts["rearrange_obj_labels"][0].detach().cpu().numpy() 46 | predictions = predictions["rearrange_obj_labels"][0].detach().cpu().numpy() 47 | predictions = predictions > 0.5 48 | # remove paddings 49 | target_mask = gts != -100 50 | gts = gts[target_mask] 51 | predictions = predictions[target_mask] 52 | 53 | return predictions, gts 54 | 55 | def validate(self): 56 | """ 57 | validate the pretrained model on the dataset 58 | 59 | :return: 60 | """ 61 | data_cfg = self.cfg.dataset 62 | data_iter = DataLoader(self.dataset, batch_size=data_cfg.batch_size, shuffle=False, 63 | collate_fn=ObjectSetReferDataset.collate_fn, 64 | pin_memory=data_cfg.pin_memory, num_workers=data_cfg.num_workers) 65 | 66 | validate(self.model, data_iter, self.epoch, self.cfg.device) 67 | 68 | 69 | def inference(model_dir, dirs_cfg, visualize=False, inference_visualization_dir=None): 70 | 71 | # object selection information 72 | # goal = {"rearrange": {"features": [], "objects": [], "combine_features_logic": None, "count": None}, 73 | # "anchor": {"features": [], "objects": [], "combine_features_logic": None}, 74 | # "distract": {"features": [], "objects": [], "combine_features_logic": None}, 75 | # "random_selection": {"varying_features": [], "nonvarying_features": []}, 76 | # "order": {"feature": None} 77 | # } 78 | 79 | if inference_visualization_dir: 80 | if not os.path.exists(inference_visualization_dir): 81 | os.makedirs(inference_visualization_dir) 82 | 83 | object_selection_inference = ObjectSelectionInference(model_dir, dirs_cfg) 84 | test_dataset = object_selection_inference.dataset 85 | 86 | initial_scene_idxs = get_initial_scene_idxs(test_dataset) 87 | 88 | all_predictions = [] 89 | all_gts = [] 90 | all_goal_specifications = [] 91 | all_sentences = [] 92 | 93 | count = 0 94 | for idx in tqdm.tqdm(range(len(test_dataset))): 95 | 96 | if idx not in initial_scene_idxs: 97 | continue 98 | 99 | count += 1 100 | 101 | datum = test_dataset.get_raw_data(idx) 102 | goal_specification = datum["goal_specification"] 103 | object_selection_sentence = datum["sentence"][5:] 104 | reference_sentence = object_selection_inference.tokenizer.convert_to_natural_sentence( 105 | object_selection_sentence) 106 | 107 | predictions, gts = object_selection_inference.predict_target_objects(datum) 108 | 109 | if visualize: 110 | print(gts) 111 | print(predictions) 112 | print(object_selection_sentence) 113 | print(reference_sentence) 114 | print("Visualize groundtruth (dot color) and prediction (ring color)") 115 | show_pcs_with_predictions(datum["xyzs"][:len(gts)], datum["rgbs"][:len(gts)], 116 | gts, predictions, add_coordinate_frame=False) 117 | 118 | if inference_visualization_dir: 119 | buffer = show_pcs_with_predictions(datum["xyzs"][:len(gts)], datum["rgbs"][:len(gts)], 120 | gts, predictions, add_coordinate_frame=False, return_buffer=True) 121 | img = np.uint8(np.asarray(buffer) * 255) 122 | save_img(img, os.path.join(inference_visualization_dir, "scene_{}.png".format(idx)), text=reference_sentence) 123 | 124 | # try: 125 | # batch = test_dataset.collate_fn([test_dataset.convert_to_tensors(datum, tokenizer)]) 126 | # except KeyError: 127 | # print("skipping this for now because we are using an outdated model with an old vocab") 128 | # continue 129 | # 130 | # goal_specification = datum["goal_specification"] 131 | # 132 | # gts, predictions = infer_once(model, batch, cfg.device) 133 | # gts = gts["rearrange_obj_labels"][0].detach().cpu().numpy() 134 | # predictions = predictions["rearrange_obj_labels"][0].detach().cpu().numpy() 135 | # predictions = predictions > 0.5 136 | # # remove paddings 137 | # target_mask = gts != -100 138 | # gts = gts[target_mask] 139 | # predictions = predictions[target_mask] 140 | # 141 | # object_selection_sentence = datum["sentence"][5:] 142 | # 143 | # if visualize: 144 | # print(gts) 145 | # print(predictions) 146 | # print(object_selection_sentence) 147 | # print(tokenizer.convert_to_natural_sentence(object_selection_sentence)) 148 | # 149 | # if inference_visualization_dir is None: 150 | # show_pcs_with_predictions(datum["xyzs"][:len(gts)], datum["rgbs"][:len(gts)], 151 | # gts, predictions, add_coordinate_frame=False) 152 | # # show_pcs_with_only_predictions(datum["xyzs"][:len(gts)], datum["rgbs"][:len(gts)], 153 | # # gts, predictions, add_coordinate_frame=False) 154 | # # test_new_vis(datum["xyzs"][:len(gts)], datum["rgbs"][:len(gts)]) 155 | # else: 156 | # save_filename = os.path.join(inference_visualization_dir, "{}.png".format(idx)) 157 | # buffer = show_pcs_with_predictions(datum["xyzs"][:len(gts)], datum["rgbs"][:len(gts)], 158 | # gts, predictions, add_coordinate_frame=False, return_buffer=True) 159 | # img = np.uint8(np.asarray(buffer) * 255) 160 | # save_img(img, save_filename, text=tokenizer.convert_to_natural_sentence(datum["sentence"])) 161 | 162 | all_predictions.append(predictions) 163 | all_gts.append(gts) 164 | all_goal_specifications.append(goal_specification) 165 | all_sentences.append(object_selection_sentence) 166 | 167 | # create a more detailed report 168 | evaluate_target_object_predictions(all_gts, all_predictions, all_sentences, initial_scene_idxs, 169 | object_selection_inference.tokenizer) 170 | 171 | 172 | if __name__ == "__main__": 173 | parser = argparse.ArgumentParser(description="Run a simple model") 174 | parser.add_argument("--dataset_base_dir", help='location of the dataset', type=str) 175 | parser.add_argument("--model_dir", help='location for the saved model', type=str) 176 | parser.add_argument("--dirs_config", help='config yaml file for directories', default="", type=str) 177 | parser.add_argument("--inference_visualization_dir", help='location for saving visualizations of inference results', 178 | type=str, default=None) 179 | parser.add_argument("--visualize", default=1, type=int, help='whether to visualize inference results while running') 180 | args = parser.parse_args() 181 | 182 | os.environ["DATETIME"] = time.strftime("%Y%m%d-%H%M%S") 183 | 184 | # # debug only 185 | # args.dataset_base_dir = "/home/weiyu/data_drive/data_new_objects_test_split" 186 | # args.model_dir = "/home/weiyu/Research/intern/StructFormer/models/object_selection_network/best_model" 187 | # args.dirs_config = "/home/weiyu/Research/intern/StructFormer/structformer/configs/data/line_dirs.yaml" 188 | # args.visualize = True 189 | 190 | if args.dirs_config: 191 | assert os.path.exists(args.dirs_config), "Cannot find config yaml file at {}".format(args.dirs_config) 192 | dirs_cfg = OmegaConf.load(args.dirs_config) 193 | dirs_cfg.dataset_base_dir = args.dataset_base_dir 194 | OmegaConf.resolve(dirs_cfg) 195 | else: 196 | dirs_cfg = None 197 | 198 | inference(args.model_dir, dirs_cfg, args.visualize, args.inference_visualization_dir) -------------------------------------------------------------------------------- /src/structformer/evaluation/test_structformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import os 4 | import copy 5 | import tqdm 6 | import argparse 7 | from omegaconf import OmegaConf 8 | import time 9 | 10 | from torch.utils.data import DataLoader 11 | 12 | from structformer.data.tokenizer import Tokenizer 13 | import structformer.data.sequence_dataset as prior_dataset 14 | import structformer.training.train_structformer as prior_model 15 | from structformer.utils.rearrangement import show_pcs 16 | from structformer.evaluation.inference import PointCloudRearrangement 17 | 18 | 19 | def test_model(model_dir, dirs_cfg): 20 | prior_inference = PriorInference(model_dir, dirs_cfg, data_split="test") 21 | prior_inference.validate() 22 | 23 | 24 | class PriorInference: 25 | 26 | def __init__(self, model_dir, dirs_cfg, data_split="test"): 27 | 28 | cfg, tokenizer, model, optimizer, scheduler, epoch = prior_model.load_model(model_dir, dirs_cfg) 29 | 30 | data_cfg = cfg.dataset 31 | 32 | dataset = prior_dataset.SequenceDataset(data_cfg.dirs, data_cfg.index_dirs, data_split, tokenizer, 33 | data_cfg.max_num_objects, 34 | data_cfg.max_num_other_objects, 35 | data_cfg.max_num_shape_parameters, 36 | data_cfg.max_num_rearrange_features, 37 | data_cfg.max_num_anchor_features, 38 | data_cfg.num_pts, 39 | data_cfg.use_structure_frame) 40 | 41 | self.cfg = cfg 42 | self.tokenizer = tokenizer 43 | self.model = model 44 | self.cfg = cfg 45 | self.dataset = dataset 46 | self.epoch = epoch 47 | 48 | def validate(self): 49 | """ 50 | validate the pretrained model on the dataset 51 | 52 | :return: 53 | """ 54 | data_cfg = self.cfg.dataset 55 | data_iter = DataLoader(self.dataset, batch_size=data_cfg.batch_size, shuffle=False, 56 | collate_fn=prior_dataset.SequenceDataset.collate_fn, 57 | pin_memory=data_cfg.pin_memory, num_workers=data_cfg.num_workers) 58 | 59 | prior_model.validate(self.cfg, self.model, data_iter, self.epoch, self.cfg.device) 60 | 61 | def limited_batch_inference(self, data, verbose=False): 62 | """ 63 | This function makes the assumption that scenes in the batch have the same number of objects that need to be 64 | rearranged 65 | 66 | :param data: 67 | :param model: 68 | :param test_dataset: 69 | :param tokenizer: 70 | :param cfg: 71 | :param num_samples: 72 | :param verbose: 73 | :return: 74 | """ 75 | 76 | data_size = len(data) 77 | batch_size = self.cfg.dataset.batch_size 78 | if verbose: 79 | print("data size:", data_size) 80 | print("batch size:", batch_size) 81 | 82 | num_batches = int(data_size / batch_size) 83 | if data_size % batch_size != 0: 84 | num_batches += 1 85 | 86 | all_obj_preds = [] 87 | all_struct_preds = [] 88 | for b in range(num_batches): 89 | if b + 1 == num_batches: 90 | # last batch 91 | batch = data[b * batch_size:] 92 | else: 93 | batch = data[b * batch_size: (b+1) * batch_size] 94 | data_tensors = [self.dataset.convert_to_tensors(d, self.tokenizer) for d in batch] 95 | data_tensors = self.dataset.collate_fn(data_tensors) 96 | predictions = prior_model.infer_once(self.cfg, self.model, data_tensors, self.cfg.device) 97 | 98 | obj_x_preds = torch.cat(predictions["obj_x_outputs"], dim=0) 99 | obj_y_preds = torch.cat(predictions["obj_y_outputs"], dim=0) 100 | obj_z_preds = torch.cat(predictions["obj_z_outputs"], dim=0) 101 | obj_theta_preds = torch.cat(predictions["obj_theta_outputs"], dim=0) 102 | obj_preds = torch.cat([obj_x_preds, obj_y_preds, obj_z_preds, obj_theta_preds], dim=1) # batch_size * max num objects, output_dim 103 | 104 | struct_x_preds = torch.cat(predictions["struct_x_inputs"], dim=0) 105 | struct_y_preds = torch.cat(predictions["struct_y_inputs"], dim=0) 106 | struct_z_preds = torch.cat(predictions["struct_z_inputs"], dim=0) 107 | struct_theta_preds = torch.cat(predictions["struct_theta_inputs"], dim=0) 108 | struct_preds = torch.cat([struct_x_preds, struct_y_preds, struct_z_preds, struct_theta_preds], dim=1) # batch_size, output_dim 109 | 110 | all_obj_preds.append(obj_preds) 111 | all_struct_preds.append(struct_preds) 112 | 113 | obj_preds = torch.cat(all_obj_preds, dim=0) # data_size * max num objects, output_dim 114 | struct_preds = torch.cat(all_struct_preds, dim=0) # data_size, output_dim 115 | 116 | obj_preds = obj_preds.detach().cpu().numpy() 117 | struct_preds = struct_preds.detach().cpu().numpy() 118 | 119 | obj_preds = obj_preds.reshape(data_size, -1, obj_preds.shape[-1]) # batch_size, max num objects, output_dim 120 | 121 | return struct_preds, obj_preds 122 | 123 | 124 | def inference_beam_decoding(model_dir, dirs_cfg, beam_size=100, max_scene_decodes=30000, 125 | visualize=True, visualize_action_sequence=False, 126 | inference_visualization_dir=None): 127 | """ 128 | 129 | :param model_dir: 130 | :param beam_size: 131 | :param max_scene_decodes: 132 | :param visualize: 133 | :param visualize_action_sequence: 134 | :param inference_visualization_dir: 135 | :param side_view: 136 | :return: 137 | """ 138 | 139 | if inference_visualization_dir and not os.path.exists(inference_visualization_dir): 140 | os.makedirs(inference_visualization_dir) 141 | 142 | prior_inference = PriorInference(model_dir, dirs_cfg) 143 | test_dataset = prior_inference.dataset 144 | 145 | decoded_scene_count = 0 146 | with tqdm.tqdm(total=len(test_dataset)) as pbar: 147 | # for idx in np.random.choice(range(len(test_dataset)), len(test_dataset), replace=False): 148 | for idx in range(len(test_dataset)): 149 | 150 | if decoded_scene_count == max_scene_decodes: 151 | break 152 | 153 | filename = test_dataset.get_data_index(idx) 154 | scene_id = os.path.split(filename)[1][4:-3] 155 | 156 | decoded_scene_count += 1 157 | 158 | ############################################ 159 | # retrieve data 160 | beam_data = [] 161 | beam_pc_rearrangements = [] 162 | for b in range(beam_size): 163 | datum = test_dataset.get_raw_data(idx, inference_mode=True, shuffle_object_index=False) 164 | 165 | # not necessary, but just to ensure no test leakage 166 | datum["struct_x_inputs"] = [0] 167 | datum["struct_y_inputs"] = [0] 168 | datum["struct_y_inputs"] = [0] 169 | datum["struct_theta_inputs"] = [[0] * 9] 170 | for obj_idx in range(len(datum["obj_x_inputs"])): 171 | datum["obj_x_inputs"][obj_idx] = 0 172 | datum["obj_y_inputs"][obj_idx] = 0 173 | datum["obj_z_inputs"][obj_idx] = 0 174 | datum["obj_theta_inputs"][obj_idx] = [0] * 9 175 | 176 | # We can play with different language here 177 | # datum["sentence"] = modify_language(datum["sentence"], radius=0.5) 178 | # datum["sentence"] = modify_language(datum["sentence"], position_x=1) 179 | # datum["sentence"] = modify_language(datum["sentence"], position_y=0.5) 180 | 181 | beam_data.append(datum) 182 | beam_pc_rearrangements.append(PointCloudRearrangement(datum)) 183 | 184 | if visualize: 185 | datum = beam_data[0] 186 | print("#"*50) 187 | print("sentence", datum["sentence"]) 188 | show_pcs(datum["xyzs"] + datum["other_xyzs"], datum["rgbs"] + datum["other_rgbs"], 189 | add_coordinate_frame=False, side_view=True, add_table=True) 190 | 191 | ############################################ 192 | # autoregressive decoding 193 | num_target_objects = beam_pc_rearrangements[0].num_target_objects 194 | # first predict structure pose 195 | beam_goal_struct_pose, target_object_preds = prior_inference.limited_batch_inference(beam_data) 196 | for b in range(beam_size): 197 | datum = beam_data[b] 198 | datum["struct_x_inputs"] = [beam_goal_struct_pose[b][0]] 199 | datum["struct_y_inputs"] = [beam_goal_struct_pose[b][1]] 200 | datum["struct_z_inputs"] = [beam_goal_struct_pose[b][2]] 201 | datum["struct_theta_inputs"] = [beam_goal_struct_pose[b][3:]] 202 | 203 | # then iteratively predict pose of each object 204 | beam_goal_obj_poses = [] 205 | for obj_idx in range(num_target_objects): 206 | struct_preds, target_object_preds = prior_inference.limited_batch_inference(beam_data) 207 | beam_goal_obj_poses.append(target_object_preds[:, obj_idx]) 208 | for b in range(beam_size): 209 | datum = beam_data[b] 210 | datum["obj_x_inputs"][obj_idx] = target_object_preds[b][obj_idx][0] 211 | datum["obj_y_inputs"][obj_idx] = target_object_preds[b][obj_idx][1] 212 | datum["obj_z_inputs"][obj_idx] = target_object_preds[b][obj_idx][2] 213 | datum["obj_theta_inputs"][obj_idx] = target_object_preds[b][obj_idx][3:] 214 | # concat in the object dim 215 | beam_goal_obj_poses = np.stack(beam_goal_obj_poses, axis=0) 216 | # swap axis 217 | beam_goal_obj_poses = np.swapaxes(beam_goal_obj_poses, 1, 0) # batch size, number of target objects, pose dim 218 | 219 | ############################################ 220 | # move pc 221 | for bi in range(beam_size): 222 | beam_pc_rearrangements[bi].set_goal_poses(beam_goal_struct_pose[bi], beam_goal_obj_poses[bi]) 223 | beam_pc_rearrangements[bi].rearrange() 224 | 225 | ############################################ 226 | if visualize: 227 | for pc_rearrangement in beam_pc_rearrangements: 228 | pc_rearrangement.visualize("goal", add_other_objects=True, 229 | add_coordinate_frame=False, side_view=True, add_table=True) 230 | 231 | if inference_visualization_dir: 232 | for pc_rearrangement in beam_pc_rearrangements: 233 | pc_rearrangement.visualize("goal", add_other_objects=True, 234 | add_coordinate_frame=False, side_view=True, add_table=True, 235 | save_vis=True, 236 | save_filename=os.path.join(inference_visualization_dir, "{}.jpg".format(scene_id))) 237 | 238 | pbar.update(1) 239 | 240 | 241 | if __name__ == "__main__": 242 | parser = argparse.ArgumentParser(description="Run a simple model") 243 | parser.add_argument("--dataset_base_dir", help='location of the dataset', type=str) 244 | parser.add_argument("--model_dir", help='location for the saved model', type=str) 245 | parser.add_argument("--dirs_config", help='config yaml file for directories', default="", type=str) 246 | args = parser.parse_args() 247 | 248 | os.environ["DATETIME"] = time.strftime("%Y%m%d-%H%M%S") 249 | 250 | # # debug only 251 | # args.dataset_base_dir = "/home/weiyu/data_drive/data_new_objects_test_split" 252 | # args.model_dir = "/home/weiyu/Research/intern/StructFormer/models/structformer_line/best_model" 253 | # args.dirs_config = "/home/weiyu/Research/intern/StructFormer/structformer/configs/data/line_dirs.yaml" 254 | 255 | if args.dirs_config: 256 | assert os.path.exists(args.dirs_config), "Cannot find config yaml file at {}".format(args.dirs_config) 257 | dirs_cfg = OmegaConf.load(args.dirs_config) 258 | dirs_cfg.dataset_base_dir = args.dataset_base_dir 259 | OmegaConf.resolve(dirs_cfg) 260 | else: 261 | dirs_cfg = None 262 | 263 | inference_beam_decoding(args.model_dir, dirs_cfg, beam_size=3, max_scene_decodes=30000, 264 | visualize=True, visualize_action_sequence=False, 265 | inference_visualization_dir=None) 266 | -------------------------------------------------------------------------------- /src/structformer/evaluation/test_structformer_no_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import os 4 | import copy 5 | import tqdm 6 | import argparse 7 | from omegaconf import OmegaConf 8 | import time 9 | from torch.utils.data import DataLoader 10 | 11 | from structformer.data.tokenizer import Tokenizer 12 | import structformer.data.sequence_dataset as prior_dataset 13 | import structformer.training.train_structformer_no_encoder as prior_model 14 | from structformer.utils.rearrangement import show_pcs 15 | from structformer.evaluation.inference import PointCloudRearrangement 16 | 17 | 18 | def test_model(model_dir, dirs_cfg): 19 | prior_inference = PriorInference(model_dir, dirs_cfg, data_split="test") 20 | prior_inference.validate() 21 | 22 | 23 | class PriorInference: 24 | 25 | def __init__(self, model_dir, dirs_cfg, data_split="test"): 26 | 27 | cfg, tokenizer, model, optimizer, scheduler, epoch = prior_model.load_model(model_dir, dirs_cfg) 28 | 29 | data_cfg = cfg.dataset 30 | 31 | dataset = prior_dataset.SequenceDataset(data_cfg.dirs, data_cfg.index_dirs, data_split, tokenizer, 32 | data_cfg.max_num_objects, 33 | data_cfg.max_num_other_objects, 34 | data_cfg.max_num_shape_parameters, 35 | data_cfg.max_num_rearrange_features, 36 | data_cfg.max_num_anchor_features, 37 | data_cfg.num_pts, 38 | data_cfg.use_structure_frame) 39 | 40 | self.cfg = cfg 41 | self.tokenizer = tokenizer 42 | self.model = model 43 | self.cfg = cfg 44 | self.dataset = dataset 45 | self.epoch = epoch 46 | 47 | def validate(self): 48 | """ 49 | validate the pretrained model on the dataset 50 | 51 | :return: 52 | """ 53 | data_cfg = self.cfg.dataset 54 | data_iter = DataLoader(self.dataset, batch_size=data_cfg.batch_size, shuffle=False, 55 | collate_fn=prior_dataset.SequenceDataset.collate_fn, 56 | pin_memory=data_cfg.pin_memory, num_workers=data_cfg.num_workers) 57 | 58 | prior_model.validate(self.cfg, self.model, data_iter, self.epoch, self.cfg.device) 59 | 60 | def limited_batch_inference(self, data, verbose=False): 61 | """ 62 | This function makes the assumption that scenes in the batch have the same number of objects that need to be 63 | rearranged 64 | 65 | :param data: 66 | :param model: 67 | :param test_dataset: 68 | :param tokenizer: 69 | :param cfg: 70 | :param num_samples: 71 | :param verbose: 72 | :return: 73 | """ 74 | 75 | data_size = len(data) 76 | batch_size = self.cfg.dataset.batch_size 77 | if verbose: 78 | print("data size:", data_size) 79 | print("batch size:", batch_size) 80 | 81 | num_batches = int(data_size / batch_size) 82 | if data_size % batch_size != 0: 83 | num_batches += 1 84 | 85 | all_obj_preds = [] 86 | all_struct_preds = [] 87 | for b in range(num_batches): 88 | if b + 1 == num_batches: 89 | # last batch 90 | batch = data[b * batch_size:] 91 | else: 92 | batch = data[b * batch_size: (b+1) * batch_size] 93 | data_tensors = [self.dataset.convert_to_tensors(d, self.tokenizer) for d in batch] 94 | data_tensors = self.dataset.collate_fn(data_tensors) 95 | predictions = prior_model.infer_once(self.cfg, self.model, data_tensors, self.cfg.device) 96 | 97 | obj_x_preds = torch.cat(predictions["obj_x_outputs"], dim=0) 98 | obj_y_preds = torch.cat(predictions["obj_y_outputs"], dim=0) 99 | obj_z_preds = torch.cat(predictions["obj_z_outputs"], dim=0) 100 | obj_theta_preds = torch.cat(predictions["obj_theta_outputs"], dim=0) 101 | obj_preds = torch.cat([obj_x_preds, obj_y_preds, obj_z_preds, obj_theta_preds], dim=1) # batch_size * max num objects, output_dim 102 | 103 | struct_x_preds = torch.cat(predictions["struct_x_inputs"], dim=0) 104 | struct_y_preds = torch.cat(predictions["struct_y_inputs"], dim=0) 105 | struct_z_preds = torch.cat(predictions["struct_z_inputs"], dim=0) 106 | struct_theta_preds = torch.cat(predictions["struct_theta_inputs"], dim=0) 107 | struct_preds = torch.cat([struct_x_preds, struct_y_preds, struct_z_preds, struct_theta_preds], dim=1) # batch_size, output_dim 108 | 109 | all_obj_preds.append(obj_preds) 110 | all_struct_preds.append(struct_preds) 111 | 112 | obj_preds = torch.cat(all_obj_preds, dim=0) # data_size * max num objects, output_dim 113 | struct_preds = torch.cat(all_struct_preds, dim=0) # data_size, output_dim 114 | 115 | obj_preds = obj_preds.detach().cpu().numpy() 116 | struct_preds = struct_preds.detach().cpu().numpy() 117 | 118 | obj_preds = obj_preds.reshape(data_size, -1, obj_preds.shape[-1]) # batch_size, max num objects, output_dim 119 | 120 | return struct_preds, obj_preds 121 | 122 | 123 | def inference_beam_decoding(model_dir, dirs_cfg, beam_size=100, max_scene_decodes=30000, 124 | visualize=True, visualize_action_sequence=False, 125 | inference_visualization_dir=None): 126 | """ 127 | 128 | :param model_dir: 129 | :param beam_size: 130 | :param max_scene_decodes: 131 | :param visualize: 132 | :param visualize_action_sequence: 133 | :param inference_visualization_dir: 134 | :param side_view: 135 | :return: 136 | """ 137 | 138 | if inference_visualization_dir and not os.path.exists(inference_visualization_dir): 139 | os.makedirs(inference_visualization_dir) 140 | 141 | prior_inference = PriorInference(model_dir, dirs_cfg) 142 | test_dataset = prior_inference.dataset 143 | 144 | decoded_scene_count = 0 145 | with tqdm.tqdm(total=len(test_dataset)) as pbar: 146 | # for idx in np.random.choice(range(len(test_dataset)), len(test_dataset), replace=False): 147 | for idx in range(len(test_dataset)): 148 | 149 | if decoded_scene_count == max_scene_decodes: 150 | break 151 | 152 | filename = test_dataset.get_data_index(idx) 153 | scene_id = os.path.split(filename)[1][4:-3] 154 | 155 | decoded_scene_count += 1 156 | 157 | ############################################ 158 | # retrieve data 159 | beam_data = [] 160 | beam_pc_rearrangements = [] 161 | for b in range(beam_size): 162 | datum = test_dataset.get_raw_data(idx, inference_mode=True, shuffle_object_index=False) 163 | 164 | # not necessary, but just to ensure no test leakage 165 | datum["struct_x_inputs"] = [0] 166 | datum["struct_y_inputs"] = [0] 167 | datum["struct_y_inputs"] = [0] 168 | datum["struct_theta_inputs"] = [[0] * 9] 169 | for obj_idx in range(len(datum["obj_x_inputs"])): 170 | datum["obj_x_inputs"][obj_idx] = 0 171 | datum["obj_y_inputs"][obj_idx] = 0 172 | datum["obj_z_inputs"][obj_idx] = 0 173 | datum["obj_theta_inputs"][obj_idx] = [0] * 9 174 | 175 | # We can play with different language here 176 | # datum["sentence"] = modify_language(datum["sentence"], radius=0.5) 177 | # datum["sentence"] = modify_language(datum["sentence"], position_x=1) 178 | # datum["sentence"] = modify_language(datum["sentence"], position_y=0.5) 179 | 180 | beam_data.append(datum) 181 | beam_pc_rearrangements.append(PointCloudRearrangement(datum)) 182 | 183 | if visualize: 184 | datum = beam_data[0] 185 | print("#"*50) 186 | print("sentence", datum["sentence"]) 187 | show_pcs(datum["xyzs"] + datum["other_xyzs"], datum["rgbs"] + datum["other_rgbs"], 188 | add_coordinate_frame=False, side_view=True, add_table=True) 189 | 190 | ############################################ 191 | # autoregressive decoding 192 | num_target_objects = beam_pc_rearrangements[0].num_target_objects 193 | # first predict structure pose 194 | beam_goal_struct_pose, target_object_preds = prior_inference.limited_batch_inference(beam_data) 195 | for b in range(beam_size): 196 | datum = beam_data[b] 197 | datum["struct_x_inputs"] = [beam_goal_struct_pose[b][0]] 198 | datum["struct_y_inputs"] = [beam_goal_struct_pose[b][1]] 199 | datum["struct_z_inputs"] = [beam_goal_struct_pose[b][2]] 200 | datum["struct_theta_inputs"] = [beam_goal_struct_pose[b][3:]] 201 | 202 | # then iteratively predict pose of each object 203 | beam_goal_obj_poses = [] 204 | for obj_idx in range(num_target_objects): 205 | struct_preds, target_object_preds = prior_inference.limited_batch_inference(beam_data) 206 | beam_goal_obj_poses.append(target_object_preds[:, obj_idx]) 207 | for b in range(beam_size): 208 | datum = beam_data[b] 209 | datum["obj_x_inputs"][obj_idx] = target_object_preds[b][obj_idx][0] 210 | datum["obj_y_inputs"][obj_idx] = target_object_preds[b][obj_idx][1] 211 | datum["obj_z_inputs"][obj_idx] = target_object_preds[b][obj_idx][2] 212 | datum["obj_theta_inputs"][obj_idx] = target_object_preds[b][obj_idx][3:] 213 | # concat in the object dim 214 | beam_goal_obj_poses = np.stack(beam_goal_obj_poses, axis=0) 215 | # swap axis 216 | beam_goal_obj_poses = np.swapaxes(beam_goal_obj_poses, 1, 0) # batch size, number of target objects, pose dim 217 | 218 | ############################################ 219 | # move pc 220 | for bi in range(beam_size): 221 | beam_pc_rearrangements[bi].set_goal_poses(beam_goal_struct_pose[bi], beam_goal_obj_poses[bi]) 222 | beam_pc_rearrangements[bi].rearrange() 223 | 224 | ############################################ 225 | if visualize: 226 | for pc_rearrangement in beam_pc_rearrangements: 227 | pc_rearrangement.visualize("goal", add_other_objects=True, 228 | add_coordinate_frame=False, side_view=True, add_table=True) 229 | 230 | if inference_visualization_dir: 231 | for pc_rearrangement in beam_pc_rearrangements: 232 | pc_rearrangement.visualize("goal", add_other_objects=True, 233 | add_coordinate_frame=False, side_view=True, add_table=True, 234 | save_vis=True, 235 | save_filename=os.path.join(inference_visualization_dir, "{}.jpg".format(scene_id))) 236 | 237 | pbar.update(1) 238 | 239 | 240 | if __name__ == "__main__": 241 | parser = argparse.ArgumentParser(description="Run a simple model") 242 | parser.add_argument("--dataset_base_dir", help='location of the dataset', type=str) 243 | parser.add_argument("--model_dir", help='location for the saved model', type=str) 244 | parser.add_argument("--dirs_config", help='config yaml file for directories', default="", type=str) 245 | args = parser.parse_args() 246 | 247 | os.environ["DATETIME"] = time.strftime("%Y%m%d-%H%M%S") 248 | 249 | # # debug only 250 | # args.dataset_base_dir = "/home/weiyu/data_drive/data_new_objects_test_split" 251 | # args.model_dir = "/home/weiyu/Research/intern/StructFormer/models/structformer_no_encoder_tower/best_model" 252 | # args.dirs_config = "/home/weiyu/Research/intern/StructFormer/structformer/configs/data/tower_dirs.yaml" 253 | 254 | if args.dirs_config: 255 | assert os.path.exists(args.dirs_config), "Cannot find config yaml file at {}".format(args.dirs_config) 256 | dirs_cfg = OmegaConf.load(args.dirs_config) 257 | dirs_cfg.dataset_base_dir = args.dataset_base_dir 258 | OmegaConf.resolve(dirs_cfg) 259 | else: 260 | dirs_cfg = None 261 | 262 | inference_beam_decoding(args.model_dir, dirs_cfg, beam_size=3, max_scene_decodes=30000, 263 | visualize=True, visualize_action_sequence=False, 264 | inference_visualization_dir=None) 265 | 266 | -------------------------------------------------------------------------------- /src/structformer/evaluation/test_structformer_no_structure.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import os 4 | import copy 5 | import tqdm 6 | import argparse 7 | from omegaconf import OmegaConf 8 | import time 9 | from torch.utils.data import DataLoader 10 | 11 | from structformer.data.tokenizer import Tokenizer 12 | import structformer.data.sequence_dataset as prior_dataset 13 | import structformer.training.train_structformer_no_structure as prior_model 14 | from structformer.utils.rearrangement import show_pcs 15 | from structformer.evaluation.inference import PointCloudRearrangement 16 | 17 | 18 | def test_model(model_dir, dirs_cfg): 19 | prior_inference = PriorInference(model_dir, dirs_cfg, data_split="test") 20 | prior_inference.validate() 21 | 22 | 23 | class PriorInference: 24 | 25 | def __init__(self, model_dir, dirs_cfg, data_split="test"): 26 | 27 | cfg, tokenizer, model, optimizer, scheduler, epoch = prior_model.load_model(model_dir, dirs_cfg) 28 | 29 | data_cfg = cfg.dataset 30 | 31 | dataset = prior_dataset.SequenceDataset(data_cfg.dirs, data_cfg.index_dirs, data_split, tokenizer, 32 | data_cfg.max_num_objects, 33 | data_cfg.max_num_other_objects, 34 | data_cfg.max_num_shape_parameters, 35 | data_cfg.max_num_rearrange_features, 36 | data_cfg.max_num_anchor_features, 37 | data_cfg.num_pts, 38 | data_cfg.use_structure_frame) 39 | 40 | self.cfg = cfg 41 | self.tokenizer = tokenizer 42 | self.model = model 43 | self.cfg = cfg 44 | self.dataset = dataset 45 | self.epoch = epoch 46 | 47 | def validate(self): 48 | """ 49 | validate the pretrained model on the dataset 50 | 51 | :return: 52 | """ 53 | data_cfg = self.cfg.dataset 54 | data_iter = DataLoader(self.dataset, batch_size=data_cfg.batch_size, shuffle=False, 55 | collate_fn=prior_dataset.SequenceDataset.collate_fn, 56 | pin_memory=data_cfg.pin_memory, num_workers=data_cfg.num_workers) 57 | 58 | prior_model.validate(self.cfg, self.model, data_iter, self.epoch, self.cfg.device) 59 | 60 | def limited_batch_inference(self, data, verbose=False): 61 | """ 62 | This function makes the assumption that scenes in the batch have the same number of objects that need to be 63 | rearranged 64 | 65 | :param data: 66 | :param model: 67 | :param test_dataset: 68 | :param tokenizer: 69 | :param cfg: 70 | :param num_samples: 71 | :param verbose: 72 | :return: 73 | """ 74 | 75 | data_size = len(data) 76 | batch_size = self.cfg.dataset.batch_size 77 | if verbose: 78 | print("data size:", data_size) 79 | print("batch size:", batch_size) 80 | 81 | num_batches = int(data_size / batch_size) 82 | if data_size % batch_size != 0: 83 | num_batches += 1 84 | 85 | all_obj_preds = [] 86 | for b in range(num_batches): 87 | if b + 1 == num_batches: 88 | # last batch 89 | batch = data[b * batch_size:] 90 | else: 91 | batch = data[b * batch_size: (b+1) * batch_size] 92 | data_tensors = [self.dataset.convert_to_tensors(d, self.tokenizer) for d in batch] 93 | data_tensors = self.dataset.collate_fn(data_tensors) 94 | predictions = prior_model.infer_once(self.cfg, self.model, data_tensors, self.cfg.device) 95 | 96 | obj_x_preds = torch.cat(predictions["obj_x_outputs"], dim=0) 97 | obj_y_preds = torch.cat(predictions["obj_y_outputs"], dim=0) 98 | obj_z_preds = torch.cat(predictions["obj_z_outputs"], dim=0) 99 | obj_theta_preds = torch.cat(predictions["obj_theta_outputs"], dim=0) 100 | obj_preds = torch.cat([obj_x_preds, obj_y_preds, obj_z_preds, obj_theta_preds], dim=1) # batch_size * max num objects, output_dim 101 | 102 | all_obj_preds.append(obj_preds) 103 | 104 | obj_preds = torch.cat(all_obj_preds, dim=0) # data_size * max num objects, output_dim 105 | 106 | obj_preds = obj_preds.detach().cpu().numpy() 107 | 108 | obj_preds = obj_preds.reshape(data_size, -1, obj_preds.shape[-1]) # batch_size, max num objects, output_dim 109 | 110 | return obj_preds 111 | 112 | 113 | def inference_beam_decoding(model_dir, dirs_cfg, beam_size=100, max_scene_decodes=30000, 114 | visualize=True, visualize_action_sequence=False, 115 | inference_visualization_dir=None): 116 | """ 117 | 118 | :param model_dir: 119 | :param beam_size: 120 | :param max_scene_decodes: 121 | :param visualize: 122 | :param visualize_action_sequence: 123 | :param inference_visualization_dir: 124 | :param side_view: 125 | :return: 126 | """ 127 | 128 | if inference_visualization_dir and not os.path.exists(inference_visualization_dir): 129 | os.makedirs(inference_visualization_dir) 130 | 131 | prior_inference = PriorInference(model_dir, dirs_cfg) 132 | test_dataset = prior_inference.dataset 133 | 134 | decoded_scene_count = 0 135 | with tqdm.tqdm(total=len(test_dataset)) as pbar: 136 | # for idx in np.random.choice(range(len(test_dataset)), len(test_dataset), replace=False): 137 | for idx in range(len(test_dataset)): 138 | 139 | if decoded_scene_count == max_scene_decodes: 140 | break 141 | 142 | filename = test_dataset.get_data_index(idx) 143 | scene_id = os.path.split(filename)[1][4:-3] 144 | 145 | decoded_scene_count += 1 146 | 147 | ############################################ 148 | # retrieve data 149 | beam_data = [] 150 | beam_pc_rearrangements = [] 151 | for b in range(beam_size): 152 | datum = test_dataset.get_raw_data(idx, inference_mode=True, shuffle_object_index=False) 153 | 154 | # not necessary, but just to ensure no test leakage 155 | for obj_idx in range(len(datum["obj_x_inputs"])): 156 | datum["obj_x_inputs"][obj_idx] = 0 157 | datum["obj_y_inputs"][obj_idx] = 0 158 | datum["obj_z_inputs"][obj_idx] = 0 159 | datum["obj_theta_inputs"][obj_idx] = [0] * 9 160 | 161 | # We can play with different language here 162 | # datum["sentence"] = modify_language(datum["sentence"], radius=0.5) 163 | # datum["sentence"] = modify_language(datum["sentence"], position_x=1) 164 | # datum["sentence"] = modify_language(datum["sentence"], position_y=0.5) 165 | 166 | beam_data.append(datum) 167 | beam_pc_rearrangements.append(PointCloudRearrangement(datum, use_structure_frame=False)) 168 | 169 | if visualize: 170 | datum = beam_data[0] 171 | print("#"*50) 172 | print("sentence", datum["sentence"]) 173 | show_pcs(datum["xyzs"] + datum["other_xyzs"], datum["rgbs"] + datum["other_rgbs"], 174 | add_coordinate_frame=False, side_view=True, add_table=True) 175 | 176 | ############################################ 177 | # autoregressive decoding 178 | num_target_objects = beam_pc_rearrangements[0].num_target_objects 179 | # iteratively predict pose of each object 180 | beam_goal_obj_poses = [] 181 | for obj_idx in range(num_target_objects): 182 | target_object_preds = prior_inference.limited_batch_inference(beam_data) 183 | beam_goal_obj_poses.append(target_object_preds[:, obj_idx]) 184 | for b in range(beam_size): 185 | datum = beam_data[b] 186 | datum["obj_x_inputs"][obj_idx] = target_object_preds[b][obj_idx][0] 187 | datum["obj_y_inputs"][obj_idx] = target_object_preds[b][obj_idx][1] 188 | datum["obj_z_inputs"][obj_idx] = target_object_preds[b][obj_idx][2] 189 | datum["obj_theta_inputs"][obj_idx] = target_object_preds[b][obj_idx][3:] 190 | # concat in the object dim 191 | beam_goal_obj_poses = np.stack(beam_goal_obj_poses, axis=0) 192 | # swap axis 193 | beam_goal_obj_poses = np.swapaxes(beam_goal_obj_poses, 1, 0) # batch size, number of target objects, pose dim 194 | 195 | ############################################ 196 | # move pc 197 | for bi in range(beam_size): 198 | beam_pc_rearrangements[bi].set_goal_poses(None, beam_goal_obj_poses[bi]) 199 | beam_pc_rearrangements[bi].rearrange() 200 | 201 | ############################################ 202 | if visualize: 203 | for pc_rearrangement in beam_pc_rearrangements: 204 | pc_rearrangement.visualize("goal", add_other_objects=True, 205 | add_coordinate_frame=False, side_view=True, add_table=True) 206 | 207 | if inference_visualization_dir: 208 | for pc_rearrangement in beam_pc_rearrangements: 209 | pc_rearrangement.visualize("goal", add_other_objects=True, 210 | add_coordinate_frame=False, side_view=True, add_table=True, 211 | save_vis=True, 212 | save_filename=os.path.join(inference_visualization_dir, "{}.jpg".format(scene_id))) 213 | 214 | pbar.update(1) 215 | 216 | 217 | if __name__ == "__main__": 218 | parser = argparse.ArgumentParser(description="Run a simple model") 219 | parser.add_argument("--dataset_base_dir", help='location of the dataset', type=str) 220 | parser.add_argument("--model_dir", help='location for the saved model', type=str) 221 | parser.add_argument("--dirs_config", help='config yaml file for directories', default="", type=str) 222 | args = parser.parse_args() 223 | 224 | os.environ["DATETIME"] = time.strftime("%Y%m%d-%H%M%S") 225 | 226 | # # debug only 227 | # args.dataset_base_dir = "/home/weiyu/data_drive/data_new_objects_test_split" 228 | # args.model_dir = "/home/weiyu/Research/intern/StructFormer/models/structformer_no_structure_dinner/best_model" 229 | # args.dirs_config = "/home/weiyu/Research/intern/StructFormer/structformer/configs/data/dinner_dirs.yaml" 230 | 231 | if args.dirs_config: 232 | assert os.path.exists(args.dirs_config), "Cannot find config yaml file at {}".format(args.dirs_config) 233 | dirs_cfg = OmegaConf.load(args.dirs_config) 234 | dirs_cfg.dataset_base_dir = args.dataset_base_dir 235 | OmegaConf.resolve(dirs_cfg) 236 | else: 237 | dirs_cfg = None 238 | 239 | inference_beam_decoding(args.model_dir, dirs_cfg, beam_size=3, max_scene_decodes=30000, 240 | visualize=True, visualize_action_sequence=False, 241 | inference_visualization_dir=None) 242 | 243 | 244 | -------------------------------------------------------------------------------- /src/structformer/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wliu88/StructFormer/a81d8c83313f1f24a75ca83aa9dd9f8cfcf27419/src/structformer/models/__init__.py -------------------------------------------------------------------------------- /src/structformer/models/object_selection_network.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.nn import TransformerEncoder, TransformerEncoderLayer 7 | 8 | from structformer.models.point_transformer import PointTransformerEncoderSmall 9 | 10 | 11 | class FocalLoss(nn.Module): 12 | "Focal Loss" 13 | 14 | def __init__(self, gamma=2, alpha=.25): 15 | super(FocalLoss, self).__init__() 16 | # self.alpha = torch.tensor([alpha, 1-alpha]) 17 | self.gamma = gamma 18 | 19 | def forward(self, inputs, targets): 20 | BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none') 21 | pt = torch.exp(-BCE_loss) 22 | # targets = targets.type(torch.long) 23 | # at = self.alpha.gather(0, targets.data.view(-1)) 24 | # F_loss = at*(1-pt)**self.gamma * BCE_loss 25 | F_loss = (1 - pt)**self.gamma * BCE_loss 26 | return F_loss.mean() 27 | 28 | 29 | class EncoderMLP(torch.nn.Module): 30 | def __init__(self, in_dim, out_dim, pt_dim=3, uses_pt=True): 31 | super(EncoderMLP, self).__init__() 32 | self.uses_pt = uses_pt 33 | self.output = out_dim 34 | d5 = int(in_dim) 35 | d6 = int(2 * self.output) 36 | d7 = self.output 37 | self.encode_position = nn.Sequential( 38 | nn.Linear(pt_dim, in_dim), 39 | nn.LayerNorm(in_dim), 40 | nn.ReLU(), 41 | nn.Linear(in_dim, in_dim), 42 | nn.LayerNorm(in_dim), 43 | nn.ReLU(), 44 | ) 45 | d5 = 2 * in_dim if self.uses_pt else in_dim 46 | self.fc_block = nn.Sequential( 47 | nn.Linear(int(d5), d6), 48 | nn.LayerNorm(int(d6)), 49 | nn.ReLU(), 50 | nn.Linear(int(d6), d6), 51 | nn.LayerNorm(int(d6)), 52 | nn.ReLU(), 53 | nn.Linear(d6, d7)) 54 | 55 | def forward(self, x, pt=None): 56 | if self.uses_pt: 57 | if pt is None: raise RuntimeError('did not provide pt') 58 | y = self.encode_position(pt) 59 | x = torch.cat([x, y], dim=-1) 60 | return self.fc_block(x) 61 | 62 | 63 | class RearrangeObjectsPredictorPCT(torch.nn.Module): 64 | 65 | def __init__(self, vocab_size, 66 | num_attention_heads=8, encoder_hidden_dim=16, encoder_dropout=0.1, encoder_activation="relu", encoder_num_layers=8, 67 | use_focal_loss=False, focal_loss_gamma=2): 68 | super(RearrangeObjectsPredictorPCT, self).__init__() 69 | 70 | print("Object Selection Network with Point Transformer") 71 | 72 | # object encode will have dim 256 73 | self.object_encoder = PointTransformerEncoderSmall(output_dim=256, input_dim=6, mean_center=True) 74 | 75 | # 256 = 240 (point cloud) + 8 (position idx) + 8 (token type) 76 | self.mlp = EncoderMLP(256, 240, uses_pt=False) 77 | 78 | self.word_embeddings = torch.nn.Embedding(vocab_size, 240, padding_idx=0) 79 | self.token_type_embeddings = torch.nn.Embedding(2, 8) 80 | self.position_embeddings = torch.nn.Embedding(11, 8) 81 | 82 | encoder_layers = TransformerEncoderLayer(256, num_attention_heads, 83 | encoder_hidden_dim, encoder_dropout, encoder_activation) 84 | self.encoder = TransformerEncoder(encoder_layers, encoder_num_layers) 85 | 86 | self.rearrange_object_fier = nn.Sequential(nn.Linear(256, 256), 87 | nn.LayerNorm(256), 88 | nn.ReLU(), 89 | nn.Linear(256, 128), 90 | nn.LayerNorm(128), 91 | nn.ReLU(), 92 | nn.Linear(128, 1)) 93 | 94 | ########################### 95 | if use_focal_loss: 96 | print("use focal loss") 97 | self.loss = FocalLoss(gamma=focal_loss_gamma) 98 | else: 99 | print("use standard BCE logit loss") 100 | self.loss = torch.nn.BCEWithLogitsLoss(reduction="mean") 101 | 102 | def forward(self, xyzs, rgbs, object_pad_mask, sentence, sentence_pad_mask, token_type_index, position_index): 103 | 104 | batch_size = object_pad_mask.shape[0] 105 | num_objects = object_pad_mask.shape[1] 106 | 107 | ######################### 108 | center_xyz, x = self.object_encoder(xyzs, rgbs) 109 | x = self.mlp(x, center_xyz) 110 | x = x.reshape(batch_size, num_objects, -1) 111 | 112 | ######################### 113 | sentence = self.word_embeddings(sentence) 114 | 115 | ######################### 116 | position_embed = self.position_embeddings(position_index) 117 | token_type_embed = self.token_type_embeddings(token_type_index) 118 | pad_mask = torch.cat([sentence_pad_mask, object_pad_mask], dim=1) 119 | 120 | sequence_encode = torch.cat([sentence, x], dim=1) 121 | sequence_encode = torch.cat([sequence_encode, position_embed, token_type_embed], dim=-1) 122 | ######################### 123 | # sequence_encode: [batch size, sequence_length, encoder input dimension] 124 | # input to transformer needs to have dimenion [sequence_length, batch size, encoder input dimension] 125 | sequence_encode = sequence_encode.transpose(1, 0) 126 | 127 | # convert to bool 128 | pad_mask = (pad_mask == 1) 129 | 130 | # encode: [sequence_length, batch_size, embedding_size] 131 | encode = self.encoder(sequence_encode, src_key_padding_mask=pad_mask) 132 | encode = encode.transpose(1, 0) 133 | ######################### 134 | obj_encodes = encode[:, -num_objects:, :] 135 | obj_encodes = obj_encodes.reshape(-1, obj_encodes.shape[-1]) 136 | 137 | rearrange_obj_labels = self.rearrange_object_fier(obj_encodes).squeeze(dim=1) # batch_size * num_objects 138 | 139 | predictions = {"rearrange_obj_labels": rearrange_obj_labels} 140 | 141 | return predictions 142 | 143 | def criterion(self, predictions, labels): 144 | 145 | loss = 0 146 | for key in predictions: 147 | 148 | preds = predictions[key] 149 | gts = labels[key] 150 | 151 | mask = gts == -100 152 | preds = preds[~mask] 153 | gts = gts[~mask] 154 | 155 | loss += self.loss(preds, gts) 156 | 157 | return loss 158 | 159 | def convert_logits(self, predictions): 160 | 161 | for key in predictions: 162 | if key == "rearrange_obj_labels": 163 | predictions[key] = torch.sigmoid(predictions[key]) 164 | 165 | return predictions -------------------------------------------------------------------------------- /src/structformer/models/point_transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from structformer.utils.pointnet import farthest_point_sample, index_points, square_distance 4 | 5 | # adapted from https://github.com/qq456cvb/Point-Transformers 6 | 7 | 8 | def sample_and_group(npoint, nsample, xyz, points): 9 | B, N, C = xyz.shape 10 | S = npoint 11 | 12 | fps_idx = farthest_point_sample(xyz, npoint) # [B, npoint] 13 | 14 | new_xyz = index_points(xyz, fps_idx) 15 | new_points = index_points(points, fps_idx) 16 | 17 | dists = square_distance(new_xyz, xyz) # B x npoint x N 18 | idx = dists.argsort()[:, :, :nsample] # B x npoint x K 19 | 20 | grouped_points = index_points(points, idx) 21 | grouped_points_norm = grouped_points - new_points.view(B, S, 1, -1) 22 | new_points = torch.cat([grouped_points_norm, new_points.view(B, S, 1, -1).repeat(1, 1, nsample, 1)], dim=-1) 23 | return new_xyz, new_points 24 | 25 | 26 | class Local_op(nn.Module): 27 | def __init__(self, in_channels, out_channels): 28 | super().__init__() 29 | self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=1, bias=False) 30 | self.bn1 = nn.BatchNorm1d(out_channels) 31 | self.relu = nn.ReLU() 32 | 33 | def forward(self, x): 34 | b, n, s, d = x.size() # torch.Size([32, 512, 32, 6]) 35 | x = x.permute(0, 1, 3, 2) 36 | x = x.reshape(-1, d, s) 37 | batch_size, _, N = x.size() 38 | x = self.relu(self.bn1(self.conv1(x))) # B, D, N 39 | x = torch.max(x, 2)[0] 40 | x = x.view(batch_size, -1) 41 | x = x.reshape(b, n, -1).permute(0, 2, 1) 42 | return x 43 | 44 | 45 | class SA_Layer(nn.Module): 46 | def __init__(self, channels): 47 | super().__init__() 48 | self.q_conv = nn.Conv1d(channels, channels // 4, 1, bias=False) 49 | self.k_conv = nn.Conv1d(channels, channels // 4, 1, bias=False) 50 | self.q_conv.weight = self.k_conv.weight 51 | self.v_conv = nn.Conv1d(channels, channels, 1) 52 | self.trans_conv = nn.Conv1d(channels, channels, 1) 53 | self.after_norm = nn.BatchNorm1d(channels) 54 | self.act = nn.ReLU() 55 | self.softmax = nn.Softmax(dim=-1) 56 | 57 | def forward(self, x): 58 | x_q = self.q_conv(x).permute(0, 2, 1) # b, n, c 59 | x_k = self.k_conv(x)# b, c, n 60 | x_v = self.v_conv(x) 61 | energy = x_q @ x_k # b, n, n 62 | attention = self.softmax(energy) 63 | attention = attention / (1e-9 + attention.sum(dim=1, keepdims=True)) 64 | x_r = x_v @ attention # b, c, n 65 | x_r = self.act(self.after_norm(self.trans_conv(x - x_r))) 66 | x = x + x_r 67 | return x 68 | 69 | 70 | class StackedAttention(nn.Module): 71 | def __init__(self, channels=64): 72 | super().__init__() 73 | self.conv1 = nn.Conv1d(channels, channels, kernel_size=1, bias=False) 74 | self.conv2 = nn.Conv1d(channels, channels, kernel_size=1, bias=False) 75 | 76 | self.bn1 = nn.BatchNorm1d(channels) 77 | self.bn2 = nn.BatchNorm1d(channels) 78 | 79 | self.sa1 = SA_Layer(channels) 80 | self.sa2 = SA_Layer(channels) 81 | 82 | self.relu = nn.ReLU() 83 | 84 | def forward(self, x): 85 | # 86 | # b, 3, npoint, nsample 87 | # conv2d 3 -> 128 channels 1, 1 88 | # b * npoint, c, nsample 89 | # permute reshape 90 | batch_size, _, N = x.size() 91 | 92 | x = self.relu(self.bn1(self.conv1(x))) # B, D, N 93 | x = self.relu(self.bn2(self.conv2(x))) 94 | 95 | x1 = self.sa1(x) 96 | x2 = self.sa2(x1) 97 | 98 | x = torch.cat((x1, x2), dim=1) 99 | 100 | return x 101 | 102 | 103 | class PointTransformerEncoderSmall(nn.Module): 104 | 105 | def __init__(self, output_dim=256, input_dim=6, mean_center=True): 106 | super(PointTransformerEncoderSmall, self).__init__() 107 | 108 | self.mean_center = mean_center 109 | 110 | # map the second dim of the input from input_dim to 64 111 | self.conv1 = nn.Conv1d(input_dim, 64, kernel_size=1, bias=False) 112 | self.bn1 = nn.BatchNorm1d(64) 113 | self.gather_local_0 = Local_op(in_channels=128, out_channels=64) 114 | self.gather_local_1 = Local_op(in_channels=128, out_channels=64) 115 | self.pt_last = StackedAttention(channels=64) 116 | 117 | self.relu = nn.ReLU() 118 | self.conv_fuse = nn.Sequential(nn.Conv1d(192, 256, kernel_size=1, bias=False), 119 | nn.BatchNorm1d(256), 120 | nn.LeakyReLU(negative_slope=0.2)) 121 | 122 | self.linear1 = nn.Linear(256, 256, bias=False) 123 | self.bn6 = nn.BatchNorm1d(256) 124 | self.dp1 = nn.Dropout(p=0.5) 125 | self.linear2 = nn.Linear(256, 256) 126 | 127 | def forward(self, xyz, f=None): 128 | # xyz: B, N, 3 129 | # f: B, N, D 130 | center = torch.mean(xyz, dim=1) 131 | if self.mean_center: 132 | xyz = xyz - center.view(-1, 1, 3).repeat(1, xyz.shape[1], 1) 133 | if f is None: 134 | x = self.pct(xyz) 135 | else: 136 | x = self.pct(torch.cat([xyz, f], dim=2)) # B, output_dim 137 | 138 | return center, x 139 | 140 | def pct(self, x): 141 | 142 | # x: B, N, D 143 | xyz = x[..., :3] 144 | x = x.permute(0, 2, 1) 145 | batch_size, _, _ = x.size() 146 | x = self.relu(self.bn1(self.conv1(x))) # B, D, N 147 | x = x.permute(0, 2, 1) 148 | new_xyz, new_feature = sample_and_group(npoint=128, nsample=32, xyz=xyz, points=x) 149 | feature_0 = self.gather_local_0(new_feature) 150 | feature = feature_0.permute(0, 2, 1) # B, nsamples, D 151 | new_xyz, new_feature = sample_and_group(npoint=32, nsample=16, xyz=new_xyz, points=feature) 152 | feature_1 = self.gather_local_1(new_feature) # B, D, nsamples 153 | 154 | x = self.pt_last(feature_1) # B, D * 2, nsamples 155 | x = torch.cat([x, feature_1], dim=1) # B, D * 3, nsamples 156 | x = self.conv_fuse(x) 157 | x = torch.max(x, 2)[0] 158 | x = x.view(batch_size, -1) 159 | 160 | x = self.relu(self.bn6(self.linear1(x))) 161 | x = self.dp1(x) 162 | x = self.linear2(x) 163 | 164 | return x 165 | 166 | 167 | class SampleAndGroup(nn.Module): 168 | 169 | def __init__(self, output_dim=64, input_dim=6, mean_center=True, npoints=(128, 32), nsamples=(32, 16)): 170 | super(SampleAndGroup, self).__init__() 171 | 172 | self.mean_center = mean_center 173 | self.npoints = npoints 174 | self.nsamples = nsamples 175 | 176 | # map the second dim of the input from input_dim to 64 177 | self.conv1 = nn.Conv1d(input_dim, output_dim, kernel_size=1, bias=False) 178 | self.bn1 = nn.BatchNorm1d(output_dim) 179 | self.gather_local_0 = Local_op(in_channels=output_dim * 2, out_channels=output_dim) 180 | self.gather_local_1 = Local_op(in_channels=output_dim * 2, out_channels=output_dim) 181 | self.relu = nn.ReLU() 182 | 183 | def forward(self, xyz, f): 184 | # xyz: B, N, 3 185 | # f: B, N, D 186 | center = torch.mean(xyz, dim=1) 187 | if self.mean_center: 188 | xyz = xyz - center.view(-1, 1, 3).repeat(1, xyz.shape[1], 1) 189 | x = self.sg(torch.cat([xyz, f], dim=2)) # B, nsamples, output_dim 190 | 191 | return center, x 192 | 193 | def sg(self, x): 194 | 195 | # x: B, N, D 196 | xyz = x[..., :3] 197 | x = x.permute(0, 2, 1) 198 | batch_size, _, _ = x.size() 199 | x = self.relu(self.bn1(self.conv1(x))) # B, D, N 200 | x = x.permute(0, 2, 1) 201 | new_xyz, new_feature = sample_and_group(npoint=self.npoints[0], nsample=self.nsamples[0], xyz=xyz, points=x) 202 | feature_0 = self.gather_local_0(new_feature) 203 | feature = feature_0.permute(0, 2, 1) # B, nsamples, D 204 | new_xyz, new_feature = sample_and_group(npoint=self.npoints[1], nsample=self.nsamples[1], xyz=new_xyz, points=feature) 205 | feature_1 = self.gather_local_1(new_feature) # B, D, nsamples 206 | x = feature_1.permute(0, 2, 1) # B, nsamples, D 207 | 208 | return x -------------------------------------------------------------------------------- /src/structformer/training/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wliu88/StructFormer/a81d8c83313f1f24a75ca83aa9dd9f8cfcf27419/src/structformer/training/__init__.py -------------------------------------------------------------------------------- /src/structformer/training/train_binary_model.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | from torch.optim import lr_scheduler 7 | from warmup_scheduler import GradualWarmupScheduler 8 | import numpy as np 9 | import torchvision 10 | from torchvision import datasets, models, transforms 11 | import matplotlib.pyplot as plt 12 | import time 13 | import os 14 | import copy 15 | import tqdm 16 | 17 | import pickle 18 | import argparse 19 | from omegaconf import OmegaConf 20 | from collections import defaultdict 21 | 22 | from torch.utils.data import DataLoader 23 | from structformer.data.binary_dataset import BinaryDataset 24 | from structformer.models.pose_generation_network import PriorContinuousOutBinaryPCT6D 25 | from structformer.data.tokenizer import Tokenizer 26 | from structformer.utils.rearrangement import evaluate_prior_prediction 27 | 28 | 29 | def train_model(cfg, model, data_iter, optimizer, warmup, num_epochs, device, save_best_model, grad_clipping=1.0): 30 | 31 | if save_best_model: 32 | best_model_dir = os.path.join(cfg.experiment_dir, "best_model") 33 | print("best model will be saved to {}".format(best_model_dir)) 34 | if not os.path.exists(best_model_dir): 35 | os.makedirs(best_model_dir) 36 | best_score = -np.inf 37 | 38 | for epoch in range(num_epochs): 39 | print('Epoch {}/{}'.format(epoch, num_epochs - 1)) 40 | print('-' * 10) 41 | 42 | model.train() 43 | epoch_loss = 0 44 | gts = defaultdict(list) 45 | predictions = defaultdict(list) 46 | 47 | with tqdm.tqdm(total=len(data_iter["train"])) as pbar: 48 | for step, batch in enumerate(data_iter["train"]): 49 | optimizer.zero_grad() 50 | # input 51 | query_xyz = batch["query_xyz"].to(device, non_blocking=True) 52 | query_rgb = batch["query_rgb"].to(device, non_blocking=True) 53 | anchor_xyz = batch["anchor_xyz"].to(device, non_blocking=True) 54 | anchor_rgb = batch["anchor_rgb"].to(device, non_blocking=True) 55 | bg_xyz = batch["bg_xyz"].to(device, non_blocking=True) 56 | bg_rgb = batch["bg_rgb"].to(device, non_blocking=True) 57 | sentence = batch["sentence"].to(device, non_blocking=True) 58 | position_index = batch["position_index"].to(device, non_blocking=True) 59 | pad_mask = batch["pad_mask"].to(device, non_blocking=True) 60 | 61 | # output 62 | targets = {} 63 | for key in ["obj_x_outputs", "obj_y_outputs", "obj_z_outputs", "obj_theta_outputs"]: 64 | targets[key] = batch[key].to(device, non_blocking=True) 65 | targets[key] = targets[key].reshape(targets[key].shape[0] * targets[key].shape[1], -1) 66 | 67 | preds = model.forward(query_xyz, query_rgb, anchor_xyz, anchor_rgb, bg_xyz, bg_rgb, 68 | sentence, pad_mask, position_index) 69 | 70 | loss = model.criterion(preds, targets) 71 | loss.backward() 72 | 73 | if grad_clipping != 0.0: 74 | torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clipping) 75 | 76 | optimizer.step() 77 | epoch_loss += loss 78 | 79 | for key in ["obj_x_outputs", "obj_y_outputs", "obj_z_outputs", "obj_theta_outputs"]: 80 | gts[key].append(targets[key].detach()) 81 | predictions[key].append(preds[key].detach()) 82 | 83 | pbar.update(1) 84 | pbar.set_postfix({"Batch loss": loss}) 85 | 86 | warmup.step() 87 | 88 | print('[Epoch:{}]: Training Loss:{:.4}'.format(epoch, epoch_loss)) 89 | evaluate_prior_prediction(gts, predictions, ["obj_x_outputs", "obj_y_outputs", "obj_z_outputs", "obj_theta_outputs"]) 90 | 91 | score = validate(cfg, model, data_iter["valid"], epoch, device) 92 | if save_best_model and score > best_score: 93 | print("Saving best model so far...") 94 | best_score = score 95 | save_model(best_model_dir, cfg, epoch, model) 96 | 97 | return model 98 | 99 | 100 | def validate(cfg, model, data_iter, epoch, device): 101 | """ 102 | helper function to evaluate the model 103 | 104 | :param model: 105 | :param data_iter: 106 | :param epoch: 107 | :param device: 108 | :return: 109 | """ 110 | 111 | model.eval() 112 | 113 | epoch_loss = 0 114 | gts = defaultdict(list) 115 | predictions = defaultdict(list) 116 | with torch.no_grad(): 117 | 118 | with tqdm.tqdm(total=len(data_iter)) as pbar: 119 | for step, batch in enumerate(data_iter): 120 | 121 | # input 122 | query_xyz = batch["query_xyz"].to(device, non_blocking=True) 123 | query_rgb = batch["query_rgb"].to(device, non_blocking=True) 124 | anchor_xyz = batch["anchor_xyz"].to(device, non_blocking=True) 125 | anchor_rgb = batch["anchor_rgb"].to(device, non_blocking=True) 126 | bg_xyz = batch["bg_xyz"].to(device, non_blocking=True) 127 | bg_rgb = batch["bg_rgb"].to(device, non_blocking=True) 128 | sentence = batch["sentence"].to(device, non_blocking=True) 129 | position_index = batch["position_index"].to(device, non_blocking=True) 130 | pad_mask = batch["pad_mask"].to(device, non_blocking=True) 131 | 132 | # output 133 | targets = {} 134 | for key in ["obj_x_outputs", "obj_y_outputs", "obj_z_outputs", "obj_theta_outputs"]: 135 | targets[key] = batch[key].to(device, non_blocking=True) 136 | targets[key] = targets[key].reshape(targets[key].shape[0] * targets[key].shape[1], -1) 137 | 138 | preds = model.forward(query_xyz, query_rgb, anchor_xyz, anchor_rgb, bg_xyz, bg_rgb, 139 | sentence, pad_mask, position_index) 140 | 141 | loss = model.criterion(preds, targets) 142 | 143 | for key in ["obj_x_outputs", "obj_y_outputs", "obj_z_outputs", "obj_theta_outputs"]: 144 | gts[key].append(targets[key]) 145 | predictions[key].append(preds[key]) 146 | 147 | epoch_loss += loss 148 | pbar.update(1) 149 | pbar.set_postfix({"Batch loss": loss}) 150 | 151 | print('[Epoch:{}]: Val Loss:{:.4}'.format(epoch, epoch_loss)) 152 | 153 | score = evaluate_prior_prediction(gts, predictions, ["obj_x_outputs", "obj_y_outputs", "obj_z_outputs", "obj_theta_outputs"]) 154 | return score 155 | 156 | 157 | def infer_once(cfg, model, batch, device): 158 | 159 | model.eval() 160 | 161 | predictions = defaultdict(list) 162 | with torch.no_grad(): 163 | 164 | # input 165 | query_xyz = batch["query_xyz"].to(device, non_blocking=True) 166 | query_rgb = batch["query_rgb"].to(device, non_blocking=True) 167 | anchor_xyz = batch["anchor_xyz"].to(device, non_blocking=True) 168 | anchor_rgb = batch["anchor_rgb"].to(device, non_blocking=True) 169 | bg_xyz = batch["bg_xyz"].to(device, non_blocking=True) 170 | bg_rgb = batch["bg_rgb"].to(device, non_blocking=True) 171 | sentence = batch["sentence"].to(device, non_blocking=True) 172 | position_index = batch["position_index"].to(device, non_blocking=True) 173 | pad_mask = batch["pad_mask"].to(device, non_blocking=True) 174 | 175 | # output 176 | targets = {} 177 | for key in ["obj_x_outputs", "obj_y_outputs", "obj_z_outputs", "obj_theta_outputs"]: 178 | targets[key] = batch[key].to(device, non_blocking=True) 179 | targets[key] = targets[key].reshape(targets[key].shape[0] * targets[key].shape[1], -1) 180 | 181 | preds = model.forward(query_xyz, query_rgb, anchor_xyz, anchor_rgb, bg_xyz, bg_rgb, 182 | sentence, pad_mask, position_index) 183 | 184 | for key in ["obj_x_outputs", "obj_y_outputs", "obj_z_outputs", "obj_theta_outputs"]: 185 | predictions[key].append(preds[key]) 186 | 187 | return predictions 188 | 189 | 190 | def save_model(model_dir, cfg, epoch, model, optimizer=None, scheduler=None): 191 | state_dict = {'epoch': epoch, 192 | 'model_state_dict': model.state_dict()} 193 | if optimizer is not None: 194 | state_dict["optimizer_state_dict"] = optimizer.state_dict() 195 | if scheduler is not None: 196 | state_dict["scheduler_state_dict"] = scheduler.state_dict() 197 | torch.save(state_dict, os.path.join(model_dir, "model.tar")) 198 | OmegaConf.save(cfg, os.path.join(model_dir, "config.yaml")) 199 | 200 | 201 | def load_model(model_dir, dirs_cfg): 202 | """ 203 | Load transformer model 204 | Important: to use the model, call model.eval() or model.train() 205 | :param model_dir: 206 | :return: 207 | """ 208 | # load dictionaries 209 | cfg = OmegaConf.load(os.path.join(model_dir, "config.yaml")) 210 | if dirs_cfg: 211 | cfg = OmegaConf.merge(cfg, dirs_cfg) 212 | 213 | data_cfg = cfg.dataset 214 | tokenizer = Tokenizer(data_cfg.vocab_dir) 215 | vocab_size = tokenizer.get_vocab_size() 216 | 217 | # initialize model 218 | model_cfg = cfg.model 219 | model = PriorContinuousOutBinaryPCT6D(vocab_size, 220 | num_attention_heads=model_cfg.num_attention_heads, 221 | encoder_hidden_dim=model_cfg.encoder_hidden_dim, 222 | encoder_dropout=model_cfg.encoder_dropout, 223 | encoder_activation=model_cfg.encoder_activation, 224 | encoder_num_layers=model_cfg.encoder_num_layers, 225 | object_dropout=model_cfg.object_dropout, 226 | theta_loss_divide=model_cfg.theta_loss_divide, 227 | ignore_rgb=model_cfg.ignore_rgb) 228 | model.to(cfg.device) 229 | 230 | # load state dicts 231 | checkpoint = torch.load(os.path.join(model_dir, "model.tar")) 232 | model.load_state_dict(checkpoint['model_state_dict']) 233 | 234 | optimizer = None 235 | if "optimizer_state_dict" in checkpoint: 236 | training_cfg = cfg.training 237 | optimizer = optim.Adam(model.parameters(), lr=training_cfg.learning_rate) 238 | optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 239 | 240 | scheduler = None 241 | if "scheduler_state_dict" in checkpoint: 242 | scheduler = None 243 | if scheduler: 244 | scheduler.load_state_dict(checkpoint['scheduler_state_dict']) 245 | 246 | epoch = checkpoint['epoch'] 247 | return cfg, tokenizer, model, optimizer, scheduler, epoch 248 | 249 | 250 | def run_model(cfg): 251 | 252 | np.random.seed(cfg.random_seed) 253 | torch.manual_seed(cfg.random_seed) 254 | if torch.cuda.is_available(): 255 | torch.cuda.manual_seed(cfg.random_seed) 256 | torch.cuda.manual_seed_all(cfg.random_seed) 257 | torch.backends.cudnn.deterministic = True 258 | 259 | data_cfg = cfg.dataset 260 | tokenizer = Tokenizer(data_cfg.vocab_dir) 261 | vocab_size = tokenizer.get_vocab_size() 262 | 263 | train_dataset = BinaryDataset(data_cfg.dirs, data_cfg.index_dirs, "train", tokenizer, 264 | data_cfg.max_num_objects, 265 | data_cfg.max_num_other_objects, 266 | data_cfg.max_num_shape_parameters, 267 | data_cfg.max_num_rearrange_features, 268 | data_cfg.max_num_anchor_features, 269 | data_cfg.num_pts) 270 | valid_dataset = BinaryDataset(data_cfg.dirs, data_cfg.index_dirs, "valid", tokenizer, 271 | data_cfg.max_num_objects, 272 | data_cfg.max_num_other_objects, 273 | data_cfg.max_num_shape_parameters, 274 | data_cfg.max_num_rearrange_features, 275 | data_cfg.max_num_anchor_features, 276 | data_cfg.num_pts) 277 | 278 | data_iter = {} 279 | data_iter["train"] = DataLoader(train_dataset, batch_size=data_cfg.batch_size, shuffle=True, 280 | collate_fn=BinaryDataset.collate_fn, 281 | pin_memory=data_cfg.pin_memory, num_workers=data_cfg.num_workers) 282 | data_iter["valid"] = DataLoader(valid_dataset, batch_size=data_cfg.batch_size, shuffle=False, 283 | collate_fn=BinaryDataset.collate_fn, 284 | pin_memory=data_cfg.pin_memory, num_workers=data_cfg.num_workers) 285 | 286 | # load model 287 | model_cfg = cfg.model 288 | model = PriorContinuousOutBinaryPCT6D(vocab_size, 289 | num_attention_heads=model_cfg.num_attention_heads, 290 | encoder_hidden_dim=model_cfg.encoder_hidden_dim, 291 | encoder_dropout=model_cfg.encoder_dropout, 292 | encoder_activation=model_cfg.encoder_activation, 293 | encoder_num_layers=model_cfg.encoder_num_layers, 294 | object_dropout=model_cfg.object_dropout, 295 | theta_loss_divide=model_cfg.theta_loss_divide, 296 | ignore_rgb=model_cfg.ignore_rgb) 297 | model.to(cfg.device) 298 | 299 | training_cfg = cfg.training 300 | optimizer = optim.Adam(model.parameters(), lr=training_cfg.learning_rate, weight_decay=training_cfg.l2) 301 | scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=training_cfg.lr_restart) 302 | warmup = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=training_cfg.warmup, 303 | after_scheduler=scheduler) 304 | 305 | train_model(cfg, model, data_iter, optimizer, warmup, training_cfg.max_epochs, cfg.device, cfg.save_best_model) 306 | 307 | # save model 308 | if cfg.save_model: 309 | model_dir = os.path.join(cfg.experiment_dir, "model") 310 | print("Saving model to {}".format(model_dir)) 311 | if not os.path.exists(model_dir): 312 | os.makedirs(model_dir) 313 | save_model(model_dir, cfg, cfg.max_epochs, model, optimizer, scheduler) 314 | 315 | 316 | if __name__ == "__main__": 317 | 318 | parser = argparse.ArgumentParser(description="Run a simple model") 319 | parser.add_argument("--dataset_base_dir", help='location of the dataset', type=str) 320 | parser.add_argument("--main_config", help='config yaml file for the model', 321 | default='../configs/binary_model.yaml', 322 | type=str) 323 | parser.add_argument("--dirs_config", help='config yaml file for directories', 324 | default='../configs/data/circle_dirs.yaml', 325 | type=str) 326 | args = parser.parse_args() 327 | 328 | # # debug 329 | # args.dataset_base_dir = "/home/weiyu/data_drive/data_new_objects" 330 | 331 | assert os.path.exists(args.main_config), "Cannot find config yaml file at {}".format(args.main_config) 332 | assert os.path.exists(args.dirs_config), "Cannot find config yaml file at {}".format(args.dir_config) 333 | 334 | os.environ["DATETIME"] = time.strftime("%Y%m%d-%H%M%S") 335 | 336 | main_cfg = OmegaConf.load(args.main_config) 337 | dirs_cfg = OmegaConf.load(args.dirs_config) 338 | cfg = OmegaConf.merge(main_cfg, dirs_cfg) 339 | cfg.dataset_base_dir = args.dataset_base_dir 340 | OmegaConf.resolve(cfg) 341 | 342 | if not os.path.exists(cfg.experiment_dir): 343 | os.makedirs(cfg.experiment_dir) 344 | 345 | OmegaConf.save(cfg, os.path.join(cfg.experiment_dir, "config.yaml")) 346 | 347 | run_model(cfg) -------------------------------------------------------------------------------- /src/structformer/training/train_object_selection_network.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | 3 | import torch 4 | import torch.optim as optim 5 | from torch.optim import lr_scheduler 6 | import numpy as np 7 | from warmup_scheduler import GradualWarmupScheduler 8 | import time 9 | import os 10 | import tqdm 11 | 12 | import argparse 13 | from omegaconf import OmegaConf 14 | from collections import defaultdict 15 | from sklearn.metrics import classification_report 16 | 17 | from torch.utils.data import DataLoader 18 | from structformer.data.object_set_refer_dataset import ObjectSetReferDataset 19 | from structformer.models.object_selection_network import RearrangeObjectsPredictorPCT 20 | from structformer.data.tokenizer import Tokenizer 21 | 22 | 23 | def evaluate(gts, predictions, keys, debug=True, return_classification_dict=False): 24 | """ 25 | :param gts: expect a list of tensors 26 | :param predictions: expect a list of tensor 27 | :return: 28 | """ 29 | 30 | total_scores = 0 31 | for key in keys: 32 | predictions_for_key = torch.cat(predictions[key], dim=0) 33 | gts_for_key = torch.cat(gts[key], dim=0) 34 | 35 | predicted_classes = predictions_for_key > 0.5 36 | assert len(gts_for_key) == len(predicted_classes) 37 | 38 | target_indices = gts_for_key != -100 39 | 40 | gts_for_key = gts_for_key[target_indices] 41 | predicted_classes = predicted_classes[target_indices] 42 | num_objects = len(predicted_classes) 43 | 44 | if debug: 45 | print(num_objects) 46 | print(gts_for_key.shape) 47 | print(predicted_classes.shape) 48 | print(target_indices.shape) 49 | print("Groundtruths:") 50 | print(gts_for_key[:100]) 51 | print("Predictions") 52 | print(predicted_classes[:100]) 53 | 54 | accuracy = torch.sum(gts_for_key == predicted_classes) / len(gts_for_key) 55 | print("{} objects -- {} accuracy: {}".format(num_objects, key, accuracy)) 56 | total_scores += accuracy 57 | 58 | report = classification_report(gts_for_key.detach().cpu().numpy(), predicted_classes.detach().cpu().numpy(), 59 | output_dict=True) 60 | print(report) 61 | 62 | if not return_classification_dict: 63 | return total_scores 64 | else: 65 | return report 66 | 67 | 68 | def train_model(cfg, model, data_iter, optimizer, warmup, num_epochs, device, save_best_model, grad_clipping=1.0): 69 | 70 | if save_best_model: 71 | best_model_dir = os.path.join(cfg.experiment_dir, "best_model") 72 | print("best model will be saved to {}".format(best_model_dir)) 73 | if not os.path.exists(best_model_dir): 74 | os.makedirs(best_model_dir) 75 | best_score = 0.0 76 | 77 | for epoch in range(num_epochs): 78 | print('Epoch {}/{}'.format(epoch, num_epochs - 1)) 79 | print('-' * 10) 80 | 81 | model.train() 82 | epoch_loss = 0 83 | gts = defaultdict(list) 84 | predictions = defaultdict(list) 85 | 86 | with tqdm.tqdm(total=len(data_iter["train"])) as pbar: 87 | for step, batch in enumerate(data_iter["train"]): 88 | optimizer.zero_grad() 89 | 90 | # input 91 | xyzs = batch["xyzs"].to(device, non_blocking=True) 92 | rgbs = batch["rgbs"].to(device, non_blocking=True) 93 | sentence = batch["sentence"].to(device, non_blocking=True) 94 | object_pad_mask = batch["object_pad_mask"].to(device, non_blocking=True) 95 | sentence_pad_mask = batch["sentence_pad_mask"].to(device, non_blocking=True) 96 | token_type_index = batch["token_type_index"].to(device, non_blocking=True) 97 | position_index = batch["position_index"].to(device, non_blocking=True) 98 | 99 | # output 100 | targets = {} 101 | for key in ["rearrange_obj_labels"]: 102 | targets[key] = batch[key].to(device, non_blocking=True) 103 | 104 | preds = model.forward(xyzs, rgbs, object_pad_mask, sentence, sentence_pad_mask, token_type_index, position_index) 105 | loss = model.criterion(preds, targets) 106 | 107 | loss.backward() 108 | 109 | if grad_clipping != 0.0: 110 | torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clipping) 111 | 112 | optimizer.step() 113 | epoch_loss += loss 114 | 115 | for key in ["rearrange_obj_labels"]: 116 | gts[key].append(targets[key].detach()) 117 | predictions[key].append(preds[key].detach()) 118 | 119 | pbar.update(1) 120 | pbar.set_postfix({"Batch loss": loss}) 121 | 122 | warmup.step() 123 | 124 | print('[Epoch:{}]: Training Loss:{:.4}'.format(epoch, epoch_loss)) 125 | evaluate(gts, predictions, ["rearrange_obj_labels"]) 126 | 127 | score = validate(model, data_iter["valid"], epoch, device) 128 | if save_best_model and score > best_score: 129 | print("Saving best model so far...") 130 | best_score = score 131 | save_model(best_model_dir, cfg, epoch, model) 132 | 133 | return model 134 | 135 | 136 | def validate(model, data_iter, epoch, device): 137 | """ 138 | helper function to evaluate the model 139 | 140 | :param model: 141 | :param data_iter: 142 | :param epoch: 143 | :param device: 144 | :return: 145 | """ 146 | 147 | model.eval() 148 | 149 | epoch_loss = 0 150 | gts = defaultdict(list) 151 | predictions = defaultdict(list) 152 | with torch.no_grad(): 153 | 154 | with tqdm.tqdm(total=len(data_iter)) as pbar: 155 | for step, batch in enumerate(data_iter): 156 | 157 | # input 158 | xyzs = batch["xyzs"].to(device, non_blocking=True) 159 | rgbs = batch["rgbs"].to(device, non_blocking=True) 160 | sentence = batch["sentence"].to(device, non_blocking=True) 161 | object_pad_mask = batch["object_pad_mask"].to(device, non_blocking=True) 162 | sentence_pad_mask = batch["sentence_pad_mask"].to(device, non_blocking=True) 163 | token_type_index = batch["token_type_index"].to(device, non_blocking=True) 164 | position_index = batch["position_index"].to(device, non_blocking=True) 165 | 166 | # output 167 | targets = {} 168 | for key in ["rearrange_obj_labels"]: 169 | targets[key] = batch[key].to(device, non_blocking=True) 170 | 171 | preds = model.forward(xyzs, rgbs, object_pad_mask, sentence, sentence_pad_mask, token_type_index, 172 | position_index) 173 | loss = model.criterion(preds, targets) 174 | 175 | for key in ["rearrange_obj_labels"]: 176 | gts[key].append(targets[key]) 177 | predictions[key].append(preds[key]) 178 | 179 | epoch_loss += loss 180 | pbar.update(1) 181 | pbar.set_postfix({"Batch loss": loss}) 182 | 183 | print('[Epoch:{}]: Val Loss:{:.4}'.format(epoch, epoch_loss)) 184 | 185 | score = evaluate(gts, predictions, ["rearrange_obj_labels"]) 186 | return score 187 | 188 | 189 | def infer_once(model, batch, device): 190 | """ 191 | helper function to evaluate the model 192 | 193 | :param model: 194 | :param data_iter: 195 | :param epoch: 196 | :param device: 197 | :return: 198 | """ 199 | 200 | model.eval() 201 | 202 | gts = defaultdict(list) 203 | predictions = defaultdict(list) 204 | with torch.no_grad(): 205 | 206 | # input 207 | xyzs = batch["xyzs"].to(device, non_blocking=True) 208 | rgbs = batch["rgbs"].to(device, non_blocking=True) 209 | sentence = batch["sentence"].to(device, non_blocking=True) 210 | object_pad_mask = batch["object_pad_mask"].to(device, non_blocking=True) 211 | sentence_pad_mask = batch["sentence_pad_mask"].to(device, non_blocking=True) 212 | token_type_index = batch["token_type_index"].to(device, non_blocking=True) 213 | position_index = batch["position_index"].to(device, non_blocking=True) 214 | 215 | # output 216 | targets = {} 217 | for key in ["rearrange_obj_labels"]: 218 | targets[key] = batch[key].to(device, non_blocking=True) 219 | 220 | preds = model.forward(xyzs, rgbs, object_pad_mask, sentence, sentence_pad_mask, token_type_index, 221 | position_index) 222 | 223 | for key in ["rearrange_obj_labels"]: 224 | gts[key].append(targets[key]) 225 | predictions[key].append(preds[key]) 226 | 227 | return gts, predictions 228 | 229 | 230 | def save_model(model_dir, cfg, epoch, model, optimizer=None, scheduler=None): 231 | state_dict = {'epoch': epoch, 232 | 'model_state_dict': model.state_dict()} 233 | if optimizer is not None: 234 | state_dict["optimizer_state_dict"] = optimizer.state_dict() 235 | if scheduler is not None: 236 | state_dict["scheduler_state_dict"] = scheduler.state_dict() 237 | torch.save(state_dict, os.path.join(model_dir, "model.tar")) 238 | OmegaConf.save(cfg, os.path.join(model_dir, "config.yaml")) 239 | 240 | 241 | def load_model(model_dir, dirs_cfg): 242 | """ 243 | Load transformer model 244 | Important: to use the model, call model.eval() or model.train() 245 | :param model_dir: 246 | :return: 247 | """ 248 | # load dictionaries 249 | cfg = OmegaConf.load(os.path.join(model_dir, "config.yaml")) 250 | if dirs_cfg: 251 | cfg = OmegaConf.merge(cfg, dirs_cfg) 252 | 253 | data_cfg = cfg.dataset 254 | tokenizer = Tokenizer(data_cfg.vocab_dir) 255 | vocab_size = tokenizer.get_vocab_size() 256 | 257 | # initialize model 258 | model_cfg = cfg.model 259 | model = RearrangeObjectsPredictorPCT(vocab_size, 260 | num_attention_heads=model_cfg.num_attention_heads, 261 | encoder_hidden_dim=model_cfg.encoder_hidden_dim, 262 | encoder_dropout=model_cfg.encoder_dropout, 263 | encoder_activation=model_cfg.encoder_activation, 264 | encoder_num_layers=model_cfg.encoder_num_layers, 265 | use_focal_loss=model_cfg.use_focal_loss, 266 | focal_loss_gamma=model_cfg.focal_loss_gamma) 267 | model.to(cfg.device) 268 | 269 | # load state dicts 270 | checkpoint = torch.load(os.path.join(model_dir, "model.tar")) 271 | model.load_state_dict(checkpoint['model_state_dict']) 272 | 273 | optimizer = None 274 | if "optimizer_state_dict" in checkpoint: 275 | training_cfg = cfg.training 276 | optimizer = optim.Adam(model.parameters(), lr=training_cfg.learning_rate) 277 | optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 278 | 279 | scheduler = None 280 | if "scheduler_state_dict" in checkpoint: 281 | scheduler = None 282 | if scheduler: 283 | scheduler.load_state_dict(checkpoint['scheduler_state_dict']) 284 | 285 | epoch = checkpoint['epoch'] 286 | return cfg, tokenizer, model, optimizer, scheduler, epoch 287 | 288 | 289 | def run_model(cfg): 290 | 291 | np.random.seed(cfg.random_seed) 292 | torch.manual_seed(cfg.random_seed) 293 | if torch.cuda.is_available(): 294 | torch.cuda.manual_seed(cfg.random_seed) 295 | torch.cuda.manual_seed_all(cfg.random_seed) 296 | torch.backends.cudnn.deterministic = True 297 | 298 | data_cfg = cfg.dataset 299 | tokenizer = Tokenizer(data_cfg.vocab_dir) 300 | vocab_size = tokenizer.get_vocab_size() 301 | 302 | train_dataset = ObjectSetReferDataset(data_cfg.dirs, data_cfg.index_dirs, "train", tokenizer, 303 | data_cfg.max_num_all_objects, 304 | data_cfg.max_num_shape_parameters, 305 | data_cfg.max_num_rearrange_features, 306 | data_cfg.max_num_anchor_features, 307 | data_cfg.num_pts) 308 | 309 | valid_dataset = ObjectSetReferDataset(data_cfg.dirs, data_cfg.index_dirs, "valid", tokenizer, 310 | data_cfg.max_num_all_objects, 311 | data_cfg.max_num_shape_parameters, 312 | data_cfg.max_num_rearrange_features, 313 | data_cfg.max_num_anchor_features, 314 | data_cfg.num_pts) 315 | 316 | data_iter = {} 317 | data_iter["train"] = DataLoader(train_dataset, batch_size=data_cfg.batch_size, shuffle=True, 318 | num_workers=data_cfg.num_workers, 319 | collate_fn=ObjectSetReferDataset.collate_fn, 320 | pin_memory=data_cfg.pin_memory) 321 | data_iter["valid"] = DataLoader(valid_dataset, batch_size=data_cfg.batch_size, shuffle=False, 322 | num_workers=data_cfg.num_workers, 323 | collate_fn=ObjectSetReferDataset.collate_fn, 324 | pin_memory=data_cfg.pin_memory) 325 | 326 | # load model 327 | model_cfg = cfg.model 328 | model = RearrangeObjectsPredictorPCT(vocab_size, 329 | num_attention_heads=model_cfg.num_attention_heads, 330 | encoder_hidden_dim=model_cfg.encoder_hidden_dim, 331 | encoder_dropout=model_cfg.encoder_dropout, 332 | encoder_activation=model_cfg.encoder_activation, 333 | encoder_num_layers=model_cfg.encoder_num_layers, 334 | use_focal_loss=model_cfg.use_focal_loss, 335 | focal_loss_gamma=model_cfg.focal_loss_gamma) 336 | model.to(cfg.device) 337 | 338 | training_cfg = cfg.training 339 | optimizer = optim.Adam(model.parameters(), lr=training_cfg.learning_rate, weight_decay=training_cfg.l2) 340 | scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=training_cfg.lr_restart) 341 | warmup = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=training_cfg.warmup, 342 | after_scheduler=scheduler) 343 | 344 | train_model(cfg, model, data_iter, optimizer, warmup, training_cfg.max_epochs, cfg.device, cfg.save_best_model) 345 | 346 | # save model 347 | if cfg.save_model: 348 | model_dir = os.path.join(cfg.experiment_dir, "model") 349 | print("Saving model to {}".format(model_dir)) 350 | if not os.path.exists(model_dir): 351 | os.makedirs(model_dir) 352 | save_model(model_dir, cfg, cfg.max_epochs, model, optimizer, scheduler) 353 | 354 | 355 | if __name__ == "__main__": 356 | 357 | parser = argparse.ArgumentParser(description="Run a simple model") 358 | parser.add_argument("--dataset_base_dir", help='location of the dataset', type=str) 359 | parser.add_argument("--main_config", help='config yaml file for the model', 360 | default='../configs/object_selection_network.yaml', 361 | type=str) 362 | parser.add_argument("--dirs_config", help='config yaml file for directories', 363 | default='../configs/data/circle_dirs.yaml', 364 | type=str) 365 | args = parser.parse_args() 366 | 367 | # # debug 368 | # args.dataset_base_dir = "/home/weiyu/data_drive/data_new_objects" 369 | 370 | assert os.path.exists(args.main_config), "Cannot find config yaml file at {}".format(args.main_config) 371 | assert os.path.exists(args.dirs_config), "Cannot find config yaml file at {}".format(args.dir_config) 372 | 373 | os.environ["DATETIME"] = time.strftime("%Y%m%d-%H%M%S") 374 | 375 | main_cfg = OmegaConf.load(args.main_config) 376 | dirs_cfg = OmegaConf.load(args.dirs_config) 377 | cfg = OmegaConf.merge(main_cfg, dirs_cfg) 378 | cfg.dataset_base_dir = args.dataset_base_dir 379 | OmegaConf.resolve(cfg) 380 | 381 | if not os.path.exists(cfg.experiment_dir): 382 | os.makedirs(cfg.experiment_dir) 383 | 384 | OmegaConf.save(cfg, os.path.join(cfg.experiment_dir, "config.yaml")) 385 | 386 | run_model(cfg) -------------------------------------------------------------------------------- /src/structformer/training/train_structformer_no_structure.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | from torch.optim import lr_scheduler 7 | from warmup_scheduler import GradualWarmupScheduler 8 | import numpy as np 9 | import torchvision 10 | from torchvision import datasets, models, transforms 11 | import matplotlib.pyplot as plt 12 | import time 13 | import os 14 | import copy 15 | import tqdm 16 | 17 | import pickle 18 | import argparse 19 | from omegaconf import OmegaConf 20 | from collections import defaultdict 21 | 22 | from torch.utils.data import DataLoader 23 | from structformer.data.sequence_dataset import SequenceDataset 24 | from structformer.models.pose_generation_network import PriorContinuousOutEncoderDecoderPCT6DDropoutAllObjects 25 | from structformer.data.tokenizer import Tokenizer 26 | from structformer.utils.rearrangement import evaluate_prior_prediction, generate_square_subsequent_mask 27 | 28 | 29 | def train_model(cfg, model, data_iter, optimizer, warmup, num_epochs, device, save_best_model, grad_clipping=1.0): 30 | 31 | if save_best_model: 32 | best_model_dir = os.path.join(cfg.experiment_dir, "best_model") 33 | print("best model will be saved to {}".format(best_model_dir)) 34 | if not os.path.exists(best_model_dir): 35 | os.makedirs(best_model_dir) 36 | best_score = -np.inf 37 | 38 | for epoch in range(num_epochs): 39 | print('Epoch {}/{}'.format(epoch, num_epochs - 1)) 40 | print('-' * 10) 41 | 42 | model.train() 43 | epoch_loss = 0 44 | gts = defaultdict(list) 45 | predictions = defaultdict(list) 46 | 47 | with tqdm.tqdm(total=len(data_iter["train"])) as pbar: 48 | for step, batch in enumerate(data_iter["train"]): 49 | optimizer.zero_grad() 50 | # input 51 | xyzs = batch["xyzs"].to(device, non_blocking=True) 52 | rgbs = batch["rgbs"].to(device, non_blocking=True) 53 | object_pad_mask = batch["object_pad_mask"].to(device, non_blocking=True) 54 | other_xyzs = batch["other_xyzs"].to(device, non_blocking=True) 55 | other_rgbs = batch["other_rgbs"].to(device, non_blocking=True) 56 | other_object_pad_mask = batch["other_object_pad_mask"].to(device, non_blocking=True) 57 | sentence = batch["sentence"].to(device, non_blocking=True) 58 | sentence_pad_mask = batch["sentence_pad_mask"].to(device, non_blocking=True) 59 | token_type_index = batch["token_type_index"].to(device, non_blocking=True) 60 | position_index = batch["position_index"].to(device, non_blocking=True) 61 | obj_x_inputs = batch["obj_x_inputs"].to(device, non_blocking=True) 62 | obj_y_inputs = batch["obj_y_inputs"].to(device, non_blocking=True) 63 | obj_z_inputs = batch["obj_z_inputs"].to(device, non_blocking=True) 64 | obj_theta_inputs = batch["obj_theta_inputs"].to(device, non_blocking=True) 65 | 66 | tgt_mask = generate_square_subsequent_mask(object_pad_mask.shape[1]).to(device, non_blocking=True) 67 | start_token = torch.zeros((object_pad_mask.shape[0], 1), dtype=torch.long).to(device, non_blocking=True) 68 | 69 | # output 70 | targets = {} 71 | for key in ["obj_x_outputs", "obj_y_outputs", "obj_z_outputs", "obj_theta_outputs"]: 72 | targets[key] = batch[key].to(device, non_blocking=True) 73 | targets[key] = targets[key].reshape(targets[key].shape[0] * targets[key].shape[1], -1) 74 | 75 | preds = model.forward(xyzs, rgbs, object_pad_mask, other_xyzs, other_rgbs, other_object_pad_mask, 76 | sentence, sentence_pad_mask, token_type_index, 77 | obj_x_inputs, obj_y_inputs, obj_z_inputs, obj_theta_inputs, position_index, 78 | tgt_mask, start_token) 79 | 80 | loss = model.criterion(preds, targets) 81 | loss.backward() 82 | 83 | if grad_clipping != 0.0: 84 | torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clipping) 85 | 86 | optimizer.step() 87 | epoch_loss += loss 88 | 89 | for key in ["obj_x_outputs", "obj_y_outputs", "obj_z_outputs", "obj_theta_outputs"]: 90 | gts[key].append(targets[key].detach()) 91 | predictions[key].append(preds[key].detach()) 92 | 93 | pbar.update(1) 94 | pbar.set_postfix({"Batch loss": loss}) 95 | 96 | warmup.step() 97 | 98 | print('[Epoch:{}]: Training Loss:{:.4}'.format(epoch, epoch_loss)) 99 | evaluate_prior_prediction(gts, predictions, ["obj_x_outputs", "obj_y_outputs", "obj_z_outputs", "obj_theta_outputs"]) 100 | 101 | score = validate(cfg, model, data_iter["valid"], epoch, device) 102 | if save_best_model and score > best_score: 103 | print("Saving best model so far...") 104 | best_score = score 105 | save_model(best_model_dir, cfg, epoch, model) 106 | 107 | return model 108 | 109 | 110 | def validate(cfg, model, data_iter, epoch, device): 111 | """ 112 | helper function to evaluate the model 113 | 114 | :param model: 115 | :param data_iter: 116 | :param epoch: 117 | :param device: 118 | :return: 119 | """ 120 | 121 | model.eval() 122 | 123 | epoch_loss = 0 124 | gts = defaultdict(list) 125 | predictions = defaultdict(list) 126 | with torch.no_grad(): 127 | 128 | with tqdm.tqdm(total=len(data_iter)) as pbar: 129 | for step, batch in enumerate(data_iter): 130 | 131 | # input 132 | xyzs = batch["xyzs"].to(device, non_blocking=True) 133 | rgbs = batch["rgbs"].to(device, non_blocking=True) 134 | object_pad_mask = batch["object_pad_mask"].to(device, non_blocking=True) 135 | other_xyzs = batch["other_xyzs"].to(device, non_blocking=True) 136 | other_rgbs = batch["other_rgbs"].to(device, non_blocking=True) 137 | other_object_pad_mask = batch["other_object_pad_mask"].to(device, non_blocking=True) 138 | sentence = batch["sentence"].to(device, non_blocking=True) 139 | sentence_pad_mask = batch["sentence_pad_mask"].to(device, non_blocking=True) 140 | token_type_index = batch["token_type_index"].to(device, non_blocking=True) 141 | position_index = batch["position_index"].to(device, non_blocking=True) 142 | obj_x_inputs = batch["obj_x_inputs"].to(device, non_blocking=True) 143 | obj_y_inputs = batch["obj_y_inputs"].to(device, non_blocking=True) 144 | obj_z_inputs = batch["obj_z_inputs"].to(device, non_blocking=True) 145 | obj_theta_inputs = batch["obj_theta_inputs"].to(device, non_blocking=True) 146 | 147 | tgt_mask = generate_square_subsequent_mask(object_pad_mask.shape[1]).to(device, non_blocking=True) 148 | start_token = torch.zeros((object_pad_mask.shape[0], 1), dtype=torch.long).to(device, non_blocking=True) 149 | 150 | # output 151 | targets = {} 152 | for key in ["obj_x_outputs", "obj_y_outputs", "obj_z_outputs", "obj_theta_outputs"]: 153 | targets[key] = batch[key].to(device, non_blocking=True) 154 | targets[key] = targets[key].reshape(targets[key].shape[0] * targets[key].shape[1], -1) 155 | 156 | preds = model.forward(xyzs, rgbs, object_pad_mask, other_xyzs, other_rgbs, other_object_pad_mask, 157 | sentence, sentence_pad_mask, token_type_index, 158 | obj_x_inputs, obj_y_inputs, obj_z_inputs, obj_theta_inputs, position_index, 159 | tgt_mask, start_token) 160 | loss = model.criterion(preds, targets) 161 | 162 | for key in ["obj_x_outputs", "obj_y_outputs", "obj_z_outputs", "obj_theta_outputs"]: 163 | gts[key].append(targets[key]) 164 | predictions[key].append(preds[key]) 165 | 166 | epoch_loss += loss 167 | pbar.update(1) 168 | pbar.set_postfix({"Batch loss": loss}) 169 | 170 | print('[Epoch:{}]: Val Loss:{:.4}'.format(epoch, epoch_loss)) 171 | 172 | score = evaluate_prior_prediction(gts, predictions, 173 | ["obj_x_outputs", "obj_y_outputs", "obj_z_outputs", "obj_theta_outputs"]) 174 | return score 175 | 176 | 177 | def infer_once(cfg, model, batch, device): 178 | 179 | model.eval() 180 | 181 | predictions = defaultdict(list) 182 | with torch.no_grad(): 183 | 184 | # input 185 | xyzs = batch["xyzs"].to(device, non_blocking=True) 186 | rgbs = batch["rgbs"].to(device, non_blocking=True) 187 | object_pad_mask = batch["object_pad_mask"].to(device, non_blocking=True) 188 | other_xyzs = batch["other_xyzs"].to(device, non_blocking=True) 189 | other_rgbs = batch["other_rgbs"].to(device, non_blocking=True) 190 | other_object_pad_mask = batch["other_object_pad_mask"].to(device, non_blocking=True) 191 | sentence = batch["sentence"].to(device, non_blocking=True) 192 | sentence_pad_mask = batch["sentence_pad_mask"].to(device, non_blocking=True) 193 | token_type_index = batch["token_type_index"].to(device, non_blocking=True) 194 | position_index = batch["position_index"].to(device, non_blocking=True) 195 | 196 | obj_x_inputs = batch["obj_x_inputs"].to(device, non_blocking=True) 197 | obj_y_inputs = batch["obj_y_inputs"].to(device, non_blocking=True) 198 | obj_z_inputs = batch["obj_z_inputs"].to(device, non_blocking=True) 199 | obj_theta_inputs = batch["obj_theta_inputs"].to(device, non_blocking=True) 200 | 201 | tgt_mask = generate_square_subsequent_mask(object_pad_mask.shape[1]).to(device, non_blocking=True) 202 | start_token = torch.zeros((object_pad_mask.shape[0], 1), dtype=torch.long).to(device, non_blocking=True) 203 | 204 | preds = model.forward(xyzs, rgbs, object_pad_mask, other_xyzs, other_rgbs, other_object_pad_mask, 205 | sentence, sentence_pad_mask, token_type_index, 206 | obj_x_inputs, obj_y_inputs, obj_z_inputs, obj_theta_inputs, position_index, 207 | tgt_mask, start_token) 208 | 209 | for key in ["obj_x_outputs", "obj_y_outputs", "obj_z_outputs", "obj_theta_outputs"]: 210 | predictions[key].append(preds[key]) 211 | 212 | return predictions 213 | 214 | 215 | def save_model(model_dir, cfg, epoch, model, optimizer=None, scheduler=None): 216 | state_dict = {'epoch': epoch, 217 | 'model_state_dict': model.state_dict()} 218 | if optimizer is not None: 219 | state_dict["optimizer_state_dict"] = optimizer.state_dict() 220 | if scheduler is not None: 221 | state_dict["scheduler_state_dict"] = scheduler.state_dict() 222 | torch.save(state_dict, os.path.join(model_dir, "model.tar")) 223 | OmegaConf.save(cfg, os.path.join(model_dir, "config.yaml")) 224 | 225 | 226 | def load_model(model_dir, dirs_cfg): 227 | """ 228 | Load transformer model 229 | Important: to use the model, call model.eval() or model.train() 230 | :param model_dir: 231 | :return: 232 | """ 233 | # load dictionaries 234 | cfg = OmegaConf.load(os.path.join(model_dir, "config.yaml")) 235 | if dirs_cfg: 236 | cfg = OmegaConf.merge(cfg, dirs_cfg) 237 | 238 | data_cfg = cfg.dataset 239 | tokenizer = Tokenizer(data_cfg.vocab_dir) 240 | vocab_size = tokenizer.get_vocab_size() 241 | 242 | # initialize model 243 | model_cfg = cfg.model 244 | model = PriorContinuousOutEncoderDecoderPCT6DDropoutAllObjects(vocab_size, 245 | num_attention_heads=model_cfg.num_attention_heads, 246 | encoder_hidden_dim=model_cfg.encoder_hidden_dim, 247 | encoder_dropout=model_cfg.encoder_dropout, 248 | encoder_activation=model_cfg.encoder_activation, 249 | encoder_num_layers=model_cfg.encoder_num_layers, 250 | object_dropout=model_cfg.object_dropout, 251 | theta_loss_divide=model_cfg.theta_loss_divide, 252 | ignore_rgb=model_cfg.ignore_rgb) 253 | model.to(cfg.device) 254 | 255 | # load state dicts 256 | checkpoint = torch.load(os.path.join(model_dir, "model.tar")) 257 | model.load_state_dict(checkpoint['model_state_dict']) 258 | 259 | optimizer = None 260 | if "optimizer_state_dict" in checkpoint: 261 | training_cfg = cfg.training 262 | optimizer = optim.Adam(model.parameters(), lr=training_cfg.learning_rate) 263 | optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 264 | 265 | scheduler = None 266 | if "scheduler_state_dict" in checkpoint: 267 | scheduler = None 268 | if scheduler: 269 | scheduler.load_state_dict(checkpoint['scheduler_state_dict']) 270 | 271 | epoch = checkpoint['epoch'] 272 | return cfg, tokenizer, model, optimizer, scheduler, epoch 273 | 274 | 275 | def run_model(cfg): 276 | 277 | np.random.seed(cfg.random_seed) 278 | torch.manual_seed(cfg.random_seed) 279 | if torch.cuda.is_available(): 280 | torch.cuda.manual_seed(cfg.random_seed) 281 | torch.cuda.manual_seed_all(cfg.random_seed) 282 | torch.backends.cudnn.deterministic = True 283 | 284 | data_cfg = cfg.dataset 285 | tokenizer = Tokenizer(data_cfg.vocab_dir) 286 | vocab_size = tokenizer.get_vocab_size() 287 | 288 | train_dataset = SequenceDataset(data_cfg.dirs, data_cfg.index_dirs, "train", tokenizer, 289 | data_cfg.max_num_objects, 290 | data_cfg.max_num_other_objects, 291 | data_cfg.max_num_shape_parameters, 292 | data_cfg.max_num_rearrange_features, 293 | data_cfg.max_num_anchor_features, 294 | data_cfg.num_pts, 295 | data_cfg.use_structure_frame) 296 | valid_dataset = SequenceDataset(data_cfg.dirs, data_cfg.index_dirs, "valid", tokenizer, 297 | data_cfg.max_num_objects, 298 | data_cfg.max_num_other_objects, 299 | data_cfg.max_num_shape_parameters, 300 | data_cfg.max_num_rearrange_features, 301 | data_cfg.max_num_anchor_features, 302 | data_cfg.num_pts, 303 | data_cfg.use_structure_frame) 304 | 305 | data_iter = {} 306 | data_iter["train"] = DataLoader(train_dataset, batch_size=data_cfg.batch_size, shuffle=True, 307 | collate_fn=SequenceDataset.collate_fn, 308 | pin_memory=data_cfg.pin_memory, num_workers=data_cfg.num_workers) 309 | data_iter["valid"] = DataLoader(valid_dataset, batch_size=data_cfg.batch_size, shuffle=False, 310 | collate_fn=SequenceDataset.collate_fn, 311 | pin_memory=data_cfg.pin_memory, num_workers=data_cfg.num_workers) 312 | 313 | # load model 314 | model_cfg = cfg.model 315 | model = PriorContinuousOutEncoderDecoderPCT6DDropoutAllObjects(vocab_size, 316 | num_attention_heads=model_cfg.num_attention_heads, 317 | encoder_hidden_dim=model_cfg.encoder_hidden_dim, 318 | encoder_dropout=model_cfg.encoder_dropout, 319 | encoder_activation=model_cfg.encoder_activation, 320 | encoder_num_layers=model_cfg.encoder_num_layers, 321 | object_dropout=model_cfg.object_dropout, 322 | theta_loss_divide=model_cfg.theta_loss_divide, 323 | ignore_rgb=model_cfg.ignore_rgb) 324 | model.to(cfg.device) 325 | 326 | training_cfg = cfg.training 327 | optimizer = optim.Adam(model.parameters(), lr=training_cfg.learning_rate, weight_decay=training_cfg.l2) 328 | scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=training_cfg.lr_restart) 329 | warmup = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=training_cfg.warmup, 330 | after_scheduler=scheduler) 331 | 332 | train_model(cfg, model, data_iter, optimizer, warmup, training_cfg.max_epochs, cfg.device, cfg.save_best_model) 333 | 334 | # save model 335 | if cfg.save_model: 336 | model_dir = os.path.join(cfg.experiment_dir, "model") 337 | print("Saving model to {}".format(model_dir)) 338 | if not os.path.exists(model_dir): 339 | os.makedirs(model_dir) 340 | save_model(model_dir, cfg, cfg.max_epochs, model, optimizer, scheduler) 341 | 342 | 343 | if __name__ == "__main__": 344 | 345 | parser = argparse.ArgumentParser(description="Run a simple model") 346 | parser.add_argument("--dataset_base_dir", help='location of the dataset', type=str) 347 | parser.add_argument("--main_config", help='config yaml file for the model', 348 | default='../configs/structformer_no_structure.yaml', 349 | type=str) 350 | parser.add_argument("--dirs_config", help='config yaml file for directories', 351 | default='../configs/data/circle_dirs.yaml', 352 | type=str) 353 | args = parser.parse_args() 354 | 355 | # # debug 356 | # args.dataset_base_dir = "/home/weiyu/data_drive/data_new_objects" 357 | 358 | assert os.path.exists(args.main_config), "Cannot find config yaml file at {}".format(args.main_config) 359 | assert os.path.exists(args.dirs_config), "Cannot find config yaml file at {}".format(args.dir_config) 360 | 361 | os.environ["DATETIME"] = time.strftime("%Y%m%d-%H%M%S") 362 | 363 | main_cfg = OmegaConf.load(args.main_config) 364 | dirs_cfg = OmegaConf.load(args.dirs_config) 365 | cfg = OmegaConf.merge(main_cfg, dirs_cfg) 366 | cfg.dataset_base_dir = args.dataset_base_dir 367 | OmegaConf.resolve(cfg) 368 | 369 | if not os.path.exists(cfg.experiment_dir): 370 | os.makedirs(cfg.experiment_dir) 371 | 372 | OmegaConf.save(cfg, os.path.join(cfg.experiment_dir, "config.yaml")) 373 | 374 | run_model(cfg) -------------------------------------------------------------------------------- /src/structformer/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wliu88/StructFormer/a81d8c83313f1f24a75ca83aa9dd9f8cfcf27419/src/structformer/utils/__init__.py -------------------------------------------------------------------------------- /src/structformer/utils/brain2/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wliu88/StructFormer/a81d8c83313f1f24a75ca83aa9dd9f8cfcf27419/src/structformer/utils/brain2/__init__.py -------------------------------------------------------------------------------- /src/structformer/utils/brain2/camera.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | 10 | from __future__ import print_function 11 | 12 | import numpy as np 13 | import open3d 14 | import trimesh 15 | 16 | import structformer.utils.transformations as tra 17 | from structformer.utils.brain2.pose import make_pose 18 | 19 | 20 | def get_camera_from_h5(h5): 21 | """ Simple reference to help make these """ 22 | proj_near = h5['cam_near'][()] 23 | proj_far = h5['cam_far'][()] 24 | proj_fov = h5['cam_fov'][()] 25 | width = h5['cam_width'][()] 26 | height = h5['cam_height'][()] 27 | return GenericCameraReference(proj_near, proj_far, proj_fov, width, height) 28 | 29 | 30 | class GenericCameraReference(object): 31 | """ Class storing camera information and providing easy image capture """ 32 | 33 | def __init__(self, proj_near=0.01, proj_far=5., proj_fov=60., img_width=640, 34 | img_height=480): 35 | 36 | self.proj_near = proj_near 37 | self.proj_far = proj_far 38 | self.proj_fov = proj_fov 39 | self.img_width = img_width 40 | self.img_height = img_height 41 | self.x_offset = self.img_width / 2. 42 | self.y_offset = self.img_height / 2. 43 | 44 | # Compute focal params 45 | aspect_ratio = self.img_width / self.img_height 46 | e = 1 / (np.tan(np.radians(self.proj_fov/2.))) 47 | t = self.proj_near / e 48 | b = -t 49 | r = t * aspect_ratio 50 | l = -r 51 | # pixels per meter 52 | alpha = self.img_width / (r-l) 53 | self.focal_length = self.proj_near * alpha 54 | self.fx = self.focal_length 55 | self.fy = self.focal_length 56 | self.pose = None 57 | self.inv_pose = None 58 | 59 | def set_pose(self, trans, rot): 60 | self.pose = make_pose(trans, rot) 61 | self.inv_pose = tra.inverse_matrix(self.pose) 62 | 63 | def set_pose_matrix(self, matrix): 64 | self.pose = matrix 65 | self.inv_pose = tra.inverse_matrix(matrix) 66 | 67 | def transform_to_world_coords(self, xyz): 68 | """ transform xyz into world coordinates """ 69 | #cam_pose = tra.inverse_matrix(self.pose).dot(tra.euler_matrix(np.pi, 0, 0)) 70 | #xyz = trimesh.transform_points(xyz, self.inv_pose) 71 | #xyz = trimesh.transform_points(xyz, cam_pose) 72 | #pose = tra.euler_matrix(np.pi, 0, 0) @ self.pose 73 | pose = self.pose 74 | xyz = trimesh.transform_points(xyz, pose) 75 | return xyz 76 | 77 | def get_camera_presets(): 78 | return [ 79 | "n/a", 80 | "azure_depth_nfov", 81 | "realsense", 82 | "azure_720p", 83 | "simple256", 84 | "simple512", 85 | ] 86 | 87 | 88 | def get_camera_preset(name): 89 | 90 | if name == "azure_depth_nfov": 91 | # Setting for depth camera is pretty different from RGB 92 | height, width, fov = 576, 640, 75 93 | if name == "azure_720p": 94 | # This is actually the 720p RGB setting 95 | # Used for our color camera most of the time 96 | #height, width, fov = 720, 1280, 90 97 | height, width, fov = 720, 1280, 60 98 | elif name == "realsense": 99 | height, width, fov = 480, 640, 60 100 | elif name == "simple256": 101 | height, width, fov = 256, 256, 60 102 | elif name == "simple512": 103 | height, width, fov = 512, 512, 60 104 | else: 105 | raise RuntimeError(('camera "%s" not supported, choose from: ' + 106 | str(get_camera_presets())) % str(name)) 107 | return height, width, fov 108 | 109 | 110 | def get_generic_camera(name): 111 | h, w, fov = get_camera_preset(name) 112 | return GenericCameraReference(img_height=h, img_width=w, proj_fov=fov) 113 | 114 | 115 | def get_matrix_of_indices(height, width): 116 | """ Get indices """ 117 | return np.indices((height, width), dtype=np.float32).transpose(1,2,0) 118 | 119 | # -------------------------------------------------------- 120 | # NOTE: this code taken from Arsalan and modified 121 | def compute_xyz(depth_img, camera, visualize_xyz=False, 122 | xmap=None, ymap=None, max_clip_depth=5): 123 | """ Compute xyz image from depth for a camera """ 124 | 125 | # We need thes eparameters 126 | height = camera.img_height 127 | width = camera.img_width 128 | assert depth_img.shape[0] == camera.img_height 129 | assert depth_img.shape[1] == camera.img_width 130 | fx = camera.fx 131 | fy = camera.fy 132 | cx = camera.x_offset 133 | cy = camera.y_offset 134 | 135 | """ 136 | # Create the matrix of parameters 137 | indices = np.indices((height, width), dtype=np.float32).transpose(1,2,0) 138 | # pixel indices start at top-left corner. for these equations, it starts at bottom-left 139 | # indices[..., 0] = np.flipud(indices[..., 0]) 140 | z_e = depth_img 141 | x_e = (indices[..., 1] - x_offset) * z_e / fx 142 | y_e = (indices[..., 0] - y_offset) * z_e / fy 143 | xyz_img = np.stack([x_e, y_e, z_e], axis=-1) # Shape: [H x W x 3] 144 | """ 145 | 146 | height = depth_img.shape[0] 147 | width = depth_img.shape[1] 148 | input_x = np.arange(width) 149 | input_y = np.arange(height) 150 | input_x, input_y = np.meshgrid(input_x, input_y) 151 | input_x = input_x.flatten() 152 | input_y = input_y.flatten() 153 | input_z = depth_img.flatten() 154 | # clip points that are farther than max distance 155 | input_z[input_z > max_clip_depth] = 0 156 | output_x = (input_x * input_z - cx * input_z) / fx 157 | output_y = (input_y * input_z - cy * input_z) / fy 158 | raw_pc = np.stack([output_x, output_y, input_z], -1).reshape( 159 | height, width, 3 160 | ) 161 | return raw_pc 162 | 163 | if visualize_xyz: 164 | unordered_pc = xyz_img.reshape(-1, 3) 165 | pcd = open3d.geometry.PointCloud() 166 | pcd.points = open3d.utility.Vector3dVector(unordered_pc) 167 | pcd.transform([[1,0,0,0], [0,1,0,0], [0,0,-1,0], [0,0,0,1]]) # Transform it so it's not upside down 168 | open3d.visualization.draw_geometries([pcd]) 169 | 170 | return xyz_img 171 | 172 | def show_pcs(xyz, rgb): 173 | """ Display point clouds """ 174 | if len(xyz.shape) > 2: 175 | unordered_pc = xyz.reshape(-1, 3) 176 | unordered_rgb = rgb.reshape(-1, 3) / 255. 177 | assert(unordered_rgb.shape[0] == unordered_pc.shape[0]) 178 | assert(unordered_pc.shape[1] == 3) 179 | assert(unordered_rgb.shape[1] == 3) 180 | pcd = open3d.geometry.PointCloud() 181 | pcd.points = open3d.utility.Vector3dVector(unordered_pc) 182 | pcd.colors = open3d.utility.Vector3dVector(unordered_rgb) 183 | pcd.transform([[1,0,0,0],[0,1,0,0],[0,0,-1,0],[0,0,0,1]]) # Transform it so it's not upside down 184 | open3d.visualization.draw_geometries([pcd]) 185 | -------------------------------------------------------------------------------- /src/structformer/utils/brain2/image.py: -------------------------------------------------------------------------------- 1 | """ 2 | By Chris Paxton. 3 | 4 | Copyright (c) 2018, Johns Hopkins University 5 | All rights reserved. 6 | 7 | Redistribution and use in source and binary forms, with or without 8 | modification, are permitted provided that the following conditions are met: 9 | * Redistributions of source code must retain the above copyright 10 | notice, this list of conditions and the following disclaimer. 11 | * Redistributions in binary form must reproduce the above copyright 12 | notice, this list of conditions and the following disclaimer in the 13 | documentation and/or other materials provided with the distribution. 14 | * Neither the name of the Johns Hopkins University nor the 15 | names of its contributors may be used to endorse or promote products 16 | derived from this software without specific prior written permission. 17 | 18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 19 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 20 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | DISCLAIMED. IN NO EVENT SHALL JOHNS HOPKINS UNIVERSITY BE LIABLE FOR ANY 22 | DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 23 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 24 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 25 | ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 26 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 27 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | """ 29 | 30 | import numpy as np 31 | import io 32 | from PIL import Image 33 | 34 | def GetJpeg(img): 35 | ''' 36 | Save a numpy array as a Jpeg, then get it out as a binary blob 37 | ''' 38 | im = Image.fromarray(np.uint8(img)) 39 | output = io.BytesIO() 40 | im.save(output, format="JPEG", quality=80) 41 | return output.getvalue() 42 | 43 | def JpegToNumpy(jpeg): 44 | stream = io.BytesIO(jpeg) 45 | im = Image.open(stream) 46 | return np.asarray(im, dtype=np.uint8) 47 | 48 | def ConvertJpegListToNumpy(data): 49 | length = len(data) 50 | imgs = [] 51 | for raw in data: 52 | imgs.append(JpegToNumpy(raw)) 53 | arr = np.array(imgs) 54 | return arr 55 | 56 | def DepthToZBuffer(img, z_near, z_far): 57 | real_depth = z_near * z_far / (z_far - img * (z_far - z_near)) 58 | return real_depth 59 | 60 | def ZBufferToRGB(img, z_near, z_far): 61 | real_depth = z_near * z_far / (z_far - img * (z_far - z_near)) 62 | depth_m = np.uint8(real_depth) 63 | depth_cm = np.uint8((real_depth-depth_m)*100) 64 | depth_tmm = np.uint8((real_depth-depth_m-0.01*depth_cm)*10000) 65 | return np.dstack([depth_m, depth_cm, depth_tmm]) 66 | 67 | def RGBToDepth(img, min_dist=0., max_dist=2.,): 68 | return (img[:,:,0]+.01*img[:,:,1]+.0001*img[:,:,2]).clip(min_dist, max_dist) 69 | #return img[:,:,0]+.01*img[:,:,1]+.0001*img[:,:,2] 70 | 71 | def MaskToRGBA(img): 72 | buf = img.astype(np.int32) 73 | A = buf.astype(np.uint8) 74 | buf = np.right_shift(buf, 8) 75 | B = buf.astype(np.uint8) 76 | buf = np.right_shift(buf, 8) 77 | G = buf.astype(np.uint8) 78 | buf = np.right_shift(buf, 8) 79 | R = buf.astype(np.uint8) 80 | 81 | dims = [np.expand_dims(d, -1) for d in [R,G,B,A]] 82 | return np.concatenate(dims, axis=-1) 83 | 84 | def RGBAToMask(img): 85 | mask = np.zeros(img.shape[:-1], dtype=np.int32) 86 | buf = img.astype(np.int32) 87 | for i, dim in enumerate([3,2,1,0]): 88 | shift = 8*i 89 | #print(i, dim, shift, buf[0,0,dim], np.left_shift(buf[0,0,dim], shift)) 90 | mask += np.left_shift(buf[:,:, dim], shift) 91 | return mask 92 | 93 | def RGBAArrayToMasks(img): 94 | mask = np.zeros(img.shape[:-1], dtype=np.int32) 95 | buf = img.astype(np.int32) 96 | for i, dim in enumerate([3,2,1,0]): 97 | shift = 8*i 98 | mask += np.left_shift(buf[:,:,:, dim], shift) 99 | return mask 100 | 101 | def GetPNG(img): 102 | ''' 103 | Save a numpy array as a PNG, then get it out as a binary blob 104 | ''' 105 | im = Image.fromarray(np.uint8(img)) 106 | output = io.BytesIO() 107 | im.save(output, format="PNG")#, quality=80) 108 | return output.getvalue() 109 | 110 | def PNGToNumpy(png): 111 | stream = io.BytesIO(png) 112 | im = Image.open(stream) 113 | return np.array(im, dtype=np.uint8) 114 | 115 | def ConvertPNGListToNumpy(data): 116 | length = len(data) 117 | imgs = [] 118 | for raw in data: 119 | imgs.append(PNGToNumpy(raw)) 120 | arr = np.array(imgs) 121 | return arr 122 | 123 | def ConvertDepthPNGListToNumpy(data): 124 | length = len(data) 125 | imgs = [] 126 | for raw in data: 127 | imgs.append(RGBToDepth(PNGToNumpy(raw))) 128 | arr = np.array(imgs) 129 | return arr 130 | 131 | import cv2 132 | def Shrink(img, nw=64): 133 | h,w = img.shape[:2] 134 | ratio = float(nw) / w 135 | nh = int(ratio * h) 136 | img2 = cv2.resize(img, dsize=(nw, nh), 137 | interpolation=cv2.INTER_NEAREST) 138 | return img2 139 | 140 | def ShrinkSmooth(img, nw=64): 141 | h,w = img.shape[:2] 142 | ratio = float(nw) / w 143 | nh = int(ratio * h) 144 | img2 = cv2.resize(img, dsize=(nw, nh), 145 | interpolation=cv2.INTER_LINEAR) 146 | return img2 147 | 148 | def CropCenter(img, cropx, cropy): 149 | y = img.shape[0] 150 | x = img.shape[1] 151 | startx = (x // 2) - (cropx // 2) 152 | starty = (y // 2) - (cropy // 2) 153 | return img[starty: starty + cropy, startx : startx + cropx] 154 | 155 | -------------------------------------------------------------------------------- /src/structformer/utils/brain2/pose.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | 10 | from __future__ import print_function 11 | 12 | import structformer.utils.transformations as tra 13 | 14 | 15 | def make_pose(trans, rot): 16 | """Make 4x4 matrix from (trans, rot)""" 17 | pose = tra.quaternion_matrix(rot) 18 | pose[:3, 3] = trans 19 | return pose 20 | -------------------------------------------------------------------------------- /src/structformer/utils/pointnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from time import time 5 | import numpy as np 6 | 7 | 8 | # reference https://github.com/yanx27/Pointnet_Pointnet2_pytorch, modified by Yang You 9 | 10 | 11 | def timeit(tag, t): 12 | print("{}: {}s".format(tag, time() - t)) 13 | return time() 14 | 15 | def pc_normalize(pc): 16 | if type(pc).__module__ == np.__name__: 17 | centroid = np.mean(pc, axis=0) 18 | pc = pc - centroid 19 | m = np.max(np.sqrt(np.sum(pc**2, axis=1))) 20 | pc = pc / m 21 | else: 22 | centroid = torch.mean(pc, dim=0) 23 | pc = pc - centroid 24 | m = torch.max(torch.sqrt(torch.sum(pc ** 2, dim=1))) 25 | pc = pc / m 26 | return pc 27 | 28 | def square_distance(src, dst): 29 | """ 30 | Calculate Euclid distance between each two points. 31 | src^T * dst = xn * xm + yn * ym + zn * zm; 32 | sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn; 33 | sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm; 34 | dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2 35 | = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst 36 | Input: 37 | src: source points, [B, N, C] 38 | dst: target points, [B, M, C] 39 | Output: 40 | dist: per-point square distance, [B, N, M] 41 | """ 42 | return torch.sum((src[:, :, None] - dst[:, None]) ** 2, dim=-1) 43 | 44 | 45 | def index_points(points, idx): 46 | """ 47 | Input: 48 | points: input points data, [B, N, C] 49 | idx: sample index data, [B, S, [K]] 50 | Return: 51 | new_points:, indexed points data, [B, S, [K], C] 52 | """ 53 | raw_size = idx.size() 54 | idx = idx.reshape(raw_size[0], -1) 55 | res = torch.gather(points, 1, idx[..., None].expand(-1, -1, points.size(-1))) 56 | return res.reshape(*raw_size, -1) 57 | 58 | 59 | def farthest_point_sample(xyz, npoint): 60 | """ 61 | Input: 62 | xyz: pointcloud data, [B, N, 3] 63 | npoint: number of samples 64 | Return: 65 | centroids: sampled pointcloud index, [B, npoint] 66 | """ 67 | device = xyz.device 68 | B, N, C = xyz.shape 69 | centroids = torch.zeros(B, npoint, dtype=torch.long).to(device) 70 | distance = torch.ones(B, N).to(device) * 1e10 71 | farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device) 72 | batch_indices = torch.arange(B, dtype=torch.long).to(device) 73 | for i in range(npoint): 74 | centroids[:, i] = farthest 75 | centroid = xyz[batch_indices, farthest, :].view(B, 1, 3) 76 | dist = torch.sum((xyz - centroid) ** 2, -1) 77 | distance = torch.min(distance, dist) 78 | farthest = torch.max(distance, -1)[1] 79 | return centroids 80 | 81 | 82 | def query_ball_point(radius, nsample, xyz, new_xyz): 83 | """ 84 | Input: 85 | radius: local region radius 86 | nsample: max sample number in local region 87 | xyz: all points, [B, N, 3] 88 | new_xyz: query points, [B, S, 3] 89 | Return: 90 | group_idx: grouped points index, [B, S, nsample] 91 | """ 92 | device = xyz.device 93 | B, N, C = xyz.shape 94 | _, S, _ = new_xyz.shape 95 | group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1]) 96 | sqrdists = square_distance(new_xyz, xyz) 97 | group_idx[sqrdists > radius ** 2] = N 98 | group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample] 99 | group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample]) 100 | mask = group_idx == N 101 | group_idx[mask] = group_first[mask] 102 | return group_idx 103 | 104 | 105 | def sample_and_group(npoint, radius, nsample, xyz, points, returnfps=False, knn=False): 106 | """ 107 | Input: 108 | npoint: 109 | radius: 110 | nsample: 111 | xyz: input points position data, [B, N, 3] 112 | points: input points data, [B, N, D] 113 | Return: 114 | new_xyz: sampled points position data, [B, npoint, nsample, 3] 115 | new_points: sampled points data, [B, npoint, nsample, 3+D] 116 | """ 117 | B, N, C = xyz.shape 118 | S = npoint 119 | fps_idx = farthest_point_sample(xyz, npoint) # [B, npoint] 120 | torch.cuda.empty_cache() 121 | new_xyz = index_points(xyz, fps_idx) 122 | torch.cuda.empty_cache() 123 | if knn: 124 | dists = square_distance(new_xyz, xyz) # B x npoint x N 125 | idx = dists.argsort()[:, :, :nsample] # B x npoint x K 126 | else: 127 | idx = query_ball_point(radius, nsample, xyz, new_xyz) 128 | torch.cuda.empty_cache() 129 | grouped_xyz = index_points(xyz, idx) # [B, npoint, nsample, C] 130 | torch.cuda.empty_cache() 131 | grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C) 132 | torch.cuda.empty_cache() 133 | 134 | if points is not None: 135 | grouped_points = index_points(points, idx) 136 | new_points = torch.cat([grouped_xyz_norm, grouped_points], dim=-1) # [B, npoint, nsample, C+D] 137 | else: 138 | new_points = grouped_xyz_norm 139 | if returnfps: 140 | return new_xyz, new_points, grouped_xyz, fps_idx 141 | else: 142 | return new_xyz, new_points 143 | 144 | 145 | def sample_and_group_all(xyz, points): 146 | """ 147 | Input: 148 | xyz: input points position data, [B, N, 3] 149 | points: input points data, [B, N, D] 150 | Return: 151 | new_xyz: sampled points position data, [B, 1, 3] 152 | new_points: sampled points data, [B, 1, N, 3+D] 153 | """ 154 | device = xyz.device 155 | B, N, C = xyz.shape 156 | new_xyz = torch.zeros(B, 1, C).to(device) 157 | grouped_xyz = xyz.view(B, 1, N, C) 158 | if points is not None: 159 | new_points = torch.cat([grouped_xyz, points.view(B, 1, N, -1)], dim=-1) 160 | else: 161 | new_points = grouped_xyz 162 | return new_xyz, new_points 163 | 164 | 165 | class PointNetSetAbstraction(nn.Module): 166 | def __init__(self, npoint, radius, nsample, in_channel, mlp, group_all, knn=False): 167 | super(PointNetSetAbstraction, self).__init__() 168 | self.npoint = npoint 169 | self.radius = radius 170 | self.nsample = nsample 171 | self.knn = knn 172 | self.mlp_convs = nn.ModuleList() 173 | self.mlp_bns = nn.ModuleList() 174 | last_channel = in_channel 175 | for out_channel in mlp: 176 | self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1)) 177 | self.mlp_bns.append(nn.BatchNorm2d(out_channel)) 178 | last_channel = out_channel 179 | self.group_all = group_all 180 | 181 | def forward(self, xyz, points): 182 | """ 183 | Input: 184 | xyz: input points position data, [B, N, C] 185 | points: input points data, [B, N, C] 186 | Return: 187 | new_xyz: sampled points position data, [B, S, C] 188 | new_points_concat: sample points feature data, [B, S, D'] 189 | """ 190 | if self.group_all: 191 | new_xyz, new_points = sample_and_group_all(xyz, points) 192 | else: 193 | new_xyz, new_points = sample_and_group(self.npoint, self.radius, self.nsample, xyz, points, knn=self.knn) 194 | # new_xyz: sampled points position data, [B, npoint, C] 195 | # new_points: sampled points data, [B, npoint, nsample, C+D] 196 | new_points = new_points.permute(0, 3, 2, 1) # [B, C+D, nsample,npoint] 197 | for i, conv in enumerate(self.mlp_convs): 198 | bn = self.mlp_bns[i] 199 | new_points = F.relu(bn(conv(new_points))) 200 | 201 | new_points = torch.max(new_points, 2)[0].transpose(1, 2) 202 | return new_xyz, new_points 203 | 204 | 205 | class PointNetSetAbstractionMsg(nn.Module): 206 | def __init__(self, npoint, radius_list, nsample_list, in_channel, mlp_list, knn=False): 207 | super(PointNetSetAbstractionMsg, self).__init__() 208 | self.npoint = npoint 209 | self.radius_list = radius_list 210 | self.nsample_list = nsample_list 211 | self.knn = knn 212 | self.conv_blocks = nn.ModuleList() 213 | self.bn_blocks = nn.ModuleList() 214 | for i in range(len(mlp_list)): 215 | convs = nn.ModuleList() 216 | bns = nn.ModuleList() 217 | last_channel = in_channel + 3 218 | for out_channel in mlp_list[i]: 219 | convs.append(nn.Conv2d(last_channel, out_channel, 1)) 220 | bns.append(nn.BatchNorm2d(out_channel)) 221 | last_channel = out_channel 222 | self.conv_blocks.append(convs) 223 | self.bn_blocks.append(bns) 224 | 225 | def forward(self, xyz, points, seed_idx=None): 226 | """ 227 | Input: 228 | xyz: input points position data, [B, C, N] 229 | points: input points data, [B, D, N] 230 | Return: 231 | new_xyz: sampled points position data, [B, C, S] 232 | new_points_concat: sample points feature data, [B, D', S] 233 | """ 234 | 235 | B, N, C = xyz.shape 236 | S = self.npoint 237 | new_xyz = index_points(xyz, farthest_point_sample(xyz, S) if seed_idx is None else seed_idx) 238 | new_points_list = [] 239 | for i, radius in enumerate(self.radius_list): 240 | K = self.nsample_list[i] 241 | if self.knn: 242 | dists = square_distance(new_xyz, xyz) # B x npoint x N 243 | group_idx = dists.argsort()[:, :, :K] # B x npoint x K 244 | else: 245 | group_idx = query_ball_point(radius, K, xyz, new_xyz) 246 | grouped_xyz = index_points(xyz, group_idx) 247 | grouped_xyz -= new_xyz.view(B, S, 1, C) 248 | if points is not None: 249 | grouped_points = index_points(points, group_idx) 250 | grouped_points = torch.cat([grouped_points, grouped_xyz], dim=-1) 251 | else: 252 | grouped_points = grouped_xyz 253 | 254 | grouped_points = grouped_points.permute(0, 3, 2, 1) # [B, D, K, S] 255 | for j in range(len(self.conv_blocks[i])): 256 | conv = self.conv_blocks[i][j] 257 | bn = self.bn_blocks[i][j] 258 | grouped_points = F.relu(bn(conv(grouped_points))) 259 | new_points = torch.max(grouped_points, 2)[0] # [B, D', S] 260 | new_points_list.append(new_points) 261 | 262 | new_points_concat = torch.cat(new_points_list, dim=1).transpose(1, 2) 263 | return new_xyz, new_points_concat 264 | 265 | 266 | # NoteL this function swaps N and C 267 | class PointNetFeaturePropagation(nn.Module): 268 | def __init__(self, in_channel, mlp): 269 | super(PointNetFeaturePropagation, self).__init__() 270 | self.mlp_convs = nn.ModuleList() 271 | self.mlp_bns = nn.ModuleList() 272 | last_channel = in_channel 273 | for out_channel in mlp: 274 | self.mlp_convs.append(nn.Conv1d(last_channel, out_channel, 1)) 275 | self.mlp_bns.append(nn.BatchNorm1d(out_channel)) 276 | last_channel = out_channel 277 | 278 | def forward(self, xyz1, xyz2, points1, points2): 279 | """ 280 | Input: 281 | xyz1: input points position data, [B, C, N] 282 | xyz2: sampled input points position data, [B, C, S] 283 | points1: input points data, [B, D, N] 284 | points2: input points data, [B, D, S] 285 | Return: 286 | new_points: upsampled points data, [B, D', N] 287 | """ 288 | xyz1 = xyz1.permute(0, 2, 1) 289 | xyz2 = xyz2.permute(0, 2, 1) 290 | 291 | points2 = points2.permute(0, 2, 1) 292 | B, N, C = xyz1.shape 293 | _, S, _ = xyz2.shape 294 | 295 | if S == 1: 296 | interpolated_points = points2.repeat(1, N, 1) 297 | else: 298 | dists = square_distance(xyz1, xyz2) 299 | dists, idx = dists.sort(dim=-1) 300 | dists, idx = dists[:, :, :3], idx[:, :, :3] # [B, N, 3] 301 | 302 | dist_recip = 1.0 / (dists + 1e-8) 303 | norm = torch.sum(dist_recip, dim=2, keepdim=True) 304 | weight = dist_recip / norm 305 | interpolated_points = torch.sum(index_points(points2, idx) * weight.view(B, N, 3, 1), dim=2) 306 | 307 | if points1 is not None: 308 | points1 = points1.permute(0, 2, 1) 309 | new_points = torch.cat([points1, interpolated_points], dim=-1) 310 | else: 311 | new_points = interpolated_points 312 | 313 | new_points = new_points.permute(0, 2, 1) 314 | for i, conv in enumerate(self.mlp_convs): 315 | bn = self.mlp_bns[i] 316 | new_points = F.relu(bn(conv(new_points))) 317 | return new_points -------------------------------------------------------------------------------- /src/structformer/utils/rotation_continuity.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | import numpy as np 5 | 6 | # Code adapted from the rotation continuity repo (https://github.com/papagina/RotationContinuity) 7 | 8 | #T_poses num*3 9 | #r_matrix batch*3*3 10 | def compute_pose_from_rotation_matrix(T_pose, r_matrix): 11 | batch=r_matrix.shape[0] 12 | joint_num = T_pose.shape[0] 13 | r_matrices = r_matrix.view(batch,1, 3,3).expand(batch,joint_num, 3,3).contiguous().view(batch*joint_num,3,3) 14 | src_poses = T_pose.view(1,joint_num,3,1).expand(batch,joint_num,3,1).contiguous().view(batch*joint_num,3,1) 15 | 16 | out_poses = torch.matmul(r_matrices, src_poses) #(batch*joint_num)*3*1 17 | 18 | return out_poses.view(batch, joint_num, 3) 19 | 20 | # batch*n 21 | def normalize_vector( v, return_mag =False): 22 | batch=v.shape[0] 23 | v_mag = torch.sqrt(v.pow(2).sum(1))# batch 24 | v_mag = torch.max(v_mag, torch.autograd.Variable(torch.FloatTensor([1e-8]).cuda())) 25 | v_mag = v_mag.view(batch,1).expand(batch,v.shape[1]) 26 | v = v/v_mag 27 | if(return_mag==True): 28 | return v, v_mag[:,0] 29 | else: 30 | return v 31 | 32 | # u, v batch*n 33 | def cross_product( u, v): 34 | batch = u.shape[0] 35 | #print (u.shape) 36 | #print (v.shape) 37 | i = u[:,1]*v[:,2] - u[:,2]*v[:,1] 38 | j = u[:,2]*v[:,0] - u[:,0]*v[:,2] 39 | k = u[:,0]*v[:,1] - u[:,1]*v[:,0] 40 | 41 | out = torch.cat((i.view(batch,1), j.view(batch,1), k.view(batch,1)),1)#batch*3 42 | 43 | return out 44 | 45 | 46 | #poses batch*6 47 | #poses 48 | def compute_rotation_matrix_from_ortho6d(ortho6d): 49 | x_raw = ortho6d[:,0:3]#batch*3 50 | y_raw = ortho6d[:,3:6]#batch*3 51 | 52 | x = normalize_vector(x_raw) #batch*3 53 | z = cross_product(x,y_raw) #batch*3 54 | z = normalize_vector(z)#batch*3 55 | y = cross_product(z,x)#batch*3 56 | 57 | x = x.view(-1,3,1) 58 | y = y.view(-1,3,1) 59 | z = z.view(-1,3,1) 60 | matrix = torch.cat((x,y,z), 2) #batch*3*3 61 | return matrix 62 | 63 | 64 | #in batch*6 65 | #out batch*5 66 | def stereographic_project(a): 67 | dim = a.shape[1] 68 | a = normalize_vector(a) 69 | out = a[:,0:dim-1]/(1-a[:,dim-1]) 70 | return out 71 | 72 | 73 | 74 | #in a batch*5, axis int 75 | def stereographic_unproject(a, axis=None): 76 | """ 77 | Inverse of stereographic projection: increases dimension by one. 78 | """ 79 | batch=a.shape[0] 80 | if axis is None: 81 | axis = a.shape[1] 82 | s2 = torch.pow(a,2).sum(1) #batch 83 | ans = torch.autograd.Variable(torch.zeros(batch, a.shape[1]+1).cuda()) #batch*6 84 | unproj = 2*a/(s2+1).view(batch,1).repeat(1,a.shape[1]) #batch*5 85 | if(axis>0): 86 | ans[:,:axis] = unproj[:,:axis] #batch*(axis-0) 87 | ans[:,axis] = (s2-1)/(s2+1) #batch 88 | ans[:,axis+1:] = unproj[:,axis:] #batch*(5-axis) # Note that this is a no-op if the default option (last axis) is used 89 | return ans 90 | 91 | 92 | #a batch*5 93 | #out batch*3*3 94 | def compute_rotation_matrix_from_ortho5d(a): 95 | batch = a.shape[0] 96 | proj_scale_np = np.array([np.sqrt(2)+1, np.sqrt(2)+1, np.sqrt(2)]) #3 97 | proj_scale = torch.autograd.Variable(torch.FloatTensor(proj_scale_np).cuda()).view(1,3).repeat(batch,1) #batch,3 98 | 99 | u = stereographic_unproject(a[:, 2:5] * proj_scale, axis=0)#batch*4 100 | norm = torch.sqrt(torch.pow(u[:,1:],2).sum(1)) #batch 101 | u = u/ norm.view(batch,1).repeat(1,u.shape[1]) #batch*4 102 | b = torch.cat((a[:,0:2], u),1)#batch*6 103 | matrix = compute_rotation_matrix_from_ortho6d(b) 104 | return matrix 105 | 106 | 107 | #quaternion batch*4 108 | def compute_rotation_matrix_from_quaternion( quaternion): 109 | batch=quaternion.shape[0] 110 | 111 | 112 | quat = normalize_vector(quaternion).contiguous() 113 | 114 | qw = quat[...,0].contiguous().view(batch, 1) 115 | qx = quat[...,1].contiguous().view(batch, 1) 116 | qy = quat[...,2].contiguous().view(batch, 1) 117 | qz = quat[...,3].contiguous().view(batch, 1) 118 | 119 | # Unit quaternion rotation matrices computatation 120 | xx = qx*qx 121 | yy = qy*qy 122 | zz = qz*qz 123 | xy = qx*qy 124 | xz = qx*qz 125 | yz = qy*qz 126 | xw = qx*qw 127 | yw = qy*qw 128 | zw = qz*qw 129 | 130 | row0 = torch.cat((1-2*yy-2*zz, 2*xy - 2*zw, 2*xz + 2*yw), 1) #batch*3 131 | row1 = torch.cat((2*xy+ 2*zw, 1-2*xx-2*zz, 2*yz-2*xw ), 1) #batch*3 132 | row2 = torch.cat((2*xz-2*yw, 2*yz+2*xw, 1-2*xx-2*yy), 1) #batch*3 133 | 134 | matrix = torch.cat((row0.view(batch, 1, 3), row1.view(batch,1,3), row2.view(batch,1,3)),1) #batch*3*3 135 | 136 | return matrix 137 | 138 | #axisAngle batch*4 angle, x,y,z 139 | def compute_rotation_matrix_from_axisAngle( axisAngle): 140 | batch = axisAngle.shape[0] 141 | 142 | theta = torch.tanh(axisAngle[:,0])*np.pi #[-180, 180] 143 | sin = torch.sin(theta*0.5) 144 | axis = normalize_vector(axisAngle[:,1:4]) #batch*3 145 | qw = torch.cos(theta*0.5) 146 | qx = axis[:,0]*sin 147 | qy = axis[:,1]*sin 148 | qz = axis[:,2]*sin 149 | 150 | # Unit quaternion rotation matrices computatation 151 | xx = (qx*qx).view(batch,1) 152 | yy = (qy*qy).view(batch,1) 153 | zz = (qz*qz).view(batch,1) 154 | xy = (qx*qy).view(batch,1) 155 | xz = (qx*qz).view(batch,1) 156 | yz = (qy*qz).view(batch,1) 157 | xw = (qx*qw).view(batch,1) 158 | yw = (qy*qw).view(batch,1) 159 | zw = (qz*qw).view(batch,1) 160 | 161 | row0 = torch.cat((1-2*yy-2*zz, 2*xy - 2*zw, 2*xz + 2*yw), 1) #batch*3 162 | row1 = torch.cat((2*xy+ 2*zw, 1-2*xx-2*zz, 2*yz-2*xw ), 1) #batch*3 163 | row2 = torch.cat((2*xz-2*yw, 2*yz+2*xw, 1-2*xx-2*yy), 1) #batch*3 164 | 165 | matrix = torch.cat((row0.view(batch, 1, 3), row1.view(batch,1,3), row2.view(batch,1,3)),1) #batch*3*3 166 | 167 | return matrix 168 | 169 | #axisAngle batch*3 (x,y,z)*theta 170 | def compute_rotation_matrix_from_Rodriguez( rod): 171 | batch = rod.shape[0] 172 | 173 | axis, theta = normalize_vector(rod, return_mag=True) 174 | 175 | sin = torch.sin(theta) 176 | 177 | 178 | qw = torch.cos(theta) 179 | qx = axis[:,0]*sin 180 | qy = axis[:,1]*sin 181 | qz = axis[:,2]*sin 182 | 183 | # Unit quaternion rotation matrices computatation 184 | xx = (qx*qx).view(batch,1) 185 | yy = (qy*qy).view(batch,1) 186 | zz = (qz*qz).view(batch,1) 187 | xy = (qx*qy).view(batch,1) 188 | xz = (qx*qz).view(batch,1) 189 | yz = (qy*qz).view(batch,1) 190 | xw = (qx*qw).view(batch,1) 191 | yw = (qy*qw).view(batch,1) 192 | zw = (qz*qw).view(batch,1) 193 | 194 | row0 = torch.cat((1-2*yy-2*zz, 2*xy - 2*zw, 2*xz + 2*yw), 1) #batch*3 195 | row1 = torch.cat((2*xy+ 2*zw, 1-2*xx-2*zz, 2*yz-2*xw ), 1) #batch*3 196 | row2 = torch.cat((2*xz-2*yw, 2*yz+2*xw, 1-2*xx-2*yy), 1) #batch*3 197 | 198 | matrix = torch.cat((row0.view(batch, 1, 3), row1.view(batch,1,3), row2.view(batch,1,3)),1) #batch*3*3 199 | 200 | return matrix 201 | 202 | #axisAngle batch*3 a,b,c 203 | def compute_rotation_matrix_from_hopf( hopf): 204 | batch = hopf.shape[0] 205 | 206 | theta = (torch.tanh(hopf[:,0])+1.0)*np.pi/2.0 #[0, pi] 207 | phi = (torch.tanh(hopf[:,1])+1.0)*np.pi #[0,2pi) 208 | tao = (torch.tanh(hopf[:,2])+1.0)*np.pi #[0,2pi) 209 | 210 | qw = torch.cos(theta/2)*torch.cos(tao/2) 211 | qx = torch.cos(theta/2)*torch.sin(tao/2) 212 | qy = torch.sin(theta/2)*torch.cos(phi+tao/2) 213 | qz = torch.sin(theta/2)*torch.sin(phi+tao/2) 214 | 215 | # Unit quaternion rotation matrices computatation 216 | xx = (qx*qx).view(batch,1) 217 | yy = (qy*qy).view(batch,1) 218 | zz = (qz*qz).view(batch,1) 219 | xy = (qx*qy).view(batch,1) 220 | xz = (qx*qz).view(batch,1) 221 | yz = (qy*qz).view(batch,1) 222 | xw = (qx*qw).view(batch,1) 223 | yw = (qy*qw).view(batch,1) 224 | zw = (qz*qw).view(batch,1) 225 | 226 | row0 = torch.cat((1-2*yy-2*zz, 2*xy - 2*zw, 2*xz + 2*yw), 1) #batch*3 227 | row1 = torch.cat((2*xy+ 2*zw, 1-2*xx-2*zz, 2*yz-2*xw ), 1) #batch*3 228 | row2 = torch.cat((2*xz-2*yw, 2*yz+2*xw, 1-2*xx-2*yy), 1) #batch*3 229 | 230 | matrix = torch.cat((row0.view(batch, 1, 3), row1.view(batch,1,3), row2.view(batch,1,3)),1) #batch*3*3 231 | 232 | return matrix 233 | 234 | 235 | #euler batch*4 236 | #output cuda batch*3*3 matrices in the rotation order of XZ'Y'' (intrinsic) or YZX (extrinsic) 237 | def compute_rotation_matrix_from_euler(euler): 238 | batch=euler.shape[0] 239 | 240 | c1=torch.cos(euler[:,0]).view(batch,1)#batch*1 241 | s1=torch.sin(euler[:,0]).view(batch,1)#batch*1 242 | c2=torch.cos(euler[:,2]).view(batch,1)#batch*1 243 | s2=torch.sin(euler[:,2]).view(batch,1)#batch*1 244 | c3=torch.cos(euler[:,1]).view(batch,1)#batch*1 245 | s3=torch.sin(euler[:,1]).view(batch,1)#batch*1 246 | 247 | row1=torch.cat((c2*c3, -s2, c2*s3 ), 1).view(-1,1,3) #batch*1*3 248 | row2=torch.cat((c1*s2*c3+s1*s3, c1*c2, c1*s2*s3-s1*c3), 1).view(-1,1,3) #batch*1*3 249 | row3=torch.cat((s1*s2*c3-c1*s3, s1*c2, s1*s2*s3+c1*c3), 1).view(-1,1,3) #batch*1*3 250 | 251 | matrix = torch.cat((row1, row2, row3), 1) #batch*3*3 252 | 253 | 254 | return matrix 255 | 256 | 257 | #euler_sin_cos batch*6 258 | #output cuda batch*3*3 matrices in the rotation order of XZ'Y'' (intrinsic) or YZX (extrinsic) 259 | def compute_rotation_matrix_from_euler_sin_cos(euler_sin_cos): 260 | batch=euler_sin_cos.shape[0] 261 | 262 | s1 = euler_sin_cos[:,0].view(batch,1) 263 | c1 = euler_sin_cos[:,1].view(batch,1) 264 | s2 = euler_sin_cos[:,2].view(batch,1) 265 | c2 = euler_sin_cos[:,3].view(batch,1) 266 | s3 = euler_sin_cos[:,4].view(batch,1) 267 | c3 = euler_sin_cos[:,5].view(batch,1) 268 | 269 | 270 | row1=torch.cat((c2*c3, -s2, c2*s3 ), 1).view(-1,1,3) #batch*1*3 271 | row2=torch.cat((c1*s2*c3+s1*s3, c1*c2, c1*s2*s3-s1*c3), 1).view(-1,1,3) #batch*1*3 272 | row3=torch.cat((s1*s2*c3-c1*s3, s1*c2, s1*s2*s3+c1*c3), 1).view(-1,1,3) #batch*1*3 273 | 274 | matrix = torch.cat((row1, row2, row3), 1) #batch*3*3 275 | 276 | 277 | return matrix 278 | 279 | 280 | #matrices batch*3*3 281 | #both matrix are orthogonal rotation matrices 282 | #out theta between 0 to 180 degree batch 283 | def compute_geodesic_distance_from_two_matrices(m1, m2): 284 | batch=m1.shape[0] 285 | m = torch.bmm(m1, m2.transpose(1,2)) #batch*3*3 286 | 287 | cos = ( m[:,0,0] + m[:,1,1] + m[:,2,2] - 1 )/2 288 | cos = torch.min(cos, torch.autograd.Variable(torch.ones(batch).cuda()) ) 289 | cos = torch.max(cos, torch.autograd.Variable(torch.ones(batch).cuda())*-1 ) 290 | 291 | 292 | theta = torch.acos(cos) 293 | 294 | #theta = torch.min(theta, 2*np.pi - theta) 295 | 296 | 297 | return theta 298 | 299 | 300 | #matrices batch*3*3 301 | #both matrix are orthogonal rotation matrices 302 | #out theta between 0 to 180 degree batch 303 | def compute_angle_from_r_matrices(m): 304 | 305 | batch=m.shape[0] 306 | 307 | cos = ( m[:,0,0] + m[:,1,1] + m[:,2,2] - 1 )/2 308 | cos = torch.min(cos, torch.autograd.Variable(torch.ones(batch).cuda()) ) 309 | cos = torch.max(cos, torch.autograd.Variable(torch.ones(batch).cuda())*-1 ) 310 | 311 | theta = torch.acos(cos) 312 | 313 | return theta 314 | 315 | def get_sampled_rotation_matrices_by_quat(batch): 316 | #quat = torch.autograd.Variable(torch.rand(batch,4).cuda()) 317 | quat = torch.autograd.Variable(torch.randn(batch, 4).cuda()) 318 | matrix = compute_rotation_matrix_from_quaternion(quat) 319 | return matrix 320 | 321 | def get_sampled_rotation_matrices_by_hpof(batch): 322 | 323 | theta = torch.autograd.Variable(torch.FloatTensor(np.random.uniform(0,1, batch)*np.pi).cuda()) #[0, pi] 324 | phi = torch.autograd.Variable(torch.FloatTensor(np.random.uniform(0,2,batch)*np.pi).cuda()) #[0,2pi) 325 | tao = torch.autograd.Variable(torch.FloatTensor(np.random.uniform(0,2,batch)*np.pi).cuda()) #[0,2pi) 326 | 327 | 328 | qw = torch.cos(theta/2)*torch.cos(tao/2) 329 | qx = torch.cos(theta/2)*torch.sin(tao/2) 330 | qy = torch.sin(theta/2)*torch.cos(phi+tao/2) 331 | qz = torch.sin(theta/2)*torch.sin(phi+tao/2) 332 | 333 | # Unit quaternion rotation matrices computatation 334 | xx = (qx*qx).view(batch,1) 335 | yy = (qy*qy).view(batch,1) 336 | zz = (qz*qz).view(batch,1) 337 | xy = (qx*qy).view(batch,1) 338 | xz = (qx*qz).view(batch,1) 339 | yz = (qy*qz).view(batch,1) 340 | xw = (qx*qw).view(batch,1) 341 | yw = (qy*qw).view(batch,1) 342 | zw = (qz*qw).view(batch,1) 343 | 344 | row0 = torch.cat((1-2*yy-2*zz, 2*xy - 2*zw, 2*xz + 2*yw), 1) #batch*3 345 | row1 = torch.cat((2*xy+ 2*zw, 1-2*xx-2*zz, 2*yz-2*xw ), 1) #batch*3 346 | row2 = torch.cat((2*xz-2*yw, 2*yz+2*xw, 1-2*xx-2*yy), 1) #batch*3 347 | 348 | matrix = torch.cat((row0.view(batch, 1, 3), row1.view(batch,1,3), row2.view(batch,1,3)),1) #batch*3*3 349 | 350 | return matrix 351 | 352 | #axisAngle batch*4 angle, x,y,z 353 | def get_sampled_rotation_matrices_by_axisAngle( batch, return_quaternion=False): 354 | 355 | theta = torch.autograd.Variable(torch.FloatTensor(np.random.uniform(-1,1, batch)*np.pi).cuda()) #[0, pi] #[-180, 180] 356 | sin = torch.sin(theta) 357 | axis = torch.autograd.Variable(torch.randn(batch, 3).cuda()) 358 | axis = normalize_vector(axis) #batch*3 359 | qw = torch.cos(theta) 360 | qx = axis[:,0]*sin 361 | qy = axis[:,1]*sin 362 | qz = axis[:,2]*sin 363 | 364 | quaternion = torch.cat((qw.view(batch,1), qx.view(batch,1), qy.view(batch,1), qz.view(batch,1)), 1 ) 365 | 366 | # Unit quaternion rotation matrices computatation 367 | xx = (qx*qx).view(batch,1) 368 | yy = (qy*qy).view(batch,1) 369 | zz = (qz*qz).view(batch,1) 370 | xy = (qx*qy).view(batch,1) 371 | xz = (qx*qz).view(batch,1) 372 | yz = (qy*qz).view(batch,1) 373 | xw = (qx*qw).view(batch,1) 374 | yw = (qy*qw).view(batch,1) 375 | zw = (qz*qw).view(batch,1) 376 | 377 | row0 = torch.cat((1-2*yy-2*zz, 2*xy - 2*zw, 2*xz + 2*yw), 1) #batch*3 378 | row1 = torch.cat((2*xy+ 2*zw, 1-2*xx-2*zz, 2*yz-2*xw ), 1) #batch*3 379 | row2 = torch.cat((2*xz-2*yw, 2*yz+2*xw, 1-2*xx-2*yy), 1) #batch*3 380 | 381 | matrix = torch.cat((row0.view(batch, 1, 3), row1.view(batch,1,3), row2.view(batch,1,3)),1) #batch*3*3 382 | 383 | if(return_quaternion==True): 384 | return matrix, quaternion 385 | else: 386 | return matrix 387 | 388 | 389 | 390 | 391 | 392 | 393 | 394 | 395 | 396 | --------------------------------------------------------------------------------