├── .gitignore
├── LICENSE
├── README.md
├── configs
├── equigraspflow_full.yml
└── equigraspflow_partial.yml
├── environment.yml
├── images
└── generation.gif
├── loaders
├── __init__.py
├── acronym.py
└── utils.py
├── losses
├── __init__.py
└── mse_loss.py
├── metrics
├── __init__.py
└── emd.py
├── models
├── __init__.py
├── equi_grasp_flow.py
├── vn_dgcnn.py
├── vn_layers.py
└── vn_vector_fields.py
├── test_full.py
├── test_partial.py
├── train.py
├── trainers
├── __init__.py
└── grasp_trainer.py
└── utils
├── Lie.py
├── average_meter.py
├── distributions.py
├── logger.py
├── mesh.py
├── ode_solvers.py
├── optimizers.py
├── partial_point_cloud.py
└── visualization.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled
2 | __pycache__/
3 | **/*.pyc
4 |
5 | # Folders
6 | .vscode/
7 | dataset/
8 | train_results/
9 | pretrained_models/
10 | test_results/
11 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2025 Byeongdo Lim
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # EquiGraspFlow: SE(3)-Equivariant 6-DoF Grasp Pose Generative Flows
2 |
3 | The official repository for (Byeongdo Lim, Jongmin Kim, Jihwan Kim, Yonghyeon Lee, and Frank C. Park, CoRL 2024)
4 |
5 | - [Project page](https://equigraspflow.github.io/)
6 | - [Openreview](https://openreview.net/forum?id=5lSkn5v4LK&referrer=%5BAuthor%20Console%5D(%2Fgroup%3Fid%3Drobot-learning.org%2FCoRL%2F2024%2FConference%2FAuthors%23your-submissions))
7 | - [Paper](https://openreview.net/pdf?id=5lSkn5v4LK)
8 | - [Video](https://youtu.be/fxOveMwugo4?si=L1bmYNOMPbCHY1Cr)
9 | - [Poster](https://drive.google.com/file/d/1UTBoNDDT7FzHcXHSrFDA6x4v5hr3-g51/view?usp=sharing)
10 |
11 |
12 |
13 |
14 |
15 |
16 | ## Requirements
17 |
18 | ### Conda environment
19 |
20 | You can create a Conda environment using the following command.
21 | You can customize the environment name by modifying the `name` field in the `environment.yml` file.
22 |
23 | ```bash
24 | conda env create -f environment.yml
25 | ```
26 |
27 | This will automatically install the required packages, including:
28 |
29 | - `python==3.10`
30 | - `omegaconf`
31 | - `tensorboardX`
32 | - `pyyaml`
33 | - `numpy==1.26`
34 | - `torch`
35 | - `scipy`
36 | - `tqdm`
37 | - `h5py`
38 | - `open3d==0.16.0`
39 | - `roma`
40 | - `pandas`
41 | - `openypyxl`
42 |
43 | To activate the environment, use:
44 |
45 | ```bash
46 | conda activate equigraspflow
47 | ```
48 |
49 |
50 | ### Dataset
51 |
52 | We use the Laptop, Mug, Bowl, and Pencil categories of the ACRONYM dataset [1].
53 | The dataset can be downloaded from [this link](https://drive.google.com/drive/folders/1H1PeUbyxvNtzoWc6Le2pKqOqp2WLSnau?usp=drive_link).
54 | Create a `dataset` directory and place the data in that directory, or customize the path to the dataset by modifying `DATASET_DIR` in `acronym.py` and `utils.py` within the `loaders` directory.
55 |
56 |
57 | ## Training
58 |
59 | ### Train a new model
60 |
61 | The training script is `train.py`, and comes with the following arguments:
62 |
63 | - `--config`: Path to the training configuration YAML file.
64 | - `--device`: GPU number to use (default: `0`). Use `cpu` to run on CPU.
65 | - `--logdir`: Directory where the results will be saved (default: `train_results`).
66 | - `--run`: Name for the training session (default: `{date}-{time}`).
67 |
68 | To train EquiGraspFlow using the full point cloud, run:
69 |
70 | ```bash
71 | python train.py --config configs/equigraspflow_full.yml
72 | ```
73 |
74 | Alternatively, to train EquiGraspFlow with the partial point cloud, use:
75 |
76 | ```bash
77 | python train.py --config configs/equigraspflow_partial.yml
78 | ```
79 |
80 | Note: Training with the partial point cloud cannot be done in headless mode; a display is required.
81 |
82 | You can change the data augmentation strategy for each data split by modifying the `augmentation` field in the training configuration YAML file.
83 |
84 |
85 | ### View training results
86 |
87 | We log the results of the training process using TensorBoard. You can view the TensorBoard results by running:
88 |
89 | ```bash
90 | tensorboard --logdir {path} --host {IP_address}
91 | ```
92 | Replace `path` with the specific path to your training results and `IP_address` with your IP address.
93 |
94 |
95 | ## Pretrained models
96 |
97 | The pretrained models can be downloaded from [this link](https://drive.google.com/drive/folders/1H-MXRVcTekdEfzXU_suSw7Afi-7o8I39?usp=sharing).
98 |
99 |
100 | ## Test
101 |
102 | ### Run test
103 |
104 | The test scripts, `test_full.py` and `test_partial.py`, calculate the Earth Mover's Distance [2] between the generated and ground-truth grasp poses and store the visualizations of the generated grasp poses.
105 | It has the following arguments:
106 |
107 | - `--train_result_path`: Path to the directory containing training results.
108 | - `--checkpoint`: Model checkpoint to use.
109 | - `--device`: GPU number to use (default: `0`). Use `cpu` to run on CPU.
110 | - `--logdir`: Directory where the results will be saved (default: `test_results`).
111 | - `--run`: Name for the experiment (default: `{date}-{time}`).
112 |
113 | For example, to test EquiGraspFlow using the full point cloud with the `model_best_val_loss.pkl` checkpoint in `pretrained_model/equigraspflow_full` directory, use:
114 |
115 | ```bash
116 | python test_full.py --train_result_path train_results/equigraspflow_full --checkpoint model_best_val_loss.pkl
117 | ```
118 |
119 | Alternatively, to test EquiGraspFlow using the partial point cloud with the `model_best_val_loss.pkl` checkpoint in `pretrained_model/equigraspflow_partial` directory, use:
120 |
121 | ```bash
122 | python test_partial.py --train_result_path train_results/equigraspflow_partial --checkpoint model_best_val_loss.pkl
123 | ```
124 |
125 |
126 | ### Display visualizations
127 |
128 | The visualizations of the generated grasp poses are stored in `visualizations.json` within the test results directory.
129 | To display these visualizations, use the following code:
130 |
131 | ```python
132 | import plotly.io as pio
133 |
134 | pio.from_json(open('{path}/visualizations.json', 'r').read()).show()
135 | ```
136 |
137 | Replace `path` with your test results directory.
138 |
139 |
140 | ## References
141 |
142 | [1] C. Eppner, A. Mousavian, and D. Fox. Acronym: A large-scale grasp dataset based on simulation, ICRA 2021. [[paper](https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9560844&casa_token=VAlWdJNx458AAAAA:z3KlV9ALMjYG34RNbCVmUPEPlFkS6b7NIty76glWYuMbn3XwXpTtmrTV2PnmzhrGr_5QN_jQpg&tag=1)]
143 |
144 | [2] A. Tanaka. Discriminator optimal transport. NeurIPS 2019. [[paper](https://proceedings.neurips.cc/paper/2019/file/8abfe8ac9ec214d68541fcb888c0b4c3-Paper.pdf)]
145 |
146 |
147 | ## Citation
148 | If you found this repository useful in your research, please cite:
149 |
150 | ```text
151 | @inproceedings{lim2024equigraspflow,
152 | title={EquiGraspFlow: SE (3)-Equivariant 6-DoF Grasp Pose Generative Flows},
153 | author={Lim, Byeongdo and Kim, Jongmin and Kim, Jihwan and Lee, Yonghyeon and Park, Frank C},
154 | booktitle={8th Annual Conference on Robot Learning},
155 | year={2024}
156 | }
157 | ```
158 |
--------------------------------------------------------------------------------
/configs/equigraspflow_full.yml:
--------------------------------------------------------------------------------
1 | data:
2 | train:
3 | dataset:
4 | name: full
5 | obj_types: [Laptop, Mug, Bowl, Pencil]
6 | augmentation: SO3
7 | scale: 8
8 | batch_size: 4
9 | num_workers: 8
10 | val:
11 | dataset:
12 | name: full
13 | obj_types: [Laptop, Mug, Bowl, Pencil]
14 | augmentation: SO3
15 | scale: 8
16 | num_rots: 3
17 | batch_size: 4
18 | num_workers: 8
19 | test:
20 | dataset:
21 | name: full
22 | obj_types: [Laptop, Mug, Bowl, Pencil]
23 | augmentation: SO3
24 | scale: 8
25 | num_rots: 3
26 | batch_size: 4
27 | num_workers: 8
28 | model:
29 | name: equigraspflow
30 | p_uncond: 0.2
31 | guidance: 2.0
32 | init_dist:
33 | name: SO3_uniform_R3_normal
34 | encoder:
35 | name: vn_dgcnn_enc
36 | num_neighbors: 40
37 | dims: [1, 21, 21, 42, 85, 170, 341]
38 | use_bn: False
39 | vector_field:
40 | name: vn_vf
41 | dims: [346, 256, 256, 128, 128, 128, 2]
42 | use_bn: False
43 | ode_solver:
44 | name: SE3_RK_mk
45 | num_steps: 20
46 | losses:
47 | - name: mse
48 | optimizer:
49 | name: adam
50 | lr: 0.0001
51 | weight_decay: 1.0e-6
52 | metrics:
53 | - name: emd
54 | type: SE3
55 | trainer:
56 | name: grasp_full
57 | criteria:
58 | - name: emd
59 | better: lower
60 | num_epochs: 40000
61 | print_interval: 100
62 | val_interval: 10000
63 | eval_interval: 100000
64 | vis_interval: 100000
65 | save_interval: 2000000
66 |
--------------------------------------------------------------------------------
/configs/equigraspflow_partial.yml:
--------------------------------------------------------------------------------
1 | data:
2 | train:
3 | dataset:
4 | name: partial
5 | obj_types: [Laptop, Mug, Bowl, Pencil]
6 | augmentation: SO3
7 | scale: 8
8 | batch_size: 4
9 | num_workers: 8
10 | val:
11 | dataset:
12 | name: partial
13 | obj_types: [Laptop, Mug, Bowl, Pencil]
14 | augmentation: SO3
15 | scale: 8
16 | num_rots: 3
17 | num_views: 3
18 | batch_size: 4
19 | num_workers: 8
20 | test:
21 | dataset:
22 | name: partial
23 | obj_types: [Laptop, Mug, Bowl, Pencil]
24 | augmentation: SO3
25 | scale: 8
26 | num_rots: 3
27 | num_views: 3
28 | batch_size: 4
29 | num_workers: 8
30 | model:
31 | name: equigraspflow
32 | p_uncond: 0.2
33 | guidance: 1.5
34 | init_dist:
35 | name: SO3_uniform_R3_normal
36 | encoder:
37 | name: vn_dgcnn_enc
38 | num_neighbors: 40
39 | dims: [1, 21, 21, 42, 85, 170, 341]
40 | use_bn: False
41 | vector_field:
42 | name: vn_vf
43 | dims: [346, 256, 256, 128, 128, 128, 2]
44 | use_bn: False
45 | ode_solver:
46 | name: SE3_RK_mk
47 | num_steps: 20
48 | losses:
49 | - name: mse
50 | optimizer:
51 | name: adam
52 | lr: 0.0001
53 | weight_decay: 1.0e-6
54 | metrics:
55 | - name: emd
56 | type: SE3
57 | trainer:
58 | name: grasp_partial
59 | criteria:
60 | - name: emd
61 | better: lower
62 | num_epochs: 40000
63 | print_interval: 100
64 | val_interval: 10000
65 | eval_interval: 100000
66 | vis_interval: 100000
67 | save_interval: 2000000
68 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: equigraspflow
2 | channels:
3 | - defaults
4 | dependencies:
5 | - pip=24.2
6 | - python=3.10
7 | - pip:
8 | - omegaconf
9 | - tensorboardX
10 | - pyyaml
11 | - numpy==1.26
12 | - torch
13 | - scipy
14 | - tqdm
15 | - h5py
16 | - open3d==0.16.0
17 | - roma
18 | - pandas
19 | - openpyxl
20 |
--------------------------------------------------------------------------------
/images/generation.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bdlim99/EquiGraspFlow/f921425f27d0f80cc250a288b0b9cedbb9f61b41/images/generation.gif
--------------------------------------------------------------------------------
/loaders/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from loaders.acronym import AcronymFullPointCloud, AcronymPartialPointCloud
4 |
5 |
6 | def get_dataloader(split, cfg_dataloader):
7 | cfg_dataloader.dataset.split = split
8 |
9 | dataset = get_dataset(cfg_dataloader.dataset)
10 |
11 | dataloader = torch.utils.data.DataLoader(
12 | dataset,
13 | batch_size=cfg_dataloader.batch_size,
14 | shuffle=cfg_dataloader.get('shuffle', True),
15 | num_workers=cfg_dataloader.get('num_workers', 8),
16 | collate_fn=collate_fn
17 | )
18 |
19 | return dataloader
20 |
21 |
22 | def get_dataset(cfg_dataset):
23 | name = cfg_dataset.pop('name')
24 |
25 | if name == 'full':
26 | dataset = AcronymFullPointCloud(**cfg_dataset)
27 | elif name == 'partial':
28 | dataset = AcronymPartialPointCloud(**cfg_dataset)
29 | else:
30 | raise NotImplementedError(f"Dataset {name} not implemented.")
31 |
32 | return dataset
33 |
34 |
35 | def collate_fn(batch_original):
36 | batch_collated = {}
37 |
38 | for key in batch_original[0].keys():
39 | if key == 'Ts_grasp':
40 | batch_collated[key] = [sample[key] for sample in batch_original]
41 | else:
42 | batch_collated[key] = torch.stack([sample[key] for sample in batch_original])
43 |
44 | return batch_collated
45 |
--------------------------------------------------------------------------------
/loaders/acronym.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from scipy.spatial.transform import Rotation
4 | import os
5 | from tqdm import tqdm
6 | import h5py
7 | from copy import deepcopy
8 |
9 | from loaders.utils import load_grasp_poses, load_mesh
10 | from utils.Lie import super_fibonacci_spiral, get_fibonacci_sphere
11 | from utils.partial_point_cloud import get_partial_point_clouds
12 |
13 |
14 | DATASET_DIR = 'dataset'
15 | NUM_GRASPS = 100
16 |
17 |
18 | class AcronymFullPointCloud(torch.utils.data.Dataset):
19 | def __init__(self, split, obj_types, augmentation, scale, num_pts=1024, num_rots=1):
20 | # Initialize
21 | self.len_dataset = 0
22 | self.split = split
23 | self.num_pts = num_pts
24 | self.augmentation = augmentation
25 | self.scale = scale
26 | self.num_rots = num_rots
27 | self.obj_types = obj_types
28 |
29 | # Initialize maximum number of objects
30 | self.max_num_objs = 0
31 |
32 | # Get evenly distributed rotations for validation and test splits
33 | if split in ['val', 'test']:
34 | if augmentation == 'None':
35 | self.Rs = np.expand_dims(np.eye(3), axis=0)
36 | elif augmentation == 'z':
37 | degree = (np.arange(num_rots) / num_rots) * (2 * np.pi)
38 | self.Rs = Rotation.from_rotvec(degree * np.array([0, 0, 1])).as_matrix()
39 | elif augmentation == 'SO3':
40 | self.Rs = super_fibonacci_spiral(num_rots)
41 | else:
42 | raise ValueError("Choose augmentation from ['None', 'z', 'SO3'].")
43 | else:
44 | assert num_rots == 1, "Number of rotations must be 1 in train set."
45 |
46 | self.Rs = np.expand_dims(np.eye(3), axis=0)
47 |
48 | Ts = np.tile(np.eye(4), (num_rots, 1, 1))
49 | Ts[:, :3, :3] = self.Rs
50 |
51 | # Initialize data indices
52 | data_idxs_types = []
53 |
54 | # Initialize lists
55 | self.mesh_list_types = []
56 | self.Ts_grasp_list_types = []
57 |
58 | self.pc_list_types = []
59 |
60 | self.obj_idxs_types = []
61 |
62 | for obj_type in tqdm(obj_types, desc="Iterating object types ...", leave=False):
63 | # Get data filenames
64 | filenames = sorted(os.listdir(os.path.join(DATASET_DIR, 'grasps', obj_type)))
65 |
66 | # Get object indices for the split
67 | obj_idxs = np.load(os.path.join(DATASET_DIR, 'splits', obj_type, f'idxs_{split}.npy'))
68 |
69 | # Initialize data indices
70 | data_idxs_objs = []
71 |
72 | # Initialize lists
73 | mesh_list_objs = []
74 | Ts_grasp_list_objs = []
75 |
76 | pc_list_objs = []
77 |
78 | obj_idxs_objs = []
79 |
80 | for obj_idx in tqdm(obj_idxs, desc="Iterating objects ...", leave=False):
81 | # Get data filename
82 | filename = filenames[obj_idx]
83 |
84 | # Load data
85 | data = h5py.File(os.path.join(DATASET_DIR, 'grasps', obj_type, filename))
86 |
87 | # Load grasp poses
88 | Ts_grasp = load_grasp_poses(data)
89 |
90 | # Continue if grasp poses are not enough
91 | if len(Ts_grasp) < NUM_GRASPS:
92 | continue
93 | else:
94 | obj_idxs_objs += [obj_idx]
95 |
96 | # Load mesh
97 | mesh = load_mesh(data)
98 |
99 | # Scale
100 | mesh.scale(scale, center=(0, 0, 0))
101 | Ts_grasp[:, :3, 3] *= scale
102 |
103 | # Sample point cloud
104 | pc = np.asarray(mesh.sample_points_uniformly(num_pts).points).T
105 |
106 | # Translate to the center of the point cloud
107 | center = pc.mean(axis=1)
108 | mesh.translate(-center)
109 | pc -= np.expand_dims(center, axis=1)
110 | Ts_grasp[:, :3, 3] -= center
111 |
112 | # Rotate mesh
113 | mesh_list_rots = []
114 |
115 | for R in tqdm(self.Rs, desc="Iterating rotations ...", leave=False):
116 | mesh_rot = deepcopy(mesh)
117 | mesh_rot.rotate(R, center=(0, 0, 0))
118 |
119 | mesh_list_rots += [mesh_rot]
120 |
121 | # Rotate the other data
122 | pc_rots = self.Rs @ pc
123 | Ts_grasp_rots = np.einsum('rij,njk->rnik', Ts, Ts_grasp)
124 |
125 | # Fill data indices
126 | data_idxs_objs += [list(range(self.len_dataset, self.len_dataset+num_rots))]
127 |
128 | # Append data
129 | mesh_list_objs += [mesh_list_rots]
130 | Ts_grasp_list_objs += [Ts_grasp_rots]
131 |
132 | pc_list_objs += [pc_rots]
133 |
134 | # Increase number of data
135 | self.len_dataset += num_rots
136 |
137 | # Append data
138 | data_idxs_types += [data_idxs_objs]
139 |
140 | self.mesh_list_types += [mesh_list_objs]
141 | self.Ts_grasp_list_types += [Ts_grasp_list_objs]
142 |
143 | self.pc_list_types += [pc_list_objs]
144 |
145 | self.obj_idxs_types += [obj_idxs_objs]
146 |
147 | # Update maximum number of objects
148 | if len(obj_idxs_objs) > self.max_num_objs:
149 | self.max_num_objs = len(obj_idxs_objs)
150 |
151 | # Convert data indices from lists to numpy array
152 | self.data_idxs = np.full((len(obj_types), self.max_num_objs, num_rots), -1)
153 |
154 | for i, data_idxs_objs in enumerate(data_idxs_types):
155 | self.data_idxs[i, :len(data_idxs_objs)] = data_idxs_objs
156 |
157 | # Setup scene indices
158 | self.scene_idxs = self.data_idxs
159 | self.num_scenes = self.scene_idxs.max() + 1
160 |
161 | def __len__(self):
162 | return self.len_dataset
163 |
164 | def __getitem__(self, idx):
165 | # Get type, object, and rotation indices
166 | idx_type, idx_obj, idx_rot = np.where(self.data_idxs==idx)
167 |
168 | idx_type = idx_type.item()
169 | idx_obj = idx_obj.item()
170 | idx_rot = idx_rot.item()
171 |
172 | # Load grasp poses
173 | Ts_grasp = self.Ts_grasp_list_types[idx_type][idx_obj][idx_rot].copy()
174 |
175 | if self.split == 'train':
176 | # Load mesh
177 | mesh = self.mesh_list_types[idx_type][idx_obj][idx_rot]
178 |
179 | # Sample point cloud
180 | pc = np.asarray(mesh.sample_points_uniformly(self.num_pts).points).T
181 |
182 | # Translate to the point cloud center
183 | center = pc.mean(axis=1)
184 | pc -= np.expand_dims(center, axis=1)
185 | Ts_grasp[:, :3, 3] -= center
186 |
187 | # Rotate data
188 | if self.augmentation != 'None':
189 | if self.augmentation == 'z':
190 | # Randomly rotate around z-axis
191 | degree = np.random.rand() * 2 * np.pi
192 | R = Rotation.from_rotvec(degree * np.array([0, 0, 1])).as_matrix()
193 | elif self.augmentation == 'SO3':
194 | # Randomly rotate
195 | R = Rotation.random().as_matrix()
196 | else:
197 | raise ValueError("Choose augmentation from ['None', 'z', 'SO3'].")
198 |
199 | T = np.eye(4)
200 | T[:3, :3] = R
201 |
202 | pc = R @ pc
203 | Ts_grasp = T @ Ts_grasp
204 | else:
205 | # Load point cloud
206 | pc = self.pc_list_types[idx_type][idx_obj][idx_rot]
207 |
208 | return {'pc': torch.Tensor(pc), 'Ts_grasp': torch.Tensor(Ts_grasp)}
209 |
210 |
211 | class AcronymPartialPointCloud(torch.utils.data.Dataset):
212 | def __init__(self, split, obj_types, augmentation, scale, num_pts=512, num_rots=1, num_views=1):
213 | # Initialize
214 | self.len_dataset = 0
215 | self.split = split
216 | self.num_pts = num_pts
217 | self.augmentation = augmentation
218 | self.scale = scale
219 | self.num_rots = num_rots
220 | self.num_views = num_views
221 | self.obj_types = obj_types
222 |
223 | # Initialize maximum number of objects
224 | self.max_num_objs = 0
225 |
226 | # Get rotations
227 | if split in ['val', 'test']:
228 | if augmentation == 'None':
229 | self.Rs = np.expand_dims(np.eye(3), axis=0)
230 | elif augmentation == 'z':
231 | degree = (np.arange(num_rots) / num_rots) * (2 * np.pi)
232 | self.Rs = Rotation.from_rotvec(degree * np.array([0, 0, 1])).as_matrix()
233 | elif augmentation == 'SO3':
234 | self.Rs = super_fibonacci_spiral(num_rots)
235 | else:
236 | assert augmentation == 'None', "Choose augmentation from ['None', 'z', 'SO3']."
237 | else:
238 | assert num_rots == 1, "Number of rotations must be 1 in train set."
239 |
240 | self.Rs = np.expand_dims(np.eye(3), axis=0)
241 |
242 | Ts = np.tile(np.eye(4), (num_rots, 1, 1))
243 | Ts[:, :3, :3] = self.Rs
244 |
245 | # Get viewpoint vector
246 | if split in ['val', 'test']:
247 | view_vecs = get_fibonacci_sphere(num_views)
248 | else:
249 | assert num_views == 1, "Number of viewpoint vector must be 1 in train set."
250 |
251 | view_vecs = np.array([[0, 0, 1]])
252 |
253 | # Initialize data indices
254 | data_idxs_types = []
255 |
256 | # Initialize lists
257 | self.mesh_list_types = []
258 | self.Ts_grasp_list_types = []
259 |
260 | self.partial_pc_list_types = []
261 |
262 | self.obj_idxs_types = []
263 |
264 | for obj_type in tqdm(obj_types, desc="Iterating object types ...", leave=False):
265 | # Get data filenames
266 | filenames = sorted(os.listdir(os.path.join(DATASET_DIR, 'grasps', obj_type)))
267 |
268 | # Get object indices for the split
269 | obj_idxs = np.load(os.path.join(DATASET_DIR, 'splits', obj_type, f'idxs_{split}.npy'))
270 |
271 | # Initialize data indices
272 | data_idxs_objs = []
273 |
274 | # Initialize lists
275 | mesh_list_objs = []
276 | Ts_grasp_list_objs = []
277 |
278 | partial_pc_list_objs = []
279 |
280 | obj_idxs_objs = []
281 |
282 | for obj_idx in tqdm(obj_idxs, desc="Iterating objects ...", leave=False):
283 | # Get data filename
284 | filename = filenames[obj_idx]
285 |
286 | # Load data
287 | data = h5py.File(os.path.join(DATASET_DIR, 'grasps', obj_type, filename))
288 |
289 | # Load grasp poses
290 | Ts_grasp = load_grasp_poses(data)
291 |
292 | # Continue if grasp poses are not enough
293 | if len(Ts_grasp) < NUM_GRASPS:
294 | continue
295 | else:
296 | obj_idxs_objs += [obj_idx]
297 |
298 | # Load mesh
299 | mesh = load_mesh(data)
300 |
301 | # Translate to the center of the mesh
302 | center = mesh.get_center()
303 | mesh.translate(-center)
304 | Ts_grasp[:, :3, 3] -= center
305 |
306 | # Scale
307 | mesh.scale(scale, center=(0, 0, 0))
308 | Ts_grasp[:, :3, 3] *= scale
309 |
310 | # Initialize data indices
311 | data_idxs_rots = []
312 |
313 | # Initialize lists
314 | mesh_list_rots = []
315 | partial_pc_list_rots = []
316 |
317 | for R in tqdm(self.Rs, desc="Iterating rotations ...", leave=False):
318 | # Rotate mesh
319 | mesh_rot = deepcopy(mesh)
320 | mesh_rot.rotate(R, center=(0, 0, 0))
321 |
322 | # Sample partial point clouds
323 | partial_pc_views = get_partial_point_clouds(mesh_rot, view_vecs, num_pts, use_tqdm=True).transpose(0, 2, 1)
324 |
325 | # Initialize mesh list
326 | mesh_list_views = []
327 |
328 | for partial_pc in partial_pc_views:
329 | # Translate mesh to the center of the partial point cloud
330 | mesh_view = deepcopy(mesh_rot)
331 | mesh_view.translate(-partial_pc.mean(axis=1))
332 |
333 | # Append mesh
334 | mesh_list_views += [mesh_view]
335 |
336 | # Fill data indices
337 | data_idxs_rots += [list(range(self.len_dataset, self.len_dataset+num_views))]
338 |
339 | # Append data
340 | mesh_list_rots += [mesh_list_views]
341 | partial_pc_list_rots += [partial_pc_views]
342 |
343 | # Increase number of data
344 | self.len_dataset += num_views
345 |
346 | # Stack partial point clouds
347 | partial_pc_rots = np.stack(partial_pc_list_rots)
348 |
349 | # Rotate grasp poses
350 | Ts_grasp_rots = np.einsum('rij,njk->rnik', Ts, Ts_grasp)
351 |
352 | # Translate to the center of the partial point clouds
353 | center_rots = partial_pc_rots.mean(axis=3)
354 |
355 | Ts_grasp_rots = np.expand_dims(Ts_grasp_rots, axis=1).repeat(num_views, axis=1)
356 |
357 | partial_pc_rots -= np.expand_dims(center_rots, axis=3)
358 | Ts_grasp_rots[:, :, :, :3, 3] -= np.expand_dims(center_rots, axis=2)
359 |
360 | # Append data
361 | data_idxs_objs += [data_idxs_rots]
362 |
363 | mesh_list_objs += [mesh_list_rots]
364 | Ts_grasp_list_objs += [Ts_grasp_rots]
365 |
366 | partial_pc_list_objs += [partial_pc_rots]
367 |
368 | # Append data
369 | data_idxs_types += [data_idxs_objs]
370 |
371 | self.mesh_list_types += [mesh_list_objs]
372 | self.Ts_grasp_list_types += [Ts_grasp_list_objs]
373 |
374 | self.partial_pc_list_types += [partial_pc_list_objs]
375 |
376 | self.obj_idxs_types += [obj_idxs_objs]
377 |
378 | # Update maximum number of objects
379 | if len(obj_idxs_objs) > self.max_num_objs:
380 | self.max_num_objs = len(obj_idxs_objs)
381 |
382 | # Convert data indices from lists to numpy array
383 | self.data_idxs = np.full((len(obj_types), self.max_num_objs, num_rots, num_views), -1)
384 |
385 | for i, data_idxs_objs in enumerate(data_idxs_types):
386 | self.data_idxs[i, :len(data_idxs_objs)] = data_idxs_objs
387 |
388 | # Setup scene indices
389 | self.scene_idxs = self.data_idxs
390 | self.num_scenes = self.scene_idxs.max() + 1
391 |
392 | def __len__(self):
393 | return self.len_dataset
394 |
395 | def __getitem__(self, idx):
396 | # Get type, object, and rotation indices
397 | idx_type, idx_obj, idx_rot, idx_view = np.where(self.data_idxs==idx)
398 |
399 | idx_type = idx_type.item()
400 | idx_obj = idx_obj.item()
401 | idx_rot = idx_rot.item()
402 | idx_view = idx_view.item()
403 |
404 | # Load grasp poses
405 | Ts_grasp = self.Ts_grasp_list_types[idx_type][idx_obj][idx_rot][idx_view].copy()
406 |
407 | if self.split == 'train':
408 | # Load mesh
409 | mesh = deepcopy(self.mesh_list_types[idx_type][idx_obj][idx_rot][idx_view])
410 |
411 | # Get random rotation
412 | if self.augmentation == 'None':
413 | R = np.eye(3)
414 | elif self.augmentation == 'z':
415 | degree = np.random.rand() * 2 * np.pi
416 | R = Rotation.from_rotvec(degree * np.array([0, 0, 1])).as_matrix()
417 | elif self.augmentation == 'SO3':
418 | R = Rotation.random().as_matrix()
419 | else:
420 | raise ValueError("Choose augmentation from ['None', 'z', 'SO3'].")
421 |
422 | T = np.eye(4)
423 | T[:3, :3] = R
424 |
425 | # Rotate mesh
426 | mesh.rotate(R, center=(0, 0, 0))
427 |
428 | # Sample partial point cloud
429 | while True:
430 | try:
431 | view_vecs = -1 + 2 * np.random.rand(1, 3)
432 | view_vecs = view_vecs / np.linalg.norm(view_vecs)
433 |
434 | partial_pc = get_partial_point_clouds(mesh, view_vecs, self.num_pts)[0].T
435 |
436 | break
437 | except:
438 | pass
439 |
440 | # Rotate grasp poses
441 | Ts_grasp = T @ Ts_grasp
442 |
443 | # Translate to the center of the partial point cloud
444 | center = partial_pc.mean(axis=1)
445 | partial_pc -= np.expand_dims(center, axis=1)
446 | Ts_grasp[:, :3, 3] -= center
447 | else:
448 | # Load point cloud
449 | partial_pc = self.partial_pc_list_types[idx_type][idx_obj][idx_rot][idx_view]
450 |
451 | return {'pc': torch.Tensor(partial_pc), 'Ts_grasp': torch.Tensor(Ts_grasp)}
452 |
--------------------------------------------------------------------------------
/loaders/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import open3d as o3d
3 | import pickle
4 | import numpy as np
5 |
6 |
7 | DATASET_DIR = 'dataset'
8 |
9 |
10 | def load_grasp_poses(data):
11 | grasps = data['grasps/transforms'][()]
12 | success = data['grasps/qualities/flex/object_in_gripper'][()]
13 |
14 | grasps_good = grasps[success==1]
15 |
16 | return grasps_good
17 |
18 |
19 | def load_mesh(data):
20 | mesh_path = data['object/file'][()].decode('utf-8')
21 | mesh_scale = data['object/scale'][()]
22 |
23 | mesh = o3d.io.read_triangle_mesh(os.path.join(DATASET_DIR, mesh_path))
24 |
25 | mesh.scale(mesh_scale, center=(0, 0, 0))
26 |
27 | return mesh
28 |
--------------------------------------------------------------------------------
/losses/__init__.py:
--------------------------------------------------------------------------------
1 | from losses.mse_loss import MSELoss
2 |
3 |
4 | def get_losses(cfg_losses):
5 | losses = {}
6 |
7 | for cfg_loss in cfg_losses:
8 | name = cfg_loss.name
9 |
10 | losses[name] = get_loss(cfg_loss)
11 |
12 | return losses
13 |
14 |
15 | def get_loss(cfg_loss):
16 | name = cfg_loss.pop('name')
17 |
18 | if name == 'mse':
19 | loss = MSELoss(**cfg_loss)
20 | else:
21 | raise NotImplementedError(f"Loss {name} is not implemented.")
22 |
23 | return loss
24 |
--------------------------------------------------------------------------------
/losses/mse_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class MSELoss:
5 | def __init__(self, weight=1, reduction='mean'):
6 | self.weight = weight
7 |
8 | self.mse_loss = torch.nn.MSELoss(reduction=reduction)
9 |
10 | def __call__(self, pred, target):
11 | loss = self.mse_loss(pred, target)
12 |
13 | return loss
14 |
--------------------------------------------------------------------------------
/metrics/__init__.py:
--------------------------------------------------------------------------------
1 | from metrics.emd import EMDCalculator
2 |
3 |
4 | def get_metrics(cfg_metrics):
5 | metrics = {}
6 |
7 | for cfg_metric in cfg_metrics:
8 | name = cfg_metric.name
9 |
10 | metrics[name] = get_metric(cfg_metric)
11 |
12 | return metrics
13 |
14 |
15 | def get_metric(cfg_metric):
16 | name = cfg_metric.pop('name')
17 |
18 | if name == 'emd':
19 | metric = EMDCalculator(**cfg_metric)
20 | else:
21 | raise NotImplementedError(f"Metric {name} is not implemented.")
22 |
23 | return metric
24 |
--------------------------------------------------------------------------------
/metrics/emd.py:
--------------------------------------------------------------------------------
1 | from scipy.optimize import linear_sum_assignment
2 |
3 | from utils.Lie import SE3_geodesic_dist
4 |
5 |
6 | class EMDCalculator:
7 | def __init__(self, type):
8 | self.type = type
9 |
10 | def calculate_distance(self, x, y):
11 | if self.type == 'SE3':
12 | T_x = x.view(-1, 4, 4)
13 | T_y = y.view(-1, 4, 4)
14 |
15 | return SE3_geodesic_dist(T_x, T_y).view(x.shape[:2])
16 | elif self.type == 'L2':
17 | return ((x - y) ** 2).sum(dim=3).sqrt()
18 | else:
19 | raise NotImplementedError(f"Type {self.type} is not implemented. Choose type between 'SE3' and 'L2'.")
20 |
21 | def __call__(self, source, target):
22 | assert len(source) == len(target), f"The number of samples in source {len(source)} must be equal to the number of samples in target {len(target)}."
23 |
24 | source = source.unsqueeze(1).repeat(1, len(target), 1, 1)
25 | target = target.unsqueeze(0).repeat(len(source), 1, 1, 1)
26 |
27 | distance = self.calculate_distance(source, target).cpu().numpy()
28 |
29 | idxs_row, idxs_col = linear_sum_assignment(distance)
30 |
31 | emd = distance[idxs_row, idxs_col].mean()
32 |
33 | return emd
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from utils.distributions import get_dist
4 | from models.vn_dgcnn import VNDGCNNEncoder
5 | from models.vn_vector_fields import VNVectorFields
6 | from utils.ode_solvers import get_ode_solver
7 | from models.equi_grasp_flow import EquiGraspFlow
8 |
9 |
10 | def get_model(cfg_model):
11 | name = cfg_model.pop('name')
12 | checkpoint = cfg_model.get('checkpoint', None)
13 |
14 | if name == 'equigraspflow':
15 | model = _get_equigraspflow(cfg_model)
16 | else:
17 | raise NotImplementedError(f"Model {name} is not implemented.")
18 |
19 | if checkpoint is not None:
20 | checkpoint = torch.load(checkpoint, map_location='cpu')
21 |
22 | if 'model_state' in checkpoint:
23 | model.load_state_dict(checkpoint['model_state'])
24 |
25 | return model
26 |
27 |
28 | def _get_equigraspflow(cfg):
29 | p_uncond = cfg.pop('p_uncond')
30 | guidance = cfg.pop('guidance')
31 |
32 | init_dist = get_dist(cfg.pop('init_dist'))
33 | encoder = get_net(cfg.pop('encoder'))
34 | vector_field = get_net(cfg.pop('vector_field'))
35 | ode_solver = get_ode_solver(cfg.pop('ode_solver'))
36 |
37 | model = EquiGraspFlow(p_uncond, guidance, init_dist, encoder, vector_field, ode_solver)
38 |
39 | return model
40 |
41 |
42 | def get_net(cfg_net):
43 | name = cfg_net.pop('name')
44 |
45 | if name == 'vn_dgcnn_enc':
46 | net = _get_vn_dgcnn_enc(cfg_net)
47 | elif name == 'vn_vf':
48 | net = _get_vn_vf(cfg_net)
49 | else:
50 | raise NotImplementedError(f"Network {name} is not implemented.")
51 |
52 | return net
53 |
54 |
55 | def _get_vn_dgcnn_enc(cfg):
56 | net = VNDGCNNEncoder(**cfg)
57 |
58 | return net
59 |
60 |
61 | def _get_vn_vf(cfg):
62 | net = VNVectorFields(**cfg)
63 |
64 | return net
65 |
--------------------------------------------------------------------------------
/models/equi_grasp_flow.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from copy import deepcopy
3 |
4 | from utils.Lie import inv_SO3, log_SO3, exp_so3, bracket_so3
5 |
6 |
7 | class EquiGraspFlow(torch.nn.Module):
8 | def __init__(self, p_uncond, guidance, init_dist, encoder, vector_field, ode_solver):
9 | super().__init__()
10 |
11 | self.p_uncond = p_uncond
12 | self.guidance = guidance
13 |
14 | self.init_dist = init_dist
15 | self.encoder = encoder
16 | self.vector_field = vector_field
17 | self.ode_solver = ode_solver
18 |
19 | def step(self, data, losses, split, optimizer=None):
20 | # Get data
21 | pc = data['pc']
22 | x_1 = data['Ts_grasp']
23 |
24 | # Get number of grasp poses in each batch and combine batched data
25 | nums_grasps = torch.tensor([len(Ts_grasp) for Ts_grasp in x_1], device=data['pc'].device)
26 |
27 | x_1 = torch.cat(x_1, dim=0)
28 |
29 | # Sample t and x_0
30 | t = torch.rand(len(x_1), 1).to(x_1.device)
31 | x_0 = self.init_dist(len(x_1), x_1.device)
32 |
33 | # Get x_t and u_t
34 | x_t, u_t = get_traj(x_0, x_1, t)
35 |
36 | # Forward
37 | v_t = self(pc, t, x_t, nums_grasps)
38 |
39 | # Calculate loss
40 | loss_mse = losses['mse'](v_t, u_t)
41 |
42 | loss = losses['mse'].weight * loss_mse
43 |
44 | # Backward
45 | if optimizer is not None:
46 | loss.backward()
47 | optimizer.step()
48 |
49 | # Archive results
50 | results = {
51 | f'scalar/{split}/loss': loss.item(),
52 | }
53 |
54 | return results
55 |
56 | def forward(self, pc, t, x_t, nums_grasps):
57 | z = torch.zeros((len(pc), self.encoder.dims[-1], 3), device=pc.device)
58 |
59 | # Encode point cloud
60 | z = self.encoder(pc)
61 |
62 | # Repeat feature
63 | z = z.repeat_interleave(nums_grasps, dim=0)
64 |
65 | # Null condition
66 | mask_uncond = torch.bernoulli(torch.Tensor([self.p_uncond] * len(z))).to(bool)
67 |
68 | z[mask_uncond] = torch.zeros_like(z[mask_uncond])
69 |
70 | # Get vector
71 | v_t = self.vector_field(z, t, x_t)
72 |
73 | return v_t
74 |
75 | @torch.no_grad()
76 | def sample(self, pc, nums_grasps):
77 | # Sample initial samples
78 | x_0 = self.init_dist(sum(nums_grasps), pc.device)
79 | self.X0SAMPLED = deepcopy(x_0)
80 |
81 | # Encode point cloud
82 | z = self.encoder(pc)
83 |
84 | # Repeat feature
85 | z = z.repeat_interleave(nums_grasps, dim=0)
86 |
87 | # Push-forward initial samples
88 | x_1_hat = self.ode_solver(z, x_0, self.guided_vector_field)[:, -1]
89 |
90 | # Batch x_1_hat
91 | x_1_hat = x_1_hat.split(nums_grasps.tolist())
92 |
93 | return x_1_hat
94 |
95 | def guided_vector_field(self, z, t, x_t):
96 | v_t = (1 - self.guidance) * self.vector_field(torch.zeros_like(z), t, x_t) + self.guidance * self.vector_field(z, t, x_t)
97 |
98 | return v_t
99 |
100 |
101 | def get_traj(x_0, x_1, t):
102 | # Get rotations
103 | R_0 = x_0[:, :3, :3]
104 | R_1 = x_1[:, :3, :3]
105 |
106 | # Get translations
107 | p_0 = x_0[:, :3, 3]
108 | p_1 = x_1[:, :3, 3]
109 |
110 | # Get x_t
111 | x_t = torch.eye(4).repeat(len(x_1), 1, 1).to(x_1)
112 | x_t[:, :3, :3] = (R_0 @ exp_so3(t.unsqueeze(2) * log_SO3(inv_SO3(R_0) @ R_1)))
113 | x_t[:, :3, 3] = p_0 + t * (p_1 - p_0)
114 |
115 | # Get u_t
116 | u_t = torch.zeros(len(x_1), 6).to(x_1)
117 | u_t[:, :3] = bracket_so3(log_SO3(inv_SO3(R_0) @ R_1))
118 | u_t[:, :3] = torch.einsum('bij,bj->bi', R_0, u_t[:, :3]) # Convert w_b to w_s
119 | u_t[:, 3:] = p_1 - p_0
120 |
121 | return x_t, u_t
122 |
--------------------------------------------------------------------------------
/models/vn_dgcnn.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from models.vn_layers import VNLinearLeakyReLU, knn
4 |
5 |
6 | class VNDGCNNEncoder(torch.nn.Module):
7 | def __init__(self, num_neighbors, dims=[1, 21, 21, 42, 85, 341], use_bn=False):
8 | super().__init__()
9 |
10 | self.num_neighbors = num_neighbors
11 | self.dims = dims
12 |
13 | layers = []
14 |
15 | for dim_in, dim_out in zip(dims[:-2], dims[1:-1]):
16 | layers += [VNLinearLeakyReLU(2 * dim_in, dim_out, use_bn=use_bn)]
17 |
18 | layers += [VNLinearLeakyReLU(sum(dims[1:-1]), dims[-1], dim=4, share_nonlinearity=True, use_bn=use_bn)]
19 |
20 | self.layers = torch.nn.ModuleList(layers)
21 |
22 | def forward(self, x):
23 | x = x.unsqueeze(1)
24 |
25 | x_list = []
26 |
27 | for layer in self.layers[:-1]:
28 | x = get_graph_feature(x, k=self.num_neighbors)
29 | x = layer(x)
30 | x = x.mean(dim=-1)
31 |
32 | x_list += [x]
33 |
34 | x = torch.cat(x_list, dim=1)
35 |
36 | x = self.layers[-1](x)
37 | x = x.mean(dim=-1)
38 |
39 | return x
40 |
41 |
42 | def get_graph_feature(x, k=20):
43 | batch_size = x.shape[0]
44 | num_pts = x.shape[3]
45 |
46 | x = x.view(batch_size, -1, num_pts)
47 |
48 | idx = knn(x, k=k)
49 | idx_base = torch.arange(0, batch_size, device=idx.device).unsqueeze(1).unsqueeze(2) * num_pts
50 | idx = idx + idx_base
51 | idx = idx.view(-1)
52 |
53 | num_dims = x.shape[1] // 3
54 |
55 | x = x.transpose(2, 1).contiguous()
56 | feature = x.view(batch_size*num_pts, -1)[idx]
57 | feature = feature.view(batch_size, num_pts, k, num_dims, 3)
58 | x = x.view(batch_size, num_pts, 1, num_dims, 3).repeat(1, 1, k, 1, 1)
59 |
60 | feature = torch.cat((feature-x, x), dim=3).permute(0, 3, 4, 1, 2).contiguous()
61 |
62 | return feature
--------------------------------------------------------------------------------
/models/vn_layers.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | EPS = 1e-6
5 |
6 |
7 | class VNLinear(torch.nn.Module):
8 | def __init__(self, in_channels, out_channels):
9 | super().__init__()
10 |
11 | self.map_to_feat = torch.nn.Linear(in_channels, out_channels, bias=False)
12 |
13 | def forward(self, x):
14 | '''
15 | x: point features of shape [B, N_feat, 3, N_samples, ...]
16 | '''
17 | x_out = self.map_to_feat(x.transpose(1,-1)).transpose(1,-1)
18 |
19 | return x_out
20 |
21 |
22 | class VNBatchNorm(torch.nn.Module):
23 | def __init__(self, num_features, dim):
24 | super().__init__()
25 |
26 | if dim == 3 or dim == 4:
27 | self.bn = torch.nn.BatchNorm1d(num_features)
28 | elif dim == 5:
29 | self.bn = torch.nn.BatchNorm2d(num_features)
30 |
31 | def forward(self, x):
32 | '''
33 | x: point features of shape [B, N_feat, 3, N_samples, ...]
34 | '''
35 | # norm = torch.sqrt((x*x).sum(2))
36 | norm = torch.norm(x, dim=2) + EPS
37 | norm_bn = self.bn(norm)
38 | norm = norm.unsqueeze(2)
39 | norm_bn = norm_bn.unsqueeze(2)
40 | x = x / norm * norm_bn
41 |
42 | return x
43 |
44 |
45 | class VNLinearLeakyReLU(torch.nn.Module):
46 | def __init__(self, in_channels, out_channels, dim=5, share_nonlinearity=False, use_bn=True, negative_slope=0.2):
47 | super().__init__()
48 |
49 | self.negative_slope = negative_slope
50 | self.use_bn = use_bn
51 |
52 | # Linear
53 | self.map_to_feat = torch.nn.Linear(in_channels, out_channels, bias=False)
54 |
55 | # BatchNorm
56 | if use_bn:
57 | self.bn = VNBatchNorm(out_channels, dim=dim)
58 |
59 | # LeakyReLU
60 | if share_nonlinearity:
61 | self.map_to_dir = torch.nn.Linear(in_channels, 1, bias=False)
62 | else:
63 | self.map_to_dir = torch.nn.Linear(in_channels, out_channels, bias=False)
64 |
65 | def forward(self, x):
66 | '''
67 | x: point features of shape [B, N_feat, 3, N_samples, ...]
68 | '''
69 | # Linear
70 | p = self.map_to_feat(x.transpose(1,-1)).transpose(1,-1)
71 |
72 | # BatchNorm
73 | if self.use_bn:
74 | p = self.bn(p)
75 |
76 | # LeakyReLU
77 | d = self.map_to_dir(x.transpose(1,-1)).transpose(1,-1)
78 | dotprod = (p*d).sum(2, keepdims=True)
79 | mask = (dotprod >= 0).float()
80 | d_norm_sq = (d*d).sum(2, keepdims=True)
81 | x_out = self.negative_slope * p + (1-self.negative_slope) * (mask*p + (1-mask)*(p-(dotprod/(d_norm_sq+EPS))*d))
82 |
83 | return x_out
84 |
85 |
86 | def knn(x, k):
87 | pairwise_distance = (x.unsqueeze(-1) - x.unsqueeze(-2)).norm(dim=1) ** 2
88 |
89 | idx = pairwise_distance.topk(k, dim=-1, largest=False)[1] # (batch_size, num_pts, k)
90 |
91 | return idx
--------------------------------------------------------------------------------
/models/vn_vector_fields.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from models.vn_layers import VNLinearLeakyReLU, VNLinear
4 |
5 |
6 | class VNVectorFields(torch.nn.Module):
7 | def __init__(self, dims, use_bn):
8 | super().__init__()
9 |
10 | # Setup lifting layer
11 | self.lifting_layer = VNLinear(dims[0] - 1, 1)
12 |
13 | # Setup VN-MLP
14 | layers = []
15 |
16 | for i in range(len(dims)-2):
17 | layers += [VNLinearLeakyReLU(dims[i], dims[i+1], dim=4, use_bn=use_bn)]
18 |
19 | layers += [VNLinear(dims[-2], dims[-1])]
20 |
21 | self.layers = torch.nn.Sequential(*layers)
22 |
23 | def forward(self, z, t, x_t):
24 | # Construct scalar-list and vector-list
25 | s = t.unsqueeze(1)
26 | v = torch.cat((z, x_t[:, :3].transpose(1, 2)), dim=1)
27 |
28 | # Lift scalar-list to vector-list
29 | trans = self.lifting_layer(v)
30 | v_s = s @ trans
31 |
32 | # Concatenate
33 | v = torch.cat((v, v_s), dim=1)
34 |
35 | # Forward VN-MLP
36 | out = self.layers(v).contiguous()
37 |
38 | # Convert two 3-dim vectors to one 6-dim vector
39 | out = out.view(-1, 6)
40 |
41 | return out
42 |
--------------------------------------------------------------------------------
/test_full.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from datetime import datetime
3 | import os
4 | from omegaconf import OmegaConf
5 | import logging
6 | import yaml
7 | import random
8 | import numpy as np
9 | import torch
10 | import pandas as pd
11 | import plotly.graph_objects as go
12 |
13 | from loaders import get_dataloader
14 | from models import get_model
15 | from metrics import get_metrics
16 | from utils.visualization import PlotlySubplotsVisualizer
17 |
18 |
19 | NUM_GRASPS = 100
20 |
21 |
22 | def main(args, cfg):
23 | seed = cfg.get('seed', 1)
24 | random.seed(seed)
25 | np.random.seed(seed)
26 | torch.manual_seed(seed)
27 | torch.cuda.manual_seed(seed)
28 | torch.set_num_threads(8)
29 | torch.backends.cudnn.deterministic = True
30 |
31 | # Setup testloader
32 | test_loader = get_dataloader('test', cfg.data.test)
33 |
34 | # Setup model
35 | model = get_model(cfg.model).to(cfg.device)
36 |
37 | # Setup metrics
38 | metrics = get_metrics(cfg.metrics)
39 |
40 | # Setup plotly visualizer
41 | visualizer = PlotlySubplotsVisualizer(rows=1, cols=test_loader.dataset.num_rots)
42 |
43 | # Start test
44 | results = test(model, test_loader, metrics, cfg.device, visualizer)
45 |
46 | # Print results
47 | print_results(test_loader, results)
48 |
49 | # Write xlsx
50 | log_results(args.logdir, test_loader, results)
51 |
52 | # Save plotly figure
53 | save_figure(args.logdir, visualizer)
54 |
55 |
56 | def test(model, test_loader, metrics, device, visualizer):
57 | # Initialize
58 | model.eval()
59 |
60 | # Get dataset
61 | obj_types = test_loader.dataset.obj_types
62 | obj_idxs_types = test_loader.dataset.obj_idxs_types
63 | pc_list_types = test_loader.dataset.pc_list_types
64 | mesh_list_types = test_loader.dataset.mesh_list_types
65 | Ts_grasp_list_types = test_loader.dataset.Ts_grasp_list_types
66 |
67 | # Get scale, maximum number of objects and number of rotations
68 | scale = test_loader.dataset.scale
69 | max_num_objs = test_loader.dataset.max_num_objs
70 | num_rots = test_loader.dataset.num_rots
71 |
72 | # Setup metric result arrays
73 | results = {key: np.full((len(pc_list_types), max_num_objs, num_rots), np.nan) for key in list(metrics.keys())}
74 |
75 | # Setup labels for button in plotly figure
76 | visualizer.labels = []
77 |
78 | # Iterate
79 | for i, (obj_type, obj_idxs_objs, pc_list_objs, mesh_list_objs, Ts_grasp_list_objs) in enumerate(zip(obj_types, obj_idxs_types, pc_list_types, mesh_list_types, Ts_grasp_list_types)):
80 | for j, (obj_idx, pc_rots, mesh_list_rots, Ts_grasp_rots_target) in enumerate(zip(obj_idxs_objs, pc_list_objs, mesh_list_objs, Ts_grasp_list_objs)):
81 | # Setup input
82 | pc_rots = torch.Tensor(pc_rots).to(device)
83 | Ts_grasp_rots_target = torch.Tensor(Ts_grasp_rots_target).to(device)
84 | nums_grasps = torch.tensor([len(Ts_grasp_target) for Ts_grasp_target in Ts_grasp_rots_target], device=pc_rots.device)
85 |
86 | # Sample grasp poses
87 | Ts_grasp_rots_pred = model.sample(pc_rots, nums_grasps)
88 |
89 | # Compute metrics
90 | for k, (mesh, Ts_grasp_pred, Ts_grasp_target) in enumerate(zip(mesh_list_rots, Ts_grasp_rots_pred, Ts_grasp_rots_target)):
91 | # Setup message
92 | msg = f"object type: {obj_type}, object index: {obj_idx}, rotation index: {k}, "
93 |
94 | # Rescale mesh and grasp poses
95 | mesh.scale(1/scale, center=(0, 0, 0))
96 | Ts_grasp_pred[:, :3, 3] /= scale
97 | Ts_grasp_target[:, :3, 3] /= scale
98 |
99 | for key, metric in metrics.items():
100 | # Compute metrics
101 | result = metric(Ts_grasp_pred, Ts_grasp_target)
102 |
103 | # Add result to message
104 | msg += f"{key}: {result:.4f}, "
105 |
106 | # Fill array
107 | results[key][i, j, k] = result
108 |
109 | # Print result message
110 | print(msg)
111 | logging.info(msg)
112 |
113 | # Get indices for sampling grasp poses for visualization
114 | idxs = torch.randperm(len(Ts_grasp_pred))[:NUM_GRASPS]
115 |
116 | # Add mesh and gripper to visualizer
117 | visualizer.add_mesh(mesh, row=1, col=k+1)
118 | visualizer.add_grippers(Ts_grasp_pred[idxs], color='grey', row=1, col=k+1)
119 |
120 | visualizer.labels += [f'{obj_type}_{obj_idx}']
121 |
122 | return results
123 |
124 |
125 | def print_results(test_loader, results):
126 | # Get object types and object ids
127 | obj_types = test_loader.dataset.obj_types
128 |
129 | # Print results
130 | for idx_type, obj_type in enumerate(obj_types):
131 | msg = f"object type: {obj_type}"
132 |
133 | for key in results.keys():
134 | msg += f", {key}: {np.nanmean(results[key][idx_type]):.4f}"
135 |
136 | print(msg)
137 | logging.info(msg)
138 |
139 |
140 | def log_results(logdir, test_loader, results):
141 | # Get object types and object ids
142 | obj_types = test_loader.dataset.obj_types
143 | obj_idxs_types = test_loader.dataset.obj_idxs_types
144 |
145 | # Write xlsx
146 | for key, result in results.items():
147 | with pd.ExcelWriter(os.path.join(logdir, f'{key}.xlsx')) as w:
148 | for obj_type, obj_idxs_objs, result_type in zip(obj_types, obj_idxs_types, result):
149 | df = pd.DataFrame(result_type[:len(obj_idxs_objs)], index=obj_idxs_objs)
150 | df.to_excel(w, sheet_name=obj_type)
151 |
152 |
153 | def save_figure(logdir, visualizer):
154 | # Get number of traces and number of subplots
155 | num_traces = len(visualizer.fig.data)
156 | num_subplots = visualizer.num_subplots
157 |
158 | # Make only the first scene visible
159 | for idx_trace in range(num_subplots*(1+NUM_GRASPS), num_traces):
160 | visualizer.fig.update_traces(visible=False, selector=idx_trace)
161 |
162 | # Make buttons list
163 | buttons = []
164 |
165 | for idx_scene, label in enumerate(visualizer.labels):
166 | # Initialize visibility list
167 | visibility = num_traces * [False]
168 |
169 | # Make only the selected scene visible
170 | for idx_trace in range(num_subplots*(1+NUM_GRASPS)*idx_scene, num_subplots*(1+NUM_GRASPS)*(idx_scene+1)):
171 | visibility[idx_trace] = True
172 |
173 | # Make and append button
174 | button = dict(label=label, method='restyle', args=[{'visible': visibility}])
175 |
176 | buttons += [button]
177 |
178 | # Update buttons
179 | visualizer.fig.update_layout(updatemenus=[go.layout.Updatemenu(active=0, buttons=buttons)])
180 |
181 | # Save figure
182 | visualizer.fig.write_json(os.path.join(logdir, 'visualizations.json'))
183 |
184 |
185 | if __name__ == '__main__':
186 | # Parse arguments
187 | parser = argparse.ArgumentParser()
188 |
189 | parser.add_argument('--train_result_path', type=str)
190 | parser.add_argument('--checkpoint', type=str)
191 | parser.add_argument('--device', default=0)
192 | parser.add_argument('--logdir', default='test_results')
193 | parser.add_argument('--run', type=str, default=datetime.now().strftime('%Y%m%d-%H%M'))
194 |
195 | args = parser.parse_args()
196 |
197 | # Load config
198 | config_filename = [file for file in os.listdir(args.train_result_path) if file.endswith('.yml')][0]
199 |
200 | cfg = OmegaConf.load(os.path.join(args.train_result_path, config_filename))
201 |
202 | # Setup checkpoint
203 | cfg.model.checkpoint = os.path.join(args.train_result_path, args.checkpoint)
204 |
205 | # Setup device
206 | if args.device == 'cpu':
207 | cfg.device = 'cpu'
208 | else:
209 | cfg.device = f'cuda:{args.device}'
210 |
211 | # Setup logdir
212 | config_basename = os.path.splitext(config_filename)[0]
213 |
214 | args.logdir = os.path.join(args.logdir, config_basename, args.run)
215 |
216 | os.makedirs(args.logdir, exist_ok=True)
217 |
218 | # Setup logging
219 | logging.basicConfig(
220 | filename=os.path.join(args.logdir, 'logging.log'),
221 | format='%(asctime)s [%(levelname)s] %(message)s',
222 | datefmt='%Y/%m/%d %I:%M:%S %p',
223 | level=logging.DEBUG
224 | )
225 |
226 | # Print result directory
227 | print(f"Result directory: {args.logdir}")
228 | logging.info(f"Result directory: {args.logdir}")
229 |
230 | # Save config
231 | config_path = os.path.join(args.logdir, config_filename)
232 | yaml.dump(yaml.safe_load(OmegaConf.to_yaml(cfg)), open(config_path, 'w'))
233 |
234 | print(f"Config saved as {config_path}")
235 | logging.info(f"Config saved as {config_path}")
236 |
237 | main(args, cfg)
238 |
--------------------------------------------------------------------------------
/test_partial.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from datetime import datetime
3 | import os
4 | from omegaconf import OmegaConf
5 | import logging
6 | import yaml
7 | import random
8 | import numpy as np
9 | import torch
10 | import plotly.graph_objects as go
11 |
12 | from loaders import get_dataloader
13 | from models import get_model
14 | from metrics import get_metrics
15 | from utils.visualization import PlotlySubplotsVisualizer
16 |
17 |
18 | NUM_GRASPS = 25
19 |
20 |
21 | def main(args, cfg):
22 | seed = cfg.get('seed', 1)
23 | random.seed(seed)
24 | np.random.seed(seed)
25 | torch.manual_seed(seed)
26 | torch.cuda.manual_seed(seed)
27 | torch.set_num_threads(8)
28 | torch.backends.cudnn.deterministic = True
29 |
30 | # Setup testloader
31 | test_loader = get_dataloader('test', cfg.data.test)
32 |
33 | # Setup model
34 | model = get_model(cfg.model).to(cfg.device)
35 |
36 | # Setup metrics
37 | metrics = get_metrics(cfg.metrics)
38 |
39 | # Setup plotly visualizer
40 | visualizer = PlotlySubplotsVisualizer(rows=test_loader.dataset.num_rots, cols=test_loader.dataset.num_views)
41 | visualizer.fig.update_layout(height=2700)
42 |
43 | # Start test
44 | results = test(args, model, test_loader, metrics, cfg.device, visualizer)
45 |
46 | # Print results
47 | print_results(test_loader, results)
48 |
49 | # Save plotly figure
50 | save_figure(args.logdir, visualizer)
51 |
52 |
53 | def test(args, model, test_loader, metrics, device, visualizer):
54 | # Initialize
55 | model.eval()
56 |
57 | # Get arguments
58 | logdir = args.logdir
59 |
60 | # Get dataset
61 | obj_types = test_loader.dataset.obj_types
62 | obj_idxs_types = test_loader.dataset.obj_idxs_types
63 | partial_pc_list_types = test_loader.dataset.partial_pc_list_types
64 | mesh_list_types = test_loader.dataset.mesh_list_types
65 | Ts_grasp_list_types = test_loader.dataset.Ts_grasp_list_types
66 |
67 | # Get scale, maximum number of objects and number of rotations
68 | scale = test_loader.dataset.scale if hasattr(test_loader.dataset, 'scale') else 1
69 | max_num_objs = test_loader.dataset.max_num_objs
70 | num_rots = test_loader.dataset.num_rots
71 | num_views = test_loader.dataset.num_views
72 |
73 | # Setup metric result arrays
74 | results = {key: np.full((len(partial_pc_list_types), max_num_objs, num_rots, num_views), np.nan) for key in list(metrics.keys())}
75 |
76 | # Setup labels for button in plotly figure
77 | visualizer.labels = []
78 |
79 | # Iterate
80 | for i, (obj_type, obj_idxs_objs, partial_pc_list_objs, Ts_grasp_list_objs, mesh_list_objs) in enumerate(zip(obj_types, obj_idxs_types, partial_pc_list_types, Ts_grasp_list_types, mesh_list_types)):
81 | for j, (obj_idx, partial_pc_rots, Ts_grasp_rots_target, mesh_list_rots) in enumerate(zip(obj_idxs_objs, partial_pc_list_objs, Ts_grasp_list_objs, mesh_list_objs)):
82 | # Setup input
83 | Ts_grasp_rots_target = torch.Tensor(Ts_grasp_rots_target).to(device)
84 |
85 | for k, (partial_pc_views, Ts_grasp_views_target, mesh_list_views) in enumerate(zip(partial_pc_rots, Ts_grasp_rots_target, mesh_list_rots)):
86 | # Setup input
87 | partial_pc_views = torch.Tensor(partial_pc_views).to(device)
88 | nums_grasps = torch.tensor([Ts_grasp_views_target.shape[1]]*len(partial_pc_views), device=partial_pc_views.device)
89 |
90 | # Sample grasp poses
91 | Ts_grasp_views_pred = model.sample(partial_pc_views, nums_grasps)
92 |
93 | # Compute metrics
94 | for l, (partial_pc, Ts_grasp_pred, Ts_grasp_target, mesh) in enumerate(zip(partial_pc_views, Ts_grasp_views_pred, Ts_grasp_views_target, mesh_list_views)):
95 | # Setup message
96 | msg = f"object type: {obj_type}, object index: {obj_idx}, rotation index: {k}, viewpoint index: {l}, "
97 |
98 | # Rescale point cloud and grasp poses
99 | partial_pc /= scale
100 | Ts_grasp_pred[:, :3, 3] /= scale
101 | Ts_grasp_target[:, :3, 3] /= scale
102 | mesh.scale(1/scale, center=(0, 0, 0))
103 |
104 | for key, metric in metrics.items():
105 | # Compute metrics
106 | result = metric(Ts_grasp_pred, Ts_grasp_target)
107 |
108 | # Add result to message
109 | msg += f"{key}: {result:.4f}, "
110 |
111 | # Fill array
112 | results[key][i, j, k, l] = result
113 |
114 | # Print result message
115 | print(msg)
116 | logging.info(msg)
117 |
118 | # Get indices for sampling grasp poses for simulation
119 | idxs = torch.randperm(len(Ts_grasp_pred))[:NUM_GRASPS]
120 |
121 | # Add mesh, partial point cloud, and gripper to visualizer
122 | visualizer.add_mesh(mesh, row=k+1, col=l+1)
123 |
124 | visualizer.add_pc(partial_pc.cpu().numpy().T, row=k+1, col=l+1)
125 | visualizer.add_grippers(Ts_grasp_pred[idxs], color='grey', row=k+1, col=l+1)
126 |
127 | visualizer.labels += [f'{obj_type}_{obj_idx}']
128 |
129 | return results
130 |
131 |
132 | def print_results(test_loader, results):
133 | # Get object types and object ids
134 | obj_types = test_loader.dataset.obj_types
135 |
136 | # Print results
137 | for idx_type, obj_type in enumerate(obj_types):
138 | msg = f"object type: {obj_type}"
139 |
140 | for key in results.keys():
141 | msg += f", {key}: {np.nanmean(results[key][idx_type]):.4f}"
142 |
143 | print(msg)
144 | logging.info(msg)
145 |
146 |
147 | def save_figure(logdir, visualizer):
148 | # Get number of traces and number of subplots
149 | num_traces = len(visualizer.fig.data)
150 | num_subplots = visualizer.num_subplots
151 |
152 | # Make only the first scene visible
153 | for idx_trace in range(num_subplots*(2+NUM_GRASPS), num_traces):
154 | visualizer.fig.update_traces(visible=False, selector=idx_trace)
155 |
156 | # Make buttons list
157 | buttons = []
158 |
159 | for idx_scene, label in enumerate(visualizer.labels):
160 | # Initialize visibility list
161 | visibility = num_traces * [False]
162 |
163 | # Make only the selected scene visible
164 | for idx_trace in range(num_subplots*(2+NUM_GRASPS)*idx_scene, num_subplots*(2+NUM_GRASPS)*(idx_scene+1)):
165 | visibility[idx_trace] = True
166 |
167 | # Make and append button
168 | button = dict(label=label, method='restyle', args=[{'visible': visibility}])
169 |
170 | buttons += [button]
171 |
172 | # Update buttons
173 | visualizer.fig.update_layout(updatemenus=[go.layout.Updatemenu(active=0, buttons=buttons)])
174 |
175 | # Save figure
176 | visualizer.fig.write_json(os.path.join(logdir, 'visualizations.json'))
177 |
178 |
179 | if __name__ == '__main__':
180 | # Parse arguments
181 | parser = argparse.ArgumentParser()
182 |
183 | parser.add_argument('--train_result_path', type=str)
184 | parser.add_argument('--checkpoint', type=str)
185 | parser.add_argument('--device', default=0)
186 | parser.add_argument('--logdir', default='test_results')
187 | parser.add_argument('--run', type=str, default=datetime.now().strftime('%Y%m%d-%H%M'))
188 |
189 | args = parser.parse_args()
190 |
191 | # Load config
192 | config_filename = [file for file in os.listdir(args.train_result_path) if file.endswith('.yml')][0]
193 |
194 | cfg = OmegaConf.load(os.path.join(args.train_result_path, config_filename))
195 |
196 | # Setup checkpoint
197 | cfg.model.checkpoint = os.path.join(args.train_result_path, args.checkpoint)
198 |
199 | # Setup device
200 | if args.device == 'cpu':
201 | cfg.device = 'cpu'
202 | else:
203 | cfg.device = f'cuda:{args.device}'
204 |
205 | # Setup logdir
206 | config_basename = os.path.splitext(config_filename)[0]
207 |
208 | args.logdir = os.path.join(args.logdir, config_basename, args.run)
209 |
210 | os.makedirs(args.logdir, exist_ok=True)
211 |
212 | # Setup logging
213 | logging.basicConfig(
214 | filename=os.path.join(args.logdir, 'logging.log'),
215 | format='%(asctime)s [%(levelname)s] %(message)s',
216 | datefmt='%Y/%m/%d %I:%M:%S %p',
217 | level=logging.DEBUG
218 | )
219 |
220 | # Print result directory
221 | print(f"Result directory: {args.logdir}")
222 | logging.info(f"Result directory: {args.logdir}")
223 |
224 | # Save config
225 | config_path = os.path.join(args.logdir, config_filename)
226 | yaml.dump(yaml.safe_load(OmegaConf.to_yaml(cfg)), open(config_path, 'w'))
227 |
228 | print(f"Config saved as {config_path}")
229 | logging.info(f"Config saved as {config_path}")
230 |
231 | main(args, cfg)
232 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from datetime import datetime
3 | from omegaconf import OmegaConf
4 | import os
5 | from tensorboardX import SummaryWriter
6 | import logging
7 | import yaml
8 | import random
9 | import numpy as np
10 | import torch
11 |
12 | from loaders import get_dataloader
13 | from models import get_model
14 | from losses import get_losses
15 | from utils.optimizers import get_optimizer
16 | from metrics import get_metrics
17 | from utils.logger import Logger
18 | from trainers import get_trainer
19 |
20 |
21 | def main(cfg, writer):
22 | # Setup seed
23 | seed = cfg.get('seed', 1)
24 | random.seed(seed)
25 | np.random.seed(seed)
26 | torch.manual_seed(seed)
27 | torch.cuda.manual_seed(seed)
28 | torch.set_num_threads(8)
29 | torch.backends.cudnn.deterministic = True
30 |
31 | # Setup dataloader
32 | dataloaders = {}
33 |
34 | for split in ['train', 'val']:
35 | dataloaders[split] = get_dataloader(split, cfg.data[split])
36 |
37 | # Setup model
38 | model = get_model(cfg.model).to(cfg.device)
39 |
40 | # Setup losses
41 | losses = get_losses(cfg.losses)
42 |
43 | # Setup optimizer
44 | optimizer = get_optimizer(cfg.optimizer, model.parameters())
45 |
46 | # Setup metrics
47 | metrics = get_metrics(cfg.metrics)
48 |
49 | # Setup logger
50 | logger = Logger(writer)
51 |
52 | # Setup trainer
53 | trainer = get_trainer(cfg.trainer, cfg.device, dataloaders, model, losses, optimizer, metrics, logger)
54 |
55 | # Start learning
56 | trainer.run()
57 |
58 |
59 | if __name__ == '__main__':
60 | # Parse arguments
61 | parser = argparse.ArgumentParser()
62 |
63 | parser.add_argument('--config', type=str)
64 | parser.add_argument('--device', default=0)
65 | parser.add_argument('--logdir', default='train_results')
66 | parser.add_argument('--run', type=str, default=datetime.now().strftime('%Y%m%d-%H%M'))
67 |
68 | args = parser.parse_args()
69 |
70 | # Load and print config
71 | cfg = OmegaConf.load(args.config)
72 | print(OmegaConf.to_yaml(cfg))
73 |
74 | # Setup device
75 | if args.device == 'cpu':
76 | cfg.device = 'cpu'
77 | else:
78 | cfg.device = f'cuda:{args.device}'
79 |
80 | # Setup logdir
81 | config_filename = os.path.basename(args.config)
82 | config_basename = os.path.splitext(config_filename)[0]
83 |
84 | logdir = os.path.join(args.logdir, config_basename, args.run)
85 |
86 | # Setup tensorboard writer
87 | writer = SummaryWriter(logdir)
88 |
89 | # Setup logging
90 | logging.basicConfig(
91 | filename=os.path.join(logdir, 'logging.log'),
92 | format='%(asctime)s [%(levelname)s] %(message)s',
93 | datefmt='%Y/%m/%d %I:%M:%S %p',
94 | level=logging.DEBUG
95 | )
96 |
97 | # Print logdir
98 | print(f"Result directory: {logdir}")
99 | logging.info(f"Result directory: {logdir}")
100 |
101 | # Save config
102 | config_path = os.path.join(logdir, config_filename)
103 | yaml.dump(yaml.safe_load(OmegaConf.to_yaml(cfg)), open(config_path, 'w'))
104 |
105 | print(f"Config saved as {config_path}")
106 | logging.info(f"Config saved as {config_path}")
107 |
108 | main(cfg, writer)
109 |
--------------------------------------------------------------------------------
/trainers/__init__.py:
--------------------------------------------------------------------------------
1 | from trainers.grasp_trainer import GraspPoseGeneratorTrainer, PartialGraspPoseGeneratorTrainer
2 |
3 |
4 | def get_trainer(cfg_trainer, device, dataloaders, model, losses, optimizer, metrics, logger):
5 | name = cfg_trainer.pop('name')
6 |
7 | if name == 'grasp_full':
8 | trainer = GraspPoseGeneratorTrainer(cfg_trainer, device, dataloaders, model, losses, optimizer, metrics, logger)
9 | elif name == 'grasp_partial':
10 | trainer = PartialGraspPoseGeneratorTrainer(cfg_trainer, device, dataloaders, model, losses, optimizer, metrics, logger)
11 | else:
12 | raise NotImplementedError(f"Trainer {name} is not implemented.")
13 |
14 | return trainer
15 |
--------------------------------------------------------------------------------
/trainers/grasp_trainer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import time
3 | import logging
4 | from tqdm import tqdm
5 | import numpy as np
6 | from copy import deepcopy
7 | import os
8 |
9 | from utils.average_meter import AverageMeter
10 | from utils.mesh import generate_grasp_scene_list, meshes_to_numpy
11 |
12 |
13 | NUM_GRASPS = 100
14 |
15 |
16 | class GraspPoseGeneratorTrainer:
17 | def __init__(self, cfg, device, dataloaders, model, losses, optimizer, metrics, logger):
18 | self.cfg = cfg
19 | self.train_loader = dataloaders['train']
20 | self.val_loader = dataloaders['val']
21 | self.device = device
22 | self.model = model
23 | self.losses = losses
24 | self.optimizer = optimizer
25 | self.logger = logger
26 | self.metrics = metrics
27 |
28 | # Get logdir
29 | self.logdir = self.logger.writer.file_writer.get_logdir()
30 |
31 | # Setup meters
32 | self.setup_meters()
33 |
34 | # Initialize performance dictionary
35 | self.setup_performance_dict()
36 |
37 | def setup_meters(self):
38 | # Setup time meter
39 | self.time_meter = AverageMeter()
40 |
41 | # Setup scalar meters for train
42 | for data in self.train_loader:
43 | break
44 |
45 | for key, val in data.items():
46 | if type(val) == torch.Tensor:
47 | data[key] = val.to(self.device)
48 | elif type(val) == list:
49 | data[key] = [v.to(self.device) for v in val]
50 |
51 | with torch.no_grad():
52 | results_train = self.model.step(data, self.losses, 'train')
53 | results_val = self.model.step(data, self.losses, 'val')
54 |
55 | self.train_meters = {key: AverageMeter() for key in results_train.keys() if 'scalar' in key}
56 | self.val_meters = {key: AverageMeter() for key in results_val.keys() if 'scalar' in key}
57 |
58 | # Setup metric meters
59 | self.metric_meters = {key: AverageMeter() for key in self.metrics.keys()}
60 |
61 | def setup_performance_dict(self):
62 | self.performances = {'val_loss': torch.inf}
63 |
64 | for criterion in self.cfg.criteria:
65 | assert criterion.name in self.metrics.keys(), f"Criterion {criterion.name} not in metrics keys {self.metrics.keys()}."
66 |
67 | if criterion.better == 'higher':
68 | self.performances[criterion.name] = 0
69 | elif criterion.better == 'lower':
70 | self.performances[criterion.name] = torch.inf
71 | else:
72 | raise ValueError(f"Criterion better with {criterion.better} value is not supported. Choose 'higher' or 'lower'.")
73 |
74 | def run(self):
75 | # Initialize
76 | iter = 0
77 |
78 | # Start learning
79 | for epoch in range(1, self.cfg.num_epochs+1):
80 | for data in self.train_loader:
81 | iter += 1
82 |
83 | # Training
84 | results_train = self.train(data)
85 |
86 | # Print
87 | if iter % self.cfg.print_interval == 0:
88 | self.print(results_train, epoch, iter)
89 |
90 | # Validation
91 | if iter % self.cfg.val_interval == 0:
92 | self.validate(epoch, iter)
93 |
94 | # Evaluation
95 | if iter % self.cfg.eval_interval == 0:
96 | self.evaluate(epoch, iter)
97 |
98 | # Visualization
99 | if iter % self.cfg.vis_interval == 0:
100 | self.visualize(epoch, iter)
101 |
102 | # Save
103 | if iter % self.cfg.save_interval == 0:
104 | self.save(epoch, iter)
105 |
106 | def train(self, data):
107 | # Initialize
108 | self.model.train()
109 | self.optimizer.zero_grad()
110 |
111 | # Setup input
112 | for key, val in data.items():
113 | if type(val) == torch.Tensor:
114 | data[key] = val.to(self.device)
115 | elif type(val) == list:
116 | data[key] = [v.to(self.device) for v in val]
117 |
118 | # Step
119 | time_start = time.time()
120 |
121 | results = self.model.step(data, self.losses, 'train', self.optimizer)
122 |
123 | time_end = time.time()
124 |
125 | # Update time meter
126 | self.time_meter.update(time_end - time_start)
127 |
128 | # Update train meters
129 | for key, meter in self.train_meters.items():
130 | meter.update(results[key], n=len(data['pc']))
131 |
132 | return results
133 |
134 | def print(self, results, epoch, iter):
135 | # Get averaged train results
136 | for key, meter in self.train_meters.items():
137 | results[key] = meter.avg
138 |
139 | # Log averaged train results
140 | self.logger.log(results, iter)
141 |
142 | # Print averaged train results
143 | msg = f"[ Training ] epoch: {epoch}/{self.cfg.num_epochs}, iter: {iter}, "
144 | msg += ", ".join([f"{key.split('/')[-1]}: {meter.avg:.4f}" for key, meter in self.train_meters.items()])
145 | msg += f", elapsed time: {self.time_meter.sum:.4f}"
146 |
147 | print(msg)
148 | logging.info(msg)
149 |
150 | # Reset time meter and train meters
151 | self.time_meter.reset()
152 |
153 | for key, meter in self.train_meters.items():
154 | meter.reset()
155 |
156 | def validate(self, epoch, iter):
157 | # Initialize
158 | self.model.eval()
159 |
160 | time_start = time.time()
161 |
162 | with torch.no_grad():
163 | for data in tqdm(self.val_loader, desc="Validating ...", leave=False):
164 | # Setup input
165 | for key, val in data.items():
166 | if type(val) == torch.Tensor:
167 | data[key] = val.to(self.device)
168 | elif type(val) == list:
169 | data[key] = [v.to(self.device) for v in val]
170 |
171 | # Step
172 | results = self.model.step(data, self.losses, 'val')
173 |
174 | # Update validation meters
175 | for key, meter in self.val_meters.items():
176 | meter.update(results[key], n=len(data['pc']))
177 |
178 | time_end = time.time()
179 |
180 | # Get averaged validation results
181 | for key, meter in self.val_meters.items():
182 | results[key] = meter.avg
183 |
184 | # Log averaged validation results
185 | self.logger.log(results, iter)
186 |
187 | # Print averaged validation results
188 | msg = f"[ Validation ] epoch: {epoch}/{self.cfg.num_epochs}, iter: {iter}, "
189 | msg += ", ".join([f"{key.split('/')[-1]}: {meter.avg:.4f}" for key, meter in self.val_meters.items()])
190 | msg += f", elapsed time: {time_end-time_start:.4f}"
191 |
192 | print(msg)
193 | logging.info(msg)
194 |
195 | # Determine best validation loss
196 | val_loss = self.val_meters['scalar/val/loss'].avg
197 |
198 | if val_loss < self.performances['val_loss']:
199 | # Save model
200 | self.save(epoch, iter, criterion='val_loss', data={'val_loss': val_loss})
201 |
202 | # Update best validation loss
203 | self.performances['val_loss'] = val_loss
204 |
205 | # Reset meters
206 | for key, meter in self.val_meters.items():
207 | meter.reset()
208 |
209 | def evaluate(self, epoch, iter):
210 | # Initialize
211 | self.model.eval()
212 |
213 | # Get dataset and scale
214 | pc_list_types = self.val_loader.dataset.pc_list_types
215 | Ts_grasp_list_types = self.val_loader.dataset.Ts_grasp_list_types
216 | mesh_list_types = deepcopy(self.val_loader.dataset.mesh_list_types)
217 |
218 | scale = self.val_loader.dataset.scale if hasattr(self.val_loader.dataset, 'scale') else 1
219 |
220 | time_start = time.time()
221 |
222 | # Iterate object types
223 | for pc_list_objs, Ts_grasp_list_objs, mesh_list_objs in zip(tqdm(pc_list_types, desc="Evaluating for object types ...", leave=False), Ts_grasp_list_types, mesh_list_types):
224 | # Setup metric meters for objects
225 | metric_meters_objs = {key: AverageMeter() for key in self.metrics.keys()}
226 |
227 | # Iterate objects
228 | for pc_rots, Ts_grasp_rots_target, mesh_list_rots in zip(tqdm(pc_list_objs, desc="Evaluating for objects ...", leave=False), Ts_grasp_list_objs, mesh_list_objs):
229 | # Setup metric meters for rotations
230 | metric_meters_rots = {key: AverageMeter() for key in self.metrics.keys()}
231 |
232 | # Setup input
233 | pc_rots = torch.Tensor(pc_rots).to(self.device)
234 | Ts_grasp_rots_target = torch.Tensor(Ts_grasp_rots_target).to(self.device)
235 | nums_grasps = torch.tensor([len(Ts_grasp_target) for Ts_grasp_target in Ts_grasp_rots_target], device=pc_rots.device)
236 |
237 | # Generate grasp poses
238 | Ts_grasp_rots_pred = self.model.sample(pc_rots, nums_grasps)
239 |
240 | # Rescale grasp poses and mesh
241 | for Ts_grasp_pred, Ts_grasp_target, mesh in zip(Ts_grasp_rots_pred, Ts_grasp_rots_target, mesh_list_rots):
242 | Ts_grasp_pred[:, :3, 3] /= scale
243 | Ts_grasp_target[:, :3, 3] /= scale
244 | mesh.scale(1/scale, center=(0, 0, 0))
245 |
246 | # Compute metrics for rotations
247 | for Ts_grasp_pred, Ts_grasp_target, mesh in zip(Ts_grasp_rots_pred, Ts_grasp_rots_target, mesh_list_rots):
248 | for key, metric in self.metrics.items():
249 | if key == 'collision_rate':
250 | # Get indices for sampling grasp poses for simulation
251 | assert NUM_GRASPS <= len(Ts_grasp_pred), f"Number of grasps for simulation ({NUM_GRASPS}) must be less than or equal to the number of grasps predicted ({len(Ts_grasp_pred)})."
252 |
253 | idxs = torch.randperm(len(Ts_grasp_pred))[:NUM_GRASPS]
254 |
255 | metric_meters_rots[key].update(metric(mesh, Ts_grasp_pred[idxs]))
256 | else:
257 | metric_meters_rots[key].update(metric(Ts_grasp_pred, Ts_grasp_target))
258 |
259 | # Compute metrics for objects
260 | for key, meter in metric_meters_objs.items():
261 | meter.update(metric_meters_rots[key].avg)
262 |
263 | # Compute metrics for object types
264 | for key, meter in self.metric_meters.items():
265 | meter.update(metric_meters_objs[key].avg)
266 |
267 | time_end = time.time()
268 |
269 | # Get averaged evaluation results
270 | results = {}
271 |
272 | for key, meter in self.metric_meters.items():
273 | results[f'scalar/metrics/{key}'] = meter.avg
274 |
275 | # Log averaged evaluation results
276 | self.logger.log(results, iter)
277 |
278 | # Print averaged evaluation results
279 | msg = f"[ Evaluation ] epoch: {epoch}/{self.cfg.num_epochs}, iter: {iter}, "
280 | msg += ", ".join([f"{key}: {meter.avg:.4f}" for key, meter in self.metric_meters.items()])
281 | msg += f", elapsed time: {time_end-time_start:.4f}"
282 |
283 | print(msg)
284 | logging.info(msg)
285 |
286 | # Save model if best evaluation performance
287 | for criterion in self.cfg.criteria:
288 | # Determine best performance
289 | performance = self.metric_meters[criterion.name].avg
290 |
291 | if criterion.better == 'higher' and performance > self.performances[criterion.name]:
292 | best = True
293 | elif criterion.better == 'lower' and performance < self.performances[criterion.name]:
294 | best = True
295 | else:
296 | best = False
297 |
298 | if best:
299 | # Save model
300 | self.save(epoch, iter, criterion=criterion.name, data={criterion.name: performance})
301 |
302 | # Update best validation loss
303 | self.performances[criterion.name] = performance
304 |
305 | # Reset metric meters
306 | for key, meter in self.metric_meters.items():
307 | meter.reset()
308 |
309 | def visualize(self, epoch, iter):
310 | # Initialize
311 | self.model.eval()
312 |
313 | time_start = time.time()
314 |
315 | mesh_list = []
316 | pc_list = []
317 | Ts_grasp_pred_list = []
318 | Ts_grasp_target_list = []
319 |
320 | # Get random data indices
321 | idxs = np.random.choice(self.val_loader.dataset.num_scenes, size=3, replace=False)
322 |
323 | # Get scale
324 | scale = self.val_loader.dataset.scale if hasattr(self.val_loader.dataset, 'scale') else 1
325 |
326 | for idx in idxs:
327 | idx_type, idx_obj, idx_rot = np.where(self.val_loader.dataset.scene_idxs==idx)
328 |
329 | idx_type = idx_type.item()
330 | idx_obj = idx_obj.item()
331 | idx_rot = idx_rot.item()
332 |
333 | # Get input
334 | mesh = deepcopy(self.val_loader.dataset.mesh_list_types[idx_type][idx_obj][idx_rot])
335 | pc = self.val_loader.dataset.pc_list_types[idx_type][idx_obj][idx_rot]
336 | Ts_grasp_target = self.val_loader.dataset.Ts_grasp_list_types[idx_type][idx_obj][idx_rot]
337 |
338 | # Sample ground-truth grasp poses
339 | idxs_grasp = np.random.choice(len(Ts_grasp_target), size=10, replace=False)
340 | Ts_grasp_target = Ts_grasp_target[idxs_grasp]
341 |
342 | # Append data to list
343 | mesh_list += [mesh]
344 | pc_list += [torch.Tensor(pc)]
345 | Ts_grasp_target_list += [Ts_grasp_target]
346 |
347 | # Setup input
348 | pc = torch.stack(pc_list).to(self.device)
349 | nums_grasps = torch.tensor([10, 10, 10], device=self.device)
350 |
351 | # Generate grasp poses
352 | Ts_grasp_pred_list = self.model.sample(pc, nums_grasps)
353 | Ts_grasp_pred_list = [Ts_grasp_pred.cpu().numpy() for Ts_grasp_pred in Ts_grasp_pred_list]
354 |
355 | # Rescale mesh and grasp poses
356 | for mesh, Ts_grasp_pred, Ts_grasp_target in zip(mesh_list, Ts_grasp_pred_list, Ts_grasp_target_list):
357 | mesh.scale(1/scale, center=(0, 0, 0))
358 | Ts_grasp_pred[:, :3, 3] /= scale
359 | Ts_grasp_target[:, :3, 3] /= scale
360 |
361 | # Generate scene
362 | scene_list_pred = generate_grasp_scene_list(mesh_list, Ts_grasp_pred_list)
363 | scene_list_target = generate_grasp_scene_list(mesh_list, Ts_grasp_target_list)
364 |
365 | # Get vertices, triangles and colors
366 | vertices_pred, triangles_pred, colors_pred = meshes_to_numpy(scene_list_pred)
367 | vertices_target, triangles_target, colors_target = meshes_to_numpy(scene_list_target)
368 |
369 | time_end = time.time()
370 |
371 | # Get visualization results
372 | results = {
373 | 'mesh/pred': {'vertices': vertices_pred, 'colors': colors_pred, 'faces': triangles_pred},
374 | 'mesh/target': {'vertices': vertices_target, 'colors': colors_target, 'faces': triangles_target}
375 | }
376 |
377 | # Log visualization results
378 | self.logger.log(results, iter)
379 |
380 | # Print visualization status
381 | msg = f"[Visualization] epoch: {epoch}/{self.cfg.num_epochs}, iter: {iter}"
382 | msg += f", elapsed time: {time_end-time_start:.4f}"
383 |
384 | print(msg)
385 | logging.info(msg)
386 |
387 | def save(self, epoch, iter, criterion=None, data={}):
388 | # Set save name
389 | if criterion is None:
390 | save_name = f'model_iter_{iter}.pkl'
391 | else:
392 | save_name = f'model_best_{criterion}.pkl'
393 |
394 | # Construct object to save
395 | object = {
396 | 'epoch': epoch,
397 | 'iter': iter,
398 | 'model_state': self.model.state_dict(),
399 | 'optimizer': self.optimizer.state_dict(),
400 | }
401 | object.update(data)
402 |
403 | # Save object
404 | save_path = os.path.join(self.logdir, save_name)
405 |
406 | torch.save(object, save_path)
407 |
408 | # Print save status
409 | string = f"[ Save ] epoch: {epoch}/{self.cfg.num_epochs}, iter: {iter}, save {save_name}"
410 |
411 | if criterion is not None:
412 | string += f", {criterion}: {data[criterion]:.6f} / best_{criterion}: {self.performances[criterion]:.6f}"
413 |
414 | print(string)
415 | logging.info(string)
416 |
417 |
418 | class PartialGraspPoseGeneratorTrainer(GraspPoseGeneratorTrainer):
419 | def evaluate(self, epoch, iter):
420 | # Initialize
421 | self.model.eval()
422 |
423 | # Get dataset and scale
424 | partial_pc_list_types = self.val_loader.dataset.partial_pc_list_types
425 | Ts_grasp_list_types = self.val_loader.dataset.Ts_grasp_list_types
426 | mesh_list_types = deepcopy(self.val_loader.dataset.mesh_list_types)
427 |
428 | scale = self.val_loader.dataset.scale if hasattr(self.val_loader.dataset, 'scale') else 1
429 |
430 | time_start = time.time()
431 |
432 | # Iterate object types
433 | for partial_pc_list_objs, Ts_grasp_list_objs, mesh_list_objs in zip(tqdm(partial_pc_list_types, desc="Evaluating for object types ...", leave=False), Ts_grasp_list_types, mesh_list_types):
434 | # Setup metric meters for objects
435 | metric_meters_objs = {key: AverageMeter() for key in self.metrics.keys()}
436 |
437 | # Iterate objects
438 | for partial_pc_rots, Ts_grasp_rots_target, mesh_list_rots in zip(tqdm(partial_pc_list_objs, desc="Evaluating for objects ...", leave=False), Ts_grasp_list_objs, mesh_list_objs):
439 | # Setup metric meters for rotations
440 | metric_meters_rots = {key: AverageMeter() for key in self.metrics.keys()}
441 |
442 | # Setup input
443 | Ts_grasp_rots_target = torch.Tensor(Ts_grasp_rots_target).to(self.device)
444 |
445 | # Iterate rotations
446 | for partial_pc_views, Ts_grasp_views_target, mesh_list_views in zip(tqdm(partial_pc_rots, desc="Evaluating for rotations ...", leave=False), Ts_grasp_rots_target, mesh_list_rots):
447 | # Setup metric meters for viewpoints
448 | metric_meters_views = {key: AverageMeter() for key in self.metrics.keys()}
449 |
450 | # Setup input
451 | partial_pc_views = torch.Tensor(partial_pc_views).to(self.device)
452 | nums_grasps = torch.tensor([Ts_grasp_views_target.shape[1]]*len(partial_pc_views), device=partial_pc_views.device)
453 |
454 | # Generate grasp poses
455 | Ts_grasp_views_pred = self.model.sample(partial_pc_views, nums_grasps)
456 |
457 | for Ts_grasp_pred, Ts_grasp_target, mesh in zip(Ts_grasp_views_pred, Ts_grasp_views_target, mesh_list_views):
458 | # Rescale grasp poses and mesh
459 | Ts_grasp_pred[:, :3, 3] /= scale
460 | Ts_grasp_target[:, :3, 3] /= scale
461 | mesh.scale(1/scale, center=(0, 0, 0))
462 |
463 | # Compute metrics for viewpoints
464 | for key, metric in self.metrics.items():
465 | if key == 'collision_rate':
466 | # Get indices for sampling grasp poses for simulation
467 | assert NUM_GRASPS <= len(Ts_grasp_pred), f"Number of grasps for simulation ({NUM_GRASPS}) must be less than or equal to the number of grasps predicted ({len(Ts_grasp_pred)})."
468 |
469 | idxs = torch.randperm(len(Ts_grasp_pred))[:NUM_GRASPS]
470 |
471 | metric_meters_views[key].update(metric(mesh, Ts_grasp_pred[idxs]))
472 | else:
473 | metric_meters_views[key].update(metric(Ts_grasp_pred, Ts_grasp_target))
474 |
475 | # Compute metrics for rotations
476 | for key, meter in metric_meters_objs.items():
477 | meter.update(metric_meters_views[key].avg)
478 |
479 | # Compute metrics for objects
480 | for key, meter in metric_meters_objs.items():
481 | meter.update(metric_meters_rots[key].avg)
482 |
483 | # Compute metrics for object types
484 | for key, meter in self.metric_meters.items():
485 | meter.update(metric_meters_objs[key].avg)
486 |
487 | time_end = time.time()
488 |
489 | # Get averaged evaluation results
490 | results = {}
491 |
492 | for key, meter in self.metric_meters.items():
493 | results[f'scalar/metrics/{key}'] = meter.avg
494 |
495 | # Log averaged evaluation results
496 | self.logger.log(results, iter)
497 |
498 | # Print averaged evaluation results
499 | msg = f"[ Evaluation ] epoch: {epoch}/{self.cfg.num_epochs}, iter: {iter}, "
500 | msg += ", ".join([f"{key}: {meter.avg:.4f}" for key, meter in self.metric_meters.items()])
501 | msg += f", elapsed time: {time_end-time_start:.4f}"
502 |
503 | print(msg)
504 | logging.info(msg)
505 |
506 | # Save model if best evaluation performance
507 | for criterion in self.cfg.criteria:
508 | # Determine best performance
509 | performance = self.metric_meters[criterion.name].avg
510 |
511 | if criterion.better == 'higher' and performance > self.performances[criterion.name]:
512 | best = True
513 | elif criterion.better == 'lower' and performance < self.performances[criterion.name]:
514 | best = True
515 | else:
516 | best = False
517 |
518 | if best:
519 | # Save model
520 | self.save(epoch, iter, criterion=criterion.name, data={criterion.name: performance})
521 |
522 | # Update best validation loss
523 | self.performances[criterion.name] = performance
524 |
525 | # Reset metric meters
526 | for key, meter in self.metric_meters.items():
527 | meter.reset()
528 |
529 | def visualize(self, epoch, iter):
530 | # Initialize
531 | self.model.eval()
532 |
533 | time_start = time.time()
534 |
535 | mesh_list = []
536 | partial_pc_list = []
537 | Ts_grasp_pred_list = []
538 | Ts_grasp_target_list = []
539 |
540 | # Get random data indices
541 | idxs = np.random.choice(self.val_loader.dataset.num_scenes, size=3, replace=False)
542 |
543 | # Get scale
544 | scale = self.val_loader.dataset.scale if hasattr(self.val_loader.dataset, 'scale') else 1
545 |
546 | for idx in idxs:
547 | idx_type, idx_obj, idx_rot, idx_view = np.where(self.val_loader.dataset.scene_idxs==idx)
548 |
549 | idx_type = idx_type.item()
550 | idx_obj = idx_obj.item()
551 | idx_rot = idx_rot.item()
552 | idx_view = idx_view.item()
553 |
554 | # Get input
555 | mesh = deepcopy(self.val_loader.dataset.mesh_list_types[idx_type][idx_obj][idx_rot][idx_view])
556 | partial_pc = self.val_loader.dataset.partial_pc_list_types[idx_type][idx_obj][idx_rot][idx_view]
557 | Ts_grasp_target = self.val_loader.dataset.Ts_grasp_list_types[idx_type][idx_obj][idx_rot][idx_view]
558 |
559 | # Sample ground-truth grasp poses
560 | idxs_grasp = np.random.choice(len(Ts_grasp_target), size=10, replace=False)
561 | Ts_grasp_target = Ts_grasp_target[idxs_grasp]
562 |
563 | # Append data to list
564 | mesh_list += [mesh]
565 | partial_pc_list += [torch.Tensor(partial_pc)]
566 | Ts_grasp_target_list += [Ts_grasp_target]
567 |
568 | # Setup input
569 | partial_pc = torch.stack(partial_pc_list).to(self.device)
570 | nums_grasps = torch.tensor([10, 10, 10], device=self.device)
571 |
572 | # Generate grasp poses
573 | Ts_grasp_pred_list = self.model.sample(partial_pc, nums_grasps)
574 | Ts_grasp_pred_list = [Ts_grasp_pred.cpu().numpy() for Ts_grasp_pred in Ts_grasp_pred_list]
575 |
576 | # Rescale mesh and grasp poses
577 | for mesh, Ts_grasp_pred, Ts_grasp_target in zip(mesh_list, Ts_grasp_pred_list, Ts_grasp_target_list):
578 | mesh.scale(1/scale, center=(0, 0, 0))
579 | Ts_grasp_pred[:, :3, 3] /= scale
580 | Ts_grasp_target[:, :3, 3] /= scale
581 |
582 | # Generate scene
583 | scene_list_pred = generate_grasp_scene_list(mesh_list, Ts_grasp_pred_list)
584 | scene_list_target = generate_grasp_scene_list(mesh_list, Ts_grasp_target_list)
585 |
586 | # Get vertices, triangles and colors
587 | vertices_pred, triangles_pred, colors_pred = meshes_to_numpy(scene_list_pred)
588 | vertices_target, triangles_target, colors_target = meshes_to_numpy(scene_list_target)
589 |
590 | time_end = time.time()
591 |
592 | # Get visualization results
593 | results = {
594 | 'mesh/pred': {'vertices': vertices_pred, 'colors': colors_pred, 'faces': triangles_pred},
595 | 'mesh/target': {'vertices': vertices_target, 'colors': colors_target, 'faces': triangles_target}
596 | }
597 |
598 | # Log visualization results
599 | self.logger.log(results, iter)
600 |
601 | # Print visualization status
602 | msg = f"[Visualization] epoch: {epoch}/{self.cfg.num_epochs}, iter: {iter}"
603 | msg += f", elapsed time: {time_end-time_start:.4f}"
604 |
605 | print(msg)
606 | logging.info(msg)
607 |
--------------------------------------------------------------------------------
/utils/Lie.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from scipy.spatial.transform import Rotation
4 |
5 |
6 | EPS = 1e-4
7 |
8 |
9 | def is_SO3(R):
10 | test_0 = torch.allclose(R @ R.transpose(1, 2), torch.eye(3).repeat(len(R), 1, 1).to(R), atol=EPS)
11 | test_1 = torch.allclose(R.transpose(1, 2) @ R, torch.eye(3).repeat(len(R), 1, 1).to(R), atol=EPS)
12 |
13 | test = test_0 and test_1
14 |
15 | return test
16 |
17 |
18 | def is_SE3(T):
19 | test_0 = is_SO3(T[:, :3, :3])
20 | test_1 = torch.equal(T[:, 3, :3], torch.zeros_like(T[:, 3, :3]))
21 | test_2 = torch.equal(T[:, 3, 3], torch.ones_like(T[:, 3, 3]))
22 |
23 | test = test_0 and test_1 and test_2
24 |
25 | return test
26 |
27 |
28 | def inv_SO3(R):
29 | assert R.shape[1:] == (3, 3), f"inv_SO3: input must be of shape (N, 3, 3). Current shape: {tuple(R.shape)}"
30 | assert is_SO3(R), "inv_SO3: input must be SO(3) matrices"
31 |
32 | inv_R = R.transpose(1, 2)
33 |
34 | return inv_R
35 |
36 |
37 | def inv_SE3(T):
38 | assert T.shape[1:] == (4, 4), f"inv_SE3: input must be of shape (N, 4, 4). Current shape: {tuple(T.shape)}"
39 | assert is_SE3(T), "inv_SE3: input must be SE(3) matrices"
40 |
41 | R = T[:, :3, :3]
42 | p = T[:, :3, 3]
43 |
44 | inv_T = torch.eye(4).repeat(len(T), 1, 1).to(T)
45 | inv_T[:, :3, :3] = inv_SO3(R)
46 | inv_T[:, :3, 3] = - torch.einsum('nij,nj->ni', inv_SO3(R), p)
47 |
48 | return inv_T
49 |
50 |
51 | def bracket_so3(w):
52 | # vector -> matrix
53 | if w.shape[1:] == (3,):
54 | zeros = w.new_zeros(len(w))
55 |
56 | out = torch.stack([
57 | torch.stack([zeros, -w[:, 2], w[:, 1]], dim=1),
58 | torch.stack([w[:, 2], zeros, -w[:, 0]], dim=1),
59 | torch.stack([-w[:, 1], w[:, 0], zeros], dim=1)
60 | ], dim=1)
61 |
62 | # matrix -> vector
63 | elif w.shape[1:] == (3, 3):
64 | out = torch.stack([w[:, 2, 1], w[:, 0, 2], w[:, 1, 0]], dim=1)
65 |
66 | else:
67 | raise f"bracket_so3: input must be of shape (N, 3) or (N, 3, 3). Current shape: {tuple(w.shape)}"
68 |
69 | return out
70 |
71 |
72 | def bracket_se3(S):
73 | # vector -> matrix
74 | if S.shape[1:] == (6,):
75 | w_mat = bracket_so3(S[:, :3])
76 |
77 | out = torch.cat((
78 | torch.cat((w_mat, S[:, 3:].unsqueeze(2)), dim=2),
79 | S.new_zeros(len(S), 1, 4)
80 | ), dim=1)
81 |
82 | # matrix -> vector
83 | elif S.shape[1:] == (4, 4):
84 | w_vec = bracket_so3(S[:, :3, :3])
85 |
86 | out = torch.cat((w_vec, S[:, :3, 3]), dim=1)
87 |
88 | else:
89 | raise f"bracket_se: input must be of shape (N, 6) or (N, 4, 4). Current shape: {tuple(S.shape)}"
90 |
91 | return out
92 |
93 |
94 | def log_SO3(R):
95 | # return logSO3(R)
96 | n = R.shape[0]
97 | assert R.shape == (n, 3, 3), f"log_SO3: input must be of shape (N, 3, 3). Current shape: {tuple(R.shape)}"
98 | assert is_SO3(R), "log_SO3: input must be SO(3) matrices"
99 |
100 | tr_R = torch.diagonal(R, dim1=1, dim2=2).sum(1)
101 | w_mat = torch.zeros_like(R)
102 | theta = torch.acos(torch.clamp((tr_R - 1) / 2, -1 + EPS, 1 - EPS))
103 |
104 | is_regular = (tr_R + 1 > EPS)
105 | is_singular = (tr_R + 1 <= EPS)
106 |
107 | theta = theta.unsqueeze(1).unsqueeze(2)
108 |
109 | w_mat_regular = (1 / (2 * torch.sin(theta[is_regular]) + EPS)) * (R[is_regular] - R[is_regular].transpose(1, 2)) * theta[is_regular]
110 |
111 | w_mat_singular = (R[is_singular] - torch.eye(3).to(R)) / 2
112 |
113 | w_vec_singular = torch.sqrt(torch.diagonal(w_mat_singular, dim1=1, dim2=2) + 1)
114 | w_vec_singular[torch.isnan(w_vec_singular)] = 0
115 |
116 | w_1 = w_vec_singular[:, 0]
117 | w_2 = w_vec_singular[:, 1] * (torch.sign(w_mat_singular[:, 0, 1]) + (w_1 == 0))
118 | w_3 = w_vec_singular[:, 2] * torch.sign(4 * torch.sign(w_mat_singular[:, 0, 2]) + 2 * (w_1 == 0) * torch.sign(w_mat_singular[:, 1, 2]) + 1 * (w_1 == 0) * (w_2 == 0))
119 |
120 | w_vec_singular = torch.stack([w_1, w_2, w_3], dim=1)
121 |
122 | w_mat[is_regular] = w_mat_regular
123 | w_mat[is_singular] = bracket_so3(w_vec_singular) * torch.pi
124 |
125 | return w_mat
126 |
127 |
128 | def log_SE3(T):
129 | assert T.shape[1:] == (4, 4), f"log_SE3: input must be of shape (N, 4, 4). Current shape: {tuple(T.shape)}"
130 | assert is_SE3(T), "log_SE3: input must be SE(3) matrices"
131 |
132 | R = T[:, :3, :3]
133 | p = T[:, :3, 3]
134 |
135 | tr_R = torch.diagonal(R, dim1=1, dim2=2).sum(1)
136 | theta = torch.acos(torch.clamp((tr_R - 1) / 2, -1 + EPS, 1 - EPS)).unsqueeze(1).unsqueeze(2)
137 |
138 | w_mat = log_SO3(R)
139 | w_mat_hat = w_mat / (theta + EPS)
140 |
141 | inv_G = torch.eye(3).repeat(len(T), 1, 1).to(T) - (theta / 2) * w_mat_hat + (1 - (theta / (2 * torch.tan(theta / 2) + EPS))) * w_mat_hat @ w_mat_hat
142 |
143 | S = torch.zeros_like(T)
144 | S[:, :3, :3] = w_mat
145 | S[:, :3, 3] = torch.einsum('nij,nj->ni', inv_G, p)
146 |
147 | return S
148 |
149 |
150 | def exp_so3(w_vec):
151 | if w_vec.shape[1:] == (3, 3):
152 | w_vec = bracket_so3(w_vec)
153 | elif w_vec.shape[1:] != (3,):
154 | raise f"exp_so3: input must be of shape (N, 3) or (N, 3, 3). Current shape: {tuple(w_vec.shape)}"
155 |
156 | R = torch.eye(3).repeat(len(w_vec), 1, 1).to(w_vec)
157 |
158 | theta = w_vec.norm(dim=1)
159 |
160 | is_regular = theta > EPS
161 |
162 | w_vec_regular = w_vec[is_regular]
163 | theta_regular = theta[is_regular]
164 |
165 | theta_regular = theta_regular.unsqueeze(1)
166 |
167 | w_mat_hat_regular = bracket_so3(w_vec_regular / theta_regular)
168 |
169 | theta_regular = theta_regular.unsqueeze(2)
170 |
171 | R[is_regular] = torch.eye(3).repeat(len(w_vec_regular), 1, 1).to(w_vec_regular) + torch.sin(theta_regular) * w_mat_hat_regular + (1 - torch.cos(theta_regular)) * w_mat_hat_regular @ w_mat_hat_regular
172 |
173 | return R
174 |
175 |
176 | def exp_se3(S):
177 | if S.shape[1:] == (4, 4):
178 | S = bracket_se3(S)
179 | elif S.shape[1:] != (6,):
180 | raise f"exp_se3: input must be of shape (N, 6) or (N, 4, 4). Current shape: {tuple(S.shape)}"
181 |
182 | w_vec = S[:, :3]
183 | p = S[:, 3:]
184 |
185 | T = torch.eye(4).repeat(len(S), 1, 1).to(S)
186 |
187 | theta = w_vec.norm(dim=1)
188 |
189 | is_regular = theta > EPS
190 | is_singular = theta <= EPS
191 |
192 | w_vec_regular = w_vec[is_regular]
193 | theta_regular = theta[is_regular]
194 |
195 | theta_regular = theta_regular.unsqueeze(1)
196 |
197 | w_mat_hat_regular = bracket_so3(w_vec_regular / theta_regular)
198 |
199 | theta_regular = theta_regular.unsqueeze(2)
200 |
201 | G = theta_regular * torch.eye(3).repeat(len(S), 1, 1).to(S) + (1 - torch.cos(theta_regular)) * w_mat_hat_regular + (theta_regular - torch.cos(theta_regular)) * w_mat_hat_regular @ w_mat_hat_regular
202 |
203 | T[is_regular, :3, :3] = exp_so3(w_vec_regular)
204 | T[is_regular, :3, 3] = torch.einsum('nij,nj->ni', G, p)
205 |
206 | T[is_singular, :3, :3] = torch.eye(3).repeat(is_singular.sum(), 1, 1)
207 | T[is_singular, :3, 3] = p
208 |
209 | return T
210 |
211 |
212 | def large_adjoint(T):
213 | assert T.shape[1:] == (4, 4), f"large_adjoint: input must be of shape (N, 4, 4). Current shape: {tuple(T.shape)}"
214 | assert is_SE3(T), "large_adjoint: input must be SE(3) matrices"
215 |
216 | R = T[:, :3, :3]
217 | p = T[:, :3, 3]
218 |
219 | large_adj = T.new_zeros(len(T), 6, 6)
220 | large_adj[:, :3, :3] = R
221 | large_adj[:, 3:, :3] = bracket_so3(p) @ R
222 | large_adj[:, 3:, 3:] = R
223 |
224 | return large_adj
225 |
226 |
227 | def small_adjoint(S):
228 | if S.shape[1:] == (4, 4):
229 | w_mat = S[:, :3, :3]
230 | v_mat = bracket_so3(S[:, :3, 3])
231 | elif S.shape[1:] == (6,):
232 | w_mat = bracket_so3(S[:, :3])
233 | v_mat = bracket_so3(S[:, 3:])
234 | else:
235 | raise f"small_adj: input must be of shape (N, 6) or (N, 4, 4). Current shape: {tuple(S.shape)}"
236 |
237 | small_adj = S.new_zeros(len(S), 6, 6)
238 | small_adj[:, :3, :3] = w_mat
239 | small_adj[:, 3:, :3] = v_mat
240 | small_adj[:, 3:, 3:] = w_mat
241 |
242 | return small_adj
243 |
244 |
245 | def Lie_bracket(u, v):
246 | if u.shape[1:] == (3,):
247 | u = bracket_so3(u)
248 | elif u.shape[1:] == (6,):
249 | u = bracket_se3(u)
250 |
251 | if v.shape[1:] == (3,):
252 | v = bracket_so3(v)
253 | elif v.shape[1:] == (6,):
254 | v = bracket_se3(v)
255 |
256 | return u @ v - v @ u
257 |
258 |
259 | def is_quat(quat):
260 | test = torch.allclose(quat.norm(dim=1), quat.new_ones(len(quat)))
261 |
262 | return test
263 |
264 |
265 | def super_fibonacci_spiral(num_Rs):
266 | phi = 1.414213562304880242096980 # sqrt(2)
267 | psi = 1.533751168755204288118041
268 |
269 | s = np.arange(num_Rs) + 1 / 2
270 |
271 | t = s / num_Rs
272 | d = 2 * np.pi * s
273 |
274 | r = np.sqrt(t)
275 | R = np.sqrt(1 - t)
276 |
277 | alpha = d / phi
278 | beta = d / psi
279 |
280 | quats = np.stack([r * np.sin(alpha), r * np.cos(alpha), R * np.sin(beta), R * np.cos(beta)], axis=1)
281 |
282 | Rs = Rotation.from_quat(quats).as_matrix()
283 |
284 | return Rs
285 |
286 |
287 | def SE3_geodesic_dist(T_1, T_2):
288 | assert len(T_1) == len(T_2), f"SE3_geodesic_dist: inputs must have the same batch_size. Current shapes: T_1 - {tuple(T_1.shape)}, T_2 - {tuple(T_2.shape)}"
289 | assert is_SE3(T_1) and is_SE3(T_2), "SE3_geodesic_dist: inputs must be SE(3) matrices"
290 |
291 | R_1 = T_1[:, :3, :3]
292 | R_2 = T_2[:, :3, :3]
293 | p_1 = T_1[:, :3, 3]
294 | p_2 = T_2[:, :3, 3]
295 |
296 | delta_R = bracket_so3(log_SO3(torch.einsum('bij,bjk->bik', inv_SO3(R_1), R_2)))
297 | delta_p = p_1 - p_2
298 |
299 | dist = (delta_R ** 2 + delta_p ** 2).sum(1).sqrt()
300 |
301 | return dist
302 |
303 |
304 | def get_fibonacci_sphere(num_points):
305 | points = []
306 |
307 | phi = np.pi * (np.sqrt(5.) - 1.) # golden angle in radians
308 |
309 | for i in range(num_points):
310 | y = 1 - (i / float(num_points - 1)) * 2 # y goes from 1 to -1
311 | radius = np.sqrt(1 - y * y) # radius at y
312 |
313 | theta = phi * i # golden angle increment
314 |
315 | x = np.cos(theta) * radius
316 | z = np.sin(theta) * radius
317 |
318 | points += [np.array([x, y, z])]
319 |
320 | points = np.stack(points)
321 |
322 | return points
323 |
--------------------------------------------------------------------------------
/utils/average_meter.py:
--------------------------------------------------------------------------------
1 | class AverageMeter:
2 | def __init__(self):
3 | self.reset()
4 |
5 | def reset(self):
6 | self.avg = 0
7 | self.sum = 0
8 | self.count = 0
9 |
10 | def update(self, val, n=1):
11 | self.sum += val * n
12 | self.count += n
13 | self.avg = self.sum /self.count
14 |
--------------------------------------------------------------------------------
/utils/distributions.py:
--------------------------------------------------------------------------------
1 | import roma
2 | import torch
3 |
4 |
5 | def get_dist(cfg):
6 | name = cfg.pop('name')
7 |
8 | if name == 'SO3_uniform_R3_normal':
9 | dist_fn = SO3_uniform_R3_normal
10 | elif name == 'SO3_uniform_R3_spherical':
11 | dist_fn = SO3_uniform_R3_spherical
12 | elif name == 'SO3_centripetal_R3_normal':
13 | dist_fn = SO3_centripetal_R3_normal
14 | elif name == 'SO3_centripetal_R3_spherical':
15 | dist_fn = SO3_centripetal_R3_spherical
16 | else:
17 | raise NotImplementedError(f"Distribution {name} is not implemented.")
18 |
19 | return dist_fn
20 |
21 |
22 | def SO3_uniform_R3_normal(num_samples, device):
23 | R = roma.random_rotmat(num_samples).to(device)
24 |
25 | p = torch.randn(num_samples, 3).to(device)
26 |
27 | T = torch.eye(4).repeat(num_samples, 1, 1).to(device)
28 | T[:, :3, :3] = R
29 | T[:, :3, 3] = p
30 |
31 | return T
32 |
33 |
34 | def SO3_uniform_R3_spherical(num_samples, device):
35 | R = roma.random_rotmat(num_samples).to(device)
36 |
37 | p = torch.randn(num_samples, 3).to(device)
38 | p /= p.norm(dim=-1, keepdim=True)
39 |
40 | T = torch.eye(4).repeat(num_samples, 1, 1).to(device)
41 | T[:, :3, :3] = R
42 | T[:, :3, 3] = p
43 |
44 | return T
45 |
46 |
47 | def SO3_centripetal_R3_normal(num_samples, device):
48 | R = roma.random_rotmat(num_samples).to(device)
49 |
50 | p = - (0.112 * 5 + torch.randn(num_samples, 1).to(device).abs()) * R[:, :, 2]
51 |
52 | T = torch.eye(4).repeat(num_samples, 1, 1).to(device)
53 | T[:, :3, :3] = R
54 | T[:, :3, 3] = p
55 |
56 | return T
57 |
58 |
59 | def SO3_centripetal_R3_spherical(num_samples, device):
60 | R = roma.random_rotmat(num_samples).to(device)
61 |
62 | p = - R[:, :, 2]
63 |
64 | T = torch.eye(4).repeat(num_samples, 1, 1).to(device)
65 | T[:, :3, :3] = R
66 | T[:, :3, 3] = p
67 |
68 | return T
69 |
--------------------------------------------------------------------------------
/utils/logger.py:
--------------------------------------------------------------------------------
1 | class Logger:
2 | def __init__(self, writer):
3 | self.writer = writer
4 |
5 | def log(self, results, iter):
6 | for key, val in results.items():
7 | if 'scalar' in key:
8 | self.writer.add_scalar(key.replace('scalar/', ''), val, iter)
9 |
10 | elif 'image' in key and 'images' not in key:
11 | self.writer.add_image(key.replace('image/', ''), val, iter)
12 |
13 | elif 'images' in key:
14 | self.writer.add_images(key.replace('images/', ''), val, iter)
15 |
16 | elif 'mesh' in key:
17 | self.writer.add_mesh(key.replace('mesh/', ''), vertices=val['vertices'], colors=val['colors'], faces=val['faces'], global_step=iter)
18 |
--------------------------------------------------------------------------------
/utils/mesh.py:
--------------------------------------------------------------------------------
1 | from copy import deepcopy
2 | import open3d as o3d
3 | import numpy as np
4 |
5 |
6 | def generate_grasp_scene_list(mesh_list, Ts_grasp_list):
7 | scene_list = []
8 |
9 | for mesh, Ts_grasp in zip(mesh_list, Ts_grasp_list):
10 | scene = deepcopy(mesh)
11 |
12 | for T_grasp in Ts_grasp:
13 | mesh_base_1 = o3d.geometry.TriangleMesh.create_cylinder(radius=0.002, height=0.066, resolution=6, split=1)
14 | T_base_1 = np.eye(4)
15 | T_base_1[:3, 3] = [0, 0, 0.033]
16 | mesh_base_1.transform(T_base_1)
17 |
18 | mesh_base_2 = o3d.geometry.TriangleMesh.create_cylinder(radius=0.002, height=0.082, resolution=6, split=1)
19 | T_base_2 = np.eye(4)
20 | T_base_2[:3, :3] = mesh_base_2.get_rotation_matrix_from_xyz([0, np.pi/2, 0])
21 | T_base_2[:3, 3] = [0, 0, 0.066]
22 | mesh_base_2.transform(T_base_2)
23 |
24 | mesh_left_finger = o3d.geometry.TriangleMesh.create_cylinder(radius=0.002, height=0.046, resolution=6, split=1)
25 | T_left_finger = np.eye(4)
26 | T_left_finger[:3, 3] = [-0.041, 0, 0.089]
27 | mesh_left_finger.transform(T_left_finger)
28 |
29 | mesh_right_finger = o3d.geometry.TriangleMesh.create_cylinder(radius=0.002, height=0.046, resolution=6, split=1)
30 | T_right_finger = np.eye(4)
31 | T_right_finger[:3, 3] = [0.041, 0, 0.089]
32 | mesh_right_finger.transform(T_right_finger)
33 |
34 | mesh_gripper = mesh_base_1 + mesh_base_2 + mesh_left_finger + mesh_right_finger
35 | mesh_gripper.transform(T_grasp)
36 |
37 | scene += mesh_gripper
38 |
39 | scene.compute_vertex_normals()
40 | scene.paint_uniform_color([0.5, 0.5, 0.5])
41 |
42 | scene_list += [scene]
43 |
44 | return scene_list
45 |
46 |
47 | def meshes_to_numpy(scenes):
48 | # Initialize
49 | vertices_np = []
50 | triangles_np = []
51 | colors_np = []
52 |
53 | # Get maximum number of vertices and triangles
54 | max_num_vertices = max([len(scene.vertices) for scene in scenes])
55 | max_num_triangles = max([len(scene.triangles) for scene in scenes])
56 |
57 | # Match dimension between batches for Tensorboard
58 | for scene in scenes:
59 | diff_num_vertices = max_num_vertices - len(scene.vertices)
60 | diff_num_triangles = max_num_triangles - len(scene.triangles)
61 |
62 | vertices_np += [np.concatenate((np.asarray(scene.vertices), np.zeros((diff_num_vertices, 3))), axis=0)]
63 | triangles_np += [np.concatenate((np.asarray(scene.triangles), np.zeros((diff_num_triangles, 3))), axis=0)]
64 | colors_np += [np.concatenate((255 * np.asarray(scene.vertex_colors), np.zeros((diff_num_vertices, 3))), axis=0)]
65 |
66 | # Stack to single numpy array
67 | vertices_np = np.stack(vertices_np, axis=0)
68 | triangles_np = np.stack(triangles_np, axis=0)
69 | colors_np = np.stack(colors_np, axis=0)
70 |
71 | return vertices_np, triangles_np, colors_np
72 |
--------------------------------------------------------------------------------
/utils/ode_solvers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from copy import deepcopy
3 |
4 | from utils.Lie import bracket_so3, exp_so3, Lie_bracket
5 |
6 |
7 | def get_ode_solver(cfg):
8 | name = cfg.pop('name')
9 |
10 | if name == 'SE3_Euler':
11 | solver = SE3_Euler(**cfg)
12 | elif name == 'SE3_RK_mk':
13 | solver = SE3_RK4_MK(**cfg)
14 | else:
15 | raise NotImplementedError(f"ODE solver {name} is not implemented.")
16 |
17 | return solver
18 |
19 |
20 | class SE3_Euler:
21 | def __init__(self, num_steps):
22 | self.t = torch.linspace(0, 1, num_steps + 1)
23 |
24 | @torch.no_grad()
25 | def __call__(self, z, x_0, func):
26 | # Initialize
27 | t = self.t.to(z.device)
28 | dt = t[1:] - t[:-1]
29 | traj = x_0.new_zeros(x_0.shape[0:1] + t.shape + x_0.shape[1:])
30 | traj[:, 0] = x_0
31 |
32 | for n in range(len(t)-1):
33 | # Get n-th values
34 | x_n = traj[:, n].contiguous()
35 | t_n = t[n].repeat(len(x_0), 1)
36 | h = dt[n].repeat(len(x_0), 1)
37 |
38 | ##### Stage 1 #####
39 | # Set function input
40 | x_hat = deepcopy(x_n)
41 |
42 | # Get vector (V_s)
43 | V_1 = func(z, t_n, x_hat)
44 | w_1 = V_1[:, :3]
45 | v_1 = V_1[:, 3:]
46 |
47 | # Change w_s to w_b and transform to matrix
48 | w_1 = torch.einsum('bji,bj->bi', x_hat[:, :3, :3], w_1)
49 | w_1 = bracket_so3(w_1)
50 |
51 | ##### Update #####
52 | traj[:, n+1] = deepcopy(x_n)
53 | traj[:, n+1, :3, :3] @= exp_so3(h.unsqueeze(-1) * w_1)
54 | traj[:, n+1, :3, 3] += h * v_1
55 |
56 | return traj
57 |
58 |
59 | class SE3_RK4_MK:
60 | def __init__(self, num_steps):
61 | self.t = torch.linspace(0, 1, num_steps + 1)
62 |
63 | @torch.no_grad()
64 | def __call__(self, z, x_0, func):
65 | # Initialize
66 | t = self.t.to(z.device)
67 | dt = t[1:] - t[:-1]
68 | traj = x_0.new_zeros(x_0.shape[0:1] + t.shape + x_0.shape[1:])
69 | traj[:, 0] = x_0
70 |
71 | for n in range(len(t)-1):
72 | # Get n-th values
73 | x_n = traj[:, n].contiguous()
74 | t_n = t[n].repeat(len(x_0), 1)
75 | h = dt[n].repeat(len(x_0), 1)
76 |
77 | ##### Stage 1 #####
78 | # Set function input
79 | x_hat_1 = x_n
80 |
81 | # Get vector (V_s)
82 | V_1 = func(z, t_n, x_hat_1)
83 | w_1 = V_1[:, :3]
84 | v_1 = V_1[:, 3:]
85 |
86 | # Change w_s to w_b and transform to matrix
87 | w_1 = torch.einsum('bji,bj->bi', x_hat_1[:, :3, :3], w_1)
88 | w_1 = bracket_so3(w_1)
89 |
90 | # Set I_1
91 | I_1 = w_1
92 |
93 | ##### Stage 2 #####
94 | u_2 = h.unsqueeze(-1) * (1 / 2) * w_1
95 | u_2 += (h.unsqueeze(-1) / 12) * Lie_bracket(I_1, u_2)
96 |
97 | # Set function input
98 | x_hat_2 = deepcopy(x_n)
99 | x_hat_2[:, :3, :3] @= exp_so3(u_2)
100 | x_hat_2[:, :3, 3] += h * (v_1 / 2)
101 |
102 | # Get vector (V_s)
103 | V_2 = func(z, t_n + (h / 2), x_hat_2)
104 | w_2 = V_2[:, :3]
105 | v_2 = V_2[:, 3:]
106 |
107 | # Change w_s to w_b and transform to matrix
108 | w_2 = torch.einsum('bji,bj->bi', x_hat_2[:, :3, :3], w_2)
109 | w_2 = bracket_so3(w_2)
110 |
111 | ##### Stage 3 #####
112 | u_3 = h.unsqueeze(-1) * (1 / 2) * w_2
113 | u_3 += (h.unsqueeze(-1) / 12) * Lie_bracket(I_1, u_3)
114 |
115 | # Set function input
116 | x_hat_3 = deepcopy(x_n)
117 | x_hat_3[:, :3, :3] @= exp_so3(u_3)
118 | x_hat_3[:, :3, 3] += h * (v_2 / 2)
119 |
120 | # Get vector (V_s)
121 | V_3 = func(z, t_n + (h / 2), x_hat_3)
122 | w_3 = V_3[:, :3]
123 | v_3 = V_3[:, 3:]
124 |
125 | # Change w_s to w_b and transform to matrix
126 | w_3 = torch.einsum('bji,bj->bi', x_hat_3[:, :3, :3], w_3)
127 | w_3 = bracket_so3(w_3)
128 |
129 | ##### Stage 4 #####
130 | u_4 = h.unsqueeze(-1) * w_3
131 | u_4 += (h.unsqueeze(-1) / 6) * Lie_bracket(I_1, u_4)
132 |
133 | # Set function input
134 | x_hat_4 = deepcopy(x_n)
135 | x_hat_4[:, :3, :3] @= exp_so3(u_4)
136 | x_hat_4[:, :3, 3] += h * v_3
137 |
138 | # Get vector (V_s)
139 | V_4 = func(z, t_n + h, x_hat_4)
140 | w_4 = V_4[:, :3]
141 | v_4 = V_4[:, 3:]
142 |
143 | # Change w_s to w_b and transform to matrix
144 | w_4 = torch.einsum('bji,bj->bi', x_hat_4[:, :3, :3], w_4)
145 | w_4 = bracket_so3(w_4)
146 |
147 | ##### Update #####
148 | I_2 = (2 * (w_2 - I_1) + 2 * (w_3 - I_1) - (w_4 - I_1)) / h.unsqueeze(-1)
149 | u = h.unsqueeze(-1) * (1 / 6 * w_1 + 1 / 3 * w_2 + 1 / 3 * w_3 + 1 / 6 * w_4)
150 | u += (h.unsqueeze(-1) / 4) * Lie_bracket(I_1, u) + ((h ** 2).unsqueeze(-1) / 24) * Lie_bracket(I_2, u)
151 |
152 | traj[:, n+1] = deepcopy(x_n)
153 | traj[:, n+1, :3, :3] @= exp_so3(u)
154 | traj[:, n+1, :3, 3] += (h / 6) * (v_1 + 2 * v_2 + 2 * v_3 + v_4)
155 |
156 | return traj
157 |
--------------------------------------------------------------------------------
/utils/optimizers.py:
--------------------------------------------------------------------------------
1 | from torch.optim import SGD, Adam, ASGD, Adamax, Adadelta, Adagrad, RMSprop
2 |
3 |
4 | def get_optimizer(cfg, model_params):
5 | name = cfg.pop('name')
6 |
7 | optimizer_class = get_optimizer_class(name)
8 |
9 | optimizer = optimizer_class(model_params, **cfg)
10 |
11 | return optimizer
12 |
13 |
14 | def get_optimizer_class(name):
15 | try:
16 | return {
17 | 'sgd': SGD,
18 | 'adam': Adam,
19 | 'asgd': ASGD,
20 | 'adamax': Adamax,
21 | 'adadelta': Adadelta,
22 | 'adagrad': Adagrad,
23 | 'rmsprop': RMSprop,
24 | }[name]
25 | except:
26 | raise NotImplementedError(f"Optimizer {name} is not available.")
27 |
--------------------------------------------------------------------------------
/utils/partial_point_cloud.py:
--------------------------------------------------------------------------------
1 | import open3d as o3d
2 | import numpy as np
3 | from tqdm import tqdm
4 |
5 |
6 | def get_partial_point_clouds(mesh, view_vecs, num_points, visible_visualizer=False, use_tqdm=False, check_partial_pc=False):
7 | # Check open3d version
8 | assert o3d.__version__.split('.')[0] == '0' and o3d.__version__.split('.')[1] == '16', \
9 | f"open3d version must be 0.16, 'ctr.convert_from_pinhole_camera_parameters(camera_params)' doesn't work well in later versions"
10 |
11 | # Set distance from object center to camera
12 | distance = 1.5 * np.linalg.norm(mesh.get_oriented_bounding_box().extent)
13 |
14 | # Set visualizer
15 | vis = o3d.visualization.Visualizer()
16 | vis.create_window(visible=visible_visualizer)
17 |
18 | ctr = vis.get_view_control()
19 | camera_params = ctr.convert_to_pinhole_camera_parameters()
20 |
21 | # Add mesh
22 | vis.add_geometry(mesh)
23 |
24 | # Set camera poses
25 | view_unit_vecs = view_vecs / np.linalg.norm(view_vecs, axis=1, keepdims=True)
26 |
27 | cam_z_s = - view_unit_vecs
28 |
29 | while True:
30 | cam_x_s = -1 + 2 * np.random.rand(len(cam_z_s), 3)
31 | cam_x_s = cam_x_s - np.sum(cam_x_s*cam_z_s, axis=1, keepdims=True) * cam_z_s
32 |
33 | if np.linalg.norm(cam_x_s, axis=1).any() == 0:
34 | continue
35 | else:
36 | cam_x_s /= np.linalg.norm(cam_x_s, axis=1, keepdims=True)
37 |
38 | break
39 |
40 | cam_y_s = np.cross(cam_z_s, cam_x_s)
41 | cam_y_s /= np.linalg.norm(cam_y_s, axis=1, keepdims=True)
42 |
43 | cam_Ts = np.tile(np.eye(4), (len(view_vecs), 1, 1))
44 | cam_Ts[:, :3, :3] = np.stack([cam_x_s, cam_y_s, cam_z_s], axis=2)
45 | cam_Ts[:, :3, 3] = distance * view_unit_vecs
46 |
47 | # Get partial point clouds
48 | partial_pcds = []
49 |
50 | if use_tqdm:
51 | pbar = tqdm(cam_Ts, desc="Iterating viewpoints ...", leave=False)
52 | else:
53 | pbar = cam_Ts
54 |
55 | for cam_T in pbar:
56 | # Set camera extrinsic parameters
57 | camera_params.extrinsic = np.linalg.inv(cam_T)
58 |
59 | ctr.convert_from_pinhole_camera_parameters(camera_params)
60 |
61 | # Update visualizer
62 | vis.poll_events()
63 | vis.update_renderer()
64 |
65 | # Get partial point cloud
66 | depth = vis.capture_depth_float_buffer()
67 |
68 | partial_pcd = o3d.geometry.PointCloud.create_from_depth_image(depth, camera_params.intrinsic, camera_params.extrinsic)
69 |
70 | # Raise Exception if the number of points in point cloud is less than 'num_points'
71 | if len(np.asarray(partial_pcd.points)) < num_points:
72 | raise Exception("Point cloud has an insufficient number of points. Increase visualizer window width and height.")
73 |
74 | # Downsample point cloud to match the number of points with 'num_points'
75 | else:
76 | voxel_size = 0.5
77 | voxel_size_min = 0
78 | voxel_size_max = 1
79 |
80 | while True:
81 | partial_pcd_tmp = partial_pcd.voxel_down_sample(voxel_size)
82 |
83 | num_points_tmp = len(np.asarray(partial_pcd_tmp.points))
84 |
85 | if num_points_tmp - num_points >= 0 and num_points_tmp - num_points < 100:
86 | break
87 | else:
88 | if num_points_tmp > num_points:
89 | voxel_size_min = voxel_size
90 | elif num_points_tmp < num_points:
91 | voxel_size_max = voxel_size
92 |
93 | voxel_size = (voxel_size_min + voxel_size_max) / 2
94 |
95 | partial_pcd = partial_pcd_tmp.select_by_index(np.random.choice(num_points_tmp, num_points, replace=False))
96 |
97 | partial_pcds += [partial_pcd]
98 |
99 | vis.destroy_window()
100 |
101 | # Check obtained partial point cloud with mesh
102 | if check_partial_pc:
103 | for partial_pcd in partial_pcds:
104 | o3d.visualization.draw_geometries([mesh, partial_pcd])
105 |
106 | # Convert open3d PointCloud to numpy array
107 | partial_pcs = np.stack([np.asarray(partial_pcd.points) for partial_pcd in partial_pcds])
108 |
109 | return partial_pcs
110 |
111 |
112 | class PartialPointCloudExtractor:
113 | def __init__(self):
114 | # set offscreen rendering
115 | width = 128
116 | height = 128
117 |
118 | self.renderer = o3d.visualization.rendering.OffscreenRenderer(width, height)
119 |
120 | # Set intrinsic parameters
121 | fx = fy = 110.85125168
122 | cx = (width - 1) / 2
123 | cy = (height - 1) / 2
124 |
125 | self.intrinsic = o3d.camera.PinholeCameraIntrinsic(width, height, fx, fy, cx, cy)
126 |
127 | def extract(self, mesh, view_vecs, num_points):
128 | # set distance from object center to camera
129 | distance = np.linalg.norm(mesh.get_oriented_bounding_box().extent)
130 |
131 | # add mesh
132 | self.renderer.scene.add_geometry('mesh', mesh, o3d.visualization.rendering.MaterialRecord())
133 |
134 | # set camera poses
135 | view_unit_vecs = view_vecs / np.linalg.norm(view_vecs, axis=1, keepdims=True)
136 |
137 | cam_z_s = - view_unit_vecs
138 |
139 | while True:
140 | cam_x_s = -1 + 2 * np.random.rand(len(cam_z_s), 3)
141 | cam_x_s = cam_x_s - np.sum(cam_x_s*cam_z_s, axis=1, keepdims=True) * cam_z_s
142 |
143 | if np.linalg.norm(cam_x_s, axis=1).any() == 0:
144 | continue
145 | else:
146 | cam_x_s /= np.linalg.norm(cam_x_s, axis=1, keepdims=True)
147 |
148 | break
149 |
150 | cam_y_s = np.cross(cam_z_s, cam_x_s)
151 | cam_y_s /= np.linalg.norm(cam_y_s, axis=1, keepdims=True)
152 |
153 | cam_Ts = np.tile(np.eye(4), (len(view_vecs), 1, 1))
154 | cam_Ts[:, :3, :3] = np.stack([cam_x_s, cam_y_s, cam_z_s], axis=2)
155 | cam_Ts[:, :3, 3] = distance * view_unit_vecs
156 |
157 | # Get partial point clouds
158 | partial_pcds = []
159 |
160 | for cam_T in cam_Ts:
161 | # set extrinsic parameters
162 | extrinsic = np.linalg.inv(cam_T)
163 |
164 | # Set camera
165 | self.renderer.setup_camera(self.intrinsic, extrinsic)
166 |
167 | # Get depth image
168 | depth_image = self.renderer.render_to_depth_image(z_in_view_space=True)
169 |
170 | # get partial point cloud
171 | partial_pcd = o3d.geometry.PointCloud.create_from_depth_image(depth_image, self.intrinsic, extrinsic)
172 |
173 | pts = np.asarray(partial_pcd.points)
174 | pts = pts[~np.isnan(pts).any(1)]
175 |
176 | partial_pcd = o3d.geometry.PointCloud(points=o3d.utility.Vector3dVector(pts))
177 |
178 | # raise Exception if the number of points in point cloud is less than 'num_points'
179 | if len(np.asarray(partial_pcd.points)) < num_points:
180 | raise Exception("Point cloud has an insufficient number of points. Increase visualizer window width and height.")
181 |
182 | # downsample point cloud to match the number of points with 'num_points'
183 | else:
184 | voxel_size = 0.5
185 | voxel_size_min = 0
186 | voxel_size_max = 1
187 |
188 | while True:
189 | partial_pcd_tmp = partial_pcd.voxel_down_sample(voxel_size)
190 |
191 | num_points_tmp = len(np.asarray(partial_pcd_tmp.points))
192 |
193 | if num_points_tmp - num_points >= 0 and num_points_tmp - num_points < 100:
194 | break
195 | else:
196 | if num_points_tmp > num_points:
197 | voxel_size_min = voxel_size
198 | elif num_points_tmp < num_points:
199 | voxel_size_max = voxel_size
200 |
201 | voxel_size = (voxel_size_min + voxel_size_max) / 2
202 |
203 | partial_pcd = partial_pcd_tmp.select_by_index(np.random.choice(num_points_tmp, num_points, replace=False))
204 |
205 | partial_pcds += [partial_pcd]
206 |
207 | # convert open3d PointCloud to numpy array
208 | partial_pcs = np.stack([np.asarray(partial_pcd.points) for partial_pcd in partial_pcds])
209 |
210 | # Delete mesh
211 | self.renderer.scene.remove_geometry('mesh')
212 |
213 | return partial_pcs
214 |
--------------------------------------------------------------------------------
/utils/visualization.py:
--------------------------------------------------------------------------------
1 | from plotly.subplots import make_subplots
2 | from plotly import graph_objects as go
3 | import numpy as np
4 | import torch
5 |
6 |
7 | class PlotlySubplotsVisualizer:
8 | def __init__(self, rows, cols):
9 | self.num_subplots = rows * cols
10 |
11 | self.reset(rows, cols)
12 |
13 | def reset(self, rows, cols):
14 | self.fig = make_subplots(rows=rows, cols=cols, specs=[[{'is_3d': True}]*cols]*rows)
15 | self.fig.update_layout(height=900)
16 |
17 | def add_vector(self, x, y, z, u, v, w, row, col, color='black', width=5, sizeref=0.2, showlegend=False):
18 | self.fig.add_trace(
19 | go.Scatter3d(x=[x, x+0.9*u], y=[y, y+0.9*v], z=[z, z+0.9*w], mode='lines', line=dict(color=color, width=width), showlegend=showlegend),
20 | row=row, col=col
21 | )
22 | self.fig.add_trace(
23 | go.Cone(x=[x+u], y=[y+v], z=[z+w], u=[u], v=[v], w=[w], sizemode='absolute', sizeref=sizeref, anchor='tip', colorscale=[[0, color], [1, color]], showscale=False),
24 | row=row, col=col
25 | )
26 |
27 | def add_mesh(self, mesh, row, col, color='aquamarine', idx='', showlegend=False):
28 | vertices = np.asarray(mesh.vertices)
29 | traingles = np.asarray(mesh.triangles)
30 |
31 | self.fig.add_trace(
32 | go.Mesh3d(x=vertices[:, 0], y=vertices[:, 1], z=vertices[:, 2], i=traingles[:, 0], j=traingles[:, 1], k=traingles[:, 2], color=color, showlegend=showlegend, name='mesh '+str(idx)),
33 | row=row, col=col
34 | )
35 |
36 | def add_pc(self, pc, row, col, color='lightpink', size=5, idx='', showlegend=False):
37 | self.fig.add_trace(
38 | go.Scatter3d(x=pc[:,0], y=pc[:,1], z=pc[:,2], mode='markers', marker=dict(size=size, color=color), showlegend=showlegend, name='pc '+str(idx)),
39 | row=row, col=col,
40 | )
41 |
42 | def add_gripper(self, T, row, col, color='violet', width=5, idx='', showlegend=False):
43 | gripper_scatter3d = get_gripper_scatter3d(T, color, width, idx, showlegend)
44 |
45 | self.fig.add_trace(gripper_scatter3d, row=row, col=col)
46 |
47 | def add_grippers(self, Ts, row, col, color='violet', width=5, idx='', showlegend=False):
48 | for T_gripper in Ts:
49 | gripper_scatter3d = get_gripper_scatter3d(T_gripper, color, width, idx, showlegend)
50 |
51 | self.fig.add_trace(gripper_scatter3d, row=row, col=col)
52 |
53 | def add_frame(self, T, row, col, size=1, width=5, sizeref=0.2):
54 | self.add_vector(T[0, 3], T[1, 3], T[2, 3], size * T[0, 0], size * T[1, 0], size * T[2, 0], row, col, color='red', width=width, sizeref=sizeref)
55 | self.add_vector(T[0, 3], T[1, 3], T[2, 3], size * T[0, 1], size * T[1, 1], size * T[2, 1], row, col, color='green', width=width, sizeref=sizeref)
56 | self.add_vector(T[0, 3], T[1, 3], T[2, 3], size * T[0, 2], size * T[1, 2], size * T[2, 2], row, col, color='blue', width=width, sizeref=sizeref)
57 |
58 |
59 | def get_gripper_scatter3d(T, color, width=5, idx='', showlegend=False):
60 | unit1 = 0.066 #* 8 # 0.56
61 | unit2 = 0.041 #* 8 # 0.32
62 | unit3 = 0.046 #* 8 # 0.4
63 |
64 | pbase = torch.Tensor([0, 0, 0, 1]).reshape(1, -1)
65 | pcenter = torch.Tensor([0, 0, unit1, 1]).reshape(1, -1)
66 | pleft = torch.Tensor([unit2, 0, unit1, 1]).reshape(1, -1)
67 | pright = torch.Tensor([-unit2, 0, unit1, 1]).reshape(1, -1)
68 | plefttip = torch.Tensor([unit2, 0, unit1+unit3, 1]).reshape(1, -1)
69 | prighttip = torch.Tensor([-unit2, 0, unit1+unit3, 1]).reshape(1, -1)
70 |
71 | hand = torch.cat([pbase, pcenter, pleft, pright, plefttip, prighttip], dim=0).to(T)
72 | hand = torch.einsum('ij, kj -> ik', T, hand).cpu()
73 |
74 | phandx = [hand[0, 4], hand[0, 2], hand[0, 1], hand[0, 0], hand[0, 1], hand[0, 3], hand[0, 5]]
75 | phandy = [hand[1, 4], hand[1, 2], hand[1, 1], hand[1, 0], hand[1, 1], hand[1, 3], hand[1, 5]]
76 | phandz = [hand[2, 4], hand[2, 2], hand[2, 1], hand[2, 0], hand[2, 1], hand[2, 3], hand[2, 5]]
77 |
78 | gripper_scatter3d = go.Scatter3d(x=phandx, y=phandy, z=phandz, mode='lines', line=dict(color=color, width=width), showlegend=showlegend, name='gripper '+str(idx))
79 |
80 | return gripper_scatter3d
81 |
--------------------------------------------------------------------------------