├── .gitignore
├── LICENSE
├── README.md
├── configs
├── binary_model.yaml
├── data
│ ├── circle_dirs.yaml
│ ├── dinner_dirs.yaml
│ ├── line_dirs.yaml
│ ├── test_tower_dirs.yaml
│ └── tower_dirs.yaml
├── object_selection_network.yaml
├── structformer.yaml
├── structformer_no_encoder.yaml
└── structformer_no_structure.yaml
├── doc
└── rearrange_mugs.gif
├── pyproject.toml
├── requirements.txt
├── scripts
└── run_full_pipeline.py
├── setup.cfg
├── setup.py
└── src
└── structformer
├── __init__.py
├── data
├── __init__.py
├── binary_dataset.py
├── object_set_refer_dataset.py
├── sequence_dataset.py
└── tokenizer.py
├── evaluation
├── __init__.py
├── inference.py
├── test_binary_model.py
├── test_object_selection_network.py
├── test_structformer.py
├── test_structformer_no_encoder.py
└── test_structformer_no_structure.py
├── models
├── __init__.py
├── object_selection_network.py
├── point_transformer.py
└── pose_generation_network.py
├── training
├── __init__.py
├── train_binary_model.py
├── train_object_selection_network.py
├── train_structformer.py
├── train_structformer_no_encoder.py
└── train_structformer_no_structure.py
└── utils
├── __init__.py
├── brain2
├── __init__.py
├── camera.py
├── image.py
└── pose.py
├── pointnet.py
├── rearrangement.py
├── rotation_continuity.py
└── transformations.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Vim temporary files
114 | *.swp
115 | *.swo
116 |
117 | # Spyder project settings
118 | .spyderproject
119 | .spyproject
120 |
121 | # Rope project settings
122 | .ropeproject
123 |
124 | # mkdocs documentation
125 | /site
126 |
127 | # mypy
128 | .mypy_cache/
129 | .dmypy.json
130 | dmypy.json
131 |
132 | # Pyre type checker
133 | .pyre/
134 |
135 | # Custom
136 | /experiments
137 | /models
138 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | NVIDIA Source Code License for StructFormer
2 | 1. Definitions
3 | “Licensor” means any person or entity that distributes its Work.
4 | “Software” means the original work of authorship made available under this License.
5 | “Work” means the Software and any additions to or derivative works of the Software that are made available under this License.
6 | The terms “reproduce,” “reproduction,” “derivative works,” and “distribution” have the meaning as provided under U.S. copyright law; provided, however, that for the purposes of this License, derivative works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work.
7 | Works, including the Software, are “made available” under this License by including in or with the Work either (a) a copyright notice referencing the applicability of this License to the Work, or (b) a copy of this License.
8 | 2. License Grant
9 | 2.1 Copyright Grant. Subject to the terms and conditions of this License, each Licensor grants to you a perpetual, worldwide, non-exclusive, royalty-free, copyright license to reproduce, prepare derivative works of, publicly display, publicly perform, sublicense and distribute its Work and any resulting derivative works in any form.
10 | 3. Limitations
11 | 3.1 Redistribution. You may reproduce or distribute the Work only if (a) you do so under this License, (b) you include a complete copy of this License with your distribution, and (c) you retain without modification any copyright, patent, trademark, or attribution notices that are present in the Work.
12 | 3.2 Derivative Works. You may specify that additional or different terms apply to the use, reproduction, and distribution of your derivative works of the Work (“Your Terms”) only if (a) Your Terms provide that the use limitation in Section 3.3 applies to your derivative works, and (b) you identify the specific derivative works that are subject to Your Terms. Notwithstanding Your Terms, this License (including the redistribution requirements in Section 3.1) will continue to apply to the Work itself.
13 | 3.3 Use Limitation. The Work and any derivative works thereof only may be used or intended for use non-commercially. Notwithstanding the foregoing, NVIDIA and its affiliates may use the Work and any derivative works commercially. As used herein, “non-commercially” means for research or evaluation purposes only.
14 | 3.4 Patent Claims. If you bring or threaten to bring a patent claim against any Licensor (including any claim, cross-claim or counterclaim in a lawsuit) to enforce any patents that you allege are infringed by any Work, then your rights under this License from such Licensor (including the grant in Section 2.1) will terminate immediately.
15 | 3.5 Trademarks. This License does not grant any rights to use any Licensor’s or its affiliates’ names, logos, or trademarks, except as necessary to reproduce the notices described in this License.
16 | 3.6 Termination. If you violate any term of this License, then your rights under this License (including the grant in Section 2.1) will terminate immediately.
17 | 4. Disclaimer of Warranty.
18 | THE WORK IS PROVIDED “AS IS” WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF
19 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER THIS LICENSE.
20 | 5. Limitation of Liability.
21 | EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER COMM ERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # StructFormer
2 |
3 | Pytorch implementation for ICRA 2022 paper _StructFormer: Learning Spatial Structure for Language-Guided Semantic Rearrangement of Novel Objects_. [[PDF]](https://arxiv.org/abs/2110.10189) [[Video]](https://youtu.be/6NPdpAtMawM) [[Website]](https://sites.google.com/view/structformer)
4 |
5 | StructFormer rearranges unknown objects into semantically meaningful spatial structures based on high-level language instructions and partial-view
6 | point cloud observations of the scene. The model use multi-modal transformers to predict both which objects to manipulate and where to place them.
7 |
8 |
9 |
10 |
11 |
12 | ## License
13 | The source code is released under the [NVIDIA Source Code License](LICENSE). The dataset is released under [CC BY-NC 4.0](https://creativecommons.org/licenses/by-nc/4.0/).
14 |
15 |
16 | ## Installation
17 |
18 | ```
19 | pip install -r requirements.txt
20 | pip install -e .
21 | ```
22 |
23 | ### Notes on Dependencies
24 | - `h5py==2.10`: this specific version is needed.
25 | - `omegaconfg==2.1`: some functions used in this repo are from newer versions
26 |
27 | ### Environments
28 | The code has been tested on ubuntu 18.04 with nvidia driver 460.91, cuda 11.0, python 3.6, and pytorch 1.7.
29 |
30 | ## Organization
31 | Source code in the StructFormer package is mainly organized as:
32 | - data loaders `data`
33 | - models `models`
34 | - training scripts `training`
35 | - inference scripts `evaluation`
36 |
37 | Parameters for data loaders and models are defined in `OmegaConf` yaml files stored in `configs`.
38 |
39 | Trained models are stored in `/experiments`
40 |
41 | ## Quick Start with Pretrained Models
42 | - Set the package root dir: `export STRUCTFORMER=/path/to/StructFormer`
43 | - Download pretrained models from [this link](https://drive.google.com/file/d/1EsptihJv_lPND902P6CYbe00QW-y-rA4/view?usp=sharing) and unzip to the `$STRUCTFORMER/models` folder
44 | - Download the test split of the dataset from [this link](https://drive.google.com/file/d/1e76qJbBJ2bKYq0JzDSRWZjswySX1ftq_/view?usp=sharing) and unzip to the `$STRUCTFORMER/data_new_objects_test_split`
45 |
46 | ### Run StructFormer
47 | ```bash
48 | cd $STRUCTFORMER/scripts/
49 | python run_full_pipeline.py \
50 | --dataset_base_dir $STRUCTFORMER/data_new_objects_test_split \
51 | --object_selection_model_dir $STRUCTFORMER/models/object_selection_network/best_model \
52 | --pose_generation_model_dir $STRUCTFORMER/models/structformer_circle/best_model \
53 | --dirs_config $STRUCTFORMER/configs/data/circle_dirs.yaml
54 | ```
55 |
56 | ### Evaluate Pose Generation Networks
57 |
58 | Where `{model_name}` is one of `structformer_no_encoder`, `structformer_no_structure`, `object_selection_network`, `structformer`, and `{structure}` is one of `circle`, `line`, `tower`, or `dinner`:
59 |
60 | ```bash
61 | cd $STRUCTFORMER/src/structformer/evaluation/
62 | python test_{model_name}.py \
63 | --dataset_base_dir $STRUCTFORMER/data_new_objects_test_split \
64 | --model_dir $STRUCTFORMER/models/{model_name}_{structure}/best_model \
65 | --dirs_config $STRUCTFORMER/configs/data/{structure}_dirs.yaml
66 | ```
67 |
68 | ### Evaluate Object Selection Network
69 |
70 | Where `{structure}` is as above:
71 |
72 | ```bash
73 | cd $STRUCTFORMER/src/structformer/evaluation/
74 | python test_object_selection_network.py \
75 | --dataset_base_dir $STRUCTFORMER/data_new_objects_test_split \
76 | --model_dir $STRUCTFORMER/models/object_selection_network/best_model \
77 | --dirs_config $STRUCTFORMER/configs/data/{structure}_dirs.yaml
78 | ```
79 |
80 | ## Training
81 |
82 | - Download vocabulary list `type_vocabs_coarse.json` from [this link](https://drive.google.com/file/d/1topawwqMSvwE8Ac-8OiwMApEqwYeR5rc/view?usp=sharing) and unzip to the `$STRUCTFORMER/data_new_objects`.
83 | - Download all data for [circle](https://drive.google.com/file/d/1PTGFcAWBrQmlglygNiJz6p7s0rqe2rtP/view?usp=sharing) and unzip to the `$STRUCTFORMER/data_new_objects`.
84 |
85 | ### Pose Generation Networks
86 |
87 | Where `{model_name}` is one of `structformer_no_encoder`, `structformer_no_structure`, `object_selection_network`, `structformer`, and `{structure}` is one of `circle`, `line`, `tower`, or `dinner`:
88 |
89 | ```bash
90 | cd $STRUCTFORMER/src/structformer/training/
91 | python train_{model_name}.py \
92 | --dataset_base_dir $STRUCTFORMER/data_new_objects \
93 | --main_config $STRUCTFORMER/configs/{model_name}.yaml \
94 | --dirs_config STRUCTFORMER/configs/data/{structure}_dirs.yaml
95 | ```
96 |
97 | ### Object Selection Network
98 | ```bash
99 | cd $STRUCTFORMER/src/structformer/training/
100 | python train_object_selection_network.py \
101 | --dataset_base_dir $STRUCTFORMER/data_new_objects \
102 | --main_config $STRUCTFORMER/configs/object_selection_network.yaml \
103 | --dirs_config $STRUCTFORMER/configs/data/circle_dirs.yaml
104 | ```
105 |
106 | ## Citation
107 | If you find our work useful in your research, please cite:
108 | ```
109 | @inproceedings{structformer2022,
110 | title = {StructFormer: Learning Spatial Structure for Language-Guided Semantic Rearrangement of Novel Objects},
111 | author = {Liu, Weiyu and Paxton, Chris and Hermans, Tucker and Fox, Dieter},
112 | year = {2022},
113 | booktitle = {ICRA 2022}
114 | }
115 | ```
--------------------------------------------------------------------------------
/configs/binary_model.yaml:
--------------------------------------------------------------------------------
1 | random_seed: 1
2 | device: 0
3 |
4 | obj_xytheta_relative: False
5 | save_model: True
6 | save_best_model: True
7 |
8 | dataset:
9 | batch_size: 32
10 | max_num_shape_parameters: 5
11 | max_num_objects: 7
12 | max_num_other_objects: 5
13 | max_num_rearrange_features: 0
14 | max_num_anchor_features: 0
15 | num_pts: 1024
16 | num_workers: 4
17 | pin_memory: True
18 |
19 | model:
20 | name: binary_model
21 | num_attention_heads: 8
22 | encoder_hidden_dim: 512
23 | encoder_dropout: 0.0
24 | encoder_activation: relu
25 | encoder_num_layers: 8
26 | object_dropout: 0.1
27 | theta_loss_divide: 3
28 | ignore_rgb: True
29 |
30 | training:
31 | learning_rate: 0.0001
32 | max_epochs: 200
33 | l2: 0.0001
34 | lr_restart: 3000
35 | warmup: 10
--------------------------------------------------------------------------------
/configs/data/circle_dirs.yaml:
--------------------------------------------------------------------------------
1 | experiment_dir: ../../experiments/${env:DATETIME}
2 | dataset_base_dir: ???
3 | dataset:
4 | dirs:
5 | - ${dataset_base_dir}/examples_circle_new_objects/result
6 | index_dirs:
7 | - index_34k
8 | vocab_dir: ${dataset_base_dir}/type_vocabs_coarse.json
--------------------------------------------------------------------------------
/configs/data/dinner_dirs.yaml:
--------------------------------------------------------------------------------
1 | experiment_dir: ../../experiments/${env:DATETIME}
2 | dataset_base_dir: ???
3 | dataset:
4 | dirs:
5 | - ${dataset_base_dir}/examples_dinner_new_objects/result
6 | index_dirs:
7 | - index_24k
8 | vocab_dir: ${dataset_base_dir}/type_vocabs_coarse.json
--------------------------------------------------------------------------------
/configs/data/line_dirs.yaml:
--------------------------------------------------------------------------------
1 | experiment_dir: ../../experiments/${env:DATETIME}
2 | dataset_base_dir: ???
3 | dataset:
4 | dirs:
5 | - ${dataset_base_dir}/examples_line_new_objects/result
6 | index_dirs:
7 | - index_42k
8 | vocab_dir: ${dataset_base_dir}/type_vocabs_coarse.json
--------------------------------------------------------------------------------
/configs/data/test_tower_dirs.yaml:
--------------------------------------------------------------------------------
1 | experiment_dir: ../../experiments/${env:DATETIME}
2 | dataset_base_dir: ???
3 | dataset:
4 | dirs:
5 | - ${dataset_base_dir}/examples_tower_new_objects/result
6 | index_dirs:
7 | - index_13k
8 | vocab_dir: ${dataset_base_dir}/type_vocabs_coarse.json
--------------------------------------------------------------------------------
/configs/data/tower_dirs.yaml:
--------------------------------------------------------------------------------
1 | experiment_dir: ../../experiments/${env:DATETIME}
2 | dataset_base_dir: ???
3 | dataset:
4 | dirs:
5 | - ${dataset_base_dir}/examples_tower_new_objects/result
6 | index_dirs:
7 | - index_25k
8 | vocab_dir: ${dataset_base_dir}/type_vocabs_coarse.json
--------------------------------------------------------------------------------
/configs/object_selection_network.yaml:
--------------------------------------------------------------------------------
1 | random_seed: 1
2 | device: 0
3 |
4 | save_model: True
5 | save_best_model: True
6 |
7 | dataset:
8 | batch_size: 64
9 | max_num_all_objects: 11
10 | max_num_shape_parameters: 5
11 | max_num_rearrange_features: 1
12 | max_num_anchor_features: 3
13 | num_pts: 1024
14 | num_workers: 4
15 | pin_memory: True
16 |
17 | model:
18 | name: object_selection_network
19 | num_attention_heads: 8
20 | encoder_hidden_dim: 32
21 | encoder_dropout: 0.2
22 | encoder_activation: relu
23 | encoder_num_layers: 8
24 | use_focal_loss: True
25 | focal_loss_gamma: 2
26 |
27 | training:
28 | learning_rate: 0.0001
29 | max_epochs: 200
30 | l2: 0.0001
31 | lr_restart: 3000
32 | warmup: 10
--------------------------------------------------------------------------------
/configs/structformer.yaml:
--------------------------------------------------------------------------------
1 | random_seed: 1
2 | device: 0
3 |
4 | obj_xytheta_relative: False
5 | save_model: True
6 | save_best_model: True
7 |
8 | dataset:
9 | batch_size: 32
10 | max_num_shape_parameters: 5
11 | max_num_objects: 7
12 | max_num_other_objects: 5
13 | max_num_rearrange_features: 0
14 | max_num_anchor_features: 0
15 | num_pts: 1024
16 | num_workers: 4
17 | pin_memory: True
18 | use_structure_frame: True
19 |
20 | model:
21 | name: structformer
22 | num_attention_heads: 8
23 | encoder_hidden_dim: 512
24 | encoder_dropout: 0.0
25 | encoder_activation: relu
26 | encoder_num_layers: 8
27 | structure_dropout: 0.5
28 | object_dropout: 0.1
29 | theta_loss_divide: 3
30 | ignore_rgb: True
31 |
32 | training:
33 | learning_rate: 0.0001
34 | max_epochs: 200
35 | l2: 0.0001
36 | lr_restart: 3000
37 | warmup: 10
--------------------------------------------------------------------------------
/configs/structformer_no_encoder.yaml:
--------------------------------------------------------------------------------
1 | random_seed: 1
2 | device: 0
3 |
4 | obj_xytheta_relative: False
5 | save_model: True
6 | save_best_model: True
7 |
8 | dataset:
9 | batch_size: 32
10 | max_num_shape_parameters: 5
11 | max_num_objects: 7
12 | max_num_other_objects: 5
13 | max_num_rearrange_features: 0
14 | max_num_anchor_features: 0
15 | num_pts: 1024
16 | num_workers: 4
17 | pin_memory: True
18 | use_structure_frame: True
19 |
20 | model:
21 | name: structformer_no_encoder
22 | num_attention_heads: 8
23 | encoder_hidden_dim: 512
24 | encoder_dropout: 0.0
25 | encoder_activation: relu
26 | encoder_num_layers: 8
27 | structure_dropout: 0.5
28 | object_dropout: 0.1
29 | theta_loss_divide: 3
30 | ignore_rgb: True
31 |
32 | training:
33 | learning_rate: 0.0001
34 | max_epochs: 200
35 | l2: 0.0001
36 | lr_restart: 3000
37 | warmup: 10
--------------------------------------------------------------------------------
/configs/structformer_no_structure.yaml:
--------------------------------------------------------------------------------
1 | random_seed: 1
2 | device: 0
3 |
4 | obj_xytheta_relative: False
5 | save_model: True
6 | save_best_model: True
7 |
8 | dataset:
9 | batch_size: 32
10 | max_num_shape_parameters: 5
11 | max_num_objects: 7
12 | max_num_other_objects: 5
13 | max_num_rearrange_features: 0
14 | max_num_anchor_features: 0
15 | num_pts: 1024
16 | num_workers: 4
17 | pin_memory: True
18 | use_structure_frame: False
19 |
20 | model:
21 | name: structformer_no_structure
22 | num_attention_heads: 8
23 | encoder_hidden_dim: 512
24 | encoder_dropout: 0.0
25 | encoder_activation: relu
26 | encoder_num_layers: 8
27 | object_dropout: 0.1
28 | theta_loss_divide: 3
29 | ignore_rgb: True
30 |
31 | training:
32 | learning_rate: 0.0001
33 | max_epochs: 200
34 | l2: 0.0001
35 | lr_restart: 3000
36 | warmup: 10
--------------------------------------------------------------------------------
/doc/rearrange_mugs.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wliu88/StructFormer/a81d8c83313f1f24a75ca83aa9dd9f8cfcf27419/doc/rearrange_mugs.gif
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = [
3 | "setuptools>=42",
4 | "wheel"
5 | ]
6 | build-backend = "setuptools.build_meta"
7 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | warmup-scheduler==0.3
2 | h5py==2.10.0
3 | omegaconf==2.1.1
4 | tqdm==4.63.0
5 | opencv-python==4.5.5.62
6 | open3d==0.15.2
7 | trimesh==3.10.2
8 | torch # cuda 11 requires a different installation, see https://pytorch.org/get-started/locally/
--------------------------------------------------------------------------------
/scripts/run_full_pipeline.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import time
3 | import torch
4 | import numpy as np
5 | import os
6 | import argparse
7 | from omegaconf import OmegaConf
8 | from torch.utils.data import DataLoader
9 |
10 | from structformer.data.tokenizer import Tokenizer
11 | from structformer.evaluation.test_object_selection_network import ObjectSelectionInference
12 | from structformer.evaluation.test_structformer import PriorInference
13 | from structformer.utils.rearrangement import show_pcs_with_predictions, get_initial_scene_idxs, evaluate_target_object_predictions, save_img, show_pcs_with_labels, test_new_vis, show_pcs
14 | from structformer.evaluation.inference import PointCloudRearrangement
15 |
16 |
17 | def run_demo(object_selection_model_dir, pose_generation_model_dir, dirs_config, beam_size=3):
18 | """
19 | Run a simple demo. Creates the object selection inference model, pose generation model,
20 | and so on. Requires paths to the model + config directories.
21 | """
22 | object_selection_inference = ObjectSelectionInference(object_selection_model_dir, dirs_cfg)
23 | pose_generation_inference = PriorInference(pose_generation_model_dir, dirs_cfg)
24 |
25 | test_dataset = object_selection_inference.dataset
26 | initial_scene_idxs = get_initial_scene_idxs(test_dataset)
27 |
28 | for idx in range(len(test_dataset)):
29 | if idx not in initial_scene_idxs:
30 | continue
31 |
32 | if idx == 4:
33 | continue
34 |
35 | filename, _ = test_dataset.get_data_index(idx)
36 | scene_id = os.path.split(filename)[1][4:-3]
37 | print("-"*50)
38 | print("Scene No.{}".format(scene_id))
39 |
40 | # retrieve data
41 | init_datum = test_dataset.get_raw_data(idx)
42 | goal_specification = init_datum["goal_specification"]
43 | object_selection_structured_sentence = init_datum["sentence"][5:]
44 | structure_specification_structured_sentence = init_datum["sentence"][:5]
45 | object_selection_natural_sentence = object_selection_inference.tokenizer.convert_to_natural_sentence(
46 | object_selection_structured_sentence)
47 | structure_specification_natural_sentence = object_selection_inference.tokenizer.convert_structure_params_to_natural_language(structure_specification_structured_sentence)
48 |
49 | # object selection
50 | predictions, gts = object_selection_inference.predict_target_objects(init_datum)
51 |
52 | all_obj_xyzs = init_datum["xyzs"][:len(predictions)]
53 | all_obj_rgbs = init_datum["rgbs"][:len(predictions)]
54 | obj_idxs = [i for i, l in enumerate(predictions) if l == 1.0]
55 | if len(obj_idxs) == 0:
56 | continue
57 | other_obj_idxs = [i for i, l in enumerate(predictions) if l == 0.0]
58 | obj_xyzs = [all_obj_xyzs[i] for i in obj_idxs]
59 | obj_rgbs = [all_obj_rgbs[i] for i in obj_idxs]
60 | other_obj_xyzs = [all_obj_xyzs[i] for i in other_obj_idxs]
61 | other_obj_rgbs = [all_obj_rgbs[i] for i in other_obj_idxs]
62 |
63 | print("\nSelect objects to rearrange...")
64 | print("Instruction:", object_selection_natural_sentence)
65 | print("Visualize groundtruth (dot color) and prediction (ring color)")
66 | show_pcs_with_predictions(init_datum["xyzs"][:len(predictions)], init_datum["rgbs"][:len(predictions)],
67 | gts, predictions, add_table=True, side_view=True)
68 | print("Visualize object to rearrange")
69 | show_pcs(obj_xyzs, obj_rgbs, side_view=True, add_table=True)
70 |
71 | # pose generation
72 | max_num_objects = pose_generation_inference.cfg.dataset.max_num_objects
73 | max_num_other_objects = pose_generation_inference.cfg.dataset.max_num_other_objects
74 | if len(obj_xyzs) > max_num_objects:
75 | print("WARNING: reducing the number of \"query\" objects because this model is trained with a maximum of {} \"query\" objects. Train a new model if a larger number is needed.".format(max_num_objects))
76 | obj_xyzs = obj_xyzs[:max_num_objects]
77 | obj_rgbs = obj_rgbs[:max_num_objects]
78 | if len(other_obj_xyzs) > max_num_other_objects:
79 | print("WARNING: reducing the number of \"distractor\" objects because this model is trained with a maximum of {} \"distractor\" objects. Train a new model if a larger number is needed.".format(max_num_other_objects))
80 | other_obj_xyzs = other_obj_xyzs[:max_num_other_objects]
81 | other_obj_rgbs = other_obj_rgbs[:max_num_other_objects]
82 |
83 | pose_generation_datum = pose_generation_inference.dataset.prepare_test_data(obj_xyzs, obj_rgbs,
84 | other_obj_xyzs, other_obj_rgbs,
85 | goal_specification["shape"])
86 | beam_data = []
87 | beam_pc_rearrangements = []
88 | for b in range(beam_size):
89 | datum_copy = copy.deepcopy(pose_generation_datum)
90 | beam_data.append(datum_copy)
91 | beam_pc_rearrangements.append(PointCloudRearrangement(datum_copy))
92 |
93 | # autoregressive decoding
94 | num_target_objects = beam_pc_rearrangements[0].num_target_objects
95 |
96 | # first predict structure pose
97 | beam_goal_struct_pose, target_object_preds = pose_generation_inference.limited_batch_inference(beam_data)
98 | for b in range(beam_size):
99 | datum = beam_data[b]
100 | datum["struct_x_inputs"] = [beam_goal_struct_pose[b][0]]
101 | datum["struct_y_inputs"] = [beam_goal_struct_pose[b][1]]
102 | datum["struct_z_inputs"] = [beam_goal_struct_pose[b][2]]
103 | datum["struct_theta_inputs"] = [beam_goal_struct_pose[b][3:]]
104 |
105 | # then iteratively predict pose of each object
106 | beam_goal_obj_poses = []
107 | for obj_idx in range(num_target_objects):
108 | struct_preds, target_object_preds = pose_generation_inference.limited_batch_inference(beam_data)
109 | beam_goal_obj_poses.append(target_object_preds[:, obj_idx])
110 | for b in range(beam_size):
111 | datum = beam_data[b]
112 | datum["obj_x_inputs"][obj_idx] = target_object_preds[b][obj_idx][0]
113 | datum["obj_y_inputs"][obj_idx] = target_object_preds[b][obj_idx][1]
114 | datum["obj_z_inputs"][obj_idx] = target_object_preds[b][obj_idx][2]
115 | datum["obj_theta_inputs"][obj_idx] = target_object_preds[b][obj_idx][3:]
116 | # concat in the object dim
117 | beam_goal_obj_poses = np.stack(beam_goal_obj_poses, axis=0)
118 | # swap axis
119 | beam_goal_obj_poses = np.swapaxes(beam_goal_obj_poses, 1, 0) # batch size, number of target objects, pose dim
120 |
121 | # move pc
122 | for bi in range(beam_size):
123 | beam_pc_rearrangements[bi].set_goal_poses(beam_goal_struct_pose[bi], beam_goal_obj_poses[bi])
124 | beam_pc_rearrangements[bi].rearrange()
125 |
126 | print("\nRearrange \"query\" objects...")
127 | print("Instruction:", structure_specification_natural_sentence)
128 | for pi, pc_rearrangement in enumerate(beam_pc_rearrangements):
129 | print("Visualize rearranged scene sample {}".format(pi))
130 | pc_rearrangement.visualize("goal", add_other_objects=True, add_table=True, side_view=True)
131 |
132 |
133 | if __name__ == "__main__":
134 | parser = argparse.ArgumentParser(description="Run a simple model")
135 | parser.add_argument("--dataset_base_dir", help='location of the dataset', type=str)
136 | parser.add_argument("--object_selection_model_dir", help='location for the saved object selection model', type=str)
137 | parser.add_argument("--pose_generation_model_dir", help='location for the saved pose generation model', type=str)
138 | parser.add_argument("--dirs_config", help='config yaml file for directories', type=str)
139 | args = parser.parse_args()
140 |
141 | os.environ["DATETIME"] = time.strftime("%Y%m%d-%H%M%S")
142 |
143 | # # debug only
144 | # args.dataset_base_dir = "/home/weiyu/data_drive/data_new_objects_test_split"
145 | # args.object_selection_model_dir = "/home/weiyu/Research/intern/StructFormer/models/object_selection_network/best_model"
146 | # args.pose_generation_model_dir = "/home/weiyu/Research/intern/StructFormer/models/structformer_circle/best_model"
147 | # args.dirs_config = "/home/weiyu/Research/intern/StructFormer/configs/data/circle_dirs.yaml"
148 |
149 | if args.dirs_config:
150 | assert os.path.exists(args.dirs_config), "Cannot find config yaml file at {}".format(args.dirs_config)
151 | dirs_cfg = OmegaConf.load(args.dirs_config)
152 | dirs_cfg.dataset_base_dir = args.dataset_base_dir
153 | OmegaConf.resolve(dirs_cfg)
154 | else:
155 | dirs_cfg = None
156 |
157 | run_demo(args.object_selection_model_dir, args.pose_generation_model_dir, dirs_cfg)
158 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [metadata]
2 | name = StructFormer-wliu88
3 | version = 0.0.1
4 | author = Weiyu Liu
5 | author_email = wliu88@gatech.edu
6 | description = Source code for StructFormer
7 | long_description = file: README.md
8 | long_description_content_type = text/markdown
9 | url = https://github.com/wliu88/StructFormer
10 | project_urls =
11 | Bug Tracker = https://github.com/wliu88/StructFormer/issues
12 | classifiers =
13 | Programming Language :: Python :: 3
14 | License :: OSI Approved :: NVIDIA License
15 | Operating System :: OS Independent
16 |
17 | [options]
18 | package_dir =
19 | = src
20 | packages = find:
21 | python_requires = >=3.6
22 |
23 | [options.packages.find]
24 | where = src
25 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import setuptools
2 |
3 | with open("README.md", "r", encoding="utf-8") as fh:
4 | long_description = fh.read()
5 |
6 | setuptools.setup(
7 | name="Structformer-wliu88",
8 | version="0.0.1",
9 | author="Weiyu Liu",
10 | author_email="wliu88@gatech.edu",
11 | description="Source code for StructFormer",
12 | long_description=long_description,
13 | long_description_content_type="text/markdown",
14 | url="https://github.com/wliu88/StructFormer",
15 | project_urls={
16 | "Bug Tracker": "https://github.com/wliu88/StructFormer/issues",
17 | },
18 | classifiers=[
19 | "Programming Language :: Python :: 3",
20 | "License :: OSI Approved :: NVIDIA License",
21 | "Operating System :: OS Independent",
22 | ],
23 | package_dir={"": "src"},
24 | packages=setuptools.find_packages(where="src"),
25 | python_requires=">=3.6",
26 | )
27 |
28 |
--------------------------------------------------------------------------------
/src/structformer/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wliu88/StructFormer/a81d8c83313f1f24a75ca83aa9dd9f8cfcf27419/src/structformer/__init__.py
--------------------------------------------------------------------------------
/src/structformer/data/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wliu88/StructFormer/a81d8c83313f1f24a75ca83aa9dd9f8cfcf27419/src/structformer/data/__init__.py
--------------------------------------------------------------------------------
/src/structformer/evaluation/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wliu88/StructFormer/a81d8c83313f1f24a75ca83aa9dd9f8cfcf27419/src/structformer/evaluation/__init__.py
--------------------------------------------------------------------------------
/src/structformer/evaluation/inference.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import torch
4 | import copy
5 |
6 | import structformer.utils.transformations as tra
7 | from structformer.utils.rearrangement import show_pcs, save_pcs, move_one_object_pc, make_gifs, modify_language, sample_gaussians, fit_gaussians, show_pcs_color_order
8 |
9 |
10 | class PointCloudRearrangement:
11 |
12 | """
13 | helps to keep track of point clouds and predicted object poses for inference
14 |
15 | ToDo: make the whole thing live on pytorch tensor
16 | ToDo: support binary format
17 | """
18 |
19 | def __init__(self, initial_datum, pose_format="xyz+3x3", use_structure_frame=True):
20 |
21 | assert pose_format == "xyz+3x3", "{} pose format not supported".format(pose_format)
22 |
23 | self.use_structure_frame = use_structure_frame
24 |
25 | self.num_target_objects = None
26 | self.num_other_objects = None
27 |
28 | # important: we do not store any padding pcs and poses
29 | self.initial_xyzs = {"xyzs": [], "rgbs": [], "other_xyzs": [], "other_rgbs": []}
30 | self.goal_xyzs = {"xyzs": [], "rgbs": []}
31 | self.goal_poses = {"obj_poses": []}
32 | if self.use_structure_frame:
33 | self.goal_poses["struct_pose"] = []
34 |
35 | self.set_initial_pc(initial_datum)
36 |
37 | def set_initial_pc(self, datum):
38 |
39 | self.num_target_objects = np.sum(np.array(datum["object_pad_mask"]) == 0)
40 | self.num_other_objects = np.sum(np.array(datum["other_object_pad_mask"]) == 0)
41 |
42 | self.initial_xyzs["xyzs"] = datum["xyzs"][:self.num_target_objects]
43 | self.initial_xyzs["rgbs"] = datum["rgbs"][:self.num_target_objects]
44 | self.initial_xyzs["other_xyzs"] = datum["other_xyzs"][:self.num_other_objects]
45 | self.initial_xyzs["other_rgbs"] = datum["other_rgbs"][:self.num_other_objects]
46 |
47 | def set_goal_poses(self, goal_struct_pose, goal_obj_poses, input_pose_format="xyz+3x3",
48 | n_obj_idxs=None, skip_update_struct=False):
49 | """
50 |
51 | :param goal_struct_pose:
52 | :param goal_obj_poses:
53 | :param input_pose_format:
54 | :param n_obj_idxs: only set the goal poses for the indexed objects
55 | :param skip_update_struct: if set to true, do not update struct pose
56 | :return:
57 | """
58 |
59 | # in these cases, we need to ensure the goal poses have already been set so that we can only update some of it
60 | if n_obj_idxs is not None or skip_update_struct:
61 | assert len(self.goal_poses["obj_poses"]) != 0
62 | if self.use_structure_frame:
63 | assert len(self.goal_poses["struct_pose"]) != 0
64 |
65 | if input_pose_format == "xyz+3x3":
66 | # check input
67 | if not skip_update_struct and self.use_structure_frame:
68 | assert len(goal_struct_pose) == 12
69 | if n_obj_idxs is None:
70 | # in case the input contains padding poses
71 | if len(goal_obj_poses) != self.num_target_objects:
72 | goal_obj_poses = goal_obj_poses[:self.num_target_objects]
73 | else:
74 | assert len(goal_obj_poses) == len(n_obj_idxs)
75 | assert all(len(gop) == 12 for gop in goal_obj_poses)
76 |
77 | # convert to standard form
78 | if not skip_update_struct and self.use_structure_frame:
79 | if type(goal_struct_pose) != list:
80 | goal_struct_pose = goal_struct_pose.tolist()
81 | if type(goal_obj_poses) != list:
82 | goal_obj_poses = goal_obj_poses.tolist()
83 | for i in range(len(goal_obj_poses)):
84 | if type(goal_obj_poses[i]) != list:
85 | goal_obj_poses[i] = goal_obj_poses.tolist()
86 |
87 | elif input_pose_format == "flat:xyz+3x3":
88 | # check input
89 | if not skip_update_struct and self.use_structure_frame:
90 | assert len(goal_struct_pose) == 12
91 | if n_obj_idxs is None:
92 | # flat means that object poses are in one list instead of a list of lists
93 | assert len(goal_obj_poses) == self.num_target_objects * 12
94 | else:
95 | assert len(goal_obj_poses) == len(n_obj_idxs) * 12
96 |
97 | # convert to standard form
98 | if not skip_update_struct and self.use_structure_frame:
99 | if type(goal_struct_pose) != list:
100 | goal_struct_pose = goal_struct_pose.tolist()
101 | if type(goal_obj_poses) != list:
102 | goal_obj_poses = goal_obj_poses.tolist()
103 |
104 | goal_obj_poses = np.array(goal_obj_poses).reshape(-1, 12).tolist()
105 |
106 | elif input_pose_format == "flat:xyz+rpy":
107 | # check input
108 | if not skip_update_struct and self.use_structure_frame:
109 | assert len(goal_struct_pose) == 6
110 | if n_obj_idxs is None:
111 | assert len(goal_obj_poses) == self.num_target_objects * 6
112 | else:
113 | assert len(goal_obj_poses) == len(n_obj_idxs) * 6
114 |
115 | # convert to standard form
116 | if not skip_update_struct and self.use_structure_frame:
117 | if type(goal_struct_pose) != list:
118 | goal_struct_pose = goal_struct_pose.tolist()
119 | if type(goal_obj_poses) != list:
120 | goal_obj_poses = np.array(goal_obj_poses).reshape(-1, 6).tolist()
121 |
122 | if not skip_update_struct and self.use_structure_frame:
123 | goal_struct_pose = goal_struct_pose[:3] + tra.euler_matrix(goal_struct_pose[3], goal_struct_pose[4], goal_struct_pose[5])[:3, :3].flatten().tolist()
124 | converted_goal_obj_poses = []
125 | for gop in goal_obj_poses:
126 | converted_goal_obj_poses.append(
127 | gop[:3] + tra.euler_matrix(gop[3], gop[4], gop[5])[:3, :3].flatten().tolist())
128 | goal_obj_poses = converted_goal_obj_poses
129 |
130 | else:
131 | raise KeyError
132 |
133 | # update
134 | if not skip_update_struct and self.use_structure_frame:
135 | self.goal_poses["struct_pose"] = goal_struct_pose
136 | if n_obj_idxs is None:
137 | self.goal_poses["obj_poses"] = goal_obj_poses
138 | else:
139 | for count, oi in enumerate(n_obj_idxs):
140 | self.goal_poses["obj_poses"][oi] = goal_obj_poses[count]
141 |
142 | def get_goal_poses(self, output_pose_format="xyz+3x3",
143 | n_obj_idxs=None, skip_update_struct=False, combine_struct_objs=False):
144 | """
145 |
146 | :param output_pose_format:
147 | :param n_obj_idxs: only retrieve the goal poses for the indexed objects
148 | :param skip_update_struct: if set to true, do not retrieve struct pose
149 | :param combine_struct_objs: one output, return a list of lists, where the first list if for the structure pose
150 | and remainings are for object poses
151 | :return:
152 | """
153 | if output_pose_format == "xyz+3x3":
154 | if self.use_structure_frame:
155 | goal_struct_pose = self.goal_poses["struct_pose"]
156 | goal_obj_poses = self.goal_poses["obj_poses"]
157 |
158 | if n_obj_idxs is not None:
159 | goal_obj_poses = [goal_obj_poses[i] for i in n_obj_idxs]
160 |
161 | elif output_pose_format == "flat:xyz+3x3":
162 | if self.use_structure_frame:
163 | goal_struct_pose = self.goal_poses["struct_pose"]
164 |
165 | if n_obj_idxs is None:
166 | goal_obj_poses = np.array(self.goal_poses["obj_poses"]).flatten().tolist()
167 | else:
168 | goal_obj_poses = np.array([self.goal_poses["obj_poses"][i] for i in n_obj_idxs]).flatten().tolist()
169 |
170 | elif output_pose_format == "flat:xyz+rpy":
171 | if self.use_structure_frame:
172 | ax, ay, az = tra.euler_from_matrix(np.asarray(self.goal_poses["struct_pose"][3:]).reshape(3, 3))
173 | goal_struct_pose = self.goal_poses["struct_pose"][:3] + [ax, ay, az]
174 |
175 | goal_obj_poses = []
176 | for gop in self.goal_poses["obj_poses"]:
177 | ax, ay, az = tra.euler_from_matrix(np.asarray(gop[3:]).reshape(3, 3))
178 | goal_obj_poses.append(gop[:3] + [ax, ay, az])
179 |
180 | if n_obj_idxs is None:
181 | goal_obj_poses = np.array(goal_obj_poses).flatten().tolist()
182 | else:
183 | goal_obj_poses = np.array([goal_obj_poses[i] for i in n_obj_idxs]).flatten().tolist()
184 |
185 | else:
186 | raise KeyError
187 |
188 | if not skip_update_struct and self.use_structure_frame:
189 | if not combine_struct_objs:
190 | return goal_struct_pose, goal_obj_poses
191 | else:
192 | return [goal_struct_pose] + goal_obj_poses
193 | else:
194 | return None, goal_obj_poses
195 |
196 | def rearrange(self, n_obj_idxs=None):
197 | """
198 | use stored object point clouds of the initial scene and goal poses to
199 | compute object point clouds of the goal scene.
200 |
201 | :param n_obj_idxs: only update the goal point clouds of indexed objects
202 | :return:
203 | """
204 |
205 | # initial scene and goal poses have to be set first
206 | assert all(len(self.initial_xyzs[k]) != 0 for k in ["xyzs", "rgbs"])
207 | assert all(len(self.goal_poses[k]) != 0 for k in self.goal_poses)
208 |
209 | # whether we are initializing or updating
210 | no_goal_xyzs_yet = True
211 | if len(self.goal_xyzs["xyzs"]):
212 | no_goal_xyzs_yet = False
213 |
214 | if n_obj_idxs is not None:
215 | assert no_goal_xyzs_yet is False
216 |
217 | if n_obj_idxs is not None:
218 | update_obj_idxs = n_obj_idxs
219 | else:
220 | update_obj_idxs = list(range(self.num_target_objects))
221 |
222 | if self.use_structure_frame:
223 | goal_struct_pose = self.goal_poses["struct_pose"]
224 | else:
225 | goal_struct_pose = None
226 | for obj_idx in update_obj_idxs:
227 | imagined_obj_xyz, imagined_obj_rgb = move_one_object_pc(self.initial_xyzs["xyzs"][obj_idx],
228 | self.initial_xyzs["rgbs"][obj_idx],
229 | self.goal_poses["obj_poses"][obj_idx],
230 | goal_struct_pose)
231 |
232 | if no_goal_xyzs_yet:
233 | self.goal_xyzs["xyzs"].append(imagined_obj_xyz)
234 | self.goal_xyzs["rgbs"].append(imagined_obj_rgb)
235 | else:
236 | self.goal_xyzs["xyzs"][obj_idx] = imagined_obj_xyz
237 | self.goal_xyzs["rgbs"][obj_idx] = imagined_obj_rgb
238 |
239 | def visualize(self, time_step, add_other_objects=False,
240 | add_coordinate_frame=False, side_view=False, add_table=False,
241 | show_vis=True, save_vis=False, save_filename=None, order_color=False):
242 |
243 | if time_step == "initial":
244 | xyzs = self.initial_xyzs["xyzs"]
245 | rgbs = self.initial_xyzs["rgbs"]
246 | elif time_step == "goal":
247 | xyzs = self.goal_xyzs["xyzs"]
248 | rgbs = self.goal_xyzs["rgbs"]
249 | else:
250 | raise KeyError()
251 |
252 | if add_other_objects:
253 | xyzs += self.initial_xyzs["other_xyzs"]
254 | rgbs += self.initial_xyzs["other_rgbs"]
255 |
256 | if show_vis:
257 | if not order_color:
258 | show_pcs(xyzs, rgbs, add_coordinate_frame=add_coordinate_frame, side_view=side_view, add_table=add_table)
259 | else:
260 | show_pcs_color_order(xyzs, rgbs, add_coordinate_frame=add_coordinate_frame, side_view=side_view, add_table=add_table)
261 |
262 | if save_vis and save_filename is not None:
263 | save_pcs(xyzs, rgbs, save_path=save_filename, add_coordinate_frame=add_coordinate_frame, side_view=side_view, add_table=add_table)
--------------------------------------------------------------------------------
/src/structformer/evaluation/test_binary_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import os
4 | import copy
5 | import tqdm
6 | import argparse
7 | from omegaconf import OmegaConf
8 | import trimesh
9 | import time
10 | from torch.utils.data import DataLoader
11 |
12 | from structformer.data.tokenizer import Tokenizer
13 | import structformer.data.binary_dataset as prior_dataset
14 | import structformer.training.train_binary_model as prior_model
15 | from structformer.utils.rearrangement import move_one_object_pc, make_gifs, \
16 | modify_language, sample_gaussians, fit_gaussians, get_initial_scene_idxs, show_pcs, save_pcs
17 |
18 |
19 | def test_model(model_dir, dirs_cfg):
20 | prior_inference = PriorInference(model_dir, dirs_cfg, data_split="test")
21 | prior_inference.validate()
22 |
23 |
24 | class PriorInference:
25 |
26 | def __init__(self, model_dir, dirs_cfg, data_split="test"):
27 | # load prior
28 | cfg, tokenizer, model, optimizer, scheduler, epoch = prior_model.load_model(model_dir, dirs_cfg)
29 |
30 | data_cfg = cfg.dataset
31 |
32 | dataset = prior_dataset.BinaryDataset(data_cfg.dirs, data_cfg.index_dirs, data_split, tokenizer,
33 | data_cfg.max_num_objects,
34 | data_cfg.max_num_other_objects,
35 | data_cfg.max_num_shape_parameters,
36 | data_cfg.max_num_rearrange_features,
37 | data_cfg.max_num_anchor_features,
38 | data_cfg.num_pts)
39 | self.cfg = cfg
40 | self.tokenizer = tokenizer
41 | self.model = model
42 | self.cfg = cfg
43 | self.dataset = dataset
44 | self.epoch = epoch
45 |
46 | def validate(self):
47 | data_cfg = self.cfg.dataset
48 | data_iter = DataLoader(self.dataset, batch_size=data_cfg.batch_size, shuffle=False,
49 | collate_fn=prior_dataset.BinaryDataset.collate_fn,
50 | pin_memory=data_cfg.pin_memory, num_workers=data_cfg.num_workers)
51 |
52 | prior_model.validate(self.cfg, self.model, data_iter, self.epoch, self.cfg.device)
53 |
54 | def limited_batch_inference(self, data, t, verbose=False):
55 | """
56 | This function makes the assumption that scenes in the batch have the same number of objects that need to be
57 | rearranged
58 |
59 | :param data:
60 | :param model:
61 | :param test_dataset:
62 | :param tokenizer:
63 | :param cfg:
64 | :param num_samples:
65 | :param verbose:
66 | :return:
67 | """
68 |
69 | data_size = len(data)
70 | batch_size = self.cfg.dataset.batch_size
71 | if verbose:
72 | print("data size:", data_size)
73 | print("batch size:", batch_size)
74 |
75 | num_batches = int(data_size / batch_size)
76 | if data_size % batch_size != 0:
77 | num_batches += 1
78 |
79 | all_obj_preds = []
80 | all_binary_data = []
81 | for b in range(num_batches):
82 | if b + 1 == num_batches:
83 | # last batch
84 | batch = data[b * batch_size:]
85 | else:
86 | batch = data[b * batch_size: (b+1) * batch_size]
87 |
88 | binary_data = [self.dataset.convert_sequence_to_binary(d, t) for d in batch]
89 | data_tensors = [self.dataset.convert_to_tensors(d, self.tokenizer) for d in binary_data]
90 | data_tensors = self.dataset.collate_fn(data_tensors)
91 | predictions = prior_model.infer_once(self.cfg, self.model, data_tensors, self.cfg.device)
92 |
93 | obj_x_preds = torch.cat(predictions["obj_x_outputs"], dim=0)
94 | obj_y_preds = torch.cat(predictions["obj_y_outputs"], dim=0)
95 | obj_z_preds = torch.cat(predictions["obj_z_outputs"], dim=0)
96 | obj_theta_preds = torch.cat(predictions["obj_theta_outputs"], dim=0)
97 | obj_preds = torch.cat([obj_x_preds, obj_y_preds, obj_z_preds, obj_theta_preds], dim=1) # batch_size * max num objects, output_dim
98 |
99 | all_obj_preds.append(obj_preds)
100 | all_binary_data.extend(binary_data)
101 |
102 | obj_preds = torch.cat(all_obj_preds, dim=0) # data_size * max num objects, output_dim
103 | obj_preds = obj_preds.detach().cpu().numpy()
104 | obj_preds = obj_preds.reshape(data_size, obj_preds.shape[-1]) # batch_size, max num objects, output_dim
105 |
106 | return obj_preds, all_binary_data
107 |
108 |
109 | def inference_beam_decoding(model_dir, dirs_cfg, beam_size=100, max_scene_decodes=30000,
110 | visualize=True, visualize_action_sequence=False,
111 | inference_visualization_dir=None):
112 | """
113 | This function decodes a scene with a single forward pass
114 |
115 | :param model_dir:
116 | :param discriminator_model_dir:
117 | :param inference_visualization_dir:
118 | :param visualize:
119 | :param num_samples: number of MDN samples drawn, in this case it's also the number of rearrangements
120 | :param keep_steps:
121 | :param initial_scenes_only:
122 | :param verbose:
123 | :return:
124 | """
125 |
126 | if inference_visualization_dir and not os.path.exists(inference_visualization_dir):
127 | os.makedirs(inference_visualization_dir)
128 |
129 | prior_inference = PriorInference(model_dir, dirs_cfg)
130 | test_dataset = prior_inference.dataset
131 |
132 | initial_scene_idxs = get_initial_scene_idxs(test_dataset)
133 |
134 | decoded_scene_count = 0
135 | with tqdm.tqdm(total=len(initial_scene_idxs)) as pbar:
136 | # for idx in np.random.choice(range(len(test_dataset)), len(test_dataset), replace=False):
137 | for idx in initial_scene_idxs:
138 |
139 | if decoded_scene_count == max_scene_decodes:
140 | break
141 |
142 | filename, step_t = test_dataset.get_data_index(idx)
143 | scene_id = os.path.split(filename)[1][4:-3]
144 |
145 | decoded_scene_count += 1
146 |
147 | ############################################
148 | # retrieve data
149 | beam_data = []
150 | num_target_objects = None
151 | for b in range(beam_size):
152 | datum = test_dataset.get_raw_sequence_data(idx)
153 | beam_data.append(datum)
154 |
155 | if num_target_objects is None:
156 | num_target_objects = len(datum["xyzs"])
157 |
158 | # We can play with different language here
159 | # datum["sentence"] = modify_language(datum["sentence"], radius=0.5)
160 |
161 | if visualize:
162 | datum = beam_data[0]
163 | print("#"*50)
164 | print("sentence", datum["sentence"])
165 | print("num target objects", num_target_objects)
166 | show_pcs(datum["xyzs"] + datum["other_bg_xyzs"],
167 | datum["rgbs"] + datum["other_bg_rgbs"],
168 | add_coordinate_frame=False, side_view=True, add_table=True)
169 |
170 | ############################################
171 |
172 | beam_predicted_parameters = [[]] * beam_size
173 | for time_index in range(num_target_objects):
174 |
175 | # iteratively decoding
176 | target_object_preds, binary_data = prior_inference.limited_batch_inference(beam_data, time_index)
177 |
178 | for b in range(beam_size):
179 | # a list of list, where each inside list contains xyz, 3x3 rotation
180 |
181 | datum = beam_data[b]
182 | binary_datum = binary_data[b]
183 | obj_pred = target_object_preds[b]
184 |
185 | #------------
186 | goal_query_pc_translation_offset = obj_pred[:3]
187 | goal_query_pc_rotation = np.eye(4)
188 | goal_query_pc_rotation[:3, :3] = np.array(obj_pred[3:]).reshape(3, 3)
189 |
190 | query_obj_xyz = binary_datum["query_xyz"]
191 | anchor_obj_xyz = binary_datum["anchor_xyz"]
192 |
193 |
194 | current_query_pc_center = torch.mean(query_obj_xyz, dim=0).numpy()[:3]
195 | current_anchor_pc_center = torch.mean(anchor_obj_xyz, dim=0).numpy()[:3]
196 |
197 | t = np.eye(4)
198 | t[:3, 3] = current_anchor_pc_center + goal_query_pc_translation_offset - current_query_pc_center
199 | new_query_obj_xyz = trimesh.transform_points(query_obj_xyz, t)
200 |
201 | # rotating in place
202 | # R = tra.euler_matrix(0, 0, obj_pc_rotations[i])
203 | query_obj_center = np.mean(new_query_obj_xyz, axis=0)
204 | centered_query_obj_xyz = new_query_obj_xyz - query_obj_center
205 | new_centered_query_obj_xyz = trimesh.transform_points(centered_query_obj_xyz, goal_query_pc_rotation,
206 | translate=True)
207 | new_query_obj_xyz = new_centered_query_obj_xyz + query_obj_center
208 | new_query_obj_xyz = torch.tensor(new_query_obj_xyz, dtype=query_obj_xyz.dtype)
209 |
210 | # vis_query_obj_rgb = np.tile(np.array([0, 1, 0], dtype=np.float), (query_obj_xyz.shape[0], 1))
211 | # vis_anchor_obj_rgb = np.tile(np.array([1, 0, 0], dtype=np.float), (anchor_obj_xyz.shape[0], 1))
212 | # vis_new_query_obj_rgb = np.tile(np.array([0, 0, 1], dtype=np.float), (new_query_obj_xyz.shape[0], 1))
213 | # show_pcs([new_query_obj_xyz, query_obj_xyz, anchor_obj_xyz],
214 | # [vis_new_query_obj_rgb, vis_query_obj_rgb, vis_anchor_obj_rgb],
215 | # add_coordinate_frame=True)
216 |
217 | datum["xyzs"][time_index] = new_query_obj_xyz
218 |
219 | current_object_param = t[:3, 3].tolist() + goal_query_pc_rotation[:3, :3].flatten().tolist()
220 | beam_predicted_parameters[b].append(current_object_param)
221 |
222 | for b in range(beam_size):
223 | datum = beam_data[b]
224 |
225 | pc_sizes = [xyz.shape[0] for xyz in datum["other_bg_xyzs"]]
226 | table_idx = np.argmax(pc_sizes)
227 | show_pcs(datum["xyzs"] + [xyz for i, xyz in enumerate(datum["other_bg_xyzs"]) if i != table_idx],
228 | datum["rgbs"] + [rgb for i, rgb in enumerate(datum["other_bg_rgbs"]) if i != table_idx],
229 | add_coordinate_frame=False, side_view=True, add_table=True)
230 |
231 | pbar.update(1)
232 |
233 |
234 | if __name__ == "__main__":
235 | parser = argparse.ArgumentParser(description="Run a simple model")
236 | parser.add_argument("--dataset_base_dir", help='location of the dataset', type=str)
237 | parser.add_argument("--model_dir", help='location for the saved model', type=str)
238 | parser.add_argument("--dirs_config", help='config yaml file for directories', default="", type=str)
239 | args = parser.parse_args()
240 |
241 | os.environ["DATETIME"] = time.strftime("%Y%m%d-%H%M%S")
242 |
243 | # # debug only
244 | # args.dataset_base_dir = "/home/weiyu/data_drive/data_new_objects_test_split"
245 | # args.model_dir = "/home/weiyu/Research/intern/StructFormer/models/binary_model_tower/best_model"
246 | # args.dirs_config = "/home/weiyu/Research/intern/StructFormer/structformer/configs/data/tower_dirs.yaml"
247 |
248 | if args.dirs_config:
249 | assert os.path.exists(args.dirs_config), "Cannot find config yaml file at {}".format(args.dirs_config)
250 | dirs_cfg = OmegaConf.load(args.dirs_config)
251 | dirs_cfg.dataset_base_dir = args.dataset_base_dir
252 | OmegaConf.resolve(dirs_cfg)
253 | else:
254 | dirs_cfg = None
255 |
256 | inference_beam_decoding(args.model_dir, dirs_cfg, beam_size=3, max_scene_decodes=30000,
257 | visualize=True, visualize_action_sequence=False,
258 | inference_visualization_dir=None)
259 |
260 |
--------------------------------------------------------------------------------
/src/structformer/evaluation/test_object_selection_network.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import os
4 | import tqdm
5 | import time
6 | import argparse
7 | from omegaconf import OmegaConf
8 |
9 | from torch.utils.data import DataLoader
10 | from structformer.data.object_set_refer_dataset import ObjectSetReferDataset
11 | from structformer.training.train_object_selection_network import load_model, validate, infer_once
12 | from structformer.utils.rearrangement import show_pcs_with_predictions, get_initial_scene_idxs, evaluate_target_object_predictions, save_img, show_pcs_with_labels, test_new_vis
13 |
14 |
15 | class ObjectSelectionInference:
16 |
17 | def __init__(self, model_dir, dirs_cfg, data_split="test"):
18 | # load prior
19 | cfg, tokenizer, model, optimizer, scheduler, epoch = load_model(model_dir, dirs_cfg)
20 |
21 | data_cfg = cfg.dataset
22 | test_dataset = ObjectSetReferDataset(data_cfg.dirs, data_cfg.index_dirs, data_split, tokenizer,
23 | data_cfg.max_num_all_objects,
24 | data_cfg.max_num_shape_parameters,
25 | data_cfg.max_num_rearrange_features,
26 | data_cfg.max_num_anchor_features,
27 | data_cfg.num_pts)
28 |
29 | self.cfg = cfg
30 | self.tokenizer = tokenizer
31 | self.model = model
32 | self.cfg = cfg
33 | self.dataset = test_dataset
34 | self.epoch = epoch
35 |
36 | def prepare_datum(self, obj_xyzs, obj_rgbs, goal_specification, structure_parameters, gt_num_rearrange_objects):
37 | datum = self.dataset.prepare_test_data(obj_xyzs, obj_rgbs, goal_specification, structure_parameters, gt_num_rearrange_objects)
38 | return datum
39 |
40 | def predict_target_objects(self, datum):
41 |
42 | batch = self.dataset.collate_fn([self.dataset.convert_to_tensors(datum, self.tokenizer)])
43 |
44 | gts, predictions = infer_once(self.model, batch, self.cfg.device)
45 | gts = gts["rearrange_obj_labels"][0].detach().cpu().numpy()
46 | predictions = predictions["rearrange_obj_labels"][0].detach().cpu().numpy()
47 | predictions = predictions > 0.5
48 | # remove paddings
49 | target_mask = gts != -100
50 | gts = gts[target_mask]
51 | predictions = predictions[target_mask]
52 |
53 | return predictions, gts
54 |
55 | def validate(self):
56 | """
57 | validate the pretrained model on the dataset
58 |
59 | :return:
60 | """
61 | data_cfg = self.cfg.dataset
62 | data_iter = DataLoader(self.dataset, batch_size=data_cfg.batch_size, shuffle=False,
63 | collate_fn=ObjectSetReferDataset.collate_fn,
64 | pin_memory=data_cfg.pin_memory, num_workers=data_cfg.num_workers)
65 |
66 | validate(self.model, data_iter, self.epoch, self.cfg.device)
67 |
68 |
69 | def inference(model_dir, dirs_cfg, visualize=False, inference_visualization_dir=None):
70 |
71 | # object selection information
72 | # goal = {"rearrange": {"features": [], "objects": [], "combine_features_logic": None, "count": None},
73 | # "anchor": {"features": [], "objects": [], "combine_features_logic": None},
74 | # "distract": {"features": [], "objects": [], "combine_features_logic": None},
75 | # "random_selection": {"varying_features": [], "nonvarying_features": []},
76 | # "order": {"feature": None}
77 | # }
78 |
79 | if inference_visualization_dir:
80 | if not os.path.exists(inference_visualization_dir):
81 | os.makedirs(inference_visualization_dir)
82 |
83 | object_selection_inference = ObjectSelectionInference(model_dir, dirs_cfg)
84 | test_dataset = object_selection_inference.dataset
85 |
86 | initial_scene_idxs = get_initial_scene_idxs(test_dataset)
87 |
88 | all_predictions = []
89 | all_gts = []
90 | all_goal_specifications = []
91 | all_sentences = []
92 |
93 | count = 0
94 | for idx in tqdm.tqdm(range(len(test_dataset))):
95 |
96 | if idx not in initial_scene_idxs:
97 | continue
98 |
99 | count += 1
100 |
101 | datum = test_dataset.get_raw_data(idx)
102 | goal_specification = datum["goal_specification"]
103 | object_selection_sentence = datum["sentence"][5:]
104 | reference_sentence = object_selection_inference.tokenizer.convert_to_natural_sentence(
105 | object_selection_sentence)
106 |
107 | predictions, gts = object_selection_inference.predict_target_objects(datum)
108 |
109 | if visualize:
110 | print(gts)
111 | print(predictions)
112 | print(object_selection_sentence)
113 | print(reference_sentence)
114 | print("Visualize groundtruth (dot color) and prediction (ring color)")
115 | show_pcs_with_predictions(datum["xyzs"][:len(gts)], datum["rgbs"][:len(gts)],
116 | gts, predictions, add_coordinate_frame=False)
117 |
118 | if inference_visualization_dir:
119 | buffer = show_pcs_with_predictions(datum["xyzs"][:len(gts)], datum["rgbs"][:len(gts)],
120 | gts, predictions, add_coordinate_frame=False, return_buffer=True)
121 | img = np.uint8(np.asarray(buffer) * 255)
122 | save_img(img, os.path.join(inference_visualization_dir, "scene_{}.png".format(idx)), text=reference_sentence)
123 |
124 | # try:
125 | # batch = test_dataset.collate_fn([test_dataset.convert_to_tensors(datum, tokenizer)])
126 | # except KeyError:
127 | # print("skipping this for now because we are using an outdated model with an old vocab")
128 | # continue
129 | #
130 | # goal_specification = datum["goal_specification"]
131 | #
132 | # gts, predictions = infer_once(model, batch, cfg.device)
133 | # gts = gts["rearrange_obj_labels"][0].detach().cpu().numpy()
134 | # predictions = predictions["rearrange_obj_labels"][0].detach().cpu().numpy()
135 | # predictions = predictions > 0.5
136 | # # remove paddings
137 | # target_mask = gts != -100
138 | # gts = gts[target_mask]
139 | # predictions = predictions[target_mask]
140 | #
141 | # object_selection_sentence = datum["sentence"][5:]
142 | #
143 | # if visualize:
144 | # print(gts)
145 | # print(predictions)
146 | # print(object_selection_sentence)
147 | # print(tokenizer.convert_to_natural_sentence(object_selection_sentence))
148 | #
149 | # if inference_visualization_dir is None:
150 | # show_pcs_with_predictions(datum["xyzs"][:len(gts)], datum["rgbs"][:len(gts)],
151 | # gts, predictions, add_coordinate_frame=False)
152 | # # show_pcs_with_only_predictions(datum["xyzs"][:len(gts)], datum["rgbs"][:len(gts)],
153 | # # gts, predictions, add_coordinate_frame=False)
154 | # # test_new_vis(datum["xyzs"][:len(gts)], datum["rgbs"][:len(gts)])
155 | # else:
156 | # save_filename = os.path.join(inference_visualization_dir, "{}.png".format(idx))
157 | # buffer = show_pcs_with_predictions(datum["xyzs"][:len(gts)], datum["rgbs"][:len(gts)],
158 | # gts, predictions, add_coordinate_frame=False, return_buffer=True)
159 | # img = np.uint8(np.asarray(buffer) * 255)
160 | # save_img(img, save_filename, text=tokenizer.convert_to_natural_sentence(datum["sentence"]))
161 |
162 | all_predictions.append(predictions)
163 | all_gts.append(gts)
164 | all_goal_specifications.append(goal_specification)
165 | all_sentences.append(object_selection_sentence)
166 |
167 | # create a more detailed report
168 | evaluate_target_object_predictions(all_gts, all_predictions, all_sentences, initial_scene_idxs,
169 | object_selection_inference.tokenizer)
170 |
171 |
172 | if __name__ == "__main__":
173 | parser = argparse.ArgumentParser(description="Run a simple model")
174 | parser.add_argument("--dataset_base_dir", help='location of the dataset', type=str)
175 | parser.add_argument("--model_dir", help='location for the saved model', type=str)
176 | parser.add_argument("--dirs_config", help='config yaml file for directories', default="", type=str)
177 | parser.add_argument("--inference_visualization_dir", help='location for saving visualizations of inference results',
178 | type=str, default=None)
179 | parser.add_argument("--visualize", default=1, type=int, help='whether to visualize inference results while running')
180 | args = parser.parse_args()
181 |
182 | os.environ["DATETIME"] = time.strftime("%Y%m%d-%H%M%S")
183 |
184 | # # debug only
185 | # args.dataset_base_dir = "/home/weiyu/data_drive/data_new_objects_test_split"
186 | # args.model_dir = "/home/weiyu/Research/intern/StructFormer/models/object_selection_network/best_model"
187 | # args.dirs_config = "/home/weiyu/Research/intern/StructFormer/structformer/configs/data/line_dirs.yaml"
188 | # args.visualize = True
189 |
190 | if args.dirs_config:
191 | assert os.path.exists(args.dirs_config), "Cannot find config yaml file at {}".format(args.dirs_config)
192 | dirs_cfg = OmegaConf.load(args.dirs_config)
193 | dirs_cfg.dataset_base_dir = args.dataset_base_dir
194 | OmegaConf.resolve(dirs_cfg)
195 | else:
196 | dirs_cfg = None
197 |
198 | inference(args.model_dir, dirs_cfg, args.visualize, args.inference_visualization_dir)
--------------------------------------------------------------------------------
/src/structformer/evaluation/test_structformer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import os
4 | import copy
5 | import tqdm
6 | import argparse
7 | from omegaconf import OmegaConf
8 | import time
9 |
10 | from torch.utils.data import DataLoader
11 |
12 | from structformer.data.tokenizer import Tokenizer
13 | import structformer.data.sequence_dataset as prior_dataset
14 | import structformer.training.train_structformer as prior_model
15 | from structformer.utils.rearrangement import show_pcs
16 | from structformer.evaluation.inference import PointCloudRearrangement
17 |
18 |
19 | def test_model(model_dir, dirs_cfg):
20 | prior_inference = PriorInference(model_dir, dirs_cfg, data_split="test")
21 | prior_inference.validate()
22 |
23 |
24 | class PriorInference:
25 |
26 | def __init__(self, model_dir, dirs_cfg, data_split="test"):
27 |
28 | cfg, tokenizer, model, optimizer, scheduler, epoch = prior_model.load_model(model_dir, dirs_cfg)
29 |
30 | data_cfg = cfg.dataset
31 |
32 | dataset = prior_dataset.SequenceDataset(data_cfg.dirs, data_cfg.index_dirs, data_split, tokenizer,
33 | data_cfg.max_num_objects,
34 | data_cfg.max_num_other_objects,
35 | data_cfg.max_num_shape_parameters,
36 | data_cfg.max_num_rearrange_features,
37 | data_cfg.max_num_anchor_features,
38 | data_cfg.num_pts,
39 | data_cfg.use_structure_frame)
40 |
41 | self.cfg = cfg
42 | self.tokenizer = tokenizer
43 | self.model = model
44 | self.cfg = cfg
45 | self.dataset = dataset
46 | self.epoch = epoch
47 |
48 | def validate(self):
49 | """
50 | validate the pretrained model on the dataset
51 |
52 | :return:
53 | """
54 | data_cfg = self.cfg.dataset
55 | data_iter = DataLoader(self.dataset, batch_size=data_cfg.batch_size, shuffle=False,
56 | collate_fn=prior_dataset.SequenceDataset.collate_fn,
57 | pin_memory=data_cfg.pin_memory, num_workers=data_cfg.num_workers)
58 |
59 | prior_model.validate(self.cfg, self.model, data_iter, self.epoch, self.cfg.device)
60 |
61 | def limited_batch_inference(self, data, verbose=False):
62 | """
63 | This function makes the assumption that scenes in the batch have the same number of objects that need to be
64 | rearranged
65 |
66 | :param data:
67 | :param model:
68 | :param test_dataset:
69 | :param tokenizer:
70 | :param cfg:
71 | :param num_samples:
72 | :param verbose:
73 | :return:
74 | """
75 |
76 | data_size = len(data)
77 | batch_size = self.cfg.dataset.batch_size
78 | if verbose:
79 | print("data size:", data_size)
80 | print("batch size:", batch_size)
81 |
82 | num_batches = int(data_size / batch_size)
83 | if data_size % batch_size != 0:
84 | num_batches += 1
85 |
86 | all_obj_preds = []
87 | all_struct_preds = []
88 | for b in range(num_batches):
89 | if b + 1 == num_batches:
90 | # last batch
91 | batch = data[b * batch_size:]
92 | else:
93 | batch = data[b * batch_size: (b+1) * batch_size]
94 | data_tensors = [self.dataset.convert_to_tensors(d, self.tokenizer) for d in batch]
95 | data_tensors = self.dataset.collate_fn(data_tensors)
96 | predictions = prior_model.infer_once(self.cfg, self.model, data_tensors, self.cfg.device)
97 |
98 | obj_x_preds = torch.cat(predictions["obj_x_outputs"], dim=0)
99 | obj_y_preds = torch.cat(predictions["obj_y_outputs"], dim=0)
100 | obj_z_preds = torch.cat(predictions["obj_z_outputs"], dim=0)
101 | obj_theta_preds = torch.cat(predictions["obj_theta_outputs"], dim=0)
102 | obj_preds = torch.cat([obj_x_preds, obj_y_preds, obj_z_preds, obj_theta_preds], dim=1) # batch_size * max num objects, output_dim
103 |
104 | struct_x_preds = torch.cat(predictions["struct_x_inputs"], dim=0)
105 | struct_y_preds = torch.cat(predictions["struct_y_inputs"], dim=0)
106 | struct_z_preds = torch.cat(predictions["struct_z_inputs"], dim=0)
107 | struct_theta_preds = torch.cat(predictions["struct_theta_inputs"], dim=0)
108 | struct_preds = torch.cat([struct_x_preds, struct_y_preds, struct_z_preds, struct_theta_preds], dim=1) # batch_size, output_dim
109 |
110 | all_obj_preds.append(obj_preds)
111 | all_struct_preds.append(struct_preds)
112 |
113 | obj_preds = torch.cat(all_obj_preds, dim=0) # data_size * max num objects, output_dim
114 | struct_preds = torch.cat(all_struct_preds, dim=0) # data_size, output_dim
115 |
116 | obj_preds = obj_preds.detach().cpu().numpy()
117 | struct_preds = struct_preds.detach().cpu().numpy()
118 |
119 | obj_preds = obj_preds.reshape(data_size, -1, obj_preds.shape[-1]) # batch_size, max num objects, output_dim
120 |
121 | return struct_preds, obj_preds
122 |
123 |
124 | def inference_beam_decoding(model_dir, dirs_cfg, beam_size=100, max_scene_decodes=30000,
125 | visualize=True, visualize_action_sequence=False,
126 | inference_visualization_dir=None):
127 | """
128 |
129 | :param model_dir:
130 | :param beam_size:
131 | :param max_scene_decodes:
132 | :param visualize:
133 | :param visualize_action_sequence:
134 | :param inference_visualization_dir:
135 | :param side_view:
136 | :return:
137 | """
138 |
139 | if inference_visualization_dir and not os.path.exists(inference_visualization_dir):
140 | os.makedirs(inference_visualization_dir)
141 |
142 | prior_inference = PriorInference(model_dir, dirs_cfg)
143 | test_dataset = prior_inference.dataset
144 |
145 | decoded_scene_count = 0
146 | with tqdm.tqdm(total=len(test_dataset)) as pbar:
147 | # for idx in np.random.choice(range(len(test_dataset)), len(test_dataset), replace=False):
148 | for idx in range(len(test_dataset)):
149 |
150 | if decoded_scene_count == max_scene_decodes:
151 | break
152 |
153 | filename = test_dataset.get_data_index(idx)
154 | scene_id = os.path.split(filename)[1][4:-3]
155 |
156 | decoded_scene_count += 1
157 |
158 | ############################################
159 | # retrieve data
160 | beam_data = []
161 | beam_pc_rearrangements = []
162 | for b in range(beam_size):
163 | datum = test_dataset.get_raw_data(idx, inference_mode=True, shuffle_object_index=False)
164 |
165 | # not necessary, but just to ensure no test leakage
166 | datum["struct_x_inputs"] = [0]
167 | datum["struct_y_inputs"] = [0]
168 | datum["struct_y_inputs"] = [0]
169 | datum["struct_theta_inputs"] = [[0] * 9]
170 | for obj_idx in range(len(datum["obj_x_inputs"])):
171 | datum["obj_x_inputs"][obj_idx] = 0
172 | datum["obj_y_inputs"][obj_idx] = 0
173 | datum["obj_z_inputs"][obj_idx] = 0
174 | datum["obj_theta_inputs"][obj_idx] = [0] * 9
175 |
176 | # We can play with different language here
177 | # datum["sentence"] = modify_language(datum["sentence"], radius=0.5)
178 | # datum["sentence"] = modify_language(datum["sentence"], position_x=1)
179 | # datum["sentence"] = modify_language(datum["sentence"], position_y=0.5)
180 |
181 | beam_data.append(datum)
182 | beam_pc_rearrangements.append(PointCloudRearrangement(datum))
183 |
184 | if visualize:
185 | datum = beam_data[0]
186 | print("#"*50)
187 | print("sentence", datum["sentence"])
188 | show_pcs(datum["xyzs"] + datum["other_xyzs"], datum["rgbs"] + datum["other_rgbs"],
189 | add_coordinate_frame=False, side_view=True, add_table=True)
190 |
191 | ############################################
192 | # autoregressive decoding
193 | num_target_objects = beam_pc_rearrangements[0].num_target_objects
194 | # first predict structure pose
195 | beam_goal_struct_pose, target_object_preds = prior_inference.limited_batch_inference(beam_data)
196 | for b in range(beam_size):
197 | datum = beam_data[b]
198 | datum["struct_x_inputs"] = [beam_goal_struct_pose[b][0]]
199 | datum["struct_y_inputs"] = [beam_goal_struct_pose[b][1]]
200 | datum["struct_z_inputs"] = [beam_goal_struct_pose[b][2]]
201 | datum["struct_theta_inputs"] = [beam_goal_struct_pose[b][3:]]
202 |
203 | # then iteratively predict pose of each object
204 | beam_goal_obj_poses = []
205 | for obj_idx in range(num_target_objects):
206 | struct_preds, target_object_preds = prior_inference.limited_batch_inference(beam_data)
207 | beam_goal_obj_poses.append(target_object_preds[:, obj_idx])
208 | for b in range(beam_size):
209 | datum = beam_data[b]
210 | datum["obj_x_inputs"][obj_idx] = target_object_preds[b][obj_idx][0]
211 | datum["obj_y_inputs"][obj_idx] = target_object_preds[b][obj_idx][1]
212 | datum["obj_z_inputs"][obj_idx] = target_object_preds[b][obj_idx][2]
213 | datum["obj_theta_inputs"][obj_idx] = target_object_preds[b][obj_idx][3:]
214 | # concat in the object dim
215 | beam_goal_obj_poses = np.stack(beam_goal_obj_poses, axis=0)
216 | # swap axis
217 | beam_goal_obj_poses = np.swapaxes(beam_goal_obj_poses, 1, 0) # batch size, number of target objects, pose dim
218 |
219 | ############################################
220 | # move pc
221 | for bi in range(beam_size):
222 | beam_pc_rearrangements[bi].set_goal_poses(beam_goal_struct_pose[bi], beam_goal_obj_poses[bi])
223 | beam_pc_rearrangements[bi].rearrange()
224 |
225 | ############################################
226 | if visualize:
227 | for pc_rearrangement in beam_pc_rearrangements:
228 | pc_rearrangement.visualize("goal", add_other_objects=True,
229 | add_coordinate_frame=False, side_view=True, add_table=True)
230 |
231 | if inference_visualization_dir:
232 | for pc_rearrangement in beam_pc_rearrangements:
233 | pc_rearrangement.visualize("goal", add_other_objects=True,
234 | add_coordinate_frame=False, side_view=True, add_table=True,
235 | save_vis=True,
236 | save_filename=os.path.join(inference_visualization_dir, "{}.jpg".format(scene_id)))
237 |
238 | pbar.update(1)
239 |
240 |
241 | if __name__ == "__main__":
242 | parser = argparse.ArgumentParser(description="Run a simple model")
243 | parser.add_argument("--dataset_base_dir", help='location of the dataset', type=str)
244 | parser.add_argument("--model_dir", help='location for the saved model', type=str)
245 | parser.add_argument("--dirs_config", help='config yaml file for directories', default="", type=str)
246 | args = parser.parse_args()
247 |
248 | os.environ["DATETIME"] = time.strftime("%Y%m%d-%H%M%S")
249 |
250 | # # debug only
251 | # args.dataset_base_dir = "/home/weiyu/data_drive/data_new_objects_test_split"
252 | # args.model_dir = "/home/weiyu/Research/intern/StructFormer/models/structformer_line/best_model"
253 | # args.dirs_config = "/home/weiyu/Research/intern/StructFormer/structformer/configs/data/line_dirs.yaml"
254 |
255 | if args.dirs_config:
256 | assert os.path.exists(args.dirs_config), "Cannot find config yaml file at {}".format(args.dirs_config)
257 | dirs_cfg = OmegaConf.load(args.dirs_config)
258 | dirs_cfg.dataset_base_dir = args.dataset_base_dir
259 | OmegaConf.resolve(dirs_cfg)
260 | else:
261 | dirs_cfg = None
262 |
263 | inference_beam_decoding(args.model_dir, dirs_cfg, beam_size=3, max_scene_decodes=30000,
264 | visualize=True, visualize_action_sequence=False,
265 | inference_visualization_dir=None)
266 |
--------------------------------------------------------------------------------
/src/structformer/evaluation/test_structformer_no_encoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import os
4 | import copy
5 | import tqdm
6 | import argparse
7 | from omegaconf import OmegaConf
8 | import time
9 | from torch.utils.data import DataLoader
10 |
11 | from structformer.data.tokenizer import Tokenizer
12 | import structformer.data.sequence_dataset as prior_dataset
13 | import structformer.training.train_structformer_no_encoder as prior_model
14 | from structformer.utils.rearrangement import show_pcs
15 | from structformer.evaluation.inference import PointCloudRearrangement
16 |
17 |
18 | def test_model(model_dir, dirs_cfg):
19 | prior_inference = PriorInference(model_dir, dirs_cfg, data_split="test")
20 | prior_inference.validate()
21 |
22 |
23 | class PriorInference:
24 |
25 | def __init__(self, model_dir, dirs_cfg, data_split="test"):
26 |
27 | cfg, tokenizer, model, optimizer, scheduler, epoch = prior_model.load_model(model_dir, dirs_cfg)
28 |
29 | data_cfg = cfg.dataset
30 |
31 | dataset = prior_dataset.SequenceDataset(data_cfg.dirs, data_cfg.index_dirs, data_split, tokenizer,
32 | data_cfg.max_num_objects,
33 | data_cfg.max_num_other_objects,
34 | data_cfg.max_num_shape_parameters,
35 | data_cfg.max_num_rearrange_features,
36 | data_cfg.max_num_anchor_features,
37 | data_cfg.num_pts,
38 | data_cfg.use_structure_frame)
39 |
40 | self.cfg = cfg
41 | self.tokenizer = tokenizer
42 | self.model = model
43 | self.cfg = cfg
44 | self.dataset = dataset
45 | self.epoch = epoch
46 |
47 | def validate(self):
48 | """
49 | validate the pretrained model on the dataset
50 |
51 | :return:
52 | """
53 | data_cfg = self.cfg.dataset
54 | data_iter = DataLoader(self.dataset, batch_size=data_cfg.batch_size, shuffle=False,
55 | collate_fn=prior_dataset.SequenceDataset.collate_fn,
56 | pin_memory=data_cfg.pin_memory, num_workers=data_cfg.num_workers)
57 |
58 | prior_model.validate(self.cfg, self.model, data_iter, self.epoch, self.cfg.device)
59 |
60 | def limited_batch_inference(self, data, verbose=False):
61 | """
62 | This function makes the assumption that scenes in the batch have the same number of objects that need to be
63 | rearranged
64 |
65 | :param data:
66 | :param model:
67 | :param test_dataset:
68 | :param tokenizer:
69 | :param cfg:
70 | :param num_samples:
71 | :param verbose:
72 | :return:
73 | """
74 |
75 | data_size = len(data)
76 | batch_size = self.cfg.dataset.batch_size
77 | if verbose:
78 | print("data size:", data_size)
79 | print("batch size:", batch_size)
80 |
81 | num_batches = int(data_size / batch_size)
82 | if data_size % batch_size != 0:
83 | num_batches += 1
84 |
85 | all_obj_preds = []
86 | all_struct_preds = []
87 | for b in range(num_batches):
88 | if b + 1 == num_batches:
89 | # last batch
90 | batch = data[b * batch_size:]
91 | else:
92 | batch = data[b * batch_size: (b+1) * batch_size]
93 | data_tensors = [self.dataset.convert_to_tensors(d, self.tokenizer) for d in batch]
94 | data_tensors = self.dataset.collate_fn(data_tensors)
95 | predictions = prior_model.infer_once(self.cfg, self.model, data_tensors, self.cfg.device)
96 |
97 | obj_x_preds = torch.cat(predictions["obj_x_outputs"], dim=0)
98 | obj_y_preds = torch.cat(predictions["obj_y_outputs"], dim=0)
99 | obj_z_preds = torch.cat(predictions["obj_z_outputs"], dim=0)
100 | obj_theta_preds = torch.cat(predictions["obj_theta_outputs"], dim=0)
101 | obj_preds = torch.cat([obj_x_preds, obj_y_preds, obj_z_preds, obj_theta_preds], dim=1) # batch_size * max num objects, output_dim
102 |
103 | struct_x_preds = torch.cat(predictions["struct_x_inputs"], dim=0)
104 | struct_y_preds = torch.cat(predictions["struct_y_inputs"], dim=0)
105 | struct_z_preds = torch.cat(predictions["struct_z_inputs"], dim=0)
106 | struct_theta_preds = torch.cat(predictions["struct_theta_inputs"], dim=0)
107 | struct_preds = torch.cat([struct_x_preds, struct_y_preds, struct_z_preds, struct_theta_preds], dim=1) # batch_size, output_dim
108 |
109 | all_obj_preds.append(obj_preds)
110 | all_struct_preds.append(struct_preds)
111 |
112 | obj_preds = torch.cat(all_obj_preds, dim=0) # data_size * max num objects, output_dim
113 | struct_preds = torch.cat(all_struct_preds, dim=0) # data_size, output_dim
114 |
115 | obj_preds = obj_preds.detach().cpu().numpy()
116 | struct_preds = struct_preds.detach().cpu().numpy()
117 |
118 | obj_preds = obj_preds.reshape(data_size, -1, obj_preds.shape[-1]) # batch_size, max num objects, output_dim
119 |
120 | return struct_preds, obj_preds
121 |
122 |
123 | def inference_beam_decoding(model_dir, dirs_cfg, beam_size=100, max_scene_decodes=30000,
124 | visualize=True, visualize_action_sequence=False,
125 | inference_visualization_dir=None):
126 | """
127 |
128 | :param model_dir:
129 | :param beam_size:
130 | :param max_scene_decodes:
131 | :param visualize:
132 | :param visualize_action_sequence:
133 | :param inference_visualization_dir:
134 | :param side_view:
135 | :return:
136 | """
137 |
138 | if inference_visualization_dir and not os.path.exists(inference_visualization_dir):
139 | os.makedirs(inference_visualization_dir)
140 |
141 | prior_inference = PriorInference(model_dir, dirs_cfg)
142 | test_dataset = prior_inference.dataset
143 |
144 | decoded_scene_count = 0
145 | with tqdm.tqdm(total=len(test_dataset)) as pbar:
146 | # for idx in np.random.choice(range(len(test_dataset)), len(test_dataset), replace=False):
147 | for idx in range(len(test_dataset)):
148 |
149 | if decoded_scene_count == max_scene_decodes:
150 | break
151 |
152 | filename = test_dataset.get_data_index(idx)
153 | scene_id = os.path.split(filename)[1][4:-3]
154 |
155 | decoded_scene_count += 1
156 |
157 | ############################################
158 | # retrieve data
159 | beam_data = []
160 | beam_pc_rearrangements = []
161 | for b in range(beam_size):
162 | datum = test_dataset.get_raw_data(idx, inference_mode=True, shuffle_object_index=False)
163 |
164 | # not necessary, but just to ensure no test leakage
165 | datum["struct_x_inputs"] = [0]
166 | datum["struct_y_inputs"] = [0]
167 | datum["struct_y_inputs"] = [0]
168 | datum["struct_theta_inputs"] = [[0] * 9]
169 | for obj_idx in range(len(datum["obj_x_inputs"])):
170 | datum["obj_x_inputs"][obj_idx] = 0
171 | datum["obj_y_inputs"][obj_idx] = 0
172 | datum["obj_z_inputs"][obj_idx] = 0
173 | datum["obj_theta_inputs"][obj_idx] = [0] * 9
174 |
175 | # We can play with different language here
176 | # datum["sentence"] = modify_language(datum["sentence"], radius=0.5)
177 | # datum["sentence"] = modify_language(datum["sentence"], position_x=1)
178 | # datum["sentence"] = modify_language(datum["sentence"], position_y=0.5)
179 |
180 | beam_data.append(datum)
181 | beam_pc_rearrangements.append(PointCloudRearrangement(datum))
182 |
183 | if visualize:
184 | datum = beam_data[0]
185 | print("#"*50)
186 | print("sentence", datum["sentence"])
187 | show_pcs(datum["xyzs"] + datum["other_xyzs"], datum["rgbs"] + datum["other_rgbs"],
188 | add_coordinate_frame=False, side_view=True, add_table=True)
189 |
190 | ############################################
191 | # autoregressive decoding
192 | num_target_objects = beam_pc_rearrangements[0].num_target_objects
193 | # first predict structure pose
194 | beam_goal_struct_pose, target_object_preds = prior_inference.limited_batch_inference(beam_data)
195 | for b in range(beam_size):
196 | datum = beam_data[b]
197 | datum["struct_x_inputs"] = [beam_goal_struct_pose[b][0]]
198 | datum["struct_y_inputs"] = [beam_goal_struct_pose[b][1]]
199 | datum["struct_z_inputs"] = [beam_goal_struct_pose[b][2]]
200 | datum["struct_theta_inputs"] = [beam_goal_struct_pose[b][3:]]
201 |
202 | # then iteratively predict pose of each object
203 | beam_goal_obj_poses = []
204 | for obj_idx in range(num_target_objects):
205 | struct_preds, target_object_preds = prior_inference.limited_batch_inference(beam_data)
206 | beam_goal_obj_poses.append(target_object_preds[:, obj_idx])
207 | for b in range(beam_size):
208 | datum = beam_data[b]
209 | datum["obj_x_inputs"][obj_idx] = target_object_preds[b][obj_idx][0]
210 | datum["obj_y_inputs"][obj_idx] = target_object_preds[b][obj_idx][1]
211 | datum["obj_z_inputs"][obj_idx] = target_object_preds[b][obj_idx][2]
212 | datum["obj_theta_inputs"][obj_idx] = target_object_preds[b][obj_idx][3:]
213 | # concat in the object dim
214 | beam_goal_obj_poses = np.stack(beam_goal_obj_poses, axis=0)
215 | # swap axis
216 | beam_goal_obj_poses = np.swapaxes(beam_goal_obj_poses, 1, 0) # batch size, number of target objects, pose dim
217 |
218 | ############################################
219 | # move pc
220 | for bi in range(beam_size):
221 | beam_pc_rearrangements[bi].set_goal_poses(beam_goal_struct_pose[bi], beam_goal_obj_poses[bi])
222 | beam_pc_rearrangements[bi].rearrange()
223 |
224 | ############################################
225 | if visualize:
226 | for pc_rearrangement in beam_pc_rearrangements:
227 | pc_rearrangement.visualize("goal", add_other_objects=True,
228 | add_coordinate_frame=False, side_view=True, add_table=True)
229 |
230 | if inference_visualization_dir:
231 | for pc_rearrangement in beam_pc_rearrangements:
232 | pc_rearrangement.visualize("goal", add_other_objects=True,
233 | add_coordinate_frame=False, side_view=True, add_table=True,
234 | save_vis=True,
235 | save_filename=os.path.join(inference_visualization_dir, "{}.jpg".format(scene_id)))
236 |
237 | pbar.update(1)
238 |
239 |
240 | if __name__ == "__main__":
241 | parser = argparse.ArgumentParser(description="Run a simple model")
242 | parser.add_argument("--dataset_base_dir", help='location of the dataset', type=str)
243 | parser.add_argument("--model_dir", help='location for the saved model', type=str)
244 | parser.add_argument("--dirs_config", help='config yaml file for directories', default="", type=str)
245 | args = parser.parse_args()
246 |
247 | os.environ["DATETIME"] = time.strftime("%Y%m%d-%H%M%S")
248 |
249 | # # debug only
250 | # args.dataset_base_dir = "/home/weiyu/data_drive/data_new_objects_test_split"
251 | # args.model_dir = "/home/weiyu/Research/intern/StructFormer/models/structformer_no_encoder_tower/best_model"
252 | # args.dirs_config = "/home/weiyu/Research/intern/StructFormer/structformer/configs/data/tower_dirs.yaml"
253 |
254 | if args.dirs_config:
255 | assert os.path.exists(args.dirs_config), "Cannot find config yaml file at {}".format(args.dirs_config)
256 | dirs_cfg = OmegaConf.load(args.dirs_config)
257 | dirs_cfg.dataset_base_dir = args.dataset_base_dir
258 | OmegaConf.resolve(dirs_cfg)
259 | else:
260 | dirs_cfg = None
261 |
262 | inference_beam_decoding(args.model_dir, dirs_cfg, beam_size=3, max_scene_decodes=30000,
263 | visualize=True, visualize_action_sequence=False,
264 | inference_visualization_dir=None)
265 |
266 |
--------------------------------------------------------------------------------
/src/structformer/evaluation/test_structformer_no_structure.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import os
4 | import copy
5 | import tqdm
6 | import argparse
7 | from omegaconf import OmegaConf
8 | import time
9 | from torch.utils.data import DataLoader
10 |
11 | from structformer.data.tokenizer import Tokenizer
12 | import structformer.data.sequence_dataset as prior_dataset
13 | import structformer.training.train_structformer_no_structure as prior_model
14 | from structformer.utils.rearrangement import show_pcs
15 | from structformer.evaluation.inference import PointCloudRearrangement
16 |
17 |
18 | def test_model(model_dir, dirs_cfg):
19 | prior_inference = PriorInference(model_dir, dirs_cfg, data_split="test")
20 | prior_inference.validate()
21 |
22 |
23 | class PriorInference:
24 |
25 | def __init__(self, model_dir, dirs_cfg, data_split="test"):
26 |
27 | cfg, tokenizer, model, optimizer, scheduler, epoch = prior_model.load_model(model_dir, dirs_cfg)
28 |
29 | data_cfg = cfg.dataset
30 |
31 | dataset = prior_dataset.SequenceDataset(data_cfg.dirs, data_cfg.index_dirs, data_split, tokenizer,
32 | data_cfg.max_num_objects,
33 | data_cfg.max_num_other_objects,
34 | data_cfg.max_num_shape_parameters,
35 | data_cfg.max_num_rearrange_features,
36 | data_cfg.max_num_anchor_features,
37 | data_cfg.num_pts,
38 | data_cfg.use_structure_frame)
39 |
40 | self.cfg = cfg
41 | self.tokenizer = tokenizer
42 | self.model = model
43 | self.cfg = cfg
44 | self.dataset = dataset
45 | self.epoch = epoch
46 |
47 | def validate(self):
48 | """
49 | validate the pretrained model on the dataset
50 |
51 | :return:
52 | """
53 | data_cfg = self.cfg.dataset
54 | data_iter = DataLoader(self.dataset, batch_size=data_cfg.batch_size, shuffle=False,
55 | collate_fn=prior_dataset.SequenceDataset.collate_fn,
56 | pin_memory=data_cfg.pin_memory, num_workers=data_cfg.num_workers)
57 |
58 | prior_model.validate(self.cfg, self.model, data_iter, self.epoch, self.cfg.device)
59 |
60 | def limited_batch_inference(self, data, verbose=False):
61 | """
62 | This function makes the assumption that scenes in the batch have the same number of objects that need to be
63 | rearranged
64 |
65 | :param data:
66 | :param model:
67 | :param test_dataset:
68 | :param tokenizer:
69 | :param cfg:
70 | :param num_samples:
71 | :param verbose:
72 | :return:
73 | """
74 |
75 | data_size = len(data)
76 | batch_size = self.cfg.dataset.batch_size
77 | if verbose:
78 | print("data size:", data_size)
79 | print("batch size:", batch_size)
80 |
81 | num_batches = int(data_size / batch_size)
82 | if data_size % batch_size != 0:
83 | num_batches += 1
84 |
85 | all_obj_preds = []
86 | for b in range(num_batches):
87 | if b + 1 == num_batches:
88 | # last batch
89 | batch = data[b * batch_size:]
90 | else:
91 | batch = data[b * batch_size: (b+1) * batch_size]
92 | data_tensors = [self.dataset.convert_to_tensors(d, self.tokenizer) for d in batch]
93 | data_tensors = self.dataset.collate_fn(data_tensors)
94 | predictions = prior_model.infer_once(self.cfg, self.model, data_tensors, self.cfg.device)
95 |
96 | obj_x_preds = torch.cat(predictions["obj_x_outputs"], dim=0)
97 | obj_y_preds = torch.cat(predictions["obj_y_outputs"], dim=0)
98 | obj_z_preds = torch.cat(predictions["obj_z_outputs"], dim=0)
99 | obj_theta_preds = torch.cat(predictions["obj_theta_outputs"], dim=0)
100 | obj_preds = torch.cat([obj_x_preds, obj_y_preds, obj_z_preds, obj_theta_preds], dim=1) # batch_size * max num objects, output_dim
101 |
102 | all_obj_preds.append(obj_preds)
103 |
104 | obj_preds = torch.cat(all_obj_preds, dim=0) # data_size * max num objects, output_dim
105 |
106 | obj_preds = obj_preds.detach().cpu().numpy()
107 |
108 | obj_preds = obj_preds.reshape(data_size, -1, obj_preds.shape[-1]) # batch_size, max num objects, output_dim
109 |
110 | return obj_preds
111 |
112 |
113 | def inference_beam_decoding(model_dir, dirs_cfg, beam_size=100, max_scene_decodes=30000,
114 | visualize=True, visualize_action_sequence=False,
115 | inference_visualization_dir=None):
116 | """
117 |
118 | :param model_dir:
119 | :param beam_size:
120 | :param max_scene_decodes:
121 | :param visualize:
122 | :param visualize_action_sequence:
123 | :param inference_visualization_dir:
124 | :param side_view:
125 | :return:
126 | """
127 |
128 | if inference_visualization_dir and not os.path.exists(inference_visualization_dir):
129 | os.makedirs(inference_visualization_dir)
130 |
131 | prior_inference = PriorInference(model_dir, dirs_cfg)
132 | test_dataset = prior_inference.dataset
133 |
134 | decoded_scene_count = 0
135 | with tqdm.tqdm(total=len(test_dataset)) as pbar:
136 | # for idx in np.random.choice(range(len(test_dataset)), len(test_dataset), replace=False):
137 | for idx in range(len(test_dataset)):
138 |
139 | if decoded_scene_count == max_scene_decodes:
140 | break
141 |
142 | filename = test_dataset.get_data_index(idx)
143 | scene_id = os.path.split(filename)[1][4:-3]
144 |
145 | decoded_scene_count += 1
146 |
147 | ############################################
148 | # retrieve data
149 | beam_data = []
150 | beam_pc_rearrangements = []
151 | for b in range(beam_size):
152 | datum = test_dataset.get_raw_data(idx, inference_mode=True, shuffle_object_index=False)
153 |
154 | # not necessary, but just to ensure no test leakage
155 | for obj_idx in range(len(datum["obj_x_inputs"])):
156 | datum["obj_x_inputs"][obj_idx] = 0
157 | datum["obj_y_inputs"][obj_idx] = 0
158 | datum["obj_z_inputs"][obj_idx] = 0
159 | datum["obj_theta_inputs"][obj_idx] = [0] * 9
160 |
161 | # We can play with different language here
162 | # datum["sentence"] = modify_language(datum["sentence"], radius=0.5)
163 | # datum["sentence"] = modify_language(datum["sentence"], position_x=1)
164 | # datum["sentence"] = modify_language(datum["sentence"], position_y=0.5)
165 |
166 | beam_data.append(datum)
167 | beam_pc_rearrangements.append(PointCloudRearrangement(datum, use_structure_frame=False))
168 |
169 | if visualize:
170 | datum = beam_data[0]
171 | print("#"*50)
172 | print("sentence", datum["sentence"])
173 | show_pcs(datum["xyzs"] + datum["other_xyzs"], datum["rgbs"] + datum["other_rgbs"],
174 | add_coordinate_frame=False, side_view=True, add_table=True)
175 |
176 | ############################################
177 | # autoregressive decoding
178 | num_target_objects = beam_pc_rearrangements[0].num_target_objects
179 | # iteratively predict pose of each object
180 | beam_goal_obj_poses = []
181 | for obj_idx in range(num_target_objects):
182 | target_object_preds = prior_inference.limited_batch_inference(beam_data)
183 | beam_goal_obj_poses.append(target_object_preds[:, obj_idx])
184 | for b in range(beam_size):
185 | datum = beam_data[b]
186 | datum["obj_x_inputs"][obj_idx] = target_object_preds[b][obj_idx][0]
187 | datum["obj_y_inputs"][obj_idx] = target_object_preds[b][obj_idx][1]
188 | datum["obj_z_inputs"][obj_idx] = target_object_preds[b][obj_idx][2]
189 | datum["obj_theta_inputs"][obj_idx] = target_object_preds[b][obj_idx][3:]
190 | # concat in the object dim
191 | beam_goal_obj_poses = np.stack(beam_goal_obj_poses, axis=0)
192 | # swap axis
193 | beam_goal_obj_poses = np.swapaxes(beam_goal_obj_poses, 1, 0) # batch size, number of target objects, pose dim
194 |
195 | ############################################
196 | # move pc
197 | for bi in range(beam_size):
198 | beam_pc_rearrangements[bi].set_goal_poses(None, beam_goal_obj_poses[bi])
199 | beam_pc_rearrangements[bi].rearrange()
200 |
201 | ############################################
202 | if visualize:
203 | for pc_rearrangement in beam_pc_rearrangements:
204 | pc_rearrangement.visualize("goal", add_other_objects=True,
205 | add_coordinate_frame=False, side_view=True, add_table=True)
206 |
207 | if inference_visualization_dir:
208 | for pc_rearrangement in beam_pc_rearrangements:
209 | pc_rearrangement.visualize("goal", add_other_objects=True,
210 | add_coordinate_frame=False, side_view=True, add_table=True,
211 | save_vis=True,
212 | save_filename=os.path.join(inference_visualization_dir, "{}.jpg".format(scene_id)))
213 |
214 | pbar.update(1)
215 |
216 |
217 | if __name__ == "__main__":
218 | parser = argparse.ArgumentParser(description="Run a simple model")
219 | parser.add_argument("--dataset_base_dir", help='location of the dataset', type=str)
220 | parser.add_argument("--model_dir", help='location for the saved model', type=str)
221 | parser.add_argument("--dirs_config", help='config yaml file for directories', default="", type=str)
222 | args = parser.parse_args()
223 |
224 | os.environ["DATETIME"] = time.strftime("%Y%m%d-%H%M%S")
225 |
226 | # # debug only
227 | # args.dataset_base_dir = "/home/weiyu/data_drive/data_new_objects_test_split"
228 | # args.model_dir = "/home/weiyu/Research/intern/StructFormer/models/structformer_no_structure_dinner/best_model"
229 | # args.dirs_config = "/home/weiyu/Research/intern/StructFormer/structformer/configs/data/dinner_dirs.yaml"
230 |
231 | if args.dirs_config:
232 | assert os.path.exists(args.dirs_config), "Cannot find config yaml file at {}".format(args.dirs_config)
233 | dirs_cfg = OmegaConf.load(args.dirs_config)
234 | dirs_cfg.dataset_base_dir = args.dataset_base_dir
235 | OmegaConf.resolve(dirs_cfg)
236 | else:
237 | dirs_cfg = None
238 |
239 | inference_beam_decoding(args.model_dir, dirs_cfg, beam_size=3, max_scene_decodes=30000,
240 | visualize=True, visualize_action_sequence=False,
241 | inference_visualization_dir=None)
242 |
243 |
244 |
--------------------------------------------------------------------------------
/src/structformer/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wliu88/StructFormer/a81d8c83313f1f24a75ca83aa9dd9f8cfcf27419/src/structformer/models/__init__.py
--------------------------------------------------------------------------------
/src/structformer/models/object_selection_network.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from torch.nn import TransformerEncoder, TransformerEncoderLayer
7 |
8 | from structformer.models.point_transformer import PointTransformerEncoderSmall
9 |
10 |
11 | class FocalLoss(nn.Module):
12 | "Focal Loss"
13 |
14 | def __init__(self, gamma=2, alpha=.25):
15 | super(FocalLoss, self).__init__()
16 | # self.alpha = torch.tensor([alpha, 1-alpha])
17 | self.gamma = gamma
18 |
19 | def forward(self, inputs, targets):
20 | BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
21 | pt = torch.exp(-BCE_loss)
22 | # targets = targets.type(torch.long)
23 | # at = self.alpha.gather(0, targets.data.view(-1))
24 | # F_loss = at*(1-pt)**self.gamma * BCE_loss
25 | F_loss = (1 - pt)**self.gamma * BCE_loss
26 | return F_loss.mean()
27 |
28 |
29 | class EncoderMLP(torch.nn.Module):
30 | def __init__(self, in_dim, out_dim, pt_dim=3, uses_pt=True):
31 | super(EncoderMLP, self).__init__()
32 | self.uses_pt = uses_pt
33 | self.output = out_dim
34 | d5 = int(in_dim)
35 | d6 = int(2 * self.output)
36 | d7 = self.output
37 | self.encode_position = nn.Sequential(
38 | nn.Linear(pt_dim, in_dim),
39 | nn.LayerNorm(in_dim),
40 | nn.ReLU(),
41 | nn.Linear(in_dim, in_dim),
42 | nn.LayerNorm(in_dim),
43 | nn.ReLU(),
44 | )
45 | d5 = 2 * in_dim if self.uses_pt else in_dim
46 | self.fc_block = nn.Sequential(
47 | nn.Linear(int(d5), d6),
48 | nn.LayerNorm(int(d6)),
49 | nn.ReLU(),
50 | nn.Linear(int(d6), d6),
51 | nn.LayerNorm(int(d6)),
52 | nn.ReLU(),
53 | nn.Linear(d6, d7))
54 |
55 | def forward(self, x, pt=None):
56 | if self.uses_pt:
57 | if pt is None: raise RuntimeError('did not provide pt')
58 | y = self.encode_position(pt)
59 | x = torch.cat([x, y], dim=-1)
60 | return self.fc_block(x)
61 |
62 |
63 | class RearrangeObjectsPredictorPCT(torch.nn.Module):
64 |
65 | def __init__(self, vocab_size,
66 | num_attention_heads=8, encoder_hidden_dim=16, encoder_dropout=0.1, encoder_activation="relu", encoder_num_layers=8,
67 | use_focal_loss=False, focal_loss_gamma=2):
68 | super(RearrangeObjectsPredictorPCT, self).__init__()
69 |
70 | print("Object Selection Network with Point Transformer")
71 |
72 | # object encode will have dim 256
73 | self.object_encoder = PointTransformerEncoderSmall(output_dim=256, input_dim=6, mean_center=True)
74 |
75 | # 256 = 240 (point cloud) + 8 (position idx) + 8 (token type)
76 | self.mlp = EncoderMLP(256, 240, uses_pt=False)
77 |
78 | self.word_embeddings = torch.nn.Embedding(vocab_size, 240, padding_idx=0)
79 | self.token_type_embeddings = torch.nn.Embedding(2, 8)
80 | self.position_embeddings = torch.nn.Embedding(11, 8)
81 |
82 | encoder_layers = TransformerEncoderLayer(256, num_attention_heads,
83 | encoder_hidden_dim, encoder_dropout, encoder_activation)
84 | self.encoder = TransformerEncoder(encoder_layers, encoder_num_layers)
85 |
86 | self.rearrange_object_fier = nn.Sequential(nn.Linear(256, 256),
87 | nn.LayerNorm(256),
88 | nn.ReLU(),
89 | nn.Linear(256, 128),
90 | nn.LayerNorm(128),
91 | nn.ReLU(),
92 | nn.Linear(128, 1))
93 |
94 | ###########################
95 | if use_focal_loss:
96 | print("use focal loss")
97 | self.loss = FocalLoss(gamma=focal_loss_gamma)
98 | else:
99 | print("use standard BCE logit loss")
100 | self.loss = torch.nn.BCEWithLogitsLoss(reduction="mean")
101 |
102 | def forward(self, xyzs, rgbs, object_pad_mask, sentence, sentence_pad_mask, token_type_index, position_index):
103 |
104 | batch_size = object_pad_mask.shape[0]
105 | num_objects = object_pad_mask.shape[1]
106 |
107 | #########################
108 | center_xyz, x = self.object_encoder(xyzs, rgbs)
109 | x = self.mlp(x, center_xyz)
110 | x = x.reshape(batch_size, num_objects, -1)
111 |
112 | #########################
113 | sentence = self.word_embeddings(sentence)
114 |
115 | #########################
116 | position_embed = self.position_embeddings(position_index)
117 | token_type_embed = self.token_type_embeddings(token_type_index)
118 | pad_mask = torch.cat([sentence_pad_mask, object_pad_mask], dim=1)
119 |
120 | sequence_encode = torch.cat([sentence, x], dim=1)
121 | sequence_encode = torch.cat([sequence_encode, position_embed, token_type_embed], dim=-1)
122 | #########################
123 | # sequence_encode: [batch size, sequence_length, encoder input dimension]
124 | # input to transformer needs to have dimenion [sequence_length, batch size, encoder input dimension]
125 | sequence_encode = sequence_encode.transpose(1, 0)
126 |
127 | # convert to bool
128 | pad_mask = (pad_mask == 1)
129 |
130 | # encode: [sequence_length, batch_size, embedding_size]
131 | encode = self.encoder(sequence_encode, src_key_padding_mask=pad_mask)
132 | encode = encode.transpose(1, 0)
133 | #########################
134 | obj_encodes = encode[:, -num_objects:, :]
135 | obj_encodes = obj_encodes.reshape(-1, obj_encodes.shape[-1])
136 |
137 | rearrange_obj_labels = self.rearrange_object_fier(obj_encodes).squeeze(dim=1) # batch_size * num_objects
138 |
139 | predictions = {"rearrange_obj_labels": rearrange_obj_labels}
140 |
141 | return predictions
142 |
143 | def criterion(self, predictions, labels):
144 |
145 | loss = 0
146 | for key in predictions:
147 |
148 | preds = predictions[key]
149 | gts = labels[key]
150 |
151 | mask = gts == -100
152 | preds = preds[~mask]
153 | gts = gts[~mask]
154 |
155 | loss += self.loss(preds, gts)
156 |
157 | return loss
158 |
159 | def convert_logits(self, predictions):
160 |
161 | for key in predictions:
162 | if key == "rearrange_obj_labels":
163 | predictions[key] = torch.sigmoid(predictions[key])
164 |
165 | return predictions
--------------------------------------------------------------------------------
/src/structformer/models/point_transformer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from structformer.utils.pointnet import farthest_point_sample, index_points, square_distance
4 |
5 | # adapted from https://github.com/qq456cvb/Point-Transformers
6 |
7 |
8 | def sample_and_group(npoint, nsample, xyz, points):
9 | B, N, C = xyz.shape
10 | S = npoint
11 |
12 | fps_idx = farthest_point_sample(xyz, npoint) # [B, npoint]
13 |
14 | new_xyz = index_points(xyz, fps_idx)
15 | new_points = index_points(points, fps_idx)
16 |
17 | dists = square_distance(new_xyz, xyz) # B x npoint x N
18 | idx = dists.argsort()[:, :, :nsample] # B x npoint x K
19 |
20 | grouped_points = index_points(points, idx)
21 | grouped_points_norm = grouped_points - new_points.view(B, S, 1, -1)
22 | new_points = torch.cat([grouped_points_norm, new_points.view(B, S, 1, -1).repeat(1, 1, nsample, 1)], dim=-1)
23 | return new_xyz, new_points
24 |
25 |
26 | class Local_op(nn.Module):
27 | def __init__(self, in_channels, out_channels):
28 | super().__init__()
29 | self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=1, bias=False)
30 | self.bn1 = nn.BatchNorm1d(out_channels)
31 | self.relu = nn.ReLU()
32 |
33 | def forward(self, x):
34 | b, n, s, d = x.size() # torch.Size([32, 512, 32, 6])
35 | x = x.permute(0, 1, 3, 2)
36 | x = x.reshape(-1, d, s)
37 | batch_size, _, N = x.size()
38 | x = self.relu(self.bn1(self.conv1(x))) # B, D, N
39 | x = torch.max(x, 2)[0]
40 | x = x.view(batch_size, -1)
41 | x = x.reshape(b, n, -1).permute(0, 2, 1)
42 | return x
43 |
44 |
45 | class SA_Layer(nn.Module):
46 | def __init__(self, channels):
47 | super().__init__()
48 | self.q_conv = nn.Conv1d(channels, channels // 4, 1, bias=False)
49 | self.k_conv = nn.Conv1d(channels, channels // 4, 1, bias=False)
50 | self.q_conv.weight = self.k_conv.weight
51 | self.v_conv = nn.Conv1d(channels, channels, 1)
52 | self.trans_conv = nn.Conv1d(channels, channels, 1)
53 | self.after_norm = nn.BatchNorm1d(channels)
54 | self.act = nn.ReLU()
55 | self.softmax = nn.Softmax(dim=-1)
56 |
57 | def forward(self, x):
58 | x_q = self.q_conv(x).permute(0, 2, 1) # b, n, c
59 | x_k = self.k_conv(x)# b, c, n
60 | x_v = self.v_conv(x)
61 | energy = x_q @ x_k # b, n, n
62 | attention = self.softmax(energy)
63 | attention = attention / (1e-9 + attention.sum(dim=1, keepdims=True))
64 | x_r = x_v @ attention # b, c, n
65 | x_r = self.act(self.after_norm(self.trans_conv(x - x_r)))
66 | x = x + x_r
67 | return x
68 |
69 |
70 | class StackedAttention(nn.Module):
71 | def __init__(self, channels=64):
72 | super().__init__()
73 | self.conv1 = nn.Conv1d(channels, channels, kernel_size=1, bias=False)
74 | self.conv2 = nn.Conv1d(channels, channels, kernel_size=1, bias=False)
75 |
76 | self.bn1 = nn.BatchNorm1d(channels)
77 | self.bn2 = nn.BatchNorm1d(channels)
78 |
79 | self.sa1 = SA_Layer(channels)
80 | self.sa2 = SA_Layer(channels)
81 |
82 | self.relu = nn.ReLU()
83 |
84 | def forward(self, x):
85 | #
86 | # b, 3, npoint, nsample
87 | # conv2d 3 -> 128 channels 1, 1
88 | # b * npoint, c, nsample
89 | # permute reshape
90 | batch_size, _, N = x.size()
91 |
92 | x = self.relu(self.bn1(self.conv1(x))) # B, D, N
93 | x = self.relu(self.bn2(self.conv2(x)))
94 |
95 | x1 = self.sa1(x)
96 | x2 = self.sa2(x1)
97 |
98 | x = torch.cat((x1, x2), dim=1)
99 |
100 | return x
101 |
102 |
103 | class PointTransformerEncoderSmall(nn.Module):
104 |
105 | def __init__(self, output_dim=256, input_dim=6, mean_center=True):
106 | super(PointTransformerEncoderSmall, self).__init__()
107 |
108 | self.mean_center = mean_center
109 |
110 | # map the second dim of the input from input_dim to 64
111 | self.conv1 = nn.Conv1d(input_dim, 64, kernel_size=1, bias=False)
112 | self.bn1 = nn.BatchNorm1d(64)
113 | self.gather_local_0 = Local_op(in_channels=128, out_channels=64)
114 | self.gather_local_1 = Local_op(in_channels=128, out_channels=64)
115 | self.pt_last = StackedAttention(channels=64)
116 |
117 | self.relu = nn.ReLU()
118 | self.conv_fuse = nn.Sequential(nn.Conv1d(192, 256, kernel_size=1, bias=False),
119 | nn.BatchNorm1d(256),
120 | nn.LeakyReLU(negative_slope=0.2))
121 |
122 | self.linear1 = nn.Linear(256, 256, bias=False)
123 | self.bn6 = nn.BatchNorm1d(256)
124 | self.dp1 = nn.Dropout(p=0.5)
125 | self.linear2 = nn.Linear(256, 256)
126 |
127 | def forward(self, xyz, f=None):
128 | # xyz: B, N, 3
129 | # f: B, N, D
130 | center = torch.mean(xyz, dim=1)
131 | if self.mean_center:
132 | xyz = xyz - center.view(-1, 1, 3).repeat(1, xyz.shape[1], 1)
133 | if f is None:
134 | x = self.pct(xyz)
135 | else:
136 | x = self.pct(torch.cat([xyz, f], dim=2)) # B, output_dim
137 |
138 | return center, x
139 |
140 | def pct(self, x):
141 |
142 | # x: B, N, D
143 | xyz = x[..., :3]
144 | x = x.permute(0, 2, 1)
145 | batch_size, _, _ = x.size()
146 | x = self.relu(self.bn1(self.conv1(x))) # B, D, N
147 | x = x.permute(0, 2, 1)
148 | new_xyz, new_feature = sample_and_group(npoint=128, nsample=32, xyz=xyz, points=x)
149 | feature_0 = self.gather_local_0(new_feature)
150 | feature = feature_0.permute(0, 2, 1) # B, nsamples, D
151 | new_xyz, new_feature = sample_and_group(npoint=32, nsample=16, xyz=new_xyz, points=feature)
152 | feature_1 = self.gather_local_1(new_feature) # B, D, nsamples
153 |
154 | x = self.pt_last(feature_1) # B, D * 2, nsamples
155 | x = torch.cat([x, feature_1], dim=1) # B, D * 3, nsamples
156 | x = self.conv_fuse(x)
157 | x = torch.max(x, 2)[0]
158 | x = x.view(batch_size, -1)
159 |
160 | x = self.relu(self.bn6(self.linear1(x)))
161 | x = self.dp1(x)
162 | x = self.linear2(x)
163 |
164 | return x
165 |
166 |
167 | class SampleAndGroup(nn.Module):
168 |
169 | def __init__(self, output_dim=64, input_dim=6, mean_center=True, npoints=(128, 32), nsamples=(32, 16)):
170 | super(SampleAndGroup, self).__init__()
171 |
172 | self.mean_center = mean_center
173 | self.npoints = npoints
174 | self.nsamples = nsamples
175 |
176 | # map the second dim of the input from input_dim to 64
177 | self.conv1 = nn.Conv1d(input_dim, output_dim, kernel_size=1, bias=False)
178 | self.bn1 = nn.BatchNorm1d(output_dim)
179 | self.gather_local_0 = Local_op(in_channels=output_dim * 2, out_channels=output_dim)
180 | self.gather_local_1 = Local_op(in_channels=output_dim * 2, out_channels=output_dim)
181 | self.relu = nn.ReLU()
182 |
183 | def forward(self, xyz, f):
184 | # xyz: B, N, 3
185 | # f: B, N, D
186 | center = torch.mean(xyz, dim=1)
187 | if self.mean_center:
188 | xyz = xyz - center.view(-1, 1, 3).repeat(1, xyz.shape[1], 1)
189 | x = self.sg(torch.cat([xyz, f], dim=2)) # B, nsamples, output_dim
190 |
191 | return center, x
192 |
193 | def sg(self, x):
194 |
195 | # x: B, N, D
196 | xyz = x[..., :3]
197 | x = x.permute(0, 2, 1)
198 | batch_size, _, _ = x.size()
199 | x = self.relu(self.bn1(self.conv1(x))) # B, D, N
200 | x = x.permute(0, 2, 1)
201 | new_xyz, new_feature = sample_and_group(npoint=self.npoints[0], nsample=self.nsamples[0], xyz=xyz, points=x)
202 | feature_0 = self.gather_local_0(new_feature)
203 | feature = feature_0.permute(0, 2, 1) # B, nsamples, D
204 | new_xyz, new_feature = sample_and_group(npoint=self.npoints[1], nsample=self.nsamples[1], xyz=new_xyz, points=feature)
205 | feature_1 = self.gather_local_1(new_feature) # B, D, nsamples
206 | x = feature_1.permute(0, 2, 1) # B, nsamples, D
207 |
208 | return x
--------------------------------------------------------------------------------
/src/structformer/training/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wliu88/StructFormer/a81d8c83313f1f24a75ca83aa9dd9f8cfcf27419/src/structformer/training/__init__.py
--------------------------------------------------------------------------------
/src/structformer/training/train_binary_model.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, division
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.optim as optim
6 | from torch.optim import lr_scheduler
7 | from warmup_scheduler import GradualWarmupScheduler
8 | import numpy as np
9 | import torchvision
10 | from torchvision import datasets, models, transforms
11 | import matplotlib.pyplot as plt
12 | import time
13 | import os
14 | import copy
15 | import tqdm
16 |
17 | import pickle
18 | import argparse
19 | from omegaconf import OmegaConf
20 | from collections import defaultdict
21 |
22 | from torch.utils.data import DataLoader
23 | from structformer.data.binary_dataset import BinaryDataset
24 | from structformer.models.pose_generation_network import PriorContinuousOutBinaryPCT6D
25 | from structformer.data.tokenizer import Tokenizer
26 | from structformer.utils.rearrangement import evaluate_prior_prediction
27 |
28 |
29 | def train_model(cfg, model, data_iter, optimizer, warmup, num_epochs, device, save_best_model, grad_clipping=1.0):
30 |
31 | if save_best_model:
32 | best_model_dir = os.path.join(cfg.experiment_dir, "best_model")
33 | print("best model will be saved to {}".format(best_model_dir))
34 | if not os.path.exists(best_model_dir):
35 | os.makedirs(best_model_dir)
36 | best_score = -np.inf
37 |
38 | for epoch in range(num_epochs):
39 | print('Epoch {}/{}'.format(epoch, num_epochs - 1))
40 | print('-' * 10)
41 |
42 | model.train()
43 | epoch_loss = 0
44 | gts = defaultdict(list)
45 | predictions = defaultdict(list)
46 |
47 | with tqdm.tqdm(total=len(data_iter["train"])) as pbar:
48 | for step, batch in enumerate(data_iter["train"]):
49 | optimizer.zero_grad()
50 | # input
51 | query_xyz = batch["query_xyz"].to(device, non_blocking=True)
52 | query_rgb = batch["query_rgb"].to(device, non_blocking=True)
53 | anchor_xyz = batch["anchor_xyz"].to(device, non_blocking=True)
54 | anchor_rgb = batch["anchor_rgb"].to(device, non_blocking=True)
55 | bg_xyz = batch["bg_xyz"].to(device, non_blocking=True)
56 | bg_rgb = batch["bg_rgb"].to(device, non_blocking=True)
57 | sentence = batch["sentence"].to(device, non_blocking=True)
58 | position_index = batch["position_index"].to(device, non_blocking=True)
59 | pad_mask = batch["pad_mask"].to(device, non_blocking=True)
60 |
61 | # output
62 | targets = {}
63 | for key in ["obj_x_outputs", "obj_y_outputs", "obj_z_outputs", "obj_theta_outputs"]:
64 | targets[key] = batch[key].to(device, non_blocking=True)
65 | targets[key] = targets[key].reshape(targets[key].shape[0] * targets[key].shape[1], -1)
66 |
67 | preds = model.forward(query_xyz, query_rgb, anchor_xyz, anchor_rgb, bg_xyz, bg_rgb,
68 | sentence, pad_mask, position_index)
69 |
70 | loss = model.criterion(preds, targets)
71 | loss.backward()
72 |
73 | if grad_clipping != 0.0:
74 | torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clipping)
75 |
76 | optimizer.step()
77 | epoch_loss += loss
78 |
79 | for key in ["obj_x_outputs", "obj_y_outputs", "obj_z_outputs", "obj_theta_outputs"]:
80 | gts[key].append(targets[key].detach())
81 | predictions[key].append(preds[key].detach())
82 |
83 | pbar.update(1)
84 | pbar.set_postfix({"Batch loss": loss})
85 |
86 | warmup.step()
87 |
88 | print('[Epoch:{}]: Training Loss:{:.4}'.format(epoch, epoch_loss))
89 | evaluate_prior_prediction(gts, predictions, ["obj_x_outputs", "obj_y_outputs", "obj_z_outputs", "obj_theta_outputs"])
90 |
91 | score = validate(cfg, model, data_iter["valid"], epoch, device)
92 | if save_best_model and score > best_score:
93 | print("Saving best model so far...")
94 | best_score = score
95 | save_model(best_model_dir, cfg, epoch, model)
96 |
97 | return model
98 |
99 |
100 | def validate(cfg, model, data_iter, epoch, device):
101 | """
102 | helper function to evaluate the model
103 |
104 | :param model:
105 | :param data_iter:
106 | :param epoch:
107 | :param device:
108 | :return:
109 | """
110 |
111 | model.eval()
112 |
113 | epoch_loss = 0
114 | gts = defaultdict(list)
115 | predictions = defaultdict(list)
116 | with torch.no_grad():
117 |
118 | with tqdm.tqdm(total=len(data_iter)) as pbar:
119 | for step, batch in enumerate(data_iter):
120 |
121 | # input
122 | query_xyz = batch["query_xyz"].to(device, non_blocking=True)
123 | query_rgb = batch["query_rgb"].to(device, non_blocking=True)
124 | anchor_xyz = batch["anchor_xyz"].to(device, non_blocking=True)
125 | anchor_rgb = batch["anchor_rgb"].to(device, non_blocking=True)
126 | bg_xyz = batch["bg_xyz"].to(device, non_blocking=True)
127 | bg_rgb = batch["bg_rgb"].to(device, non_blocking=True)
128 | sentence = batch["sentence"].to(device, non_blocking=True)
129 | position_index = batch["position_index"].to(device, non_blocking=True)
130 | pad_mask = batch["pad_mask"].to(device, non_blocking=True)
131 |
132 | # output
133 | targets = {}
134 | for key in ["obj_x_outputs", "obj_y_outputs", "obj_z_outputs", "obj_theta_outputs"]:
135 | targets[key] = batch[key].to(device, non_blocking=True)
136 | targets[key] = targets[key].reshape(targets[key].shape[0] * targets[key].shape[1], -1)
137 |
138 | preds = model.forward(query_xyz, query_rgb, anchor_xyz, anchor_rgb, bg_xyz, bg_rgb,
139 | sentence, pad_mask, position_index)
140 |
141 | loss = model.criterion(preds, targets)
142 |
143 | for key in ["obj_x_outputs", "obj_y_outputs", "obj_z_outputs", "obj_theta_outputs"]:
144 | gts[key].append(targets[key])
145 | predictions[key].append(preds[key])
146 |
147 | epoch_loss += loss
148 | pbar.update(1)
149 | pbar.set_postfix({"Batch loss": loss})
150 |
151 | print('[Epoch:{}]: Val Loss:{:.4}'.format(epoch, epoch_loss))
152 |
153 | score = evaluate_prior_prediction(gts, predictions, ["obj_x_outputs", "obj_y_outputs", "obj_z_outputs", "obj_theta_outputs"])
154 | return score
155 |
156 |
157 | def infer_once(cfg, model, batch, device):
158 |
159 | model.eval()
160 |
161 | predictions = defaultdict(list)
162 | with torch.no_grad():
163 |
164 | # input
165 | query_xyz = batch["query_xyz"].to(device, non_blocking=True)
166 | query_rgb = batch["query_rgb"].to(device, non_blocking=True)
167 | anchor_xyz = batch["anchor_xyz"].to(device, non_blocking=True)
168 | anchor_rgb = batch["anchor_rgb"].to(device, non_blocking=True)
169 | bg_xyz = batch["bg_xyz"].to(device, non_blocking=True)
170 | bg_rgb = batch["bg_rgb"].to(device, non_blocking=True)
171 | sentence = batch["sentence"].to(device, non_blocking=True)
172 | position_index = batch["position_index"].to(device, non_blocking=True)
173 | pad_mask = batch["pad_mask"].to(device, non_blocking=True)
174 |
175 | # output
176 | targets = {}
177 | for key in ["obj_x_outputs", "obj_y_outputs", "obj_z_outputs", "obj_theta_outputs"]:
178 | targets[key] = batch[key].to(device, non_blocking=True)
179 | targets[key] = targets[key].reshape(targets[key].shape[0] * targets[key].shape[1], -1)
180 |
181 | preds = model.forward(query_xyz, query_rgb, anchor_xyz, anchor_rgb, bg_xyz, bg_rgb,
182 | sentence, pad_mask, position_index)
183 |
184 | for key in ["obj_x_outputs", "obj_y_outputs", "obj_z_outputs", "obj_theta_outputs"]:
185 | predictions[key].append(preds[key])
186 |
187 | return predictions
188 |
189 |
190 | def save_model(model_dir, cfg, epoch, model, optimizer=None, scheduler=None):
191 | state_dict = {'epoch': epoch,
192 | 'model_state_dict': model.state_dict()}
193 | if optimizer is not None:
194 | state_dict["optimizer_state_dict"] = optimizer.state_dict()
195 | if scheduler is not None:
196 | state_dict["scheduler_state_dict"] = scheduler.state_dict()
197 | torch.save(state_dict, os.path.join(model_dir, "model.tar"))
198 | OmegaConf.save(cfg, os.path.join(model_dir, "config.yaml"))
199 |
200 |
201 | def load_model(model_dir, dirs_cfg):
202 | """
203 | Load transformer model
204 | Important: to use the model, call model.eval() or model.train()
205 | :param model_dir:
206 | :return:
207 | """
208 | # load dictionaries
209 | cfg = OmegaConf.load(os.path.join(model_dir, "config.yaml"))
210 | if dirs_cfg:
211 | cfg = OmegaConf.merge(cfg, dirs_cfg)
212 |
213 | data_cfg = cfg.dataset
214 | tokenizer = Tokenizer(data_cfg.vocab_dir)
215 | vocab_size = tokenizer.get_vocab_size()
216 |
217 | # initialize model
218 | model_cfg = cfg.model
219 | model = PriorContinuousOutBinaryPCT6D(vocab_size,
220 | num_attention_heads=model_cfg.num_attention_heads,
221 | encoder_hidden_dim=model_cfg.encoder_hidden_dim,
222 | encoder_dropout=model_cfg.encoder_dropout,
223 | encoder_activation=model_cfg.encoder_activation,
224 | encoder_num_layers=model_cfg.encoder_num_layers,
225 | object_dropout=model_cfg.object_dropout,
226 | theta_loss_divide=model_cfg.theta_loss_divide,
227 | ignore_rgb=model_cfg.ignore_rgb)
228 | model.to(cfg.device)
229 |
230 | # load state dicts
231 | checkpoint = torch.load(os.path.join(model_dir, "model.tar"))
232 | model.load_state_dict(checkpoint['model_state_dict'])
233 |
234 | optimizer = None
235 | if "optimizer_state_dict" in checkpoint:
236 | training_cfg = cfg.training
237 | optimizer = optim.Adam(model.parameters(), lr=training_cfg.learning_rate)
238 | optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
239 |
240 | scheduler = None
241 | if "scheduler_state_dict" in checkpoint:
242 | scheduler = None
243 | if scheduler:
244 | scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
245 |
246 | epoch = checkpoint['epoch']
247 | return cfg, tokenizer, model, optimizer, scheduler, epoch
248 |
249 |
250 | def run_model(cfg):
251 |
252 | np.random.seed(cfg.random_seed)
253 | torch.manual_seed(cfg.random_seed)
254 | if torch.cuda.is_available():
255 | torch.cuda.manual_seed(cfg.random_seed)
256 | torch.cuda.manual_seed_all(cfg.random_seed)
257 | torch.backends.cudnn.deterministic = True
258 |
259 | data_cfg = cfg.dataset
260 | tokenizer = Tokenizer(data_cfg.vocab_dir)
261 | vocab_size = tokenizer.get_vocab_size()
262 |
263 | train_dataset = BinaryDataset(data_cfg.dirs, data_cfg.index_dirs, "train", tokenizer,
264 | data_cfg.max_num_objects,
265 | data_cfg.max_num_other_objects,
266 | data_cfg.max_num_shape_parameters,
267 | data_cfg.max_num_rearrange_features,
268 | data_cfg.max_num_anchor_features,
269 | data_cfg.num_pts)
270 | valid_dataset = BinaryDataset(data_cfg.dirs, data_cfg.index_dirs, "valid", tokenizer,
271 | data_cfg.max_num_objects,
272 | data_cfg.max_num_other_objects,
273 | data_cfg.max_num_shape_parameters,
274 | data_cfg.max_num_rearrange_features,
275 | data_cfg.max_num_anchor_features,
276 | data_cfg.num_pts)
277 |
278 | data_iter = {}
279 | data_iter["train"] = DataLoader(train_dataset, batch_size=data_cfg.batch_size, shuffle=True,
280 | collate_fn=BinaryDataset.collate_fn,
281 | pin_memory=data_cfg.pin_memory, num_workers=data_cfg.num_workers)
282 | data_iter["valid"] = DataLoader(valid_dataset, batch_size=data_cfg.batch_size, shuffle=False,
283 | collate_fn=BinaryDataset.collate_fn,
284 | pin_memory=data_cfg.pin_memory, num_workers=data_cfg.num_workers)
285 |
286 | # load model
287 | model_cfg = cfg.model
288 | model = PriorContinuousOutBinaryPCT6D(vocab_size,
289 | num_attention_heads=model_cfg.num_attention_heads,
290 | encoder_hidden_dim=model_cfg.encoder_hidden_dim,
291 | encoder_dropout=model_cfg.encoder_dropout,
292 | encoder_activation=model_cfg.encoder_activation,
293 | encoder_num_layers=model_cfg.encoder_num_layers,
294 | object_dropout=model_cfg.object_dropout,
295 | theta_loss_divide=model_cfg.theta_loss_divide,
296 | ignore_rgb=model_cfg.ignore_rgb)
297 | model.to(cfg.device)
298 |
299 | training_cfg = cfg.training
300 | optimizer = optim.Adam(model.parameters(), lr=training_cfg.learning_rate, weight_decay=training_cfg.l2)
301 | scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=training_cfg.lr_restart)
302 | warmup = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=training_cfg.warmup,
303 | after_scheduler=scheduler)
304 |
305 | train_model(cfg, model, data_iter, optimizer, warmup, training_cfg.max_epochs, cfg.device, cfg.save_best_model)
306 |
307 | # save model
308 | if cfg.save_model:
309 | model_dir = os.path.join(cfg.experiment_dir, "model")
310 | print("Saving model to {}".format(model_dir))
311 | if not os.path.exists(model_dir):
312 | os.makedirs(model_dir)
313 | save_model(model_dir, cfg, cfg.max_epochs, model, optimizer, scheduler)
314 |
315 |
316 | if __name__ == "__main__":
317 |
318 | parser = argparse.ArgumentParser(description="Run a simple model")
319 | parser.add_argument("--dataset_base_dir", help='location of the dataset', type=str)
320 | parser.add_argument("--main_config", help='config yaml file for the model',
321 | default='../configs/binary_model.yaml',
322 | type=str)
323 | parser.add_argument("--dirs_config", help='config yaml file for directories',
324 | default='../configs/data/circle_dirs.yaml',
325 | type=str)
326 | args = parser.parse_args()
327 |
328 | # # debug
329 | # args.dataset_base_dir = "/home/weiyu/data_drive/data_new_objects"
330 |
331 | assert os.path.exists(args.main_config), "Cannot find config yaml file at {}".format(args.main_config)
332 | assert os.path.exists(args.dirs_config), "Cannot find config yaml file at {}".format(args.dir_config)
333 |
334 | os.environ["DATETIME"] = time.strftime("%Y%m%d-%H%M%S")
335 |
336 | main_cfg = OmegaConf.load(args.main_config)
337 | dirs_cfg = OmegaConf.load(args.dirs_config)
338 | cfg = OmegaConf.merge(main_cfg, dirs_cfg)
339 | cfg.dataset_base_dir = args.dataset_base_dir
340 | OmegaConf.resolve(cfg)
341 |
342 | if not os.path.exists(cfg.experiment_dir):
343 | os.makedirs(cfg.experiment_dir)
344 |
345 | OmegaConf.save(cfg, os.path.join(cfg.experiment_dir, "config.yaml"))
346 |
347 | run_model(cfg)
--------------------------------------------------------------------------------
/src/structformer/training/train_object_selection_network.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, division
2 |
3 | import torch
4 | import torch.optim as optim
5 | from torch.optim import lr_scheduler
6 | import numpy as np
7 | from warmup_scheduler import GradualWarmupScheduler
8 | import time
9 | import os
10 | import tqdm
11 |
12 | import argparse
13 | from omegaconf import OmegaConf
14 | from collections import defaultdict
15 | from sklearn.metrics import classification_report
16 |
17 | from torch.utils.data import DataLoader
18 | from structformer.data.object_set_refer_dataset import ObjectSetReferDataset
19 | from structformer.models.object_selection_network import RearrangeObjectsPredictorPCT
20 | from structformer.data.tokenizer import Tokenizer
21 |
22 |
23 | def evaluate(gts, predictions, keys, debug=True, return_classification_dict=False):
24 | """
25 | :param gts: expect a list of tensors
26 | :param predictions: expect a list of tensor
27 | :return:
28 | """
29 |
30 | total_scores = 0
31 | for key in keys:
32 | predictions_for_key = torch.cat(predictions[key], dim=0)
33 | gts_for_key = torch.cat(gts[key], dim=0)
34 |
35 | predicted_classes = predictions_for_key > 0.5
36 | assert len(gts_for_key) == len(predicted_classes)
37 |
38 | target_indices = gts_for_key != -100
39 |
40 | gts_for_key = gts_for_key[target_indices]
41 | predicted_classes = predicted_classes[target_indices]
42 | num_objects = len(predicted_classes)
43 |
44 | if debug:
45 | print(num_objects)
46 | print(gts_for_key.shape)
47 | print(predicted_classes.shape)
48 | print(target_indices.shape)
49 | print("Groundtruths:")
50 | print(gts_for_key[:100])
51 | print("Predictions")
52 | print(predicted_classes[:100])
53 |
54 | accuracy = torch.sum(gts_for_key == predicted_classes) / len(gts_for_key)
55 | print("{} objects -- {} accuracy: {}".format(num_objects, key, accuracy))
56 | total_scores += accuracy
57 |
58 | report = classification_report(gts_for_key.detach().cpu().numpy(), predicted_classes.detach().cpu().numpy(),
59 | output_dict=True)
60 | print(report)
61 |
62 | if not return_classification_dict:
63 | return total_scores
64 | else:
65 | return report
66 |
67 |
68 | def train_model(cfg, model, data_iter, optimizer, warmup, num_epochs, device, save_best_model, grad_clipping=1.0):
69 |
70 | if save_best_model:
71 | best_model_dir = os.path.join(cfg.experiment_dir, "best_model")
72 | print("best model will be saved to {}".format(best_model_dir))
73 | if not os.path.exists(best_model_dir):
74 | os.makedirs(best_model_dir)
75 | best_score = 0.0
76 |
77 | for epoch in range(num_epochs):
78 | print('Epoch {}/{}'.format(epoch, num_epochs - 1))
79 | print('-' * 10)
80 |
81 | model.train()
82 | epoch_loss = 0
83 | gts = defaultdict(list)
84 | predictions = defaultdict(list)
85 |
86 | with tqdm.tqdm(total=len(data_iter["train"])) as pbar:
87 | for step, batch in enumerate(data_iter["train"]):
88 | optimizer.zero_grad()
89 |
90 | # input
91 | xyzs = batch["xyzs"].to(device, non_blocking=True)
92 | rgbs = batch["rgbs"].to(device, non_blocking=True)
93 | sentence = batch["sentence"].to(device, non_blocking=True)
94 | object_pad_mask = batch["object_pad_mask"].to(device, non_blocking=True)
95 | sentence_pad_mask = batch["sentence_pad_mask"].to(device, non_blocking=True)
96 | token_type_index = batch["token_type_index"].to(device, non_blocking=True)
97 | position_index = batch["position_index"].to(device, non_blocking=True)
98 |
99 | # output
100 | targets = {}
101 | for key in ["rearrange_obj_labels"]:
102 | targets[key] = batch[key].to(device, non_blocking=True)
103 |
104 | preds = model.forward(xyzs, rgbs, object_pad_mask, sentence, sentence_pad_mask, token_type_index, position_index)
105 | loss = model.criterion(preds, targets)
106 |
107 | loss.backward()
108 |
109 | if grad_clipping != 0.0:
110 | torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clipping)
111 |
112 | optimizer.step()
113 | epoch_loss += loss
114 |
115 | for key in ["rearrange_obj_labels"]:
116 | gts[key].append(targets[key].detach())
117 | predictions[key].append(preds[key].detach())
118 |
119 | pbar.update(1)
120 | pbar.set_postfix({"Batch loss": loss})
121 |
122 | warmup.step()
123 |
124 | print('[Epoch:{}]: Training Loss:{:.4}'.format(epoch, epoch_loss))
125 | evaluate(gts, predictions, ["rearrange_obj_labels"])
126 |
127 | score = validate(model, data_iter["valid"], epoch, device)
128 | if save_best_model and score > best_score:
129 | print("Saving best model so far...")
130 | best_score = score
131 | save_model(best_model_dir, cfg, epoch, model)
132 |
133 | return model
134 |
135 |
136 | def validate(model, data_iter, epoch, device):
137 | """
138 | helper function to evaluate the model
139 |
140 | :param model:
141 | :param data_iter:
142 | :param epoch:
143 | :param device:
144 | :return:
145 | """
146 |
147 | model.eval()
148 |
149 | epoch_loss = 0
150 | gts = defaultdict(list)
151 | predictions = defaultdict(list)
152 | with torch.no_grad():
153 |
154 | with tqdm.tqdm(total=len(data_iter)) as pbar:
155 | for step, batch in enumerate(data_iter):
156 |
157 | # input
158 | xyzs = batch["xyzs"].to(device, non_blocking=True)
159 | rgbs = batch["rgbs"].to(device, non_blocking=True)
160 | sentence = batch["sentence"].to(device, non_blocking=True)
161 | object_pad_mask = batch["object_pad_mask"].to(device, non_blocking=True)
162 | sentence_pad_mask = batch["sentence_pad_mask"].to(device, non_blocking=True)
163 | token_type_index = batch["token_type_index"].to(device, non_blocking=True)
164 | position_index = batch["position_index"].to(device, non_blocking=True)
165 |
166 | # output
167 | targets = {}
168 | for key in ["rearrange_obj_labels"]:
169 | targets[key] = batch[key].to(device, non_blocking=True)
170 |
171 | preds = model.forward(xyzs, rgbs, object_pad_mask, sentence, sentence_pad_mask, token_type_index,
172 | position_index)
173 | loss = model.criterion(preds, targets)
174 |
175 | for key in ["rearrange_obj_labels"]:
176 | gts[key].append(targets[key])
177 | predictions[key].append(preds[key])
178 |
179 | epoch_loss += loss
180 | pbar.update(1)
181 | pbar.set_postfix({"Batch loss": loss})
182 |
183 | print('[Epoch:{}]: Val Loss:{:.4}'.format(epoch, epoch_loss))
184 |
185 | score = evaluate(gts, predictions, ["rearrange_obj_labels"])
186 | return score
187 |
188 |
189 | def infer_once(model, batch, device):
190 | """
191 | helper function to evaluate the model
192 |
193 | :param model:
194 | :param data_iter:
195 | :param epoch:
196 | :param device:
197 | :return:
198 | """
199 |
200 | model.eval()
201 |
202 | gts = defaultdict(list)
203 | predictions = defaultdict(list)
204 | with torch.no_grad():
205 |
206 | # input
207 | xyzs = batch["xyzs"].to(device, non_blocking=True)
208 | rgbs = batch["rgbs"].to(device, non_blocking=True)
209 | sentence = batch["sentence"].to(device, non_blocking=True)
210 | object_pad_mask = batch["object_pad_mask"].to(device, non_blocking=True)
211 | sentence_pad_mask = batch["sentence_pad_mask"].to(device, non_blocking=True)
212 | token_type_index = batch["token_type_index"].to(device, non_blocking=True)
213 | position_index = batch["position_index"].to(device, non_blocking=True)
214 |
215 | # output
216 | targets = {}
217 | for key in ["rearrange_obj_labels"]:
218 | targets[key] = batch[key].to(device, non_blocking=True)
219 |
220 | preds = model.forward(xyzs, rgbs, object_pad_mask, sentence, sentence_pad_mask, token_type_index,
221 | position_index)
222 |
223 | for key in ["rearrange_obj_labels"]:
224 | gts[key].append(targets[key])
225 | predictions[key].append(preds[key])
226 |
227 | return gts, predictions
228 |
229 |
230 | def save_model(model_dir, cfg, epoch, model, optimizer=None, scheduler=None):
231 | state_dict = {'epoch': epoch,
232 | 'model_state_dict': model.state_dict()}
233 | if optimizer is not None:
234 | state_dict["optimizer_state_dict"] = optimizer.state_dict()
235 | if scheduler is not None:
236 | state_dict["scheduler_state_dict"] = scheduler.state_dict()
237 | torch.save(state_dict, os.path.join(model_dir, "model.tar"))
238 | OmegaConf.save(cfg, os.path.join(model_dir, "config.yaml"))
239 |
240 |
241 | def load_model(model_dir, dirs_cfg):
242 | """
243 | Load transformer model
244 | Important: to use the model, call model.eval() or model.train()
245 | :param model_dir:
246 | :return:
247 | """
248 | # load dictionaries
249 | cfg = OmegaConf.load(os.path.join(model_dir, "config.yaml"))
250 | if dirs_cfg:
251 | cfg = OmegaConf.merge(cfg, dirs_cfg)
252 |
253 | data_cfg = cfg.dataset
254 | tokenizer = Tokenizer(data_cfg.vocab_dir)
255 | vocab_size = tokenizer.get_vocab_size()
256 |
257 | # initialize model
258 | model_cfg = cfg.model
259 | model = RearrangeObjectsPredictorPCT(vocab_size,
260 | num_attention_heads=model_cfg.num_attention_heads,
261 | encoder_hidden_dim=model_cfg.encoder_hidden_dim,
262 | encoder_dropout=model_cfg.encoder_dropout,
263 | encoder_activation=model_cfg.encoder_activation,
264 | encoder_num_layers=model_cfg.encoder_num_layers,
265 | use_focal_loss=model_cfg.use_focal_loss,
266 | focal_loss_gamma=model_cfg.focal_loss_gamma)
267 | model.to(cfg.device)
268 |
269 | # load state dicts
270 | checkpoint = torch.load(os.path.join(model_dir, "model.tar"))
271 | model.load_state_dict(checkpoint['model_state_dict'])
272 |
273 | optimizer = None
274 | if "optimizer_state_dict" in checkpoint:
275 | training_cfg = cfg.training
276 | optimizer = optim.Adam(model.parameters(), lr=training_cfg.learning_rate)
277 | optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
278 |
279 | scheduler = None
280 | if "scheduler_state_dict" in checkpoint:
281 | scheduler = None
282 | if scheduler:
283 | scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
284 |
285 | epoch = checkpoint['epoch']
286 | return cfg, tokenizer, model, optimizer, scheduler, epoch
287 |
288 |
289 | def run_model(cfg):
290 |
291 | np.random.seed(cfg.random_seed)
292 | torch.manual_seed(cfg.random_seed)
293 | if torch.cuda.is_available():
294 | torch.cuda.manual_seed(cfg.random_seed)
295 | torch.cuda.manual_seed_all(cfg.random_seed)
296 | torch.backends.cudnn.deterministic = True
297 |
298 | data_cfg = cfg.dataset
299 | tokenizer = Tokenizer(data_cfg.vocab_dir)
300 | vocab_size = tokenizer.get_vocab_size()
301 |
302 | train_dataset = ObjectSetReferDataset(data_cfg.dirs, data_cfg.index_dirs, "train", tokenizer,
303 | data_cfg.max_num_all_objects,
304 | data_cfg.max_num_shape_parameters,
305 | data_cfg.max_num_rearrange_features,
306 | data_cfg.max_num_anchor_features,
307 | data_cfg.num_pts)
308 |
309 | valid_dataset = ObjectSetReferDataset(data_cfg.dirs, data_cfg.index_dirs, "valid", tokenizer,
310 | data_cfg.max_num_all_objects,
311 | data_cfg.max_num_shape_parameters,
312 | data_cfg.max_num_rearrange_features,
313 | data_cfg.max_num_anchor_features,
314 | data_cfg.num_pts)
315 |
316 | data_iter = {}
317 | data_iter["train"] = DataLoader(train_dataset, batch_size=data_cfg.batch_size, shuffle=True,
318 | num_workers=data_cfg.num_workers,
319 | collate_fn=ObjectSetReferDataset.collate_fn,
320 | pin_memory=data_cfg.pin_memory)
321 | data_iter["valid"] = DataLoader(valid_dataset, batch_size=data_cfg.batch_size, shuffle=False,
322 | num_workers=data_cfg.num_workers,
323 | collate_fn=ObjectSetReferDataset.collate_fn,
324 | pin_memory=data_cfg.pin_memory)
325 |
326 | # load model
327 | model_cfg = cfg.model
328 | model = RearrangeObjectsPredictorPCT(vocab_size,
329 | num_attention_heads=model_cfg.num_attention_heads,
330 | encoder_hidden_dim=model_cfg.encoder_hidden_dim,
331 | encoder_dropout=model_cfg.encoder_dropout,
332 | encoder_activation=model_cfg.encoder_activation,
333 | encoder_num_layers=model_cfg.encoder_num_layers,
334 | use_focal_loss=model_cfg.use_focal_loss,
335 | focal_loss_gamma=model_cfg.focal_loss_gamma)
336 | model.to(cfg.device)
337 |
338 | training_cfg = cfg.training
339 | optimizer = optim.Adam(model.parameters(), lr=training_cfg.learning_rate, weight_decay=training_cfg.l2)
340 | scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=training_cfg.lr_restart)
341 | warmup = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=training_cfg.warmup,
342 | after_scheduler=scheduler)
343 |
344 | train_model(cfg, model, data_iter, optimizer, warmup, training_cfg.max_epochs, cfg.device, cfg.save_best_model)
345 |
346 | # save model
347 | if cfg.save_model:
348 | model_dir = os.path.join(cfg.experiment_dir, "model")
349 | print("Saving model to {}".format(model_dir))
350 | if not os.path.exists(model_dir):
351 | os.makedirs(model_dir)
352 | save_model(model_dir, cfg, cfg.max_epochs, model, optimizer, scheduler)
353 |
354 |
355 | if __name__ == "__main__":
356 |
357 | parser = argparse.ArgumentParser(description="Run a simple model")
358 | parser.add_argument("--dataset_base_dir", help='location of the dataset', type=str)
359 | parser.add_argument("--main_config", help='config yaml file for the model',
360 | default='../configs/object_selection_network.yaml',
361 | type=str)
362 | parser.add_argument("--dirs_config", help='config yaml file for directories',
363 | default='../configs/data/circle_dirs.yaml',
364 | type=str)
365 | args = parser.parse_args()
366 |
367 | # # debug
368 | # args.dataset_base_dir = "/home/weiyu/data_drive/data_new_objects"
369 |
370 | assert os.path.exists(args.main_config), "Cannot find config yaml file at {}".format(args.main_config)
371 | assert os.path.exists(args.dirs_config), "Cannot find config yaml file at {}".format(args.dir_config)
372 |
373 | os.environ["DATETIME"] = time.strftime("%Y%m%d-%H%M%S")
374 |
375 | main_cfg = OmegaConf.load(args.main_config)
376 | dirs_cfg = OmegaConf.load(args.dirs_config)
377 | cfg = OmegaConf.merge(main_cfg, dirs_cfg)
378 | cfg.dataset_base_dir = args.dataset_base_dir
379 | OmegaConf.resolve(cfg)
380 |
381 | if not os.path.exists(cfg.experiment_dir):
382 | os.makedirs(cfg.experiment_dir)
383 |
384 | OmegaConf.save(cfg, os.path.join(cfg.experiment_dir, "config.yaml"))
385 |
386 | run_model(cfg)
--------------------------------------------------------------------------------
/src/structformer/training/train_structformer_no_structure.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, division
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.optim as optim
6 | from torch.optim import lr_scheduler
7 | from warmup_scheduler import GradualWarmupScheduler
8 | import numpy as np
9 | import torchvision
10 | from torchvision import datasets, models, transforms
11 | import matplotlib.pyplot as plt
12 | import time
13 | import os
14 | import copy
15 | import tqdm
16 |
17 | import pickle
18 | import argparse
19 | from omegaconf import OmegaConf
20 | from collections import defaultdict
21 |
22 | from torch.utils.data import DataLoader
23 | from structformer.data.sequence_dataset import SequenceDataset
24 | from structformer.models.pose_generation_network import PriorContinuousOutEncoderDecoderPCT6DDropoutAllObjects
25 | from structformer.data.tokenizer import Tokenizer
26 | from structformer.utils.rearrangement import evaluate_prior_prediction, generate_square_subsequent_mask
27 |
28 |
29 | def train_model(cfg, model, data_iter, optimizer, warmup, num_epochs, device, save_best_model, grad_clipping=1.0):
30 |
31 | if save_best_model:
32 | best_model_dir = os.path.join(cfg.experiment_dir, "best_model")
33 | print("best model will be saved to {}".format(best_model_dir))
34 | if not os.path.exists(best_model_dir):
35 | os.makedirs(best_model_dir)
36 | best_score = -np.inf
37 |
38 | for epoch in range(num_epochs):
39 | print('Epoch {}/{}'.format(epoch, num_epochs - 1))
40 | print('-' * 10)
41 |
42 | model.train()
43 | epoch_loss = 0
44 | gts = defaultdict(list)
45 | predictions = defaultdict(list)
46 |
47 | with tqdm.tqdm(total=len(data_iter["train"])) as pbar:
48 | for step, batch in enumerate(data_iter["train"]):
49 | optimizer.zero_grad()
50 | # input
51 | xyzs = batch["xyzs"].to(device, non_blocking=True)
52 | rgbs = batch["rgbs"].to(device, non_blocking=True)
53 | object_pad_mask = batch["object_pad_mask"].to(device, non_blocking=True)
54 | other_xyzs = batch["other_xyzs"].to(device, non_blocking=True)
55 | other_rgbs = batch["other_rgbs"].to(device, non_blocking=True)
56 | other_object_pad_mask = batch["other_object_pad_mask"].to(device, non_blocking=True)
57 | sentence = batch["sentence"].to(device, non_blocking=True)
58 | sentence_pad_mask = batch["sentence_pad_mask"].to(device, non_blocking=True)
59 | token_type_index = batch["token_type_index"].to(device, non_blocking=True)
60 | position_index = batch["position_index"].to(device, non_blocking=True)
61 | obj_x_inputs = batch["obj_x_inputs"].to(device, non_blocking=True)
62 | obj_y_inputs = batch["obj_y_inputs"].to(device, non_blocking=True)
63 | obj_z_inputs = batch["obj_z_inputs"].to(device, non_blocking=True)
64 | obj_theta_inputs = batch["obj_theta_inputs"].to(device, non_blocking=True)
65 |
66 | tgt_mask = generate_square_subsequent_mask(object_pad_mask.shape[1]).to(device, non_blocking=True)
67 | start_token = torch.zeros((object_pad_mask.shape[0], 1), dtype=torch.long).to(device, non_blocking=True)
68 |
69 | # output
70 | targets = {}
71 | for key in ["obj_x_outputs", "obj_y_outputs", "obj_z_outputs", "obj_theta_outputs"]:
72 | targets[key] = batch[key].to(device, non_blocking=True)
73 | targets[key] = targets[key].reshape(targets[key].shape[0] * targets[key].shape[1], -1)
74 |
75 | preds = model.forward(xyzs, rgbs, object_pad_mask, other_xyzs, other_rgbs, other_object_pad_mask,
76 | sentence, sentence_pad_mask, token_type_index,
77 | obj_x_inputs, obj_y_inputs, obj_z_inputs, obj_theta_inputs, position_index,
78 | tgt_mask, start_token)
79 |
80 | loss = model.criterion(preds, targets)
81 | loss.backward()
82 |
83 | if grad_clipping != 0.0:
84 | torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clipping)
85 |
86 | optimizer.step()
87 | epoch_loss += loss
88 |
89 | for key in ["obj_x_outputs", "obj_y_outputs", "obj_z_outputs", "obj_theta_outputs"]:
90 | gts[key].append(targets[key].detach())
91 | predictions[key].append(preds[key].detach())
92 |
93 | pbar.update(1)
94 | pbar.set_postfix({"Batch loss": loss})
95 |
96 | warmup.step()
97 |
98 | print('[Epoch:{}]: Training Loss:{:.4}'.format(epoch, epoch_loss))
99 | evaluate_prior_prediction(gts, predictions, ["obj_x_outputs", "obj_y_outputs", "obj_z_outputs", "obj_theta_outputs"])
100 |
101 | score = validate(cfg, model, data_iter["valid"], epoch, device)
102 | if save_best_model and score > best_score:
103 | print("Saving best model so far...")
104 | best_score = score
105 | save_model(best_model_dir, cfg, epoch, model)
106 |
107 | return model
108 |
109 |
110 | def validate(cfg, model, data_iter, epoch, device):
111 | """
112 | helper function to evaluate the model
113 |
114 | :param model:
115 | :param data_iter:
116 | :param epoch:
117 | :param device:
118 | :return:
119 | """
120 |
121 | model.eval()
122 |
123 | epoch_loss = 0
124 | gts = defaultdict(list)
125 | predictions = defaultdict(list)
126 | with torch.no_grad():
127 |
128 | with tqdm.tqdm(total=len(data_iter)) as pbar:
129 | for step, batch in enumerate(data_iter):
130 |
131 | # input
132 | xyzs = batch["xyzs"].to(device, non_blocking=True)
133 | rgbs = batch["rgbs"].to(device, non_blocking=True)
134 | object_pad_mask = batch["object_pad_mask"].to(device, non_blocking=True)
135 | other_xyzs = batch["other_xyzs"].to(device, non_blocking=True)
136 | other_rgbs = batch["other_rgbs"].to(device, non_blocking=True)
137 | other_object_pad_mask = batch["other_object_pad_mask"].to(device, non_blocking=True)
138 | sentence = batch["sentence"].to(device, non_blocking=True)
139 | sentence_pad_mask = batch["sentence_pad_mask"].to(device, non_blocking=True)
140 | token_type_index = batch["token_type_index"].to(device, non_blocking=True)
141 | position_index = batch["position_index"].to(device, non_blocking=True)
142 | obj_x_inputs = batch["obj_x_inputs"].to(device, non_blocking=True)
143 | obj_y_inputs = batch["obj_y_inputs"].to(device, non_blocking=True)
144 | obj_z_inputs = batch["obj_z_inputs"].to(device, non_blocking=True)
145 | obj_theta_inputs = batch["obj_theta_inputs"].to(device, non_blocking=True)
146 |
147 | tgt_mask = generate_square_subsequent_mask(object_pad_mask.shape[1]).to(device, non_blocking=True)
148 | start_token = torch.zeros((object_pad_mask.shape[0], 1), dtype=torch.long).to(device, non_blocking=True)
149 |
150 | # output
151 | targets = {}
152 | for key in ["obj_x_outputs", "obj_y_outputs", "obj_z_outputs", "obj_theta_outputs"]:
153 | targets[key] = batch[key].to(device, non_blocking=True)
154 | targets[key] = targets[key].reshape(targets[key].shape[0] * targets[key].shape[1], -1)
155 |
156 | preds = model.forward(xyzs, rgbs, object_pad_mask, other_xyzs, other_rgbs, other_object_pad_mask,
157 | sentence, sentence_pad_mask, token_type_index,
158 | obj_x_inputs, obj_y_inputs, obj_z_inputs, obj_theta_inputs, position_index,
159 | tgt_mask, start_token)
160 | loss = model.criterion(preds, targets)
161 |
162 | for key in ["obj_x_outputs", "obj_y_outputs", "obj_z_outputs", "obj_theta_outputs"]:
163 | gts[key].append(targets[key])
164 | predictions[key].append(preds[key])
165 |
166 | epoch_loss += loss
167 | pbar.update(1)
168 | pbar.set_postfix({"Batch loss": loss})
169 |
170 | print('[Epoch:{}]: Val Loss:{:.4}'.format(epoch, epoch_loss))
171 |
172 | score = evaluate_prior_prediction(gts, predictions,
173 | ["obj_x_outputs", "obj_y_outputs", "obj_z_outputs", "obj_theta_outputs"])
174 | return score
175 |
176 |
177 | def infer_once(cfg, model, batch, device):
178 |
179 | model.eval()
180 |
181 | predictions = defaultdict(list)
182 | with torch.no_grad():
183 |
184 | # input
185 | xyzs = batch["xyzs"].to(device, non_blocking=True)
186 | rgbs = batch["rgbs"].to(device, non_blocking=True)
187 | object_pad_mask = batch["object_pad_mask"].to(device, non_blocking=True)
188 | other_xyzs = batch["other_xyzs"].to(device, non_blocking=True)
189 | other_rgbs = batch["other_rgbs"].to(device, non_blocking=True)
190 | other_object_pad_mask = batch["other_object_pad_mask"].to(device, non_blocking=True)
191 | sentence = batch["sentence"].to(device, non_blocking=True)
192 | sentence_pad_mask = batch["sentence_pad_mask"].to(device, non_blocking=True)
193 | token_type_index = batch["token_type_index"].to(device, non_blocking=True)
194 | position_index = batch["position_index"].to(device, non_blocking=True)
195 |
196 | obj_x_inputs = batch["obj_x_inputs"].to(device, non_blocking=True)
197 | obj_y_inputs = batch["obj_y_inputs"].to(device, non_blocking=True)
198 | obj_z_inputs = batch["obj_z_inputs"].to(device, non_blocking=True)
199 | obj_theta_inputs = batch["obj_theta_inputs"].to(device, non_blocking=True)
200 |
201 | tgt_mask = generate_square_subsequent_mask(object_pad_mask.shape[1]).to(device, non_blocking=True)
202 | start_token = torch.zeros((object_pad_mask.shape[0], 1), dtype=torch.long).to(device, non_blocking=True)
203 |
204 | preds = model.forward(xyzs, rgbs, object_pad_mask, other_xyzs, other_rgbs, other_object_pad_mask,
205 | sentence, sentence_pad_mask, token_type_index,
206 | obj_x_inputs, obj_y_inputs, obj_z_inputs, obj_theta_inputs, position_index,
207 | tgt_mask, start_token)
208 |
209 | for key in ["obj_x_outputs", "obj_y_outputs", "obj_z_outputs", "obj_theta_outputs"]:
210 | predictions[key].append(preds[key])
211 |
212 | return predictions
213 |
214 |
215 | def save_model(model_dir, cfg, epoch, model, optimizer=None, scheduler=None):
216 | state_dict = {'epoch': epoch,
217 | 'model_state_dict': model.state_dict()}
218 | if optimizer is not None:
219 | state_dict["optimizer_state_dict"] = optimizer.state_dict()
220 | if scheduler is not None:
221 | state_dict["scheduler_state_dict"] = scheduler.state_dict()
222 | torch.save(state_dict, os.path.join(model_dir, "model.tar"))
223 | OmegaConf.save(cfg, os.path.join(model_dir, "config.yaml"))
224 |
225 |
226 | def load_model(model_dir, dirs_cfg):
227 | """
228 | Load transformer model
229 | Important: to use the model, call model.eval() or model.train()
230 | :param model_dir:
231 | :return:
232 | """
233 | # load dictionaries
234 | cfg = OmegaConf.load(os.path.join(model_dir, "config.yaml"))
235 | if dirs_cfg:
236 | cfg = OmegaConf.merge(cfg, dirs_cfg)
237 |
238 | data_cfg = cfg.dataset
239 | tokenizer = Tokenizer(data_cfg.vocab_dir)
240 | vocab_size = tokenizer.get_vocab_size()
241 |
242 | # initialize model
243 | model_cfg = cfg.model
244 | model = PriorContinuousOutEncoderDecoderPCT6DDropoutAllObjects(vocab_size,
245 | num_attention_heads=model_cfg.num_attention_heads,
246 | encoder_hidden_dim=model_cfg.encoder_hidden_dim,
247 | encoder_dropout=model_cfg.encoder_dropout,
248 | encoder_activation=model_cfg.encoder_activation,
249 | encoder_num_layers=model_cfg.encoder_num_layers,
250 | object_dropout=model_cfg.object_dropout,
251 | theta_loss_divide=model_cfg.theta_loss_divide,
252 | ignore_rgb=model_cfg.ignore_rgb)
253 | model.to(cfg.device)
254 |
255 | # load state dicts
256 | checkpoint = torch.load(os.path.join(model_dir, "model.tar"))
257 | model.load_state_dict(checkpoint['model_state_dict'])
258 |
259 | optimizer = None
260 | if "optimizer_state_dict" in checkpoint:
261 | training_cfg = cfg.training
262 | optimizer = optim.Adam(model.parameters(), lr=training_cfg.learning_rate)
263 | optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
264 |
265 | scheduler = None
266 | if "scheduler_state_dict" in checkpoint:
267 | scheduler = None
268 | if scheduler:
269 | scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
270 |
271 | epoch = checkpoint['epoch']
272 | return cfg, tokenizer, model, optimizer, scheduler, epoch
273 |
274 |
275 | def run_model(cfg):
276 |
277 | np.random.seed(cfg.random_seed)
278 | torch.manual_seed(cfg.random_seed)
279 | if torch.cuda.is_available():
280 | torch.cuda.manual_seed(cfg.random_seed)
281 | torch.cuda.manual_seed_all(cfg.random_seed)
282 | torch.backends.cudnn.deterministic = True
283 |
284 | data_cfg = cfg.dataset
285 | tokenizer = Tokenizer(data_cfg.vocab_dir)
286 | vocab_size = tokenizer.get_vocab_size()
287 |
288 | train_dataset = SequenceDataset(data_cfg.dirs, data_cfg.index_dirs, "train", tokenizer,
289 | data_cfg.max_num_objects,
290 | data_cfg.max_num_other_objects,
291 | data_cfg.max_num_shape_parameters,
292 | data_cfg.max_num_rearrange_features,
293 | data_cfg.max_num_anchor_features,
294 | data_cfg.num_pts,
295 | data_cfg.use_structure_frame)
296 | valid_dataset = SequenceDataset(data_cfg.dirs, data_cfg.index_dirs, "valid", tokenizer,
297 | data_cfg.max_num_objects,
298 | data_cfg.max_num_other_objects,
299 | data_cfg.max_num_shape_parameters,
300 | data_cfg.max_num_rearrange_features,
301 | data_cfg.max_num_anchor_features,
302 | data_cfg.num_pts,
303 | data_cfg.use_structure_frame)
304 |
305 | data_iter = {}
306 | data_iter["train"] = DataLoader(train_dataset, batch_size=data_cfg.batch_size, shuffle=True,
307 | collate_fn=SequenceDataset.collate_fn,
308 | pin_memory=data_cfg.pin_memory, num_workers=data_cfg.num_workers)
309 | data_iter["valid"] = DataLoader(valid_dataset, batch_size=data_cfg.batch_size, shuffle=False,
310 | collate_fn=SequenceDataset.collate_fn,
311 | pin_memory=data_cfg.pin_memory, num_workers=data_cfg.num_workers)
312 |
313 | # load model
314 | model_cfg = cfg.model
315 | model = PriorContinuousOutEncoderDecoderPCT6DDropoutAllObjects(vocab_size,
316 | num_attention_heads=model_cfg.num_attention_heads,
317 | encoder_hidden_dim=model_cfg.encoder_hidden_dim,
318 | encoder_dropout=model_cfg.encoder_dropout,
319 | encoder_activation=model_cfg.encoder_activation,
320 | encoder_num_layers=model_cfg.encoder_num_layers,
321 | object_dropout=model_cfg.object_dropout,
322 | theta_loss_divide=model_cfg.theta_loss_divide,
323 | ignore_rgb=model_cfg.ignore_rgb)
324 | model.to(cfg.device)
325 |
326 | training_cfg = cfg.training
327 | optimizer = optim.Adam(model.parameters(), lr=training_cfg.learning_rate, weight_decay=training_cfg.l2)
328 | scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=training_cfg.lr_restart)
329 | warmup = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=training_cfg.warmup,
330 | after_scheduler=scheduler)
331 |
332 | train_model(cfg, model, data_iter, optimizer, warmup, training_cfg.max_epochs, cfg.device, cfg.save_best_model)
333 |
334 | # save model
335 | if cfg.save_model:
336 | model_dir = os.path.join(cfg.experiment_dir, "model")
337 | print("Saving model to {}".format(model_dir))
338 | if not os.path.exists(model_dir):
339 | os.makedirs(model_dir)
340 | save_model(model_dir, cfg, cfg.max_epochs, model, optimizer, scheduler)
341 |
342 |
343 | if __name__ == "__main__":
344 |
345 | parser = argparse.ArgumentParser(description="Run a simple model")
346 | parser.add_argument("--dataset_base_dir", help='location of the dataset', type=str)
347 | parser.add_argument("--main_config", help='config yaml file for the model',
348 | default='../configs/structformer_no_structure.yaml',
349 | type=str)
350 | parser.add_argument("--dirs_config", help='config yaml file for directories',
351 | default='../configs/data/circle_dirs.yaml',
352 | type=str)
353 | args = parser.parse_args()
354 |
355 | # # debug
356 | # args.dataset_base_dir = "/home/weiyu/data_drive/data_new_objects"
357 |
358 | assert os.path.exists(args.main_config), "Cannot find config yaml file at {}".format(args.main_config)
359 | assert os.path.exists(args.dirs_config), "Cannot find config yaml file at {}".format(args.dir_config)
360 |
361 | os.environ["DATETIME"] = time.strftime("%Y%m%d-%H%M%S")
362 |
363 | main_cfg = OmegaConf.load(args.main_config)
364 | dirs_cfg = OmegaConf.load(args.dirs_config)
365 | cfg = OmegaConf.merge(main_cfg, dirs_cfg)
366 | cfg.dataset_base_dir = args.dataset_base_dir
367 | OmegaConf.resolve(cfg)
368 |
369 | if not os.path.exists(cfg.experiment_dir):
370 | os.makedirs(cfg.experiment_dir)
371 |
372 | OmegaConf.save(cfg, os.path.join(cfg.experiment_dir, "config.yaml"))
373 |
374 | run_model(cfg)
--------------------------------------------------------------------------------
/src/structformer/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wliu88/StructFormer/a81d8c83313f1f24a75ca83aa9dd9f8cfcf27419/src/structformer/utils/__init__.py
--------------------------------------------------------------------------------
/src/structformer/utils/brain2/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wliu88/StructFormer/a81d8c83313f1f24a75ca83aa9dd9f8cfcf27419/src/structformer/utils/brain2/__init__.py
--------------------------------------------------------------------------------
/src/structformer/utils/brain2/camera.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 |
10 | from __future__ import print_function
11 |
12 | import numpy as np
13 | import open3d
14 | import trimesh
15 |
16 | import structformer.utils.transformations as tra
17 | from structformer.utils.brain2.pose import make_pose
18 |
19 |
20 | def get_camera_from_h5(h5):
21 | """ Simple reference to help make these """
22 | proj_near = h5['cam_near'][()]
23 | proj_far = h5['cam_far'][()]
24 | proj_fov = h5['cam_fov'][()]
25 | width = h5['cam_width'][()]
26 | height = h5['cam_height'][()]
27 | return GenericCameraReference(proj_near, proj_far, proj_fov, width, height)
28 |
29 |
30 | class GenericCameraReference(object):
31 | """ Class storing camera information and providing easy image capture """
32 |
33 | def __init__(self, proj_near=0.01, proj_far=5., proj_fov=60., img_width=640,
34 | img_height=480):
35 |
36 | self.proj_near = proj_near
37 | self.proj_far = proj_far
38 | self.proj_fov = proj_fov
39 | self.img_width = img_width
40 | self.img_height = img_height
41 | self.x_offset = self.img_width / 2.
42 | self.y_offset = self.img_height / 2.
43 |
44 | # Compute focal params
45 | aspect_ratio = self.img_width / self.img_height
46 | e = 1 / (np.tan(np.radians(self.proj_fov/2.)))
47 | t = self.proj_near / e
48 | b = -t
49 | r = t * aspect_ratio
50 | l = -r
51 | # pixels per meter
52 | alpha = self.img_width / (r-l)
53 | self.focal_length = self.proj_near * alpha
54 | self.fx = self.focal_length
55 | self.fy = self.focal_length
56 | self.pose = None
57 | self.inv_pose = None
58 |
59 | def set_pose(self, trans, rot):
60 | self.pose = make_pose(trans, rot)
61 | self.inv_pose = tra.inverse_matrix(self.pose)
62 |
63 | def set_pose_matrix(self, matrix):
64 | self.pose = matrix
65 | self.inv_pose = tra.inverse_matrix(matrix)
66 |
67 | def transform_to_world_coords(self, xyz):
68 | """ transform xyz into world coordinates """
69 | #cam_pose = tra.inverse_matrix(self.pose).dot(tra.euler_matrix(np.pi, 0, 0))
70 | #xyz = trimesh.transform_points(xyz, self.inv_pose)
71 | #xyz = trimesh.transform_points(xyz, cam_pose)
72 | #pose = tra.euler_matrix(np.pi, 0, 0) @ self.pose
73 | pose = self.pose
74 | xyz = trimesh.transform_points(xyz, pose)
75 | return xyz
76 |
77 | def get_camera_presets():
78 | return [
79 | "n/a",
80 | "azure_depth_nfov",
81 | "realsense",
82 | "azure_720p",
83 | "simple256",
84 | "simple512",
85 | ]
86 |
87 |
88 | def get_camera_preset(name):
89 |
90 | if name == "azure_depth_nfov":
91 | # Setting for depth camera is pretty different from RGB
92 | height, width, fov = 576, 640, 75
93 | if name == "azure_720p":
94 | # This is actually the 720p RGB setting
95 | # Used for our color camera most of the time
96 | #height, width, fov = 720, 1280, 90
97 | height, width, fov = 720, 1280, 60
98 | elif name == "realsense":
99 | height, width, fov = 480, 640, 60
100 | elif name == "simple256":
101 | height, width, fov = 256, 256, 60
102 | elif name == "simple512":
103 | height, width, fov = 512, 512, 60
104 | else:
105 | raise RuntimeError(('camera "%s" not supported, choose from: ' +
106 | str(get_camera_presets())) % str(name))
107 | return height, width, fov
108 |
109 |
110 | def get_generic_camera(name):
111 | h, w, fov = get_camera_preset(name)
112 | return GenericCameraReference(img_height=h, img_width=w, proj_fov=fov)
113 |
114 |
115 | def get_matrix_of_indices(height, width):
116 | """ Get indices """
117 | return np.indices((height, width), dtype=np.float32).transpose(1,2,0)
118 |
119 | # --------------------------------------------------------
120 | # NOTE: this code taken from Arsalan and modified
121 | def compute_xyz(depth_img, camera, visualize_xyz=False,
122 | xmap=None, ymap=None, max_clip_depth=5):
123 | """ Compute xyz image from depth for a camera """
124 |
125 | # We need thes eparameters
126 | height = camera.img_height
127 | width = camera.img_width
128 | assert depth_img.shape[0] == camera.img_height
129 | assert depth_img.shape[1] == camera.img_width
130 | fx = camera.fx
131 | fy = camera.fy
132 | cx = camera.x_offset
133 | cy = camera.y_offset
134 |
135 | """
136 | # Create the matrix of parameters
137 | indices = np.indices((height, width), dtype=np.float32).transpose(1,2,0)
138 | # pixel indices start at top-left corner. for these equations, it starts at bottom-left
139 | # indices[..., 0] = np.flipud(indices[..., 0])
140 | z_e = depth_img
141 | x_e = (indices[..., 1] - x_offset) * z_e / fx
142 | y_e = (indices[..., 0] - y_offset) * z_e / fy
143 | xyz_img = np.stack([x_e, y_e, z_e], axis=-1) # Shape: [H x W x 3]
144 | """
145 |
146 | height = depth_img.shape[0]
147 | width = depth_img.shape[1]
148 | input_x = np.arange(width)
149 | input_y = np.arange(height)
150 | input_x, input_y = np.meshgrid(input_x, input_y)
151 | input_x = input_x.flatten()
152 | input_y = input_y.flatten()
153 | input_z = depth_img.flatten()
154 | # clip points that are farther than max distance
155 | input_z[input_z > max_clip_depth] = 0
156 | output_x = (input_x * input_z - cx * input_z) / fx
157 | output_y = (input_y * input_z - cy * input_z) / fy
158 | raw_pc = np.stack([output_x, output_y, input_z], -1).reshape(
159 | height, width, 3
160 | )
161 | return raw_pc
162 |
163 | if visualize_xyz:
164 | unordered_pc = xyz_img.reshape(-1, 3)
165 | pcd = open3d.geometry.PointCloud()
166 | pcd.points = open3d.utility.Vector3dVector(unordered_pc)
167 | pcd.transform([[1,0,0,0], [0,1,0,0], [0,0,-1,0], [0,0,0,1]]) # Transform it so it's not upside down
168 | open3d.visualization.draw_geometries([pcd])
169 |
170 | return xyz_img
171 |
172 | def show_pcs(xyz, rgb):
173 | """ Display point clouds """
174 | if len(xyz.shape) > 2:
175 | unordered_pc = xyz.reshape(-1, 3)
176 | unordered_rgb = rgb.reshape(-1, 3) / 255.
177 | assert(unordered_rgb.shape[0] == unordered_pc.shape[0])
178 | assert(unordered_pc.shape[1] == 3)
179 | assert(unordered_rgb.shape[1] == 3)
180 | pcd = open3d.geometry.PointCloud()
181 | pcd.points = open3d.utility.Vector3dVector(unordered_pc)
182 | pcd.colors = open3d.utility.Vector3dVector(unordered_rgb)
183 | pcd.transform([[1,0,0,0],[0,1,0,0],[0,0,-1,0],[0,0,0,1]]) # Transform it so it's not upside down
184 | open3d.visualization.draw_geometries([pcd])
185 |
--------------------------------------------------------------------------------
/src/structformer/utils/brain2/image.py:
--------------------------------------------------------------------------------
1 | """
2 | By Chris Paxton.
3 |
4 | Copyright (c) 2018, Johns Hopkins University
5 | All rights reserved.
6 |
7 | Redistribution and use in source and binary forms, with or without
8 | modification, are permitted provided that the following conditions are met:
9 | * Redistributions of source code must retain the above copyright
10 | notice, this list of conditions and the following disclaimer.
11 | * Redistributions in binary form must reproduce the above copyright
12 | notice, this list of conditions and the following disclaimer in the
13 | documentation and/or other materials provided with the distribution.
14 | * Neither the name of the Johns Hopkins University nor the
15 | names of its contributors may be used to endorse or promote products
16 | derived from this software without specific prior written permission.
17 |
18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
19 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
20 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21 | DISCLAIMED. IN NO EVENT SHALL JOHNS HOPKINS UNIVERSITY BE LIABLE FOR ANY
22 | DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
23 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
24 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
25 | ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
26 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
27 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28 | """
29 |
30 | import numpy as np
31 | import io
32 | from PIL import Image
33 |
34 | def GetJpeg(img):
35 | '''
36 | Save a numpy array as a Jpeg, then get it out as a binary blob
37 | '''
38 | im = Image.fromarray(np.uint8(img))
39 | output = io.BytesIO()
40 | im.save(output, format="JPEG", quality=80)
41 | return output.getvalue()
42 |
43 | def JpegToNumpy(jpeg):
44 | stream = io.BytesIO(jpeg)
45 | im = Image.open(stream)
46 | return np.asarray(im, dtype=np.uint8)
47 |
48 | def ConvertJpegListToNumpy(data):
49 | length = len(data)
50 | imgs = []
51 | for raw in data:
52 | imgs.append(JpegToNumpy(raw))
53 | arr = np.array(imgs)
54 | return arr
55 |
56 | def DepthToZBuffer(img, z_near, z_far):
57 | real_depth = z_near * z_far / (z_far - img * (z_far - z_near))
58 | return real_depth
59 |
60 | def ZBufferToRGB(img, z_near, z_far):
61 | real_depth = z_near * z_far / (z_far - img * (z_far - z_near))
62 | depth_m = np.uint8(real_depth)
63 | depth_cm = np.uint8((real_depth-depth_m)*100)
64 | depth_tmm = np.uint8((real_depth-depth_m-0.01*depth_cm)*10000)
65 | return np.dstack([depth_m, depth_cm, depth_tmm])
66 |
67 | def RGBToDepth(img, min_dist=0., max_dist=2.,):
68 | return (img[:,:,0]+.01*img[:,:,1]+.0001*img[:,:,2]).clip(min_dist, max_dist)
69 | #return img[:,:,0]+.01*img[:,:,1]+.0001*img[:,:,2]
70 |
71 | def MaskToRGBA(img):
72 | buf = img.astype(np.int32)
73 | A = buf.astype(np.uint8)
74 | buf = np.right_shift(buf, 8)
75 | B = buf.astype(np.uint8)
76 | buf = np.right_shift(buf, 8)
77 | G = buf.astype(np.uint8)
78 | buf = np.right_shift(buf, 8)
79 | R = buf.astype(np.uint8)
80 |
81 | dims = [np.expand_dims(d, -1) for d in [R,G,B,A]]
82 | return np.concatenate(dims, axis=-1)
83 |
84 | def RGBAToMask(img):
85 | mask = np.zeros(img.shape[:-1], dtype=np.int32)
86 | buf = img.astype(np.int32)
87 | for i, dim in enumerate([3,2,1,0]):
88 | shift = 8*i
89 | #print(i, dim, shift, buf[0,0,dim], np.left_shift(buf[0,0,dim], shift))
90 | mask += np.left_shift(buf[:,:, dim], shift)
91 | return mask
92 |
93 | def RGBAArrayToMasks(img):
94 | mask = np.zeros(img.shape[:-1], dtype=np.int32)
95 | buf = img.astype(np.int32)
96 | for i, dim in enumerate([3,2,1,0]):
97 | shift = 8*i
98 | mask += np.left_shift(buf[:,:,:, dim], shift)
99 | return mask
100 |
101 | def GetPNG(img):
102 | '''
103 | Save a numpy array as a PNG, then get it out as a binary blob
104 | '''
105 | im = Image.fromarray(np.uint8(img))
106 | output = io.BytesIO()
107 | im.save(output, format="PNG")#, quality=80)
108 | return output.getvalue()
109 |
110 | def PNGToNumpy(png):
111 | stream = io.BytesIO(png)
112 | im = Image.open(stream)
113 | return np.array(im, dtype=np.uint8)
114 |
115 | def ConvertPNGListToNumpy(data):
116 | length = len(data)
117 | imgs = []
118 | for raw in data:
119 | imgs.append(PNGToNumpy(raw))
120 | arr = np.array(imgs)
121 | return arr
122 |
123 | def ConvertDepthPNGListToNumpy(data):
124 | length = len(data)
125 | imgs = []
126 | for raw in data:
127 | imgs.append(RGBToDepth(PNGToNumpy(raw)))
128 | arr = np.array(imgs)
129 | return arr
130 |
131 | import cv2
132 | def Shrink(img, nw=64):
133 | h,w = img.shape[:2]
134 | ratio = float(nw) / w
135 | nh = int(ratio * h)
136 | img2 = cv2.resize(img, dsize=(nw, nh),
137 | interpolation=cv2.INTER_NEAREST)
138 | return img2
139 |
140 | def ShrinkSmooth(img, nw=64):
141 | h,w = img.shape[:2]
142 | ratio = float(nw) / w
143 | nh = int(ratio * h)
144 | img2 = cv2.resize(img, dsize=(nw, nh),
145 | interpolation=cv2.INTER_LINEAR)
146 | return img2
147 |
148 | def CropCenter(img, cropx, cropy):
149 | y = img.shape[0]
150 | x = img.shape[1]
151 | startx = (x // 2) - (cropx // 2)
152 | starty = (y // 2) - (cropy // 2)
153 | return img[starty: starty + cropy, startx : startx + cropx]
154 |
155 |
--------------------------------------------------------------------------------
/src/structformer/utils/brain2/pose.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # NVIDIA CORPORATION and its licensors retain all intellectual property
4 | # and proprietary rights in and to this software, related documentation
5 | # and any modifications thereto. Any use, reproduction, disclosure or
6 | # distribution of this software and related documentation without an express
7 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 |
10 | from __future__ import print_function
11 |
12 | import structformer.utils.transformations as tra
13 |
14 |
15 | def make_pose(trans, rot):
16 | """Make 4x4 matrix from (trans, rot)"""
17 | pose = tra.quaternion_matrix(rot)
18 | pose[:3, 3] = trans
19 | return pose
20 |
--------------------------------------------------------------------------------
/src/structformer/utils/pointnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from time import time
5 | import numpy as np
6 |
7 |
8 | # reference https://github.com/yanx27/Pointnet_Pointnet2_pytorch, modified by Yang You
9 |
10 |
11 | def timeit(tag, t):
12 | print("{}: {}s".format(tag, time() - t))
13 | return time()
14 |
15 | def pc_normalize(pc):
16 | if type(pc).__module__ == np.__name__:
17 | centroid = np.mean(pc, axis=0)
18 | pc = pc - centroid
19 | m = np.max(np.sqrt(np.sum(pc**2, axis=1)))
20 | pc = pc / m
21 | else:
22 | centroid = torch.mean(pc, dim=0)
23 | pc = pc - centroid
24 | m = torch.max(torch.sqrt(torch.sum(pc ** 2, dim=1)))
25 | pc = pc / m
26 | return pc
27 |
28 | def square_distance(src, dst):
29 | """
30 | Calculate Euclid distance between each two points.
31 | src^T * dst = xn * xm + yn * ym + zn * zm;
32 | sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
33 | sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
34 | dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
35 | = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst
36 | Input:
37 | src: source points, [B, N, C]
38 | dst: target points, [B, M, C]
39 | Output:
40 | dist: per-point square distance, [B, N, M]
41 | """
42 | return torch.sum((src[:, :, None] - dst[:, None]) ** 2, dim=-1)
43 |
44 |
45 | def index_points(points, idx):
46 | """
47 | Input:
48 | points: input points data, [B, N, C]
49 | idx: sample index data, [B, S, [K]]
50 | Return:
51 | new_points:, indexed points data, [B, S, [K], C]
52 | """
53 | raw_size = idx.size()
54 | idx = idx.reshape(raw_size[0], -1)
55 | res = torch.gather(points, 1, idx[..., None].expand(-1, -1, points.size(-1)))
56 | return res.reshape(*raw_size, -1)
57 |
58 |
59 | def farthest_point_sample(xyz, npoint):
60 | """
61 | Input:
62 | xyz: pointcloud data, [B, N, 3]
63 | npoint: number of samples
64 | Return:
65 | centroids: sampled pointcloud index, [B, npoint]
66 | """
67 | device = xyz.device
68 | B, N, C = xyz.shape
69 | centroids = torch.zeros(B, npoint, dtype=torch.long).to(device)
70 | distance = torch.ones(B, N).to(device) * 1e10
71 | farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device)
72 | batch_indices = torch.arange(B, dtype=torch.long).to(device)
73 | for i in range(npoint):
74 | centroids[:, i] = farthest
75 | centroid = xyz[batch_indices, farthest, :].view(B, 1, 3)
76 | dist = torch.sum((xyz - centroid) ** 2, -1)
77 | distance = torch.min(distance, dist)
78 | farthest = torch.max(distance, -1)[1]
79 | return centroids
80 |
81 |
82 | def query_ball_point(radius, nsample, xyz, new_xyz):
83 | """
84 | Input:
85 | radius: local region radius
86 | nsample: max sample number in local region
87 | xyz: all points, [B, N, 3]
88 | new_xyz: query points, [B, S, 3]
89 | Return:
90 | group_idx: grouped points index, [B, S, nsample]
91 | """
92 | device = xyz.device
93 | B, N, C = xyz.shape
94 | _, S, _ = new_xyz.shape
95 | group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1])
96 | sqrdists = square_distance(new_xyz, xyz)
97 | group_idx[sqrdists > radius ** 2] = N
98 | group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample]
99 | group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample])
100 | mask = group_idx == N
101 | group_idx[mask] = group_first[mask]
102 | return group_idx
103 |
104 |
105 | def sample_and_group(npoint, radius, nsample, xyz, points, returnfps=False, knn=False):
106 | """
107 | Input:
108 | npoint:
109 | radius:
110 | nsample:
111 | xyz: input points position data, [B, N, 3]
112 | points: input points data, [B, N, D]
113 | Return:
114 | new_xyz: sampled points position data, [B, npoint, nsample, 3]
115 | new_points: sampled points data, [B, npoint, nsample, 3+D]
116 | """
117 | B, N, C = xyz.shape
118 | S = npoint
119 | fps_idx = farthest_point_sample(xyz, npoint) # [B, npoint]
120 | torch.cuda.empty_cache()
121 | new_xyz = index_points(xyz, fps_idx)
122 | torch.cuda.empty_cache()
123 | if knn:
124 | dists = square_distance(new_xyz, xyz) # B x npoint x N
125 | idx = dists.argsort()[:, :, :nsample] # B x npoint x K
126 | else:
127 | idx = query_ball_point(radius, nsample, xyz, new_xyz)
128 | torch.cuda.empty_cache()
129 | grouped_xyz = index_points(xyz, idx) # [B, npoint, nsample, C]
130 | torch.cuda.empty_cache()
131 | grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C)
132 | torch.cuda.empty_cache()
133 |
134 | if points is not None:
135 | grouped_points = index_points(points, idx)
136 | new_points = torch.cat([grouped_xyz_norm, grouped_points], dim=-1) # [B, npoint, nsample, C+D]
137 | else:
138 | new_points = grouped_xyz_norm
139 | if returnfps:
140 | return new_xyz, new_points, grouped_xyz, fps_idx
141 | else:
142 | return new_xyz, new_points
143 |
144 |
145 | def sample_and_group_all(xyz, points):
146 | """
147 | Input:
148 | xyz: input points position data, [B, N, 3]
149 | points: input points data, [B, N, D]
150 | Return:
151 | new_xyz: sampled points position data, [B, 1, 3]
152 | new_points: sampled points data, [B, 1, N, 3+D]
153 | """
154 | device = xyz.device
155 | B, N, C = xyz.shape
156 | new_xyz = torch.zeros(B, 1, C).to(device)
157 | grouped_xyz = xyz.view(B, 1, N, C)
158 | if points is not None:
159 | new_points = torch.cat([grouped_xyz, points.view(B, 1, N, -1)], dim=-1)
160 | else:
161 | new_points = grouped_xyz
162 | return new_xyz, new_points
163 |
164 |
165 | class PointNetSetAbstraction(nn.Module):
166 | def __init__(self, npoint, radius, nsample, in_channel, mlp, group_all, knn=False):
167 | super(PointNetSetAbstraction, self).__init__()
168 | self.npoint = npoint
169 | self.radius = radius
170 | self.nsample = nsample
171 | self.knn = knn
172 | self.mlp_convs = nn.ModuleList()
173 | self.mlp_bns = nn.ModuleList()
174 | last_channel = in_channel
175 | for out_channel in mlp:
176 | self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1))
177 | self.mlp_bns.append(nn.BatchNorm2d(out_channel))
178 | last_channel = out_channel
179 | self.group_all = group_all
180 |
181 | def forward(self, xyz, points):
182 | """
183 | Input:
184 | xyz: input points position data, [B, N, C]
185 | points: input points data, [B, N, C]
186 | Return:
187 | new_xyz: sampled points position data, [B, S, C]
188 | new_points_concat: sample points feature data, [B, S, D']
189 | """
190 | if self.group_all:
191 | new_xyz, new_points = sample_and_group_all(xyz, points)
192 | else:
193 | new_xyz, new_points = sample_and_group(self.npoint, self.radius, self.nsample, xyz, points, knn=self.knn)
194 | # new_xyz: sampled points position data, [B, npoint, C]
195 | # new_points: sampled points data, [B, npoint, nsample, C+D]
196 | new_points = new_points.permute(0, 3, 2, 1) # [B, C+D, nsample,npoint]
197 | for i, conv in enumerate(self.mlp_convs):
198 | bn = self.mlp_bns[i]
199 | new_points = F.relu(bn(conv(new_points)))
200 |
201 | new_points = torch.max(new_points, 2)[0].transpose(1, 2)
202 | return new_xyz, new_points
203 |
204 |
205 | class PointNetSetAbstractionMsg(nn.Module):
206 | def __init__(self, npoint, radius_list, nsample_list, in_channel, mlp_list, knn=False):
207 | super(PointNetSetAbstractionMsg, self).__init__()
208 | self.npoint = npoint
209 | self.radius_list = radius_list
210 | self.nsample_list = nsample_list
211 | self.knn = knn
212 | self.conv_blocks = nn.ModuleList()
213 | self.bn_blocks = nn.ModuleList()
214 | for i in range(len(mlp_list)):
215 | convs = nn.ModuleList()
216 | bns = nn.ModuleList()
217 | last_channel = in_channel + 3
218 | for out_channel in mlp_list[i]:
219 | convs.append(nn.Conv2d(last_channel, out_channel, 1))
220 | bns.append(nn.BatchNorm2d(out_channel))
221 | last_channel = out_channel
222 | self.conv_blocks.append(convs)
223 | self.bn_blocks.append(bns)
224 |
225 | def forward(self, xyz, points, seed_idx=None):
226 | """
227 | Input:
228 | xyz: input points position data, [B, C, N]
229 | points: input points data, [B, D, N]
230 | Return:
231 | new_xyz: sampled points position data, [B, C, S]
232 | new_points_concat: sample points feature data, [B, D', S]
233 | """
234 |
235 | B, N, C = xyz.shape
236 | S = self.npoint
237 | new_xyz = index_points(xyz, farthest_point_sample(xyz, S) if seed_idx is None else seed_idx)
238 | new_points_list = []
239 | for i, radius in enumerate(self.radius_list):
240 | K = self.nsample_list[i]
241 | if self.knn:
242 | dists = square_distance(new_xyz, xyz) # B x npoint x N
243 | group_idx = dists.argsort()[:, :, :K] # B x npoint x K
244 | else:
245 | group_idx = query_ball_point(radius, K, xyz, new_xyz)
246 | grouped_xyz = index_points(xyz, group_idx)
247 | grouped_xyz -= new_xyz.view(B, S, 1, C)
248 | if points is not None:
249 | grouped_points = index_points(points, group_idx)
250 | grouped_points = torch.cat([grouped_points, grouped_xyz], dim=-1)
251 | else:
252 | grouped_points = grouped_xyz
253 |
254 | grouped_points = grouped_points.permute(0, 3, 2, 1) # [B, D, K, S]
255 | for j in range(len(self.conv_blocks[i])):
256 | conv = self.conv_blocks[i][j]
257 | bn = self.bn_blocks[i][j]
258 | grouped_points = F.relu(bn(conv(grouped_points)))
259 | new_points = torch.max(grouped_points, 2)[0] # [B, D', S]
260 | new_points_list.append(new_points)
261 |
262 | new_points_concat = torch.cat(new_points_list, dim=1).transpose(1, 2)
263 | return new_xyz, new_points_concat
264 |
265 |
266 | # NoteL this function swaps N and C
267 | class PointNetFeaturePropagation(nn.Module):
268 | def __init__(self, in_channel, mlp):
269 | super(PointNetFeaturePropagation, self).__init__()
270 | self.mlp_convs = nn.ModuleList()
271 | self.mlp_bns = nn.ModuleList()
272 | last_channel = in_channel
273 | for out_channel in mlp:
274 | self.mlp_convs.append(nn.Conv1d(last_channel, out_channel, 1))
275 | self.mlp_bns.append(nn.BatchNorm1d(out_channel))
276 | last_channel = out_channel
277 |
278 | def forward(self, xyz1, xyz2, points1, points2):
279 | """
280 | Input:
281 | xyz1: input points position data, [B, C, N]
282 | xyz2: sampled input points position data, [B, C, S]
283 | points1: input points data, [B, D, N]
284 | points2: input points data, [B, D, S]
285 | Return:
286 | new_points: upsampled points data, [B, D', N]
287 | """
288 | xyz1 = xyz1.permute(0, 2, 1)
289 | xyz2 = xyz2.permute(0, 2, 1)
290 |
291 | points2 = points2.permute(0, 2, 1)
292 | B, N, C = xyz1.shape
293 | _, S, _ = xyz2.shape
294 |
295 | if S == 1:
296 | interpolated_points = points2.repeat(1, N, 1)
297 | else:
298 | dists = square_distance(xyz1, xyz2)
299 | dists, idx = dists.sort(dim=-1)
300 | dists, idx = dists[:, :, :3], idx[:, :, :3] # [B, N, 3]
301 |
302 | dist_recip = 1.0 / (dists + 1e-8)
303 | norm = torch.sum(dist_recip, dim=2, keepdim=True)
304 | weight = dist_recip / norm
305 | interpolated_points = torch.sum(index_points(points2, idx) * weight.view(B, N, 3, 1), dim=2)
306 |
307 | if points1 is not None:
308 | points1 = points1.permute(0, 2, 1)
309 | new_points = torch.cat([points1, interpolated_points], dim=-1)
310 | else:
311 | new_points = interpolated_points
312 |
313 | new_points = new_points.permute(0, 2, 1)
314 | for i, conv in enumerate(self.mlp_convs):
315 | bn = self.mlp_bns[i]
316 | new_points = F.relu(bn(conv(new_points)))
317 | return new_points
--------------------------------------------------------------------------------
/src/structformer/utils/rotation_continuity.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.autograd import Variable
4 | import numpy as np
5 |
6 | # Code adapted from the rotation continuity repo (https://github.com/papagina/RotationContinuity)
7 |
8 | #T_poses num*3
9 | #r_matrix batch*3*3
10 | def compute_pose_from_rotation_matrix(T_pose, r_matrix):
11 | batch=r_matrix.shape[0]
12 | joint_num = T_pose.shape[0]
13 | r_matrices = r_matrix.view(batch,1, 3,3).expand(batch,joint_num, 3,3).contiguous().view(batch*joint_num,3,3)
14 | src_poses = T_pose.view(1,joint_num,3,1).expand(batch,joint_num,3,1).contiguous().view(batch*joint_num,3,1)
15 |
16 | out_poses = torch.matmul(r_matrices, src_poses) #(batch*joint_num)*3*1
17 |
18 | return out_poses.view(batch, joint_num, 3)
19 |
20 | # batch*n
21 | def normalize_vector( v, return_mag =False):
22 | batch=v.shape[0]
23 | v_mag = torch.sqrt(v.pow(2).sum(1))# batch
24 | v_mag = torch.max(v_mag, torch.autograd.Variable(torch.FloatTensor([1e-8]).cuda()))
25 | v_mag = v_mag.view(batch,1).expand(batch,v.shape[1])
26 | v = v/v_mag
27 | if(return_mag==True):
28 | return v, v_mag[:,0]
29 | else:
30 | return v
31 |
32 | # u, v batch*n
33 | def cross_product( u, v):
34 | batch = u.shape[0]
35 | #print (u.shape)
36 | #print (v.shape)
37 | i = u[:,1]*v[:,2] - u[:,2]*v[:,1]
38 | j = u[:,2]*v[:,0] - u[:,0]*v[:,2]
39 | k = u[:,0]*v[:,1] - u[:,1]*v[:,0]
40 |
41 | out = torch.cat((i.view(batch,1), j.view(batch,1), k.view(batch,1)),1)#batch*3
42 |
43 | return out
44 |
45 |
46 | #poses batch*6
47 | #poses
48 | def compute_rotation_matrix_from_ortho6d(ortho6d):
49 | x_raw = ortho6d[:,0:3]#batch*3
50 | y_raw = ortho6d[:,3:6]#batch*3
51 |
52 | x = normalize_vector(x_raw) #batch*3
53 | z = cross_product(x,y_raw) #batch*3
54 | z = normalize_vector(z)#batch*3
55 | y = cross_product(z,x)#batch*3
56 |
57 | x = x.view(-1,3,1)
58 | y = y.view(-1,3,1)
59 | z = z.view(-1,3,1)
60 | matrix = torch.cat((x,y,z), 2) #batch*3*3
61 | return matrix
62 |
63 |
64 | #in batch*6
65 | #out batch*5
66 | def stereographic_project(a):
67 | dim = a.shape[1]
68 | a = normalize_vector(a)
69 | out = a[:,0:dim-1]/(1-a[:,dim-1])
70 | return out
71 |
72 |
73 |
74 | #in a batch*5, axis int
75 | def stereographic_unproject(a, axis=None):
76 | """
77 | Inverse of stereographic projection: increases dimension by one.
78 | """
79 | batch=a.shape[0]
80 | if axis is None:
81 | axis = a.shape[1]
82 | s2 = torch.pow(a,2).sum(1) #batch
83 | ans = torch.autograd.Variable(torch.zeros(batch, a.shape[1]+1).cuda()) #batch*6
84 | unproj = 2*a/(s2+1).view(batch,1).repeat(1,a.shape[1]) #batch*5
85 | if(axis>0):
86 | ans[:,:axis] = unproj[:,:axis] #batch*(axis-0)
87 | ans[:,axis] = (s2-1)/(s2+1) #batch
88 | ans[:,axis+1:] = unproj[:,axis:] #batch*(5-axis) # Note that this is a no-op if the default option (last axis) is used
89 | return ans
90 |
91 |
92 | #a batch*5
93 | #out batch*3*3
94 | def compute_rotation_matrix_from_ortho5d(a):
95 | batch = a.shape[0]
96 | proj_scale_np = np.array([np.sqrt(2)+1, np.sqrt(2)+1, np.sqrt(2)]) #3
97 | proj_scale = torch.autograd.Variable(torch.FloatTensor(proj_scale_np).cuda()).view(1,3).repeat(batch,1) #batch,3
98 |
99 | u = stereographic_unproject(a[:, 2:5] * proj_scale, axis=0)#batch*4
100 | norm = torch.sqrt(torch.pow(u[:,1:],2).sum(1)) #batch
101 | u = u/ norm.view(batch,1).repeat(1,u.shape[1]) #batch*4
102 | b = torch.cat((a[:,0:2], u),1)#batch*6
103 | matrix = compute_rotation_matrix_from_ortho6d(b)
104 | return matrix
105 |
106 |
107 | #quaternion batch*4
108 | def compute_rotation_matrix_from_quaternion( quaternion):
109 | batch=quaternion.shape[0]
110 |
111 |
112 | quat = normalize_vector(quaternion).contiguous()
113 |
114 | qw = quat[...,0].contiguous().view(batch, 1)
115 | qx = quat[...,1].contiguous().view(batch, 1)
116 | qy = quat[...,2].contiguous().view(batch, 1)
117 | qz = quat[...,3].contiguous().view(batch, 1)
118 |
119 | # Unit quaternion rotation matrices computatation
120 | xx = qx*qx
121 | yy = qy*qy
122 | zz = qz*qz
123 | xy = qx*qy
124 | xz = qx*qz
125 | yz = qy*qz
126 | xw = qx*qw
127 | yw = qy*qw
128 | zw = qz*qw
129 |
130 | row0 = torch.cat((1-2*yy-2*zz, 2*xy - 2*zw, 2*xz + 2*yw), 1) #batch*3
131 | row1 = torch.cat((2*xy+ 2*zw, 1-2*xx-2*zz, 2*yz-2*xw ), 1) #batch*3
132 | row2 = torch.cat((2*xz-2*yw, 2*yz+2*xw, 1-2*xx-2*yy), 1) #batch*3
133 |
134 | matrix = torch.cat((row0.view(batch, 1, 3), row1.view(batch,1,3), row2.view(batch,1,3)),1) #batch*3*3
135 |
136 | return matrix
137 |
138 | #axisAngle batch*4 angle, x,y,z
139 | def compute_rotation_matrix_from_axisAngle( axisAngle):
140 | batch = axisAngle.shape[0]
141 |
142 | theta = torch.tanh(axisAngle[:,0])*np.pi #[-180, 180]
143 | sin = torch.sin(theta*0.5)
144 | axis = normalize_vector(axisAngle[:,1:4]) #batch*3
145 | qw = torch.cos(theta*0.5)
146 | qx = axis[:,0]*sin
147 | qy = axis[:,1]*sin
148 | qz = axis[:,2]*sin
149 |
150 | # Unit quaternion rotation matrices computatation
151 | xx = (qx*qx).view(batch,1)
152 | yy = (qy*qy).view(batch,1)
153 | zz = (qz*qz).view(batch,1)
154 | xy = (qx*qy).view(batch,1)
155 | xz = (qx*qz).view(batch,1)
156 | yz = (qy*qz).view(batch,1)
157 | xw = (qx*qw).view(batch,1)
158 | yw = (qy*qw).view(batch,1)
159 | zw = (qz*qw).view(batch,1)
160 |
161 | row0 = torch.cat((1-2*yy-2*zz, 2*xy - 2*zw, 2*xz + 2*yw), 1) #batch*3
162 | row1 = torch.cat((2*xy+ 2*zw, 1-2*xx-2*zz, 2*yz-2*xw ), 1) #batch*3
163 | row2 = torch.cat((2*xz-2*yw, 2*yz+2*xw, 1-2*xx-2*yy), 1) #batch*3
164 |
165 | matrix = torch.cat((row0.view(batch, 1, 3), row1.view(batch,1,3), row2.view(batch,1,3)),1) #batch*3*3
166 |
167 | return matrix
168 |
169 | #axisAngle batch*3 (x,y,z)*theta
170 | def compute_rotation_matrix_from_Rodriguez( rod):
171 | batch = rod.shape[0]
172 |
173 | axis, theta = normalize_vector(rod, return_mag=True)
174 |
175 | sin = torch.sin(theta)
176 |
177 |
178 | qw = torch.cos(theta)
179 | qx = axis[:,0]*sin
180 | qy = axis[:,1]*sin
181 | qz = axis[:,2]*sin
182 |
183 | # Unit quaternion rotation matrices computatation
184 | xx = (qx*qx).view(batch,1)
185 | yy = (qy*qy).view(batch,1)
186 | zz = (qz*qz).view(batch,1)
187 | xy = (qx*qy).view(batch,1)
188 | xz = (qx*qz).view(batch,1)
189 | yz = (qy*qz).view(batch,1)
190 | xw = (qx*qw).view(batch,1)
191 | yw = (qy*qw).view(batch,1)
192 | zw = (qz*qw).view(batch,1)
193 |
194 | row0 = torch.cat((1-2*yy-2*zz, 2*xy - 2*zw, 2*xz + 2*yw), 1) #batch*3
195 | row1 = torch.cat((2*xy+ 2*zw, 1-2*xx-2*zz, 2*yz-2*xw ), 1) #batch*3
196 | row2 = torch.cat((2*xz-2*yw, 2*yz+2*xw, 1-2*xx-2*yy), 1) #batch*3
197 |
198 | matrix = torch.cat((row0.view(batch, 1, 3), row1.view(batch,1,3), row2.view(batch,1,3)),1) #batch*3*3
199 |
200 | return matrix
201 |
202 | #axisAngle batch*3 a,b,c
203 | def compute_rotation_matrix_from_hopf( hopf):
204 | batch = hopf.shape[0]
205 |
206 | theta = (torch.tanh(hopf[:,0])+1.0)*np.pi/2.0 #[0, pi]
207 | phi = (torch.tanh(hopf[:,1])+1.0)*np.pi #[0,2pi)
208 | tao = (torch.tanh(hopf[:,2])+1.0)*np.pi #[0,2pi)
209 |
210 | qw = torch.cos(theta/2)*torch.cos(tao/2)
211 | qx = torch.cos(theta/2)*torch.sin(tao/2)
212 | qy = torch.sin(theta/2)*torch.cos(phi+tao/2)
213 | qz = torch.sin(theta/2)*torch.sin(phi+tao/2)
214 |
215 | # Unit quaternion rotation matrices computatation
216 | xx = (qx*qx).view(batch,1)
217 | yy = (qy*qy).view(batch,1)
218 | zz = (qz*qz).view(batch,1)
219 | xy = (qx*qy).view(batch,1)
220 | xz = (qx*qz).view(batch,1)
221 | yz = (qy*qz).view(batch,1)
222 | xw = (qx*qw).view(batch,1)
223 | yw = (qy*qw).view(batch,1)
224 | zw = (qz*qw).view(batch,1)
225 |
226 | row0 = torch.cat((1-2*yy-2*zz, 2*xy - 2*zw, 2*xz + 2*yw), 1) #batch*3
227 | row1 = torch.cat((2*xy+ 2*zw, 1-2*xx-2*zz, 2*yz-2*xw ), 1) #batch*3
228 | row2 = torch.cat((2*xz-2*yw, 2*yz+2*xw, 1-2*xx-2*yy), 1) #batch*3
229 |
230 | matrix = torch.cat((row0.view(batch, 1, 3), row1.view(batch,1,3), row2.view(batch,1,3)),1) #batch*3*3
231 |
232 | return matrix
233 |
234 |
235 | #euler batch*4
236 | #output cuda batch*3*3 matrices in the rotation order of XZ'Y'' (intrinsic) or YZX (extrinsic)
237 | def compute_rotation_matrix_from_euler(euler):
238 | batch=euler.shape[0]
239 |
240 | c1=torch.cos(euler[:,0]).view(batch,1)#batch*1
241 | s1=torch.sin(euler[:,0]).view(batch,1)#batch*1
242 | c2=torch.cos(euler[:,2]).view(batch,1)#batch*1
243 | s2=torch.sin(euler[:,2]).view(batch,1)#batch*1
244 | c3=torch.cos(euler[:,1]).view(batch,1)#batch*1
245 | s3=torch.sin(euler[:,1]).view(batch,1)#batch*1
246 |
247 | row1=torch.cat((c2*c3, -s2, c2*s3 ), 1).view(-1,1,3) #batch*1*3
248 | row2=torch.cat((c1*s2*c3+s1*s3, c1*c2, c1*s2*s3-s1*c3), 1).view(-1,1,3) #batch*1*3
249 | row3=torch.cat((s1*s2*c3-c1*s3, s1*c2, s1*s2*s3+c1*c3), 1).view(-1,1,3) #batch*1*3
250 |
251 | matrix = torch.cat((row1, row2, row3), 1) #batch*3*3
252 |
253 |
254 | return matrix
255 |
256 |
257 | #euler_sin_cos batch*6
258 | #output cuda batch*3*3 matrices in the rotation order of XZ'Y'' (intrinsic) or YZX (extrinsic)
259 | def compute_rotation_matrix_from_euler_sin_cos(euler_sin_cos):
260 | batch=euler_sin_cos.shape[0]
261 |
262 | s1 = euler_sin_cos[:,0].view(batch,1)
263 | c1 = euler_sin_cos[:,1].view(batch,1)
264 | s2 = euler_sin_cos[:,2].view(batch,1)
265 | c2 = euler_sin_cos[:,3].view(batch,1)
266 | s3 = euler_sin_cos[:,4].view(batch,1)
267 | c3 = euler_sin_cos[:,5].view(batch,1)
268 |
269 |
270 | row1=torch.cat((c2*c3, -s2, c2*s3 ), 1).view(-1,1,3) #batch*1*3
271 | row2=torch.cat((c1*s2*c3+s1*s3, c1*c2, c1*s2*s3-s1*c3), 1).view(-1,1,3) #batch*1*3
272 | row3=torch.cat((s1*s2*c3-c1*s3, s1*c2, s1*s2*s3+c1*c3), 1).view(-1,1,3) #batch*1*3
273 |
274 | matrix = torch.cat((row1, row2, row3), 1) #batch*3*3
275 |
276 |
277 | return matrix
278 |
279 |
280 | #matrices batch*3*3
281 | #both matrix are orthogonal rotation matrices
282 | #out theta between 0 to 180 degree batch
283 | def compute_geodesic_distance_from_two_matrices(m1, m2):
284 | batch=m1.shape[0]
285 | m = torch.bmm(m1, m2.transpose(1,2)) #batch*3*3
286 |
287 | cos = ( m[:,0,0] + m[:,1,1] + m[:,2,2] - 1 )/2
288 | cos = torch.min(cos, torch.autograd.Variable(torch.ones(batch).cuda()) )
289 | cos = torch.max(cos, torch.autograd.Variable(torch.ones(batch).cuda())*-1 )
290 |
291 |
292 | theta = torch.acos(cos)
293 |
294 | #theta = torch.min(theta, 2*np.pi - theta)
295 |
296 |
297 | return theta
298 |
299 |
300 | #matrices batch*3*3
301 | #both matrix are orthogonal rotation matrices
302 | #out theta between 0 to 180 degree batch
303 | def compute_angle_from_r_matrices(m):
304 |
305 | batch=m.shape[0]
306 |
307 | cos = ( m[:,0,0] + m[:,1,1] + m[:,2,2] - 1 )/2
308 | cos = torch.min(cos, torch.autograd.Variable(torch.ones(batch).cuda()) )
309 | cos = torch.max(cos, torch.autograd.Variable(torch.ones(batch).cuda())*-1 )
310 |
311 | theta = torch.acos(cos)
312 |
313 | return theta
314 |
315 | def get_sampled_rotation_matrices_by_quat(batch):
316 | #quat = torch.autograd.Variable(torch.rand(batch,4).cuda())
317 | quat = torch.autograd.Variable(torch.randn(batch, 4).cuda())
318 | matrix = compute_rotation_matrix_from_quaternion(quat)
319 | return matrix
320 |
321 | def get_sampled_rotation_matrices_by_hpof(batch):
322 |
323 | theta = torch.autograd.Variable(torch.FloatTensor(np.random.uniform(0,1, batch)*np.pi).cuda()) #[0, pi]
324 | phi = torch.autograd.Variable(torch.FloatTensor(np.random.uniform(0,2,batch)*np.pi).cuda()) #[0,2pi)
325 | tao = torch.autograd.Variable(torch.FloatTensor(np.random.uniform(0,2,batch)*np.pi).cuda()) #[0,2pi)
326 |
327 |
328 | qw = torch.cos(theta/2)*torch.cos(tao/2)
329 | qx = torch.cos(theta/2)*torch.sin(tao/2)
330 | qy = torch.sin(theta/2)*torch.cos(phi+tao/2)
331 | qz = torch.sin(theta/2)*torch.sin(phi+tao/2)
332 |
333 | # Unit quaternion rotation matrices computatation
334 | xx = (qx*qx).view(batch,1)
335 | yy = (qy*qy).view(batch,1)
336 | zz = (qz*qz).view(batch,1)
337 | xy = (qx*qy).view(batch,1)
338 | xz = (qx*qz).view(batch,1)
339 | yz = (qy*qz).view(batch,1)
340 | xw = (qx*qw).view(batch,1)
341 | yw = (qy*qw).view(batch,1)
342 | zw = (qz*qw).view(batch,1)
343 |
344 | row0 = torch.cat((1-2*yy-2*zz, 2*xy - 2*zw, 2*xz + 2*yw), 1) #batch*3
345 | row1 = torch.cat((2*xy+ 2*zw, 1-2*xx-2*zz, 2*yz-2*xw ), 1) #batch*3
346 | row2 = torch.cat((2*xz-2*yw, 2*yz+2*xw, 1-2*xx-2*yy), 1) #batch*3
347 |
348 | matrix = torch.cat((row0.view(batch, 1, 3), row1.view(batch,1,3), row2.view(batch,1,3)),1) #batch*3*3
349 |
350 | return matrix
351 |
352 | #axisAngle batch*4 angle, x,y,z
353 | def get_sampled_rotation_matrices_by_axisAngle( batch, return_quaternion=False):
354 |
355 | theta = torch.autograd.Variable(torch.FloatTensor(np.random.uniform(-1,1, batch)*np.pi).cuda()) #[0, pi] #[-180, 180]
356 | sin = torch.sin(theta)
357 | axis = torch.autograd.Variable(torch.randn(batch, 3).cuda())
358 | axis = normalize_vector(axis) #batch*3
359 | qw = torch.cos(theta)
360 | qx = axis[:,0]*sin
361 | qy = axis[:,1]*sin
362 | qz = axis[:,2]*sin
363 |
364 | quaternion = torch.cat((qw.view(batch,1), qx.view(batch,1), qy.view(batch,1), qz.view(batch,1)), 1 )
365 |
366 | # Unit quaternion rotation matrices computatation
367 | xx = (qx*qx).view(batch,1)
368 | yy = (qy*qy).view(batch,1)
369 | zz = (qz*qz).view(batch,1)
370 | xy = (qx*qy).view(batch,1)
371 | xz = (qx*qz).view(batch,1)
372 | yz = (qy*qz).view(batch,1)
373 | xw = (qx*qw).view(batch,1)
374 | yw = (qy*qw).view(batch,1)
375 | zw = (qz*qw).view(batch,1)
376 |
377 | row0 = torch.cat((1-2*yy-2*zz, 2*xy - 2*zw, 2*xz + 2*yw), 1) #batch*3
378 | row1 = torch.cat((2*xy+ 2*zw, 1-2*xx-2*zz, 2*yz-2*xw ), 1) #batch*3
379 | row2 = torch.cat((2*xz-2*yw, 2*yz+2*xw, 1-2*xx-2*yy), 1) #batch*3
380 |
381 | matrix = torch.cat((row0.view(batch, 1, 3), row1.view(batch,1,3), row2.view(batch,1,3)),1) #batch*3*3
382 |
383 | if(return_quaternion==True):
384 | return matrix, quaternion
385 | else:
386 | return matrix
387 |
388 |
389 |
390 |
391 |
392 |
393 |
394 |
395 |
396 |
--------------------------------------------------------------------------------