├── .gitignore ├── LICENSE ├── README.md ├── configs ├── equigraspflow_full.yml └── equigraspflow_partial.yml ├── environment.yml ├── images └── generation.gif ├── loaders ├── __init__.py ├── acronym.py └── utils.py ├── losses ├── __init__.py └── mse_loss.py ├── metrics ├── __init__.py └── emd.py ├── models ├── __init__.py ├── equi_grasp_flow.py ├── vn_dgcnn.py ├── vn_layers.py └── vn_vector_fields.py ├── test_full.py ├── test_partial.py ├── train.py ├── trainers ├── __init__.py └── grasp_trainer.py └── utils ├── Lie.py ├── average_meter.py ├── distributions.py ├── logger.py ├── mesh.py ├── ode_solvers.py ├── optimizers.py ├── partial_point_cloud.py └── visualization.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled 2 | __pycache__/ 3 | **/*.pyc 4 | 5 | # Folders 6 | .vscode/ 7 | dataset/ 8 | train_results/ 9 | pretrained_models/ 10 | test_results/ 11 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Byeongdo Lim 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # EquiGraspFlow: SE(3)-Equivariant 6-DoF Grasp Pose Generative Flows 2 | 3 | The official repository for (Byeongdo Lim, Jongmin Kim, Jihwan Kim, Yonghyeon Lee, and Frank C. Park, CoRL 2024) 4 | 5 | - [Project page](https://equigraspflow.github.io/) 6 | - [Openreview](https://openreview.net/forum?id=5lSkn5v4LK&referrer=%5BAuthor%20Console%5D(%2Fgroup%3Fid%3Drobot-learning.org%2FCoRL%2F2024%2FConference%2FAuthors%23your-submissions)) 7 | - [Paper](https://openreview.net/pdf?id=5lSkn5v4LK) 8 | - [Video](https://youtu.be/fxOveMwugo4?si=L1bmYNOMPbCHY1Cr) 9 | - [Poster](https://drive.google.com/file/d/1UTBoNDDT7FzHcXHSrFDA6x4v5hr3-g51/view?usp=sharing) 10 | 11 |
12 | 13 | 14 | 15 | 16 | ## Requirements 17 | 18 | ### Conda environment 19 | 20 | You can create a Conda environment using the following command. 21 | You can customize the environment name by modifying the `name` field in the `environment.yml` file. 22 | 23 | ```bash 24 | conda env create -f environment.yml 25 | ``` 26 | 27 | This will automatically install the required packages, including: 28 | 29 | - `python==3.10` 30 | - `omegaconf` 31 | - `tensorboardX` 32 | - `pyyaml` 33 | - `numpy==1.26` 34 | - `torch` 35 | - `scipy` 36 | - `tqdm` 37 | - `h5py` 38 | - `open3d==0.16.0` 39 | - `roma` 40 | - `pandas` 41 | - `openypyxl` 42 | 43 | To activate the environment, use: 44 | 45 | ```bash 46 | conda activate equigraspflow 47 | ``` 48 | 49 | 50 | ### Dataset 51 | 52 | We use the Laptop, Mug, Bowl, and Pencil categories of the ACRONYM dataset [1]. 53 | The dataset can be downloaded from [this link](https://drive.google.com/drive/folders/1H1PeUbyxvNtzoWc6Le2pKqOqp2WLSnau?usp=drive_link). 54 | Create a `dataset` directory and place the data in that directory, or customize the path to the dataset by modifying `DATASET_DIR` in `acronym.py` and `utils.py` within the `loaders` directory. 55 | 56 | 57 | ## Training 58 | 59 | ### Train a new model 60 | 61 | The training script is `train.py`, and comes with the following arguments: 62 | 63 | - `--config`: Path to the training configuration YAML file. 64 | - `--device`: GPU number to use (default: `0`). Use `cpu` to run on CPU. 65 | - `--logdir`: Directory where the results will be saved (default: `train_results`). 66 | - `--run`: Name for the training session (default: `{date}-{time}`). 67 | 68 | To train EquiGraspFlow using the full point cloud, run: 69 | 70 | ```bash 71 | python train.py --config configs/equigraspflow_full.yml 72 | ``` 73 | 74 | Alternatively, to train EquiGraspFlow with the partial point cloud, use: 75 | 76 | ```bash 77 | python train.py --config configs/equigraspflow_partial.yml 78 | ``` 79 | 80 | Note: Training with the partial point cloud cannot be done in headless mode; a display is required. 81 | 82 | You can change the data augmentation strategy for each data split by modifying the `augmentation` field in the training configuration YAML file. 83 | 84 | 85 | ### View training results 86 | 87 | We log the results of the training process using TensorBoard. You can view the TensorBoard results by running: 88 | 89 | ```bash 90 | tensorboard --logdir {path} --host {IP_address} 91 | ``` 92 | Replace `path` with the specific path to your training results and `IP_address` with your IP address. 93 | 94 | 95 | ## Pretrained models 96 | 97 | The pretrained models can be downloaded from [this link](https://drive.google.com/drive/folders/1H-MXRVcTekdEfzXU_suSw7Afi-7o8I39?usp=sharing). 98 | 99 | 100 | ## Test 101 | 102 | ### Run test 103 | 104 | The test scripts, `test_full.py` and `test_partial.py`, calculate the Earth Mover's Distance [2] between the generated and ground-truth grasp poses and store the visualizations of the generated grasp poses. 105 | It has the following arguments: 106 | 107 | - `--train_result_path`: Path to the directory containing training results. 108 | - `--checkpoint`: Model checkpoint to use. 109 | - `--device`: GPU number to use (default: `0`). Use `cpu` to run on CPU. 110 | - `--logdir`: Directory where the results will be saved (default: `test_results`). 111 | - `--run`: Name for the experiment (default: `{date}-{time}`). 112 | 113 | For example, to test EquiGraspFlow using the full point cloud with the `model_best_val_loss.pkl` checkpoint in `pretrained_model/equigraspflow_full` directory, use: 114 | 115 | ```bash 116 | python test_full.py --train_result_path train_results/equigraspflow_full --checkpoint model_best_val_loss.pkl 117 | ``` 118 | 119 | Alternatively, to test EquiGraspFlow using the partial point cloud with the `model_best_val_loss.pkl` checkpoint in `pretrained_model/equigraspflow_partial` directory, use: 120 | 121 | ```bash 122 | python test_partial.py --train_result_path train_results/equigraspflow_partial --checkpoint model_best_val_loss.pkl 123 | ``` 124 | 125 | 126 | ### Display visualizations 127 | 128 | The visualizations of the generated grasp poses are stored in `visualizations.json` within the test results directory. 129 | To display these visualizations, use the following code: 130 | 131 | ```python 132 | import plotly.io as pio 133 | 134 | pio.from_json(open('{path}/visualizations.json', 'r').read()).show() 135 | ``` 136 | 137 | Replace `path` with your test results directory. 138 | 139 | 140 | ## References 141 | 142 | [1] C. Eppner, A. Mousavian, and D. Fox. Acronym: A large-scale grasp dataset based on simulation, ICRA 2021. [[paper](https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9560844&casa_token=VAlWdJNx458AAAAA:z3KlV9ALMjYG34RNbCVmUPEPlFkS6b7NIty76glWYuMbn3XwXpTtmrTV2PnmzhrGr_5QN_jQpg&tag=1)] 143 | 144 | [2] A. Tanaka. Discriminator optimal transport. NeurIPS 2019. [[paper](https://proceedings.neurips.cc/paper/2019/file/8abfe8ac9ec214d68541fcb888c0b4c3-Paper.pdf)] 145 | 146 | 147 | ## Citation 148 | If you found this repository useful in your research, please cite: 149 | 150 | ```text 151 | @inproceedings{lim2024equigraspflow, 152 | title={EquiGraspFlow: SE (3)-Equivariant 6-DoF Grasp Pose Generative Flows}, 153 | author={Lim, Byeongdo and Kim, Jongmin and Kim, Jihwan and Lee, Yonghyeon and Park, Frank C}, 154 | booktitle={8th Annual Conference on Robot Learning}, 155 | year={2024} 156 | } 157 | ``` 158 | -------------------------------------------------------------------------------- /configs/equigraspflow_full.yml: -------------------------------------------------------------------------------- 1 | data: 2 | train: 3 | dataset: 4 | name: full 5 | obj_types: [Laptop, Mug, Bowl, Pencil] 6 | augmentation: SO3 7 | scale: 8 8 | batch_size: 4 9 | num_workers: 8 10 | val: 11 | dataset: 12 | name: full 13 | obj_types: [Laptop, Mug, Bowl, Pencil] 14 | augmentation: SO3 15 | scale: 8 16 | num_rots: 3 17 | batch_size: 4 18 | num_workers: 8 19 | test: 20 | dataset: 21 | name: full 22 | obj_types: [Laptop, Mug, Bowl, Pencil] 23 | augmentation: SO3 24 | scale: 8 25 | num_rots: 3 26 | batch_size: 4 27 | num_workers: 8 28 | model: 29 | name: equigraspflow 30 | p_uncond: 0.2 31 | guidance: 2.0 32 | init_dist: 33 | name: SO3_uniform_R3_normal 34 | encoder: 35 | name: vn_dgcnn_enc 36 | num_neighbors: 40 37 | dims: [1, 21, 21, 42, 85, 170, 341] 38 | use_bn: False 39 | vector_field: 40 | name: vn_vf 41 | dims: [346, 256, 256, 128, 128, 128, 2] 42 | use_bn: False 43 | ode_solver: 44 | name: SE3_RK_mk 45 | num_steps: 20 46 | losses: 47 | - name: mse 48 | optimizer: 49 | name: adam 50 | lr: 0.0001 51 | weight_decay: 1.0e-6 52 | metrics: 53 | - name: emd 54 | type: SE3 55 | trainer: 56 | name: grasp_full 57 | criteria: 58 | - name: emd 59 | better: lower 60 | num_epochs: 40000 61 | print_interval: 100 62 | val_interval: 10000 63 | eval_interval: 100000 64 | vis_interval: 100000 65 | save_interval: 2000000 66 | -------------------------------------------------------------------------------- /configs/equigraspflow_partial.yml: -------------------------------------------------------------------------------- 1 | data: 2 | train: 3 | dataset: 4 | name: partial 5 | obj_types: [Laptop, Mug, Bowl, Pencil] 6 | augmentation: SO3 7 | scale: 8 8 | batch_size: 4 9 | num_workers: 8 10 | val: 11 | dataset: 12 | name: partial 13 | obj_types: [Laptop, Mug, Bowl, Pencil] 14 | augmentation: SO3 15 | scale: 8 16 | num_rots: 3 17 | num_views: 3 18 | batch_size: 4 19 | num_workers: 8 20 | test: 21 | dataset: 22 | name: partial 23 | obj_types: [Laptop, Mug, Bowl, Pencil] 24 | augmentation: SO3 25 | scale: 8 26 | num_rots: 3 27 | num_views: 3 28 | batch_size: 4 29 | num_workers: 8 30 | model: 31 | name: equigraspflow 32 | p_uncond: 0.2 33 | guidance: 1.5 34 | init_dist: 35 | name: SO3_uniform_R3_normal 36 | encoder: 37 | name: vn_dgcnn_enc 38 | num_neighbors: 40 39 | dims: [1, 21, 21, 42, 85, 170, 341] 40 | use_bn: False 41 | vector_field: 42 | name: vn_vf 43 | dims: [346, 256, 256, 128, 128, 128, 2] 44 | use_bn: False 45 | ode_solver: 46 | name: SE3_RK_mk 47 | num_steps: 20 48 | losses: 49 | - name: mse 50 | optimizer: 51 | name: adam 52 | lr: 0.0001 53 | weight_decay: 1.0e-6 54 | metrics: 55 | - name: emd 56 | type: SE3 57 | trainer: 58 | name: grasp_partial 59 | criteria: 60 | - name: emd 61 | better: lower 62 | num_epochs: 40000 63 | print_interval: 100 64 | val_interval: 10000 65 | eval_interval: 100000 66 | vis_interval: 100000 67 | save_interval: 2000000 68 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: equigraspflow 2 | channels: 3 | - defaults 4 | dependencies: 5 | - pip=24.2 6 | - python=3.10 7 | - pip: 8 | - omegaconf 9 | - tensorboardX 10 | - pyyaml 11 | - numpy==1.26 12 | - torch 13 | - scipy 14 | - tqdm 15 | - h5py 16 | - open3d==0.16.0 17 | - roma 18 | - pandas 19 | - openpyxl 20 | -------------------------------------------------------------------------------- /images/generation.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bdlim99/EquiGraspFlow/f921425f27d0f80cc250a288b0b9cedbb9f61b41/images/generation.gif -------------------------------------------------------------------------------- /loaders/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from loaders.acronym import AcronymFullPointCloud, AcronymPartialPointCloud 4 | 5 | 6 | def get_dataloader(split, cfg_dataloader): 7 | cfg_dataloader.dataset.split = split 8 | 9 | dataset = get_dataset(cfg_dataloader.dataset) 10 | 11 | dataloader = torch.utils.data.DataLoader( 12 | dataset, 13 | batch_size=cfg_dataloader.batch_size, 14 | shuffle=cfg_dataloader.get('shuffle', True), 15 | num_workers=cfg_dataloader.get('num_workers', 8), 16 | collate_fn=collate_fn 17 | ) 18 | 19 | return dataloader 20 | 21 | 22 | def get_dataset(cfg_dataset): 23 | name = cfg_dataset.pop('name') 24 | 25 | if name == 'full': 26 | dataset = AcronymFullPointCloud(**cfg_dataset) 27 | elif name == 'partial': 28 | dataset = AcronymPartialPointCloud(**cfg_dataset) 29 | else: 30 | raise NotImplementedError(f"Dataset {name} not implemented.") 31 | 32 | return dataset 33 | 34 | 35 | def collate_fn(batch_original): 36 | batch_collated = {} 37 | 38 | for key in batch_original[0].keys(): 39 | if key == 'Ts_grasp': 40 | batch_collated[key] = [sample[key] for sample in batch_original] 41 | else: 42 | batch_collated[key] = torch.stack([sample[key] for sample in batch_original]) 43 | 44 | return batch_collated 45 | -------------------------------------------------------------------------------- /loaders/acronym.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from scipy.spatial.transform import Rotation 4 | import os 5 | from tqdm import tqdm 6 | import h5py 7 | from copy import deepcopy 8 | 9 | from loaders.utils import load_grasp_poses, load_mesh 10 | from utils.Lie import super_fibonacci_spiral, get_fibonacci_sphere 11 | from utils.partial_point_cloud import get_partial_point_clouds 12 | 13 | 14 | DATASET_DIR = 'dataset' 15 | NUM_GRASPS = 100 16 | 17 | 18 | class AcronymFullPointCloud(torch.utils.data.Dataset): 19 | def __init__(self, split, obj_types, augmentation, scale, num_pts=1024, num_rots=1): 20 | # Initialize 21 | self.len_dataset = 0 22 | self.split = split 23 | self.num_pts = num_pts 24 | self.augmentation = augmentation 25 | self.scale = scale 26 | self.num_rots = num_rots 27 | self.obj_types = obj_types 28 | 29 | # Initialize maximum number of objects 30 | self.max_num_objs = 0 31 | 32 | # Get evenly distributed rotations for validation and test splits 33 | if split in ['val', 'test']: 34 | if augmentation == 'None': 35 | self.Rs = np.expand_dims(np.eye(3), axis=0) 36 | elif augmentation == 'z': 37 | degree = (np.arange(num_rots) / num_rots) * (2 * np.pi) 38 | self.Rs = Rotation.from_rotvec(degree * np.array([0, 0, 1])).as_matrix() 39 | elif augmentation == 'SO3': 40 | self.Rs = super_fibonacci_spiral(num_rots) 41 | else: 42 | raise ValueError("Choose augmentation from ['None', 'z', 'SO3'].") 43 | else: 44 | assert num_rots == 1, "Number of rotations must be 1 in train set." 45 | 46 | self.Rs = np.expand_dims(np.eye(3), axis=0) 47 | 48 | Ts = np.tile(np.eye(4), (num_rots, 1, 1)) 49 | Ts[:, :3, :3] = self.Rs 50 | 51 | # Initialize data indices 52 | data_idxs_types = [] 53 | 54 | # Initialize lists 55 | self.mesh_list_types = [] 56 | self.Ts_grasp_list_types = [] 57 | 58 | self.pc_list_types = [] 59 | 60 | self.obj_idxs_types = [] 61 | 62 | for obj_type in tqdm(obj_types, desc="Iterating object types ...", leave=False): 63 | # Get data filenames 64 | filenames = sorted(os.listdir(os.path.join(DATASET_DIR, 'grasps', obj_type))) 65 | 66 | # Get object indices for the split 67 | obj_idxs = np.load(os.path.join(DATASET_DIR, 'splits', obj_type, f'idxs_{split}.npy')) 68 | 69 | # Initialize data indices 70 | data_idxs_objs = [] 71 | 72 | # Initialize lists 73 | mesh_list_objs = [] 74 | Ts_grasp_list_objs = [] 75 | 76 | pc_list_objs = [] 77 | 78 | obj_idxs_objs = [] 79 | 80 | for obj_idx in tqdm(obj_idxs, desc="Iterating objects ...", leave=False): 81 | # Get data filename 82 | filename = filenames[obj_idx] 83 | 84 | # Load data 85 | data = h5py.File(os.path.join(DATASET_DIR, 'grasps', obj_type, filename)) 86 | 87 | # Load grasp poses 88 | Ts_grasp = load_grasp_poses(data) 89 | 90 | # Continue if grasp poses are not enough 91 | if len(Ts_grasp) < NUM_GRASPS: 92 | continue 93 | else: 94 | obj_idxs_objs += [obj_idx] 95 | 96 | # Load mesh 97 | mesh = load_mesh(data) 98 | 99 | # Scale 100 | mesh.scale(scale, center=(0, 0, 0)) 101 | Ts_grasp[:, :3, 3] *= scale 102 | 103 | # Sample point cloud 104 | pc = np.asarray(mesh.sample_points_uniformly(num_pts).points).T 105 | 106 | # Translate to the center of the point cloud 107 | center = pc.mean(axis=1) 108 | mesh.translate(-center) 109 | pc -= np.expand_dims(center, axis=1) 110 | Ts_grasp[:, :3, 3] -= center 111 | 112 | # Rotate mesh 113 | mesh_list_rots = [] 114 | 115 | for R in tqdm(self.Rs, desc="Iterating rotations ...", leave=False): 116 | mesh_rot = deepcopy(mesh) 117 | mesh_rot.rotate(R, center=(0, 0, 0)) 118 | 119 | mesh_list_rots += [mesh_rot] 120 | 121 | # Rotate the other data 122 | pc_rots = self.Rs @ pc 123 | Ts_grasp_rots = np.einsum('rij,njk->rnik', Ts, Ts_grasp) 124 | 125 | # Fill data indices 126 | data_idxs_objs += [list(range(self.len_dataset, self.len_dataset+num_rots))] 127 | 128 | # Append data 129 | mesh_list_objs += [mesh_list_rots] 130 | Ts_grasp_list_objs += [Ts_grasp_rots] 131 | 132 | pc_list_objs += [pc_rots] 133 | 134 | # Increase number of data 135 | self.len_dataset += num_rots 136 | 137 | # Append data 138 | data_idxs_types += [data_idxs_objs] 139 | 140 | self.mesh_list_types += [mesh_list_objs] 141 | self.Ts_grasp_list_types += [Ts_grasp_list_objs] 142 | 143 | self.pc_list_types += [pc_list_objs] 144 | 145 | self.obj_idxs_types += [obj_idxs_objs] 146 | 147 | # Update maximum number of objects 148 | if len(obj_idxs_objs) > self.max_num_objs: 149 | self.max_num_objs = len(obj_idxs_objs) 150 | 151 | # Convert data indices from lists to numpy array 152 | self.data_idxs = np.full((len(obj_types), self.max_num_objs, num_rots), -1) 153 | 154 | for i, data_idxs_objs in enumerate(data_idxs_types): 155 | self.data_idxs[i, :len(data_idxs_objs)] = data_idxs_objs 156 | 157 | # Setup scene indices 158 | self.scene_idxs = self.data_idxs 159 | self.num_scenes = self.scene_idxs.max() + 1 160 | 161 | def __len__(self): 162 | return self.len_dataset 163 | 164 | def __getitem__(self, idx): 165 | # Get type, object, and rotation indices 166 | idx_type, idx_obj, idx_rot = np.where(self.data_idxs==idx) 167 | 168 | idx_type = idx_type.item() 169 | idx_obj = idx_obj.item() 170 | idx_rot = idx_rot.item() 171 | 172 | # Load grasp poses 173 | Ts_grasp = self.Ts_grasp_list_types[idx_type][idx_obj][idx_rot].copy() 174 | 175 | if self.split == 'train': 176 | # Load mesh 177 | mesh = self.mesh_list_types[idx_type][idx_obj][idx_rot] 178 | 179 | # Sample point cloud 180 | pc = np.asarray(mesh.sample_points_uniformly(self.num_pts).points).T 181 | 182 | # Translate to the point cloud center 183 | center = pc.mean(axis=1) 184 | pc -= np.expand_dims(center, axis=1) 185 | Ts_grasp[:, :3, 3] -= center 186 | 187 | # Rotate data 188 | if self.augmentation != 'None': 189 | if self.augmentation == 'z': 190 | # Randomly rotate around z-axis 191 | degree = np.random.rand() * 2 * np.pi 192 | R = Rotation.from_rotvec(degree * np.array([0, 0, 1])).as_matrix() 193 | elif self.augmentation == 'SO3': 194 | # Randomly rotate 195 | R = Rotation.random().as_matrix() 196 | else: 197 | raise ValueError("Choose augmentation from ['None', 'z', 'SO3'].") 198 | 199 | T = np.eye(4) 200 | T[:3, :3] = R 201 | 202 | pc = R @ pc 203 | Ts_grasp = T @ Ts_grasp 204 | else: 205 | # Load point cloud 206 | pc = self.pc_list_types[idx_type][idx_obj][idx_rot] 207 | 208 | return {'pc': torch.Tensor(pc), 'Ts_grasp': torch.Tensor(Ts_grasp)} 209 | 210 | 211 | class AcronymPartialPointCloud(torch.utils.data.Dataset): 212 | def __init__(self, split, obj_types, augmentation, scale, num_pts=512, num_rots=1, num_views=1): 213 | # Initialize 214 | self.len_dataset = 0 215 | self.split = split 216 | self.num_pts = num_pts 217 | self.augmentation = augmentation 218 | self.scale = scale 219 | self.num_rots = num_rots 220 | self.num_views = num_views 221 | self.obj_types = obj_types 222 | 223 | # Initialize maximum number of objects 224 | self.max_num_objs = 0 225 | 226 | # Get rotations 227 | if split in ['val', 'test']: 228 | if augmentation == 'None': 229 | self.Rs = np.expand_dims(np.eye(3), axis=0) 230 | elif augmentation == 'z': 231 | degree = (np.arange(num_rots) / num_rots) * (2 * np.pi) 232 | self.Rs = Rotation.from_rotvec(degree * np.array([0, 0, 1])).as_matrix() 233 | elif augmentation == 'SO3': 234 | self.Rs = super_fibonacci_spiral(num_rots) 235 | else: 236 | assert augmentation == 'None', "Choose augmentation from ['None', 'z', 'SO3']." 237 | else: 238 | assert num_rots == 1, "Number of rotations must be 1 in train set." 239 | 240 | self.Rs = np.expand_dims(np.eye(3), axis=0) 241 | 242 | Ts = np.tile(np.eye(4), (num_rots, 1, 1)) 243 | Ts[:, :3, :3] = self.Rs 244 | 245 | # Get viewpoint vector 246 | if split in ['val', 'test']: 247 | view_vecs = get_fibonacci_sphere(num_views) 248 | else: 249 | assert num_views == 1, "Number of viewpoint vector must be 1 in train set." 250 | 251 | view_vecs = np.array([[0, 0, 1]]) 252 | 253 | # Initialize data indices 254 | data_idxs_types = [] 255 | 256 | # Initialize lists 257 | self.mesh_list_types = [] 258 | self.Ts_grasp_list_types = [] 259 | 260 | self.partial_pc_list_types = [] 261 | 262 | self.obj_idxs_types = [] 263 | 264 | for obj_type in tqdm(obj_types, desc="Iterating object types ...", leave=False): 265 | # Get data filenames 266 | filenames = sorted(os.listdir(os.path.join(DATASET_DIR, 'grasps', obj_type))) 267 | 268 | # Get object indices for the split 269 | obj_idxs = np.load(os.path.join(DATASET_DIR, 'splits', obj_type, f'idxs_{split}.npy')) 270 | 271 | # Initialize data indices 272 | data_idxs_objs = [] 273 | 274 | # Initialize lists 275 | mesh_list_objs = [] 276 | Ts_grasp_list_objs = [] 277 | 278 | partial_pc_list_objs = [] 279 | 280 | obj_idxs_objs = [] 281 | 282 | for obj_idx in tqdm(obj_idxs, desc="Iterating objects ...", leave=False): 283 | # Get data filename 284 | filename = filenames[obj_idx] 285 | 286 | # Load data 287 | data = h5py.File(os.path.join(DATASET_DIR, 'grasps', obj_type, filename)) 288 | 289 | # Load grasp poses 290 | Ts_grasp = load_grasp_poses(data) 291 | 292 | # Continue if grasp poses are not enough 293 | if len(Ts_grasp) < NUM_GRASPS: 294 | continue 295 | else: 296 | obj_idxs_objs += [obj_idx] 297 | 298 | # Load mesh 299 | mesh = load_mesh(data) 300 | 301 | # Translate to the center of the mesh 302 | center = mesh.get_center() 303 | mesh.translate(-center) 304 | Ts_grasp[:, :3, 3] -= center 305 | 306 | # Scale 307 | mesh.scale(scale, center=(0, 0, 0)) 308 | Ts_grasp[:, :3, 3] *= scale 309 | 310 | # Initialize data indices 311 | data_idxs_rots = [] 312 | 313 | # Initialize lists 314 | mesh_list_rots = [] 315 | partial_pc_list_rots = [] 316 | 317 | for R in tqdm(self.Rs, desc="Iterating rotations ...", leave=False): 318 | # Rotate mesh 319 | mesh_rot = deepcopy(mesh) 320 | mesh_rot.rotate(R, center=(0, 0, 0)) 321 | 322 | # Sample partial point clouds 323 | partial_pc_views = get_partial_point_clouds(mesh_rot, view_vecs, num_pts, use_tqdm=True).transpose(0, 2, 1) 324 | 325 | # Initialize mesh list 326 | mesh_list_views = [] 327 | 328 | for partial_pc in partial_pc_views: 329 | # Translate mesh to the center of the partial point cloud 330 | mesh_view = deepcopy(mesh_rot) 331 | mesh_view.translate(-partial_pc.mean(axis=1)) 332 | 333 | # Append mesh 334 | mesh_list_views += [mesh_view] 335 | 336 | # Fill data indices 337 | data_idxs_rots += [list(range(self.len_dataset, self.len_dataset+num_views))] 338 | 339 | # Append data 340 | mesh_list_rots += [mesh_list_views] 341 | partial_pc_list_rots += [partial_pc_views] 342 | 343 | # Increase number of data 344 | self.len_dataset += num_views 345 | 346 | # Stack partial point clouds 347 | partial_pc_rots = np.stack(partial_pc_list_rots) 348 | 349 | # Rotate grasp poses 350 | Ts_grasp_rots = np.einsum('rij,njk->rnik', Ts, Ts_grasp) 351 | 352 | # Translate to the center of the partial point clouds 353 | center_rots = partial_pc_rots.mean(axis=3) 354 | 355 | Ts_grasp_rots = np.expand_dims(Ts_grasp_rots, axis=1).repeat(num_views, axis=1) 356 | 357 | partial_pc_rots -= np.expand_dims(center_rots, axis=3) 358 | Ts_grasp_rots[:, :, :, :3, 3] -= np.expand_dims(center_rots, axis=2) 359 | 360 | # Append data 361 | data_idxs_objs += [data_idxs_rots] 362 | 363 | mesh_list_objs += [mesh_list_rots] 364 | Ts_grasp_list_objs += [Ts_grasp_rots] 365 | 366 | partial_pc_list_objs += [partial_pc_rots] 367 | 368 | # Append data 369 | data_idxs_types += [data_idxs_objs] 370 | 371 | self.mesh_list_types += [mesh_list_objs] 372 | self.Ts_grasp_list_types += [Ts_grasp_list_objs] 373 | 374 | self.partial_pc_list_types += [partial_pc_list_objs] 375 | 376 | self.obj_idxs_types += [obj_idxs_objs] 377 | 378 | # Update maximum number of objects 379 | if len(obj_idxs_objs) > self.max_num_objs: 380 | self.max_num_objs = len(obj_idxs_objs) 381 | 382 | # Convert data indices from lists to numpy array 383 | self.data_idxs = np.full((len(obj_types), self.max_num_objs, num_rots, num_views), -1) 384 | 385 | for i, data_idxs_objs in enumerate(data_idxs_types): 386 | self.data_idxs[i, :len(data_idxs_objs)] = data_idxs_objs 387 | 388 | # Setup scene indices 389 | self.scene_idxs = self.data_idxs 390 | self.num_scenes = self.scene_idxs.max() + 1 391 | 392 | def __len__(self): 393 | return self.len_dataset 394 | 395 | def __getitem__(self, idx): 396 | # Get type, object, and rotation indices 397 | idx_type, idx_obj, idx_rot, idx_view = np.where(self.data_idxs==idx) 398 | 399 | idx_type = idx_type.item() 400 | idx_obj = idx_obj.item() 401 | idx_rot = idx_rot.item() 402 | idx_view = idx_view.item() 403 | 404 | # Load grasp poses 405 | Ts_grasp = self.Ts_grasp_list_types[idx_type][idx_obj][idx_rot][idx_view].copy() 406 | 407 | if self.split == 'train': 408 | # Load mesh 409 | mesh = deepcopy(self.mesh_list_types[idx_type][idx_obj][idx_rot][idx_view]) 410 | 411 | # Get random rotation 412 | if self.augmentation == 'None': 413 | R = np.eye(3) 414 | elif self.augmentation == 'z': 415 | degree = np.random.rand() * 2 * np.pi 416 | R = Rotation.from_rotvec(degree * np.array([0, 0, 1])).as_matrix() 417 | elif self.augmentation == 'SO3': 418 | R = Rotation.random().as_matrix() 419 | else: 420 | raise ValueError("Choose augmentation from ['None', 'z', 'SO3'].") 421 | 422 | T = np.eye(4) 423 | T[:3, :3] = R 424 | 425 | # Rotate mesh 426 | mesh.rotate(R, center=(0, 0, 0)) 427 | 428 | # Sample partial point cloud 429 | while True: 430 | try: 431 | view_vecs = -1 + 2 * np.random.rand(1, 3) 432 | view_vecs = view_vecs / np.linalg.norm(view_vecs) 433 | 434 | partial_pc = get_partial_point_clouds(mesh, view_vecs, self.num_pts)[0].T 435 | 436 | break 437 | except: 438 | pass 439 | 440 | # Rotate grasp poses 441 | Ts_grasp = T @ Ts_grasp 442 | 443 | # Translate to the center of the partial point cloud 444 | center = partial_pc.mean(axis=1) 445 | partial_pc -= np.expand_dims(center, axis=1) 446 | Ts_grasp[:, :3, 3] -= center 447 | else: 448 | # Load point cloud 449 | partial_pc = self.partial_pc_list_types[idx_type][idx_obj][idx_rot][idx_view] 450 | 451 | return {'pc': torch.Tensor(partial_pc), 'Ts_grasp': torch.Tensor(Ts_grasp)} 452 | -------------------------------------------------------------------------------- /loaders/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import open3d as o3d 3 | import pickle 4 | import numpy as np 5 | 6 | 7 | DATASET_DIR = 'dataset' 8 | 9 | 10 | def load_grasp_poses(data): 11 | grasps = data['grasps/transforms'][()] 12 | success = data['grasps/qualities/flex/object_in_gripper'][()] 13 | 14 | grasps_good = grasps[success==1] 15 | 16 | return grasps_good 17 | 18 | 19 | def load_mesh(data): 20 | mesh_path = data['object/file'][()].decode('utf-8') 21 | mesh_scale = data['object/scale'][()] 22 | 23 | mesh = o3d.io.read_triangle_mesh(os.path.join(DATASET_DIR, mesh_path)) 24 | 25 | mesh.scale(mesh_scale, center=(0, 0, 0)) 26 | 27 | return mesh 28 | -------------------------------------------------------------------------------- /losses/__init__.py: -------------------------------------------------------------------------------- 1 | from losses.mse_loss import MSELoss 2 | 3 | 4 | def get_losses(cfg_losses): 5 | losses = {} 6 | 7 | for cfg_loss in cfg_losses: 8 | name = cfg_loss.name 9 | 10 | losses[name] = get_loss(cfg_loss) 11 | 12 | return losses 13 | 14 | 15 | def get_loss(cfg_loss): 16 | name = cfg_loss.pop('name') 17 | 18 | if name == 'mse': 19 | loss = MSELoss(**cfg_loss) 20 | else: 21 | raise NotImplementedError(f"Loss {name} is not implemented.") 22 | 23 | return loss 24 | -------------------------------------------------------------------------------- /losses/mse_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class MSELoss: 5 | def __init__(self, weight=1, reduction='mean'): 6 | self.weight = weight 7 | 8 | self.mse_loss = torch.nn.MSELoss(reduction=reduction) 9 | 10 | def __call__(self, pred, target): 11 | loss = self.mse_loss(pred, target) 12 | 13 | return loss 14 | -------------------------------------------------------------------------------- /metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from metrics.emd import EMDCalculator 2 | 3 | 4 | def get_metrics(cfg_metrics): 5 | metrics = {} 6 | 7 | for cfg_metric in cfg_metrics: 8 | name = cfg_metric.name 9 | 10 | metrics[name] = get_metric(cfg_metric) 11 | 12 | return metrics 13 | 14 | 15 | def get_metric(cfg_metric): 16 | name = cfg_metric.pop('name') 17 | 18 | if name == 'emd': 19 | metric = EMDCalculator(**cfg_metric) 20 | else: 21 | raise NotImplementedError(f"Metric {name} is not implemented.") 22 | 23 | return metric 24 | -------------------------------------------------------------------------------- /metrics/emd.py: -------------------------------------------------------------------------------- 1 | from scipy.optimize import linear_sum_assignment 2 | 3 | from utils.Lie import SE3_geodesic_dist 4 | 5 | 6 | class EMDCalculator: 7 | def __init__(self, type): 8 | self.type = type 9 | 10 | def calculate_distance(self, x, y): 11 | if self.type == 'SE3': 12 | T_x = x.view(-1, 4, 4) 13 | T_y = y.view(-1, 4, 4) 14 | 15 | return SE3_geodesic_dist(T_x, T_y).view(x.shape[:2]) 16 | elif self.type == 'L2': 17 | return ((x - y) ** 2).sum(dim=3).sqrt() 18 | else: 19 | raise NotImplementedError(f"Type {self.type} is not implemented. Choose type between 'SE3' and 'L2'.") 20 | 21 | def __call__(self, source, target): 22 | assert len(source) == len(target), f"The number of samples in source {len(source)} must be equal to the number of samples in target {len(target)}." 23 | 24 | source = source.unsqueeze(1).repeat(1, len(target), 1, 1) 25 | target = target.unsqueeze(0).repeat(len(source), 1, 1, 1) 26 | 27 | distance = self.calculate_distance(source, target).cpu().numpy() 28 | 29 | idxs_row, idxs_col = linear_sum_assignment(distance) 30 | 31 | emd = distance[idxs_row, idxs_col].mean() 32 | 33 | return emd -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from utils.distributions import get_dist 4 | from models.vn_dgcnn import VNDGCNNEncoder 5 | from models.vn_vector_fields import VNVectorFields 6 | from utils.ode_solvers import get_ode_solver 7 | from models.equi_grasp_flow import EquiGraspFlow 8 | 9 | 10 | def get_model(cfg_model): 11 | name = cfg_model.pop('name') 12 | checkpoint = cfg_model.get('checkpoint', None) 13 | 14 | if name == 'equigraspflow': 15 | model = _get_equigraspflow(cfg_model) 16 | else: 17 | raise NotImplementedError(f"Model {name} is not implemented.") 18 | 19 | if checkpoint is not None: 20 | checkpoint = torch.load(checkpoint, map_location='cpu') 21 | 22 | if 'model_state' in checkpoint: 23 | model.load_state_dict(checkpoint['model_state']) 24 | 25 | return model 26 | 27 | 28 | def _get_equigraspflow(cfg): 29 | p_uncond = cfg.pop('p_uncond') 30 | guidance = cfg.pop('guidance') 31 | 32 | init_dist = get_dist(cfg.pop('init_dist')) 33 | encoder = get_net(cfg.pop('encoder')) 34 | vector_field = get_net(cfg.pop('vector_field')) 35 | ode_solver = get_ode_solver(cfg.pop('ode_solver')) 36 | 37 | model = EquiGraspFlow(p_uncond, guidance, init_dist, encoder, vector_field, ode_solver) 38 | 39 | return model 40 | 41 | 42 | def get_net(cfg_net): 43 | name = cfg_net.pop('name') 44 | 45 | if name == 'vn_dgcnn_enc': 46 | net = _get_vn_dgcnn_enc(cfg_net) 47 | elif name == 'vn_vf': 48 | net = _get_vn_vf(cfg_net) 49 | else: 50 | raise NotImplementedError(f"Network {name} is not implemented.") 51 | 52 | return net 53 | 54 | 55 | def _get_vn_dgcnn_enc(cfg): 56 | net = VNDGCNNEncoder(**cfg) 57 | 58 | return net 59 | 60 | 61 | def _get_vn_vf(cfg): 62 | net = VNVectorFields(**cfg) 63 | 64 | return net 65 | -------------------------------------------------------------------------------- /models/equi_grasp_flow.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from copy import deepcopy 3 | 4 | from utils.Lie import inv_SO3, log_SO3, exp_so3, bracket_so3 5 | 6 | 7 | class EquiGraspFlow(torch.nn.Module): 8 | def __init__(self, p_uncond, guidance, init_dist, encoder, vector_field, ode_solver): 9 | super().__init__() 10 | 11 | self.p_uncond = p_uncond 12 | self.guidance = guidance 13 | 14 | self.init_dist = init_dist 15 | self.encoder = encoder 16 | self.vector_field = vector_field 17 | self.ode_solver = ode_solver 18 | 19 | def step(self, data, losses, split, optimizer=None): 20 | # Get data 21 | pc = data['pc'] 22 | x_1 = data['Ts_grasp'] 23 | 24 | # Get number of grasp poses in each batch and combine batched data 25 | nums_grasps = torch.tensor([len(Ts_grasp) for Ts_grasp in x_1], device=data['pc'].device) 26 | 27 | x_1 = torch.cat(x_1, dim=0) 28 | 29 | # Sample t and x_0 30 | t = torch.rand(len(x_1), 1).to(x_1.device) 31 | x_0 = self.init_dist(len(x_1), x_1.device) 32 | 33 | # Get x_t and u_t 34 | x_t, u_t = get_traj(x_0, x_1, t) 35 | 36 | # Forward 37 | v_t = self(pc, t, x_t, nums_grasps) 38 | 39 | # Calculate loss 40 | loss_mse = losses['mse'](v_t, u_t) 41 | 42 | loss = losses['mse'].weight * loss_mse 43 | 44 | # Backward 45 | if optimizer is not None: 46 | loss.backward() 47 | optimizer.step() 48 | 49 | # Archive results 50 | results = { 51 | f'scalar/{split}/loss': loss.item(), 52 | } 53 | 54 | return results 55 | 56 | def forward(self, pc, t, x_t, nums_grasps): 57 | z = torch.zeros((len(pc), self.encoder.dims[-1], 3), device=pc.device) 58 | 59 | # Encode point cloud 60 | z = self.encoder(pc) 61 | 62 | # Repeat feature 63 | z = z.repeat_interleave(nums_grasps, dim=0) 64 | 65 | # Null condition 66 | mask_uncond = torch.bernoulli(torch.Tensor([self.p_uncond] * len(z))).to(bool) 67 | 68 | z[mask_uncond] = torch.zeros_like(z[mask_uncond]) 69 | 70 | # Get vector 71 | v_t = self.vector_field(z, t, x_t) 72 | 73 | return v_t 74 | 75 | @torch.no_grad() 76 | def sample(self, pc, nums_grasps): 77 | # Sample initial samples 78 | x_0 = self.init_dist(sum(nums_grasps), pc.device) 79 | self.X0SAMPLED = deepcopy(x_0) 80 | 81 | # Encode point cloud 82 | z = self.encoder(pc) 83 | 84 | # Repeat feature 85 | z = z.repeat_interleave(nums_grasps, dim=0) 86 | 87 | # Push-forward initial samples 88 | x_1_hat = self.ode_solver(z, x_0, self.guided_vector_field)[:, -1] 89 | 90 | # Batch x_1_hat 91 | x_1_hat = x_1_hat.split(nums_grasps.tolist()) 92 | 93 | return x_1_hat 94 | 95 | def guided_vector_field(self, z, t, x_t): 96 | v_t = (1 - self.guidance) * self.vector_field(torch.zeros_like(z), t, x_t) + self.guidance * self.vector_field(z, t, x_t) 97 | 98 | return v_t 99 | 100 | 101 | def get_traj(x_0, x_1, t): 102 | # Get rotations 103 | R_0 = x_0[:, :3, :3] 104 | R_1 = x_1[:, :3, :3] 105 | 106 | # Get translations 107 | p_0 = x_0[:, :3, 3] 108 | p_1 = x_1[:, :3, 3] 109 | 110 | # Get x_t 111 | x_t = torch.eye(4).repeat(len(x_1), 1, 1).to(x_1) 112 | x_t[:, :3, :3] = (R_0 @ exp_so3(t.unsqueeze(2) * log_SO3(inv_SO3(R_0) @ R_1))) 113 | x_t[:, :3, 3] = p_0 + t * (p_1 - p_0) 114 | 115 | # Get u_t 116 | u_t = torch.zeros(len(x_1), 6).to(x_1) 117 | u_t[:, :3] = bracket_so3(log_SO3(inv_SO3(R_0) @ R_1)) 118 | u_t[:, :3] = torch.einsum('bij,bj->bi', R_0, u_t[:, :3]) # Convert w_b to w_s 119 | u_t[:, 3:] = p_1 - p_0 120 | 121 | return x_t, u_t 122 | -------------------------------------------------------------------------------- /models/vn_dgcnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from models.vn_layers import VNLinearLeakyReLU, knn 4 | 5 | 6 | class VNDGCNNEncoder(torch.nn.Module): 7 | def __init__(self, num_neighbors, dims=[1, 21, 21, 42, 85, 341], use_bn=False): 8 | super().__init__() 9 | 10 | self.num_neighbors = num_neighbors 11 | self.dims = dims 12 | 13 | layers = [] 14 | 15 | for dim_in, dim_out in zip(dims[:-2], dims[1:-1]): 16 | layers += [VNLinearLeakyReLU(2 * dim_in, dim_out, use_bn=use_bn)] 17 | 18 | layers += [VNLinearLeakyReLU(sum(dims[1:-1]), dims[-1], dim=4, share_nonlinearity=True, use_bn=use_bn)] 19 | 20 | self.layers = torch.nn.ModuleList(layers) 21 | 22 | def forward(self, x): 23 | x = x.unsqueeze(1) 24 | 25 | x_list = [] 26 | 27 | for layer in self.layers[:-1]: 28 | x = get_graph_feature(x, k=self.num_neighbors) 29 | x = layer(x) 30 | x = x.mean(dim=-1) 31 | 32 | x_list += [x] 33 | 34 | x = torch.cat(x_list, dim=1) 35 | 36 | x = self.layers[-1](x) 37 | x = x.mean(dim=-1) 38 | 39 | return x 40 | 41 | 42 | def get_graph_feature(x, k=20): 43 | batch_size = x.shape[0] 44 | num_pts = x.shape[3] 45 | 46 | x = x.view(batch_size, -1, num_pts) 47 | 48 | idx = knn(x, k=k) 49 | idx_base = torch.arange(0, batch_size, device=idx.device).unsqueeze(1).unsqueeze(2) * num_pts 50 | idx = idx + idx_base 51 | idx = idx.view(-1) 52 | 53 | num_dims = x.shape[1] // 3 54 | 55 | x = x.transpose(2, 1).contiguous() 56 | feature = x.view(batch_size*num_pts, -1)[idx] 57 | feature = feature.view(batch_size, num_pts, k, num_dims, 3) 58 | x = x.view(batch_size, num_pts, 1, num_dims, 3).repeat(1, 1, k, 1, 1) 59 | 60 | feature = torch.cat((feature-x, x), dim=3).permute(0, 3, 4, 1, 2).contiguous() 61 | 62 | return feature -------------------------------------------------------------------------------- /models/vn_layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | EPS = 1e-6 5 | 6 | 7 | class VNLinear(torch.nn.Module): 8 | def __init__(self, in_channels, out_channels): 9 | super().__init__() 10 | 11 | self.map_to_feat = torch.nn.Linear(in_channels, out_channels, bias=False) 12 | 13 | def forward(self, x): 14 | ''' 15 | x: point features of shape [B, N_feat, 3, N_samples, ...] 16 | ''' 17 | x_out = self.map_to_feat(x.transpose(1,-1)).transpose(1,-1) 18 | 19 | return x_out 20 | 21 | 22 | class VNBatchNorm(torch.nn.Module): 23 | def __init__(self, num_features, dim): 24 | super().__init__() 25 | 26 | if dim == 3 or dim == 4: 27 | self.bn = torch.nn.BatchNorm1d(num_features) 28 | elif dim == 5: 29 | self.bn = torch.nn.BatchNorm2d(num_features) 30 | 31 | def forward(self, x): 32 | ''' 33 | x: point features of shape [B, N_feat, 3, N_samples, ...] 34 | ''' 35 | # norm = torch.sqrt((x*x).sum(2)) 36 | norm = torch.norm(x, dim=2) + EPS 37 | norm_bn = self.bn(norm) 38 | norm = norm.unsqueeze(2) 39 | norm_bn = norm_bn.unsqueeze(2) 40 | x = x / norm * norm_bn 41 | 42 | return x 43 | 44 | 45 | class VNLinearLeakyReLU(torch.nn.Module): 46 | def __init__(self, in_channels, out_channels, dim=5, share_nonlinearity=False, use_bn=True, negative_slope=0.2): 47 | super().__init__() 48 | 49 | self.negative_slope = negative_slope 50 | self.use_bn = use_bn 51 | 52 | # Linear 53 | self.map_to_feat = torch.nn.Linear(in_channels, out_channels, bias=False) 54 | 55 | # BatchNorm 56 | if use_bn: 57 | self.bn = VNBatchNorm(out_channels, dim=dim) 58 | 59 | # LeakyReLU 60 | if share_nonlinearity: 61 | self.map_to_dir = torch.nn.Linear(in_channels, 1, bias=False) 62 | else: 63 | self.map_to_dir = torch.nn.Linear(in_channels, out_channels, bias=False) 64 | 65 | def forward(self, x): 66 | ''' 67 | x: point features of shape [B, N_feat, 3, N_samples, ...] 68 | ''' 69 | # Linear 70 | p = self.map_to_feat(x.transpose(1,-1)).transpose(1,-1) 71 | 72 | # BatchNorm 73 | if self.use_bn: 74 | p = self.bn(p) 75 | 76 | # LeakyReLU 77 | d = self.map_to_dir(x.transpose(1,-1)).transpose(1,-1) 78 | dotprod = (p*d).sum(2, keepdims=True) 79 | mask = (dotprod >= 0).float() 80 | d_norm_sq = (d*d).sum(2, keepdims=True) 81 | x_out = self.negative_slope * p + (1-self.negative_slope) * (mask*p + (1-mask)*(p-(dotprod/(d_norm_sq+EPS))*d)) 82 | 83 | return x_out 84 | 85 | 86 | def knn(x, k): 87 | pairwise_distance = (x.unsqueeze(-1) - x.unsqueeze(-2)).norm(dim=1) ** 2 88 | 89 | idx = pairwise_distance.topk(k, dim=-1, largest=False)[1] # (batch_size, num_pts, k) 90 | 91 | return idx -------------------------------------------------------------------------------- /models/vn_vector_fields.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from models.vn_layers import VNLinearLeakyReLU, VNLinear 4 | 5 | 6 | class VNVectorFields(torch.nn.Module): 7 | def __init__(self, dims, use_bn): 8 | super().__init__() 9 | 10 | # Setup lifting layer 11 | self.lifting_layer = VNLinear(dims[0] - 1, 1) 12 | 13 | # Setup VN-MLP 14 | layers = [] 15 | 16 | for i in range(len(dims)-2): 17 | layers += [VNLinearLeakyReLU(dims[i], dims[i+1], dim=4, use_bn=use_bn)] 18 | 19 | layers += [VNLinear(dims[-2], dims[-1])] 20 | 21 | self.layers = torch.nn.Sequential(*layers) 22 | 23 | def forward(self, z, t, x_t): 24 | # Construct scalar-list and vector-list 25 | s = t.unsqueeze(1) 26 | v = torch.cat((z, x_t[:, :3].transpose(1, 2)), dim=1) 27 | 28 | # Lift scalar-list to vector-list 29 | trans = self.lifting_layer(v) 30 | v_s = s @ trans 31 | 32 | # Concatenate 33 | v = torch.cat((v, v_s), dim=1) 34 | 35 | # Forward VN-MLP 36 | out = self.layers(v).contiguous() 37 | 38 | # Convert two 3-dim vectors to one 6-dim vector 39 | out = out.view(-1, 6) 40 | 41 | return out 42 | -------------------------------------------------------------------------------- /test_full.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from datetime import datetime 3 | import os 4 | from omegaconf import OmegaConf 5 | import logging 6 | import yaml 7 | import random 8 | import numpy as np 9 | import torch 10 | import pandas as pd 11 | import plotly.graph_objects as go 12 | 13 | from loaders import get_dataloader 14 | from models import get_model 15 | from metrics import get_metrics 16 | from utils.visualization import PlotlySubplotsVisualizer 17 | 18 | 19 | NUM_GRASPS = 100 20 | 21 | 22 | def main(args, cfg): 23 | seed = cfg.get('seed', 1) 24 | random.seed(seed) 25 | np.random.seed(seed) 26 | torch.manual_seed(seed) 27 | torch.cuda.manual_seed(seed) 28 | torch.set_num_threads(8) 29 | torch.backends.cudnn.deterministic = True 30 | 31 | # Setup testloader 32 | test_loader = get_dataloader('test', cfg.data.test) 33 | 34 | # Setup model 35 | model = get_model(cfg.model).to(cfg.device) 36 | 37 | # Setup metrics 38 | metrics = get_metrics(cfg.metrics) 39 | 40 | # Setup plotly visualizer 41 | visualizer = PlotlySubplotsVisualizer(rows=1, cols=test_loader.dataset.num_rots) 42 | 43 | # Start test 44 | results = test(model, test_loader, metrics, cfg.device, visualizer) 45 | 46 | # Print results 47 | print_results(test_loader, results) 48 | 49 | # Write xlsx 50 | log_results(args.logdir, test_loader, results) 51 | 52 | # Save plotly figure 53 | save_figure(args.logdir, visualizer) 54 | 55 | 56 | def test(model, test_loader, metrics, device, visualizer): 57 | # Initialize 58 | model.eval() 59 | 60 | # Get dataset 61 | obj_types = test_loader.dataset.obj_types 62 | obj_idxs_types = test_loader.dataset.obj_idxs_types 63 | pc_list_types = test_loader.dataset.pc_list_types 64 | mesh_list_types = test_loader.dataset.mesh_list_types 65 | Ts_grasp_list_types = test_loader.dataset.Ts_grasp_list_types 66 | 67 | # Get scale, maximum number of objects and number of rotations 68 | scale = test_loader.dataset.scale 69 | max_num_objs = test_loader.dataset.max_num_objs 70 | num_rots = test_loader.dataset.num_rots 71 | 72 | # Setup metric result arrays 73 | results = {key: np.full((len(pc_list_types), max_num_objs, num_rots), np.nan) for key in list(metrics.keys())} 74 | 75 | # Setup labels for button in plotly figure 76 | visualizer.labels = [] 77 | 78 | # Iterate 79 | for i, (obj_type, obj_idxs_objs, pc_list_objs, mesh_list_objs, Ts_grasp_list_objs) in enumerate(zip(obj_types, obj_idxs_types, pc_list_types, mesh_list_types, Ts_grasp_list_types)): 80 | for j, (obj_idx, pc_rots, mesh_list_rots, Ts_grasp_rots_target) in enumerate(zip(obj_idxs_objs, pc_list_objs, mesh_list_objs, Ts_grasp_list_objs)): 81 | # Setup input 82 | pc_rots = torch.Tensor(pc_rots).to(device) 83 | Ts_grasp_rots_target = torch.Tensor(Ts_grasp_rots_target).to(device) 84 | nums_grasps = torch.tensor([len(Ts_grasp_target) for Ts_grasp_target in Ts_grasp_rots_target], device=pc_rots.device) 85 | 86 | # Sample grasp poses 87 | Ts_grasp_rots_pred = model.sample(pc_rots, nums_grasps) 88 | 89 | # Compute metrics 90 | for k, (mesh, Ts_grasp_pred, Ts_grasp_target) in enumerate(zip(mesh_list_rots, Ts_grasp_rots_pred, Ts_grasp_rots_target)): 91 | # Setup message 92 | msg = f"object type: {obj_type}, object index: {obj_idx}, rotation index: {k}, " 93 | 94 | # Rescale mesh and grasp poses 95 | mesh.scale(1/scale, center=(0, 0, 0)) 96 | Ts_grasp_pred[:, :3, 3] /= scale 97 | Ts_grasp_target[:, :3, 3] /= scale 98 | 99 | for key, metric in metrics.items(): 100 | # Compute metrics 101 | result = metric(Ts_grasp_pred, Ts_grasp_target) 102 | 103 | # Add result to message 104 | msg += f"{key}: {result:.4f}, " 105 | 106 | # Fill array 107 | results[key][i, j, k] = result 108 | 109 | # Print result message 110 | print(msg) 111 | logging.info(msg) 112 | 113 | # Get indices for sampling grasp poses for visualization 114 | idxs = torch.randperm(len(Ts_grasp_pred))[:NUM_GRASPS] 115 | 116 | # Add mesh and gripper to visualizer 117 | visualizer.add_mesh(mesh, row=1, col=k+1) 118 | visualizer.add_grippers(Ts_grasp_pred[idxs], color='grey', row=1, col=k+1) 119 | 120 | visualizer.labels += [f'{obj_type}_{obj_idx}'] 121 | 122 | return results 123 | 124 | 125 | def print_results(test_loader, results): 126 | # Get object types and object ids 127 | obj_types = test_loader.dataset.obj_types 128 | 129 | # Print results 130 | for idx_type, obj_type in enumerate(obj_types): 131 | msg = f"object type: {obj_type}" 132 | 133 | for key in results.keys(): 134 | msg += f", {key}: {np.nanmean(results[key][idx_type]):.4f}" 135 | 136 | print(msg) 137 | logging.info(msg) 138 | 139 | 140 | def log_results(logdir, test_loader, results): 141 | # Get object types and object ids 142 | obj_types = test_loader.dataset.obj_types 143 | obj_idxs_types = test_loader.dataset.obj_idxs_types 144 | 145 | # Write xlsx 146 | for key, result in results.items(): 147 | with pd.ExcelWriter(os.path.join(logdir, f'{key}.xlsx')) as w: 148 | for obj_type, obj_idxs_objs, result_type in zip(obj_types, obj_idxs_types, result): 149 | df = pd.DataFrame(result_type[:len(obj_idxs_objs)], index=obj_idxs_objs) 150 | df.to_excel(w, sheet_name=obj_type) 151 | 152 | 153 | def save_figure(logdir, visualizer): 154 | # Get number of traces and number of subplots 155 | num_traces = len(visualizer.fig.data) 156 | num_subplots = visualizer.num_subplots 157 | 158 | # Make only the first scene visible 159 | for idx_trace in range(num_subplots*(1+NUM_GRASPS), num_traces): 160 | visualizer.fig.update_traces(visible=False, selector=idx_trace) 161 | 162 | # Make buttons list 163 | buttons = [] 164 | 165 | for idx_scene, label in enumerate(visualizer.labels): 166 | # Initialize visibility list 167 | visibility = num_traces * [False] 168 | 169 | # Make only the selected scene visible 170 | for idx_trace in range(num_subplots*(1+NUM_GRASPS)*idx_scene, num_subplots*(1+NUM_GRASPS)*(idx_scene+1)): 171 | visibility[idx_trace] = True 172 | 173 | # Make and append button 174 | button = dict(label=label, method='restyle', args=[{'visible': visibility}]) 175 | 176 | buttons += [button] 177 | 178 | # Update buttons 179 | visualizer.fig.update_layout(updatemenus=[go.layout.Updatemenu(active=0, buttons=buttons)]) 180 | 181 | # Save figure 182 | visualizer.fig.write_json(os.path.join(logdir, 'visualizations.json')) 183 | 184 | 185 | if __name__ == '__main__': 186 | # Parse arguments 187 | parser = argparse.ArgumentParser() 188 | 189 | parser.add_argument('--train_result_path', type=str) 190 | parser.add_argument('--checkpoint', type=str) 191 | parser.add_argument('--device', default=0) 192 | parser.add_argument('--logdir', default='test_results') 193 | parser.add_argument('--run', type=str, default=datetime.now().strftime('%Y%m%d-%H%M')) 194 | 195 | args = parser.parse_args() 196 | 197 | # Load config 198 | config_filename = [file for file in os.listdir(args.train_result_path) if file.endswith('.yml')][0] 199 | 200 | cfg = OmegaConf.load(os.path.join(args.train_result_path, config_filename)) 201 | 202 | # Setup checkpoint 203 | cfg.model.checkpoint = os.path.join(args.train_result_path, args.checkpoint) 204 | 205 | # Setup device 206 | if args.device == 'cpu': 207 | cfg.device = 'cpu' 208 | else: 209 | cfg.device = f'cuda:{args.device}' 210 | 211 | # Setup logdir 212 | config_basename = os.path.splitext(config_filename)[0] 213 | 214 | args.logdir = os.path.join(args.logdir, config_basename, args.run) 215 | 216 | os.makedirs(args.logdir, exist_ok=True) 217 | 218 | # Setup logging 219 | logging.basicConfig( 220 | filename=os.path.join(args.logdir, 'logging.log'), 221 | format='%(asctime)s [%(levelname)s] %(message)s', 222 | datefmt='%Y/%m/%d %I:%M:%S %p', 223 | level=logging.DEBUG 224 | ) 225 | 226 | # Print result directory 227 | print(f"Result directory: {args.logdir}") 228 | logging.info(f"Result directory: {args.logdir}") 229 | 230 | # Save config 231 | config_path = os.path.join(args.logdir, config_filename) 232 | yaml.dump(yaml.safe_load(OmegaConf.to_yaml(cfg)), open(config_path, 'w')) 233 | 234 | print(f"Config saved as {config_path}") 235 | logging.info(f"Config saved as {config_path}") 236 | 237 | main(args, cfg) 238 | -------------------------------------------------------------------------------- /test_partial.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from datetime import datetime 3 | import os 4 | from omegaconf import OmegaConf 5 | import logging 6 | import yaml 7 | import random 8 | import numpy as np 9 | import torch 10 | import plotly.graph_objects as go 11 | 12 | from loaders import get_dataloader 13 | from models import get_model 14 | from metrics import get_metrics 15 | from utils.visualization import PlotlySubplotsVisualizer 16 | 17 | 18 | NUM_GRASPS = 25 19 | 20 | 21 | def main(args, cfg): 22 | seed = cfg.get('seed', 1) 23 | random.seed(seed) 24 | np.random.seed(seed) 25 | torch.manual_seed(seed) 26 | torch.cuda.manual_seed(seed) 27 | torch.set_num_threads(8) 28 | torch.backends.cudnn.deterministic = True 29 | 30 | # Setup testloader 31 | test_loader = get_dataloader('test', cfg.data.test) 32 | 33 | # Setup model 34 | model = get_model(cfg.model).to(cfg.device) 35 | 36 | # Setup metrics 37 | metrics = get_metrics(cfg.metrics) 38 | 39 | # Setup plotly visualizer 40 | visualizer = PlotlySubplotsVisualizer(rows=test_loader.dataset.num_rots, cols=test_loader.dataset.num_views) 41 | visualizer.fig.update_layout(height=2700) 42 | 43 | # Start test 44 | results = test(args, model, test_loader, metrics, cfg.device, visualizer) 45 | 46 | # Print results 47 | print_results(test_loader, results) 48 | 49 | # Save plotly figure 50 | save_figure(args.logdir, visualizer) 51 | 52 | 53 | def test(args, model, test_loader, metrics, device, visualizer): 54 | # Initialize 55 | model.eval() 56 | 57 | # Get arguments 58 | logdir = args.logdir 59 | 60 | # Get dataset 61 | obj_types = test_loader.dataset.obj_types 62 | obj_idxs_types = test_loader.dataset.obj_idxs_types 63 | partial_pc_list_types = test_loader.dataset.partial_pc_list_types 64 | mesh_list_types = test_loader.dataset.mesh_list_types 65 | Ts_grasp_list_types = test_loader.dataset.Ts_grasp_list_types 66 | 67 | # Get scale, maximum number of objects and number of rotations 68 | scale = test_loader.dataset.scale if hasattr(test_loader.dataset, 'scale') else 1 69 | max_num_objs = test_loader.dataset.max_num_objs 70 | num_rots = test_loader.dataset.num_rots 71 | num_views = test_loader.dataset.num_views 72 | 73 | # Setup metric result arrays 74 | results = {key: np.full((len(partial_pc_list_types), max_num_objs, num_rots, num_views), np.nan) for key in list(metrics.keys())} 75 | 76 | # Setup labels for button in plotly figure 77 | visualizer.labels = [] 78 | 79 | # Iterate 80 | for i, (obj_type, obj_idxs_objs, partial_pc_list_objs, Ts_grasp_list_objs, mesh_list_objs) in enumerate(zip(obj_types, obj_idxs_types, partial_pc_list_types, Ts_grasp_list_types, mesh_list_types)): 81 | for j, (obj_idx, partial_pc_rots, Ts_grasp_rots_target, mesh_list_rots) in enumerate(zip(obj_idxs_objs, partial_pc_list_objs, Ts_grasp_list_objs, mesh_list_objs)): 82 | # Setup input 83 | Ts_grasp_rots_target = torch.Tensor(Ts_grasp_rots_target).to(device) 84 | 85 | for k, (partial_pc_views, Ts_grasp_views_target, mesh_list_views) in enumerate(zip(partial_pc_rots, Ts_grasp_rots_target, mesh_list_rots)): 86 | # Setup input 87 | partial_pc_views = torch.Tensor(partial_pc_views).to(device) 88 | nums_grasps = torch.tensor([Ts_grasp_views_target.shape[1]]*len(partial_pc_views), device=partial_pc_views.device) 89 | 90 | # Sample grasp poses 91 | Ts_grasp_views_pred = model.sample(partial_pc_views, nums_grasps) 92 | 93 | # Compute metrics 94 | for l, (partial_pc, Ts_grasp_pred, Ts_grasp_target, mesh) in enumerate(zip(partial_pc_views, Ts_grasp_views_pred, Ts_grasp_views_target, mesh_list_views)): 95 | # Setup message 96 | msg = f"object type: {obj_type}, object index: {obj_idx}, rotation index: {k}, viewpoint index: {l}, " 97 | 98 | # Rescale point cloud and grasp poses 99 | partial_pc /= scale 100 | Ts_grasp_pred[:, :3, 3] /= scale 101 | Ts_grasp_target[:, :3, 3] /= scale 102 | mesh.scale(1/scale, center=(0, 0, 0)) 103 | 104 | for key, metric in metrics.items(): 105 | # Compute metrics 106 | result = metric(Ts_grasp_pred, Ts_grasp_target) 107 | 108 | # Add result to message 109 | msg += f"{key}: {result:.4f}, " 110 | 111 | # Fill array 112 | results[key][i, j, k, l] = result 113 | 114 | # Print result message 115 | print(msg) 116 | logging.info(msg) 117 | 118 | # Get indices for sampling grasp poses for simulation 119 | idxs = torch.randperm(len(Ts_grasp_pred))[:NUM_GRASPS] 120 | 121 | # Add mesh, partial point cloud, and gripper to visualizer 122 | visualizer.add_mesh(mesh, row=k+1, col=l+1) 123 | 124 | visualizer.add_pc(partial_pc.cpu().numpy().T, row=k+1, col=l+1) 125 | visualizer.add_grippers(Ts_grasp_pred[idxs], color='grey', row=k+1, col=l+1) 126 | 127 | visualizer.labels += [f'{obj_type}_{obj_idx}'] 128 | 129 | return results 130 | 131 | 132 | def print_results(test_loader, results): 133 | # Get object types and object ids 134 | obj_types = test_loader.dataset.obj_types 135 | 136 | # Print results 137 | for idx_type, obj_type in enumerate(obj_types): 138 | msg = f"object type: {obj_type}" 139 | 140 | for key in results.keys(): 141 | msg += f", {key}: {np.nanmean(results[key][idx_type]):.4f}" 142 | 143 | print(msg) 144 | logging.info(msg) 145 | 146 | 147 | def save_figure(logdir, visualizer): 148 | # Get number of traces and number of subplots 149 | num_traces = len(visualizer.fig.data) 150 | num_subplots = visualizer.num_subplots 151 | 152 | # Make only the first scene visible 153 | for idx_trace in range(num_subplots*(2+NUM_GRASPS), num_traces): 154 | visualizer.fig.update_traces(visible=False, selector=idx_trace) 155 | 156 | # Make buttons list 157 | buttons = [] 158 | 159 | for idx_scene, label in enumerate(visualizer.labels): 160 | # Initialize visibility list 161 | visibility = num_traces * [False] 162 | 163 | # Make only the selected scene visible 164 | for idx_trace in range(num_subplots*(2+NUM_GRASPS)*idx_scene, num_subplots*(2+NUM_GRASPS)*(idx_scene+1)): 165 | visibility[idx_trace] = True 166 | 167 | # Make and append button 168 | button = dict(label=label, method='restyle', args=[{'visible': visibility}]) 169 | 170 | buttons += [button] 171 | 172 | # Update buttons 173 | visualizer.fig.update_layout(updatemenus=[go.layout.Updatemenu(active=0, buttons=buttons)]) 174 | 175 | # Save figure 176 | visualizer.fig.write_json(os.path.join(logdir, 'visualizations.json')) 177 | 178 | 179 | if __name__ == '__main__': 180 | # Parse arguments 181 | parser = argparse.ArgumentParser() 182 | 183 | parser.add_argument('--train_result_path', type=str) 184 | parser.add_argument('--checkpoint', type=str) 185 | parser.add_argument('--device', default=0) 186 | parser.add_argument('--logdir', default='test_results') 187 | parser.add_argument('--run', type=str, default=datetime.now().strftime('%Y%m%d-%H%M')) 188 | 189 | args = parser.parse_args() 190 | 191 | # Load config 192 | config_filename = [file for file in os.listdir(args.train_result_path) if file.endswith('.yml')][0] 193 | 194 | cfg = OmegaConf.load(os.path.join(args.train_result_path, config_filename)) 195 | 196 | # Setup checkpoint 197 | cfg.model.checkpoint = os.path.join(args.train_result_path, args.checkpoint) 198 | 199 | # Setup device 200 | if args.device == 'cpu': 201 | cfg.device = 'cpu' 202 | else: 203 | cfg.device = f'cuda:{args.device}' 204 | 205 | # Setup logdir 206 | config_basename = os.path.splitext(config_filename)[0] 207 | 208 | args.logdir = os.path.join(args.logdir, config_basename, args.run) 209 | 210 | os.makedirs(args.logdir, exist_ok=True) 211 | 212 | # Setup logging 213 | logging.basicConfig( 214 | filename=os.path.join(args.logdir, 'logging.log'), 215 | format='%(asctime)s [%(levelname)s] %(message)s', 216 | datefmt='%Y/%m/%d %I:%M:%S %p', 217 | level=logging.DEBUG 218 | ) 219 | 220 | # Print result directory 221 | print(f"Result directory: {args.logdir}") 222 | logging.info(f"Result directory: {args.logdir}") 223 | 224 | # Save config 225 | config_path = os.path.join(args.logdir, config_filename) 226 | yaml.dump(yaml.safe_load(OmegaConf.to_yaml(cfg)), open(config_path, 'w')) 227 | 228 | print(f"Config saved as {config_path}") 229 | logging.info(f"Config saved as {config_path}") 230 | 231 | main(args, cfg) 232 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from datetime import datetime 3 | from omegaconf import OmegaConf 4 | import os 5 | from tensorboardX import SummaryWriter 6 | import logging 7 | import yaml 8 | import random 9 | import numpy as np 10 | import torch 11 | 12 | from loaders import get_dataloader 13 | from models import get_model 14 | from losses import get_losses 15 | from utils.optimizers import get_optimizer 16 | from metrics import get_metrics 17 | from utils.logger import Logger 18 | from trainers import get_trainer 19 | 20 | 21 | def main(cfg, writer): 22 | # Setup seed 23 | seed = cfg.get('seed', 1) 24 | random.seed(seed) 25 | np.random.seed(seed) 26 | torch.manual_seed(seed) 27 | torch.cuda.manual_seed(seed) 28 | torch.set_num_threads(8) 29 | torch.backends.cudnn.deterministic = True 30 | 31 | # Setup dataloader 32 | dataloaders = {} 33 | 34 | for split in ['train', 'val']: 35 | dataloaders[split] = get_dataloader(split, cfg.data[split]) 36 | 37 | # Setup model 38 | model = get_model(cfg.model).to(cfg.device) 39 | 40 | # Setup losses 41 | losses = get_losses(cfg.losses) 42 | 43 | # Setup optimizer 44 | optimizer = get_optimizer(cfg.optimizer, model.parameters()) 45 | 46 | # Setup metrics 47 | metrics = get_metrics(cfg.metrics) 48 | 49 | # Setup logger 50 | logger = Logger(writer) 51 | 52 | # Setup trainer 53 | trainer = get_trainer(cfg.trainer, cfg.device, dataloaders, model, losses, optimizer, metrics, logger) 54 | 55 | # Start learning 56 | trainer.run() 57 | 58 | 59 | if __name__ == '__main__': 60 | # Parse arguments 61 | parser = argparse.ArgumentParser() 62 | 63 | parser.add_argument('--config', type=str) 64 | parser.add_argument('--device', default=0) 65 | parser.add_argument('--logdir', default='train_results') 66 | parser.add_argument('--run', type=str, default=datetime.now().strftime('%Y%m%d-%H%M')) 67 | 68 | args = parser.parse_args() 69 | 70 | # Load and print config 71 | cfg = OmegaConf.load(args.config) 72 | print(OmegaConf.to_yaml(cfg)) 73 | 74 | # Setup device 75 | if args.device == 'cpu': 76 | cfg.device = 'cpu' 77 | else: 78 | cfg.device = f'cuda:{args.device}' 79 | 80 | # Setup logdir 81 | config_filename = os.path.basename(args.config) 82 | config_basename = os.path.splitext(config_filename)[0] 83 | 84 | logdir = os.path.join(args.logdir, config_basename, args.run) 85 | 86 | # Setup tensorboard writer 87 | writer = SummaryWriter(logdir) 88 | 89 | # Setup logging 90 | logging.basicConfig( 91 | filename=os.path.join(logdir, 'logging.log'), 92 | format='%(asctime)s [%(levelname)s] %(message)s', 93 | datefmt='%Y/%m/%d %I:%M:%S %p', 94 | level=logging.DEBUG 95 | ) 96 | 97 | # Print logdir 98 | print(f"Result directory: {logdir}") 99 | logging.info(f"Result directory: {logdir}") 100 | 101 | # Save config 102 | config_path = os.path.join(logdir, config_filename) 103 | yaml.dump(yaml.safe_load(OmegaConf.to_yaml(cfg)), open(config_path, 'w')) 104 | 105 | print(f"Config saved as {config_path}") 106 | logging.info(f"Config saved as {config_path}") 107 | 108 | main(cfg, writer) 109 | -------------------------------------------------------------------------------- /trainers/__init__.py: -------------------------------------------------------------------------------- 1 | from trainers.grasp_trainer import GraspPoseGeneratorTrainer, PartialGraspPoseGeneratorTrainer 2 | 3 | 4 | def get_trainer(cfg_trainer, device, dataloaders, model, losses, optimizer, metrics, logger): 5 | name = cfg_trainer.pop('name') 6 | 7 | if name == 'grasp_full': 8 | trainer = GraspPoseGeneratorTrainer(cfg_trainer, device, dataloaders, model, losses, optimizer, metrics, logger) 9 | elif name == 'grasp_partial': 10 | trainer = PartialGraspPoseGeneratorTrainer(cfg_trainer, device, dataloaders, model, losses, optimizer, metrics, logger) 11 | else: 12 | raise NotImplementedError(f"Trainer {name} is not implemented.") 13 | 14 | return trainer 15 | -------------------------------------------------------------------------------- /trainers/grasp_trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import time 3 | import logging 4 | from tqdm import tqdm 5 | import numpy as np 6 | from copy import deepcopy 7 | import os 8 | 9 | from utils.average_meter import AverageMeter 10 | from utils.mesh import generate_grasp_scene_list, meshes_to_numpy 11 | 12 | 13 | NUM_GRASPS = 100 14 | 15 | 16 | class GraspPoseGeneratorTrainer: 17 | def __init__(self, cfg, device, dataloaders, model, losses, optimizer, metrics, logger): 18 | self.cfg = cfg 19 | self.train_loader = dataloaders['train'] 20 | self.val_loader = dataloaders['val'] 21 | self.device = device 22 | self.model = model 23 | self.losses = losses 24 | self.optimizer = optimizer 25 | self.logger = logger 26 | self.metrics = metrics 27 | 28 | # Get logdir 29 | self.logdir = self.logger.writer.file_writer.get_logdir() 30 | 31 | # Setup meters 32 | self.setup_meters() 33 | 34 | # Initialize performance dictionary 35 | self.setup_performance_dict() 36 | 37 | def setup_meters(self): 38 | # Setup time meter 39 | self.time_meter = AverageMeter() 40 | 41 | # Setup scalar meters for train 42 | for data in self.train_loader: 43 | break 44 | 45 | for key, val in data.items(): 46 | if type(val) == torch.Tensor: 47 | data[key] = val.to(self.device) 48 | elif type(val) == list: 49 | data[key] = [v.to(self.device) for v in val] 50 | 51 | with torch.no_grad(): 52 | results_train = self.model.step(data, self.losses, 'train') 53 | results_val = self.model.step(data, self.losses, 'val') 54 | 55 | self.train_meters = {key: AverageMeter() for key in results_train.keys() if 'scalar' in key} 56 | self.val_meters = {key: AverageMeter() for key in results_val.keys() if 'scalar' in key} 57 | 58 | # Setup metric meters 59 | self.metric_meters = {key: AverageMeter() for key in self.metrics.keys()} 60 | 61 | def setup_performance_dict(self): 62 | self.performances = {'val_loss': torch.inf} 63 | 64 | for criterion in self.cfg.criteria: 65 | assert criterion.name in self.metrics.keys(), f"Criterion {criterion.name} not in metrics keys {self.metrics.keys()}." 66 | 67 | if criterion.better == 'higher': 68 | self.performances[criterion.name] = 0 69 | elif criterion.better == 'lower': 70 | self.performances[criterion.name] = torch.inf 71 | else: 72 | raise ValueError(f"Criterion better with {criterion.better} value is not supported. Choose 'higher' or 'lower'.") 73 | 74 | def run(self): 75 | # Initialize 76 | iter = 0 77 | 78 | # Start learning 79 | for epoch in range(1, self.cfg.num_epochs+1): 80 | for data in self.train_loader: 81 | iter += 1 82 | 83 | # Training 84 | results_train = self.train(data) 85 | 86 | # Print 87 | if iter % self.cfg.print_interval == 0: 88 | self.print(results_train, epoch, iter) 89 | 90 | # Validation 91 | if iter % self.cfg.val_interval == 0: 92 | self.validate(epoch, iter) 93 | 94 | # Evaluation 95 | if iter % self.cfg.eval_interval == 0: 96 | self.evaluate(epoch, iter) 97 | 98 | # Visualization 99 | if iter % self.cfg.vis_interval == 0: 100 | self.visualize(epoch, iter) 101 | 102 | # Save 103 | if iter % self.cfg.save_interval == 0: 104 | self.save(epoch, iter) 105 | 106 | def train(self, data): 107 | # Initialize 108 | self.model.train() 109 | self.optimizer.zero_grad() 110 | 111 | # Setup input 112 | for key, val in data.items(): 113 | if type(val) == torch.Tensor: 114 | data[key] = val.to(self.device) 115 | elif type(val) == list: 116 | data[key] = [v.to(self.device) for v in val] 117 | 118 | # Step 119 | time_start = time.time() 120 | 121 | results = self.model.step(data, self.losses, 'train', self.optimizer) 122 | 123 | time_end = time.time() 124 | 125 | # Update time meter 126 | self.time_meter.update(time_end - time_start) 127 | 128 | # Update train meters 129 | for key, meter in self.train_meters.items(): 130 | meter.update(results[key], n=len(data['pc'])) 131 | 132 | return results 133 | 134 | def print(self, results, epoch, iter): 135 | # Get averaged train results 136 | for key, meter in self.train_meters.items(): 137 | results[key] = meter.avg 138 | 139 | # Log averaged train results 140 | self.logger.log(results, iter) 141 | 142 | # Print averaged train results 143 | msg = f"[ Training ] epoch: {epoch}/{self.cfg.num_epochs}, iter: {iter}, " 144 | msg += ", ".join([f"{key.split('/')[-1]}: {meter.avg:.4f}" for key, meter in self.train_meters.items()]) 145 | msg += f", elapsed time: {self.time_meter.sum:.4f}" 146 | 147 | print(msg) 148 | logging.info(msg) 149 | 150 | # Reset time meter and train meters 151 | self.time_meter.reset() 152 | 153 | for key, meter in self.train_meters.items(): 154 | meter.reset() 155 | 156 | def validate(self, epoch, iter): 157 | # Initialize 158 | self.model.eval() 159 | 160 | time_start = time.time() 161 | 162 | with torch.no_grad(): 163 | for data in tqdm(self.val_loader, desc="Validating ...", leave=False): 164 | # Setup input 165 | for key, val in data.items(): 166 | if type(val) == torch.Tensor: 167 | data[key] = val.to(self.device) 168 | elif type(val) == list: 169 | data[key] = [v.to(self.device) for v in val] 170 | 171 | # Step 172 | results = self.model.step(data, self.losses, 'val') 173 | 174 | # Update validation meters 175 | for key, meter in self.val_meters.items(): 176 | meter.update(results[key], n=len(data['pc'])) 177 | 178 | time_end = time.time() 179 | 180 | # Get averaged validation results 181 | for key, meter in self.val_meters.items(): 182 | results[key] = meter.avg 183 | 184 | # Log averaged validation results 185 | self.logger.log(results, iter) 186 | 187 | # Print averaged validation results 188 | msg = f"[ Validation ] epoch: {epoch}/{self.cfg.num_epochs}, iter: {iter}, " 189 | msg += ", ".join([f"{key.split('/')[-1]}: {meter.avg:.4f}" for key, meter in self.val_meters.items()]) 190 | msg += f", elapsed time: {time_end-time_start:.4f}" 191 | 192 | print(msg) 193 | logging.info(msg) 194 | 195 | # Determine best validation loss 196 | val_loss = self.val_meters['scalar/val/loss'].avg 197 | 198 | if val_loss < self.performances['val_loss']: 199 | # Save model 200 | self.save(epoch, iter, criterion='val_loss', data={'val_loss': val_loss}) 201 | 202 | # Update best validation loss 203 | self.performances['val_loss'] = val_loss 204 | 205 | # Reset meters 206 | for key, meter in self.val_meters.items(): 207 | meter.reset() 208 | 209 | def evaluate(self, epoch, iter): 210 | # Initialize 211 | self.model.eval() 212 | 213 | # Get dataset and scale 214 | pc_list_types = self.val_loader.dataset.pc_list_types 215 | Ts_grasp_list_types = self.val_loader.dataset.Ts_grasp_list_types 216 | mesh_list_types = deepcopy(self.val_loader.dataset.mesh_list_types) 217 | 218 | scale = self.val_loader.dataset.scale if hasattr(self.val_loader.dataset, 'scale') else 1 219 | 220 | time_start = time.time() 221 | 222 | # Iterate object types 223 | for pc_list_objs, Ts_grasp_list_objs, mesh_list_objs in zip(tqdm(pc_list_types, desc="Evaluating for object types ...", leave=False), Ts_grasp_list_types, mesh_list_types): 224 | # Setup metric meters for objects 225 | metric_meters_objs = {key: AverageMeter() for key in self.metrics.keys()} 226 | 227 | # Iterate objects 228 | for pc_rots, Ts_grasp_rots_target, mesh_list_rots in zip(tqdm(pc_list_objs, desc="Evaluating for objects ...", leave=False), Ts_grasp_list_objs, mesh_list_objs): 229 | # Setup metric meters for rotations 230 | metric_meters_rots = {key: AverageMeter() for key in self.metrics.keys()} 231 | 232 | # Setup input 233 | pc_rots = torch.Tensor(pc_rots).to(self.device) 234 | Ts_grasp_rots_target = torch.Tensor(Ts_grasp_rots_target).to(self.device) 235 | nums_grasps = torch.tensor([len(Ts_grasp_target) for Ts_grasp_target in Ts_grasp_rots_target], device=pc_rots.device) 236 | 237 | # Generate grasp poses 238 | Ts_grasp_rots_pred = self.model.sample(pc_rots, nums_grasps) 239 | 240 | # Rescale grasp poses and mesh 241 | for Ts_grasp_pred, Ts_grasp_target, mesh in zip(Ts_grasp_rots_pred, Ts_grasp_rots_target, mesh_list_rots): 242 | Ts_grasp_pred[:, :3, 3] /= scale 243 | Ts_grasp_target[:, :3, 3] /= scale 244 | mesh.scale(1/scale, center=(0, 0, 0)) 245 | 246 | # Compute metrics for rotations 247 | for Ts_grasp_pred, Ts_grasp_target, mesh in zip(Ts_grasp_rots_pred, Ts_grasp_rots_target, mesh_list_rots): 248 | for key, metric in self.metrics.items(): 249 | if key == 'collision_rate': 250 | # Get indices for sampling grasp poses for simulation 251 | assert NUM_GRASPS <= len(Ts_grasp_pred), f"Number of grasps for simulation ({NUM_GRASPS}) must be less than or equal to the number of grasps predicted ({len(Ts_grasp_pred)})." 252 | 253 | idxs = torch.randperm(len(Ts_grasp_pred))[:NUM_GRASPS] 254 | 255 | metric_meters_rots[key].update(metric(mesh, Ts_grasp_pred[idxs])) 256 | else: 257 | metric_meters_rots[key].update(metric(Ts_grasp_pred, Ts_grasp_target)) 258 | 259 | # Compute metrics for objects 260 | for key, meter in metric_meters_objs.items(): 261 | meter.update(metric_meters_rots[key].avg) 262 | 263 | # Compute metrics for object types 264 | for key, meter in self.metric_meters.items(): 265 | meter.update(metric_meters_objs[key].avg) 266 | 267 | time_end = time.time() 268 | 269 | # Get averaged evaluation results 270 | results = {} 271 | 272 | for key, meter in self.metric_meters.items(): 273 | results[f'scalar/metrics/{key}'] = meter.avg 274 | 275 | # Log averaged evaluation results 276 | self.logger.log(results, iter) 277 | 278 | # Print averaged evaluation results 279 | msg = f"[ Evaluation ] epoch: {epoch}/{self.cfg.num_epochs}, iter: {iter}, " 280 | msg += ", ".join([f"{key}: {meter.avg:.4f}" for key, meter in self.metric_meters.items()]) 281 | msg += f", elapsed time: {time_end-time_start:.4f}" 282 | 283 | print(msg) 284 | logging.info(msg) 285 | 286 | # Save model if best evaluation performance 287 | for criterion in self.cfg.criteria: 288 | # Determine best performance 289 | performance = self.metric_meters[criterion.name].avg 290 | 291 | if criterion.better == 'higher' and performance > self.performances[criterion.name]: 292 | best = True 293 | elif criterion.better == 'lower' and performance < self.performances[criterion.name]: 294 | best = True 295 | else: 296 | best = False 297 | 298 | if best: 299 | # Save model 300 | self.save(epoch, iter, criterion=criterion.name, data={criterion.name: performance}) 301 | 302 | # Update best validation loss 303 | self.performances[criterion.name] = performance 304 | 305 | # Reset metric meters 306 | for key, meter in self.metric_meters.items(): 307 | meter.reset() 308 | 309 | def visualize(self, epoch, iter): 310 | # Initialize 311 | self.model.eval() 312 | 313 | time_start = time.time() 314 | 315 | mesh_list = [] 316 | pc_list = [] 317 | Ts_grasp_pred_list = [] 318 | Ts_grasp_target_list = [] 319 | 320 | # Get random data indices 321 | idxs = np.random.choice(self.val_loader.dataset.num_scenes, size=3, replace=False) 322 | 323 | # Get scale 324 | scale = self.val_loader.dataset.scale if hasattr(self.val_loader.dataset, 'scale') else 1 325 | 326 | for idx in idxs: 327 | idx_type, idx_obj, idx_rot = np.where(self.val_loader.dataset.scene_idxs==idx) 328 | 329 | idx_type = idx_type.item() 330 | idx_obj = idx_obj.item() 331 | idx_rot = idx_rot.item() 332 | 333 | # Get input 334 | mesh = deepcopy(self.val_loader.dataset.mesh_list_types[idx_type][idx_obj][idx_rot]) 335 | pc = self.val_loader.dataset.pc_list_types[idx_type][idx_obj][idx_rot] 336 | Ts_grasp_target = self.val_loader.dataset.Ts_grasp_list_types[idx_type][idx_obj][idx_rot] 337 | 338 | # Sample ground-truth grasp poses 339 | idxs_grasp = np.random.choice(len(Ts_grasp_target), size=10, replace=False) 340 | Ts_grasp_target = Ts_grasp_target[idxs_grasp] 341 | 342 | # Append data to list 343 | mesh_list += [mesh] 344 | pc_list += [torch.Tensor(pc)] 345 | Ts_grasp_target_list += [Ts_grasp_target] 346 | 347 | # Setup input 348 | pc = torch.stack(pc_list).to(self.device) 349 | nums_grasps = torch.tensor([10, 10, 10], device=self.device) 350 | 351 | # Generate grasp poses 352 | Ts_grasp_pred_list = self.model.sample(pc, nums_grasps) 353 | Ts_grasp_pred_list = [Ts_grasp_pred.cpu().numpy() for Ts_grasp_pred in Ts_grasp_pred_list] 354 | 355 | # Rescale mesh and grasp poses 356 | for mesh, Ts_grasp_pred, Ts_grasp_target in zip(mesh_list, Ts_grasp_pred_list, Ts_grasp_target_list): 357 | mesh.scale(1/scale, center=(0, 0, 0)) 358 | Ts_grasp_pred[:, :3, 3] /= scale 359 | Ts_grasp_target[:, :3, 3] /= scale 360 | 361 | # Generate scene 362 | scene_list_pred = generate_grasp_scene_list(mesh_list, Ts_grasp_pred_list) 363 | scene_list_target = generate_grasp_scene_list(mesh_list, Ts_grasp_target_list) 364 | 365 | # Get vertices, triangles and colors 366 | vertices_pred, triangles_pred, colors_pred = meshes_to_numpy(scene_list_pred) 367 | vertices_target, triangles_target, colors_target = meshes_to_numpy(scene_list_target) 368 | 369 | time_end = time.time() 370 | 371 | # Get visualization results 372 | results = { 373 | 'mesh/pred': {'vertices': vertices_pred, 'colors': colors_pred, 'faces': triangles_pred}, 374 | 'mesh/target': {'vertices': vertices_target, 'colors': colors_target, 'faces': triangles_target} 375 | } 376 | 377 | # Log visualization results 378 | self.logger.log(results, iter) 379 | 380 | # Print visualization status 381 | msg = f"[Visualization] epoch: {epoch}/{self.cfg.num_epochs}, iter: {iter}" 382 | msg += f", elapsed time: {time_end-time_start:.4f}" 383 | 384 | print(msg) 385 | logging.info(msg) 386 | 387 | def save(self, epoch, iter, criterion=None, data={}): 388 | # Set save name 389 | if criterion is None: 390 | save_name = f'model_iter_{iter}.pkl' 391 | else: 392 | save_name = f'model_best_{criterion}.pkl' 393 | 394 | # Construct object to save 395 | object = { 396 | 'epoch': epoch, 397 | 'iter': iter, 398 | 'model_state': self.model.state_dict(), 399 | 'optimizer': self.optimizer.state_dict(), 400 | } 401 | object.update(data) 402 | 403 | # Save object 404 | save_path = os.path.join(self.logdir, save_name) 405 | 406 | torch.save(object, save_path) 407 | 408 | # Print save status 409 | string = f"[ Save ] epoch: {epoch}/{self.cfg.num_epochs}, iter: {iter}, save {save_name}" 410 | 411 | if criterion is not None: 412 | string += f", {criterion}: {data[criterion]:.6f} / best_{criterion}: {self.performances[criterion]:.6f}" 413 | 414 | print(string) 415 | logging.info(string) 416 | 417 | 418 | class PartialGraspPoseGeneratorTrainer(GraspPoseGeneratorTrainer): 419 | def evaluate(self, epoch, iter): 420 | # Initialize 421 | self.model.eval() 422 | 423 | # Get dataset and scale 424 | partial_pc_list_types = self.val_loader.dataset.partial_pc_list_types 425 | Ts_grasp_list_types = self.val_loader.dataset.Ts_grasp_list_types 426 | mesh_list_types = deepcopy(self.val_loader.dataset.mesh_list_types) 427 | 428 | scale = self.val_loader.dataset.scale if hasattr(self.val_loader.dataset, 'scale') else 1 429 | 430 | time_start = time.time() 431 | 432 | # Iterate object types 433 | for partial_pc_list_objs, Ts_grasp_list_objs, mesh_list_objs in zip(tqdm(partial_pc_list_types, desc="Evaluating for object types ...", leave=False), Ts_grasp_list_types, mesh_list_types): 434 | # Setup metric meters for objects 435 | metric_meters_objs = {key: AverageMeter() for key in self.metrics.keys()} 436 | 437 | # Iterate objects 438 | for partial_pc_rots, Ts_grasp_rots_target, mesh_list_rots in zip(tqdm(partial_pc_list_objs, desc="Evaluating for objects ...", leave=False), Ts_grasp_list_objs, mesh_list_objs): 439 | # Setup metric meters for rotations 440 | metric_meters_rots = {key: AverageMeter() for key in self.metrics.keys()} 441 | 442 | # Setup input 443 | Ts_grasp_rots_target = torch.Tensor(Ts_grasp_rots_target).to(self.device) 444 | 445 | # Iterate rotations 446 | for partial_pc_views, Ts_grasp_views_target, mesh_list_views in zip(tqdm(partial_pc_rots, desc="Evaluating for rotations ...", leave=False), Ts_grasp_rots_target, mesh_list_rots): 447 | # Setup metric meters for viewpoints 448 | metric_meters_views = {key: AverageMeter() for key in self.metrics.keys()} 449 | 450 | # Setup input 451 | partial_pc_views = torch.Tensor(partial_pc_views).to(self.device) 452 | nums_grasps = torch.tensor([Ts_grasp_views_target.shape[1]]*len(partial_pc_views), device=partial_pc_views.device) 453 | 454 | # Generate grasp poses 455 | Ts_grasp_views_pred = self.model.sample(partial_pc_views, nums_grasps) 456 | 457 | for Ts_grasp_pred, Ts_grasp_target, mesh in zip(Ts_grasp_views_pred, Ts_grasp_views_target, mesh_list_views): 458 | # Rescale grasp poses and mesh 459 | Ts_grasp_pred[:, :3, 3] /= scale 460 | Ts_grasp_target[:, :3, 3] /= scale 461 | mesh.scale(1/scale, center=(0, 0, 0)) 462 | 463 | # Compute metrics for viewpoints 464 | for key, metric in self.metrics.items(): 465 | if key == 'collision_rate': 466 | # Get indices for sampling grasp poses for simulation 467 | assert NUM_GRASPS <= len(Ts_grasp_pred), f"Number of grasps for simulation ({NUM_GRASPS}) must be less than or equal to the number of grasps predicted ({len(Ts_grasp_pred)})." 468 | 469 | idxs = torch.randperm(len(Ts_grasp_pred))[:NUM_GRASPS] 470 | 471 | metric_meters_views[key].update(metric(mesh, Ts_grasp_pred[idxs])) 472 | else: 473 | metric_meters_views[key].update(metric(Ts_grasp_pred, Ts_grasp_target)) 474 | 475 | # Compute metrics for rotations 476 | for key, meter in metric_meters_objs.items(): 477 | meter.update(metric_meters_views[key].avg) 478 | 479 | # Compute metrics for objects 480 | for key, meter in metric_meters_objs.items(): 481 | meter.update(metric_meters_rots[key].avg) 482 | 483 | # Compute metrics for object types 484 | for key, meter in self.metric_meters.items(): 485 | meter.update(metric_meters_objs[key].avg) 486 | 487 | time_end = time.time() 488 | 489 | # Get averaged evaluation results 490 | results = {} 491 | 492 | for key, meter in self.metric_meters.items(): 493 | results[f'scalar/metrics/{key}'] = meter.avg 494 | 495 | # Log averaged evaluation results 496 | self.logger.log(results, iter) 497 | 498 | # Print averaged evaluation results 499 | msg = f"[ Evaluation ] epoch: {epoch}/{self.cfg.num_epochs}, iter: {iter}, " 500 | msg += ", ".join([f"{key}: {meter.avg:.4f}" for key, meter in self.metric_meters.items()]) 501 | msg += f", elapsed time: {time_end-time_start:.4f}" 502 | 503 | print(msg) 504 | logging.info(msg) 505 | 506 | # Save model if best evaluation performance 507 | for criterion in self.cfg.criteria: 508 | # Determine best performance 509 | performance = self.metric_meters[criterion.name].avg 510 | 511 | if criterion.better == 'higher' and performance > self.performances[criterion.name]: 512 | best = True 513 | elif criterion.better == 'lower' and performance < self.performances[criterion.name]: 514 | best = True 515 | else: 516 | best = False 517 | 518 | if best: 519 | # Save model 520 | self.save(epoch, iter, criterion=criterion.name, data={criterion.name: performance}) 521 | 522 | # Update best validation loss 523 | self.performances[criterion.name] = performance 524 | 525 | # Reset metric meters 526 | for key, meter in self.metric_meters.items(): 527 | meter.reset() 528 | 529 | def visualize(self, epoch, iter): 530 | # Initialize 531 | self.model.eval() 532 | 533 | time_start = time.time() 534 | 535 | mesh_list = [] 536 | partial_pc_list = [] 537 | Ts_grasp_pred_list = [] 538 | Ts_grasp_target_list = [] 539 | 540 | # Get random data indices 541 | idxs = np.random.choice(self.val_loader.dataset.num_scenes, size=3, replace=False) 542 | 543 | # Get scale 544 | scale = self.val_loader.dataset.scale if hasattr(self.val_loader.dataset, 'scale') else 1 545 | 546 | for idx in idxs: 547 | idx_type, idx_obj, idx_rot, idx_view = np.where(self.val_loader.dataset.scene_idxs==idx) 548 | 549 | idx_type = idx_type.item() 550 | idx_obj = idx_obj.item() 551 | idx_rot = idx_rot.item() 552 | idx_view = idx_view.item() 553 | 554 | # Get input 555 | mesh = deepcopy(self.val_loader.dataset.mesh_list_types[idx_type][idx_obj][idx_rot][idx_view]) 556 | partial_pc = self.val_loader.dataset.partial_pc_list_types[idx_type][idx_obj][idx_rot][idx_view] 557 | Ts_grasp_target = self.val_loader.dataset.Ts_grasp_list_types[idx_type][idx_obj][idx_rot][idx_view] 558 | 559 | # Sample ground-truth grasp poses 560 | idxs_grasp = np.random.choice(len(Ts_grasp_target), size=10, replace=False) 561 | Ts_grasp_target = Ts_grasp_target[idxs_grasp] 562 | 563 | # Append data to list 564 | mesh_list += [mesh] 565 | partial_pc_list += [torch.Tensor(partial_pc)] 566 | Ts_grasp_target_list += [Ts_grasp_target] 567 | 568 | # Setup input 569 | partial_pc = torch.stack(partial_pc_list).to(self.device) 570 | nums_grasps = torch.tensor([10, 10, 10], device=self.device) 571 | 572 | # Generate grasp poses 573 | Ts_grasp_pred_list = self.model.sample(partial_pc, nums_grasps) 574 | Ts_grasp_pred_list = [Ts_grasp_pred.cpu().numpy() for Ts_grasp_pred in Ts_grasp_pred_list] 575 | 576 | # Rescale mesh and grasp poses 577 | for mesh, Ts_grasp_pred, Ts_grasp_target in zip(mesh_list, Ts_grasp_pred_list, Ts_grasp_target_list): 578 | mesh.scale(1/scale, center=(0, 0, 0)) 579 | Ts_grasp_pred[:, :3, 3] /= scale 580 | Ts_grasp_target[:, :3, 3] /= scale 581 | 582 | # Generate scene 583 | scene_list_pred = generate_grasp_scene_list(mesh_list, Ts_grasp_pred_list) 584 | scene_list_target = generate_grasp_scene_list(mesh_list, Ts_grasp_target_list) 585 | 586 | # Get vertices, triangles and colors 587 | vertices_pred, triangles_pred, colors_pred = meshes_to_numpy(scene_list_pred) 588 | vertices_target, triangles_target, colors_target = meshes_to_numpy(scene_list_target) 589 | 590 | time_end = time.time() 591 | 592 | # Get visualization results 593 | results = { 594 | 'mesh/pred': {'vertices': vertices_pred, 'colors': colors_pred, 'faces': triangles_pred}, 595 | 'mesh/target': {'vertices': vertices_target, 'colors': colors_target, 'faces': triangles_target} 596 | } 597 | 598 | # Log visualization results 599 | self.logger.log(results, iter) 600 | 601 | # Print visualization status 602 | msg = f"[Visualization] epoch: {epoch}/{self.cfg.num_epochs}, iter: {iter}" 603 | msg += f", elapsed time: {time_end-time_start:.4f}" 604 | 605 | print(msg) 606 | logging.info(msg) 607 | -------------------------------------------------------------------------------- /utils/Lie.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from scipy.spatial.transform import Rotation 4 | 5 | 6 | EPS = 1e-4 7 | 8 | 9 | def is_SO3(R): 10 | test_0 = torch.allclose(R @ R.transpose(1, 2), torch.eye(3).repeat(len(R), 1, 1).to(R), atol=EPS) 11 | test_1 = torch.allclose(R.transpose(1, 2) @ R, torch.eye(3).repeat(len(R), 1, 1).to(R), atol=EPS) 12 | 13 | test = test_0 and test_1 14 | 15 | return test 16 | 17 | 18 | def is_SE3(T): 19 | test_0 = is_SO3(T[:, :3, :3]) 20 | test_1 = torch.equal(T[:, 3, :3], torch.zeros_like(T[:, 3, :3])) 21 | test_2 = torch.equal(T[:, 3, 3], torch.ones_like(T[:, 3, 3])) 22 | 23 | test = test_0 and test_1 and test_2 24 | 25 | return test 26 | 27 | 28 | def inv_SO3(R): 29 | assert R.shape[1:] == (3, 3), f"inv_SO3: input must be of shape (N, 3, 3). Current shape: {tuple(R.shape)}" 30 | assert is_SO3(R), "inv_SO3: input must be SO(3) matrices" 31 | 32 | inv_R = R.transpose(1, 2) 33 | 34 | return inv_R 35 | 36 | 37 | def inv_SE3(T): 38 | assert T.shape[1:] == (4, 4), f"inv_SE3: input must be of shape (N, 4, 4). Current shape: {tuple(T.shape)}" 39 | assert is_SE3(T), "inv_SE3: input must be SE(3) matrices" 40 | 41 | R = T[:, :3, :3] 42 | p = T[:, :3, 3] 43 | 44 | inv_T = torch.eye(4).repeat(len(T), 1, 1).to(T) 45 | inv_T[:, :3, :3] = inv_SO3(R) 46 | inv_T[:, :3, 3] = - torch.einsum('nij,nj->ni', inv_SO3(R), p) 47 | 48 | return inv_T 49 | 50 | 51 | def bracket_so3(w): 52 | # vector -> matrix 53 | if w.shape[1:] == (3,): 54 | zeros = w.new_zeros(len(w)) 55 | 56 | out = torch.stack([ 57 | torch.stack([zeros, -w[:, 2], w[:, 1]], dim=1), 58 | torch.stack([w[:, 2], zeros, -w[:, 0]], dim=1), 59 | torch.stack([-w[:, 1], w[:, 0], zeros], dim=1) 60 | ], dim=1) 61 | 62 | # matrix -> vector 63 | elif w.shape[1:] == (3, 3): 64 | out = torch.stack([w[:, 2, 1], w[:, 0, 2], w[:, 1, 0]], dim=1) 65 | 66 | else: 67 | raise f"bracket_so3: input must be of shape (N, 3) or (N, 3, 3). Current shape: {tuple(w.shape)}" 68 | 69 | return out 70 | 71 | 72 | def bracket_se3(S): 73 | # vector -> matrix 74 | if S.shape[1:] == (6,): 75 | w_mat = bracket_so3(S[:, :3]) 76 | 77 | out = torch.cat(( 78 | torch.cat((w_mat, S[:, 3:].unsqueeze(2)), dim=2), 79 | S.new_zeros(len(S), 1, 4) 80 | ), dim=1) 81 | 82 | # matrix -> vector 83 | elif S.shape[1:] == (4, 4): 84 | w_vec = bracket_so3(S[:, :3, :3]) 85 | 86 | out = torch.cat((w_vec, S[:, :3, 3]), dim=1) 87 | 88 | else: 89 | raise f"bracket_se: input must be of shape (N, 6) or (N, 4, 4). Current shape: {tuple(S.shape)}" 90 | 91 | return out 92 | 93 | 94 | def log_SO3(R): 95 | # return logSO3(R) 96 | n = R.shape[0] 97 | assert R.shape == (n, 3, 3), f"log_SO3: input must be of shape (N, 3, 3). Current shape: {tuple(R.shape)}" 98 | assert is_SO3(R), "log_SO3: input must be SO(3) matrices" 99 | 100 | tr_R = torch.diagonal(R, dim1=1, dim2=2).sum(1) 101 | w_mat = torch.zeros_like(R) 102 | theta = torch.acos(torch.clamp((tr_R - 1) / 2, -1 + EPS, 1 - EPS)) 103 | 104 | is_regular = (tr_R + 1 > EPS) 105 | is_singular = (tr_R + 1 <= EPS) 106 | 107 | theta = theta.unsqueeze(1).unsqueeze(2) 108 | 109 | w_mat_regular = (1 / (2 * torch.sin(theta[is_regular]) + EPS)) * (R[is_regular] - R[is_regular].transpose(1, 2)) * theta[is_regular] 110 | 111 | w_mat_singular = (R[is_singular] - torch.eye(3).to(R)) / 2 112 | 113 | w_vec_singular = torch.sqrt(torch.diagonal(w_mat_singular, dim1=1, dim2=2) + 1) 114 | w_vec_singular[torch.isnan(w_vec_singular)] = 0 115 | 116 | w_1 = w_vec_singular[:, 0] 117 | w_2 = w_vec_singular[:, 1] * (torch.sign(w_mat_singular[:, 0, 1]) + (w_1 == 0)) 118 | w_3 = w_vec_singular[:, 2] * torch.sign(4 * torch.sign(w_mat_singular[:, 0, 2]) + 2 * (w_1 == 0) * torch.sign(w_mat_singular[:, 1, 2]) + 1 * (w_1 == 0) * (w_2 == 0)) 119 | 120 | w_vec_singular = torch.stack([w_1, w_2, w_3], dim=1) 121 | 122 | w_mat[is_regular] = w_mat_regular 123 | w_mat[is_singular] = bracket_so3(w_vec_singular) * torch.pi 124 | 125 | return w_mat 126 | 127 | 128 | def log_SE3(T): 129 | assert T.shape[1:] == (4, 4), f"log_SE3: input must be of shape (N, 4, 4). Current shape: {tuple(T.shape)}" 130 | assert is_SE3(T), "log_SE3: input must be SE(3) matrices" 131 | 132 | R = T[:, :3, :3] 133 | p = T[:, :3, 3] 134 | 135 | tr_R = torch.diagonal(R, dim1=1, dim2=2).sum(1) 136 | theta = torch.acos(torch.clamp((tr_R - 1) / 2, -1 + EPS, 1 - EPS)).unsqueeze(1).unsqueeze(2) 137 | 138 | w_mat = log_SO3(R) 139 | w_mat_hat = w_mat / (theta + EPS) 140 | 141 | inv_G = torch.eye(3).repeat(len(T), 1, 1).to(T) - (theta / 2) * w_mat_hat + (1 - (theta / (2 * torch.tan(theta / 2) + EPS))) * w_mat_hat @ w_mat_hat 142 | 143 | S = torch.zeros_like(T) 144 | S[:, :3, :3] = w_mat 145 | S[:, :3, 3] = torch.einsum('nij,nj->ni', inv_G, p) 146 | 147 | return S 148 | 149 | 150 | def exp_so3(w_vec): 151 | if w_vec.shape[1:] == (3, 3): 152 | w_vec = bracket_so3(w_vec) 153 | elif w_vec.shape[1:] != (3,): 154 | raise f"exp_so3: input must be of shape (N, 3) or (N, 3, 3). Current shape: {tuple(w_vec.shape)}" 155 | 156 | R = torch.eye(3).repeat(len(w_vec), 1, 1).to(w_vec) 157 | 158 | theta = w_vec.norm(dim=1) 159 | 160 | is_regular = theta > EPS 161 | 162 | w_vec_regular = w_vec[is_regular] 163 | theta_regular = theta[is_regular] 164 | 165 | theta_regular = theta_regular.unsqueeze(1) 166 | 167 | w_mat_hat_regular = bracket_so3(w_vec_regular / theta_regular) 168 | 169 | theta_regular = theta_regular.unsqueeze(2) 170 | 171 | R[is_regular] = torch.eye(3).repeat(len(w_vec_regular), 1, 1).to(w_vec_regular) + torch.sin(theta_regular) * w_mat_hat_regular + (1 - torch.cos(theta_regular)) * w_mat_hat_regular @ w_mat_hat_regular 172 | 173 | return R 174 | 175 | 176 | def exp_se3(S): 177 | if S.shape[1:] == (4, 4): 178 | S = bracket_se3(S) 179 | elif S.shape[1:] != (6,): 180 | raise f"exp_se3: input must be of shape (N, 6) or (N, 4, 4). Current shape: {tuple(S.shape)}" 181 | 182 | w_vec = S[:, :3] 183 | p = S[:, 3:] 184 | 185 | T = torch.eye(4).repeat(len(S), 1, 1).to(S) 186 | 187 | theta = w_vec.norm(dim=1) 188 | 189 | is_regular = theta > EPS 190 | is_singular = theta <= EPS 191 | 192 | w_vec_regular = w_vec[is_regular] 193 | theta_regular = theta[is_regular] 194 | 195 | theta_regular = theta_regular.unsqueeze(1) 196 | 197 | w_mat_hat_regular = bracket_so3(w_vec_regular / theta_regular) 198 | 199 | theta_regular = theta_regular.unsqueeze(2) 200 | 201 | G = theta_regular * torch.eye(3).repeat(len(S), 1, 1).to(S) + (1 - torch.cos(theta_regular)) * w_mat_hat_regular + (theta_regular - torch.cos(theta_regular)) * w_mat_hat_regular @ w_mat_hat_regular 202 | 203 | T[is_regular, :3, :3] = exp_so3(w_vec_regular) 204 | T[is_regular, :3, 3] = torch.einsum('nij,nj->ni', G, p) 205 | 206 | T[is_singular, :3, :3] = torch.eye(3).repeat(is_singular.sum(), 1, 1) 207 | T[is_singular, :3, 3] = p 208 | 209 | return T 210 | 211 | 212 | def large_adjoint(T): 213 | assert T.shape[1:] == (4, 4), f"large_adjoint: input must be of shape (N, 4, 4). Current shape: {tuple(T.shape)}" 214 | assert is_SE3(T), "large_adjoint: input must be SE(3) matrices" 215 | 216 | R = T[:, :3, :3] 217 | p = T[:, :3, 3] 218 | 219 | large_adj = T.new_zeros(len(T), 6, 6) 220 | large_adj[:, :3, :3] = R 221 | large_adj[:, 3:, :3] = bracket_so3(p) @ R 222 | large_adj[:, 3:, 3:] = R 223 | 224 | return large_adj 225 | 226 | 227 | def small_adjoint(S): 228 | if S.shape[1:] == (4, 4): 229 | w_mat = S[:, :3, :3] 230 | v_mat = bracket_so3(S[:, :3, 3]) 231 | elif S.shape[1:] == (6,): 232 | w_mat = bracket_so3(S[:, :3]) 233 | v_mat = bracket_so3(S[:, 3:]) 234 | else: 235 | raise f"small_adj: input must be of shape (N, 6) or (N, 4, 4). Current shape: {tuple(S.shape)}" 236 | 237 | small_adj = S.new_zeros(len(S), 6, 6) 238 | small_adj[:, :3, :3] = w_mat 239 | small_adj[:, 3:, :3] = v_mat 240 | small_adj[:, 3:, 3:] = w_mat 241 | 242 | return small_adj 243 | 244 | 245 | def Lie_bracket(u, v): 246 | if u.shape[1:] == (3,): 247 | u = bracket_so3(u) 248 | elif u.shape[1:] == (6,): 249 | u = bracket_se3(u) 250 | 251 | if v.shape[1:] == (3,): 252 | v = bracket_so3(v) 253 | elif v.shape[1:] == (6,): 254 | v = bracket_se3(v) 255 | 256 | return u @ v - v @ u 257 | 258 | 259 | def is_quat(quat): 260 | test = torch.allclose(quat.norm(dim=1), quat.new_ones(len(quat))) 261 | 262 | return test 263 | 264 | 265 | def super_fibonacci_spiral(num_Rs): 266 | phi = 1.414213562304880242096980 # sqrt(2) 267 | psi = 1.533751168755204288118041 268 | 269 | s = np.arange(num_Rs) + 1 / 2 270 | 271 | t = s / num_Rs 272 | d = 2 * np.pi * s 273 | 274 | r = np.sqrt(t) 275 | R = np.sqrt(1 - t) 276 | 277 | alpha = d / phi 278 | beta = d / psi 279 | 280 | quats = np.stack([r * np.sin(alpha), r * np.cos(alpha), R * np.sin(beta), R * np.cos(beta)], axis=1) 281 | 282 | Rs = Rotation.from_quat(quats).as_matrix() 283 | 284 | return Rs 285 | 286 | 287 | def SE3_geodesic_dist(T_1, T_2): 288 | assert len(T_1) == len(T_2), f"SE3_geodesic_dist: inputs must have the same batch_size. Current shapes: T_1 - {tuple(T_1.shape)}, T_2 - {tuple(T_2.shape)}" 289 | assert is_SE3(T_1) and is_SE3(T_2), "SE3_geodesic_dist: inputs must be SE(3) matrices" 290 | 291 | R_1 = T_1[:, :3, :3] 292 | R_2 = T_2[:, :3, :3] 293 | p_1 = T_1[:, :3, 3] 294 | p_2 = T_2[:, :3, 3] 295 | 296 | delta_R = bracket_so3(log_SO3(torch.einsum('bij,bjk->bik', inv_SO3(R_1), R_2))) 297 | delta_p = p_1 - p_2 298 | 299 | dist = (delta_R ** 2 + delta_p ** 2).sum(1).sqrt() 300 | 301 | return dist 302 | 303 | 304 | def get_fibonacci_sphere(num_points): 305 | points = [] 306 | 307 | phi = np.pi * (np.sqrt(5.) - 1.) # golden angle in radians 308 | 309 | for i in range(num_points): 310 | y = 1 - (i / float(num_points - 1)) * 2 # y goes from 1 to -1 311 | radius = np.sqrt(1 - y * y) # radius at y 312 | 313 | theta = phi * i # golden angle increment 314 | 315 | x = np.cos(theta) * radius 316 | z = np.sin(theta) * radius 317 | 318 | points += [np.array([x, y, z])] 319 | 320 | points = np.stack(points) 321 | 322 | return points 323 | -------------------------------------------------------------------------------- /utils/average_meter.py: -------------------------------------------------------------------------------- 1 | class AverageMeter: 2 | def __init__(self): 3 | self.reset() 4 | 5 | def reset(self): 6 | self.avg = 0 7 | self.sum = 0 8 | self.count = 0 9 | 10 | def update(self, val, n=1): 11 | self.sum += val * n 12 | self.count += n 13 | self.avg = self.sum /self.count 14 | -------------------------------------------------------------------------------- /utils/distributions.py: -------------------------------------------------------------------------------- 1 | import roma 2 | import torch 3 | 4 | 5 | def get_dist(cfg): 6 | name = cfg.pop('name') 7 | 8 | if name == 'SO3_uniform_R3_normal': 9 | dist_fn = SO3_uniform_R3_normal 10 | elif name == 'SO3_uniform_R3_spherical': 11 | dist_fn = SO3_uniform_R3_spherical 12 | elif name == 'SO3_centripetal_R3_normal': 13 | dist_fn = SO3_centripetal_R3_normal 14 | elif name == 'SO3_centripetal_R3_spherical': 15 | dist_fn = SO3_centripetal_R3_spherical 16 | else: 17 | raise NotImplementedError(f"Distribution {name} is not implemented.") 18 | 19 | return dist_fn 20 | 21 | 22 | def SO3_uniform_R3_normal(num_samples, device): 23 | R = roma.random_rotmat(num_samples).to(device) 24 | 25 | p = torch.randn(num_samples, 3).to(device) 26 | 27 | T = torch.eye(4).repeat(num_samples, 1, 1).to(device) 28 | T[:, :3, :3] = R 29 | T[:, :3, 3] = p 30 | 31 | return T 32 | 33 | 34 | def SO3_uniform_R3_spherical(num_samples, device): 35 | R = roma.random_rotmat(num_samples).to(device) 36 | 37 | p = torch.randn(num_samples, 3).to(device) 38 | p /= p.norm(dim=-1, keepdim=True) 39 | 40 | T = torch.eye(4).repeat(num_samples, 1, 1).to(device) 41 | T[:, :3, :3] = R 42 | T[:, :3, 3] = p 43 | 44 | return T 45 | 46 | 47 | def SO3_centripetal_R3_normal(num_samples, device): 48 | R = roma.random_rotmat(num_samples).to(device) 49 | 50 | p = - (0.112 * 5 + torch.randn(num_samples, 1).to(device).abs()) * R[:, :, 2] 51 | 52 | T = torch.eye(4).repeat(num_samples, 1, 1).to(device) 53 | T[:, :3, :3] = R 54 | T[:, :3, 3] = p 55 | 56 | return T 57 | 58 | 59 | def SO3_centripetal_R3_spherical(num_samples, device): 60 | R = roma.random_rotmat(num_samples).to(device) 61 | 62 | p = - R[:, :, 2] 63 | 64 | T = torch.eye(4).repeat(num_samples, 1, 1).to(device) 65 | T[:, :3, :3] = R 66 | T[:, :3, 3] = p 67 | 68 | return T 69 | -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | class Logger: 2 | def __init__(self, writer): 3 | self.writer = writer 4 | 5 | def log(self, results, iter): 6 | for key, val in results.items(): 7 | if 'scalar' in key: 8 | self.writer.add_scalar(key.replace('scalar/', ''), val, iter) 9 | 10 | elif 'image' in key and 'images' not in key: 11 | self.writer.add_image(key.replace('image/', ''), val, iter) 12 | 13 | elif 'images' in key: 14 | self.writer.add_images(key.replace('images/', ''), val, iter) 15 | 16 | elif 'mesh' in key: 17 | self.writer.add_mesh(key.replace('mesh/', ''), vertices=val['vertices'], colors=val['colors'], faces=val['faces'], global_step=iter) 18 | -------------------------------------------------------------------------------- /utils/mesh.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | import open3d as o3d 3 | import numpy as np 4 | 5 | 6 | def generate_grasp_scene_list(mesh_list, Ts_grasp_list): 7 | scene_list = [] 8 | 9 | for mesh, Ts_grasp in zip(mesh_list, Ts_grasp_list): 10 | scene = deepcopy(mesh) 11 | 12 | for T_grasp in Ts_grasp: 13 | mesh_base_1 = o3d.geometry.TriangleMesh.create_cylinder(radius=0.002, height=0.066, resolution=6, split=1) 14 | T_base_1 = np.eye(4) 15 | T_base_1[:3, 3] = [0, 0, 0.033] 16 | mesh_base_1.transform(T_base_1) 17 | 18 | mesh_base_2 = o3d.geometry.TriangleMesh.create_cylinder(radius=0.002, height=0.082, resolution=6, split=1) 19 | T_base_2 = np.eye(4) 20 | T_base_2[:3, :3] = mesh_base_2.get_rotation_matrix_from_xyz([0, np.pi/2, 0]) 21 | T_base_2[:3, 3] = [0, 0, 0.066] 22 | mesh_base_2.transform(T_base_2) 23 | 24 | mesh_left_finger = o3d.geometry.TriangleMesh.create_cylinder(radius=0.002, height=0.046, resolution=6, split=1) 25 | T_left_finger = np.eye(4) 26 | T_left_finger[:3, 3] = [-0.041, 0, 0.089] 27 | mesh_left_finger.transform(T_left_finger) 28 | 29 | mesh_right_finger = o3d.geometry.TriangleMesh.create_cylinder(radius=0.002, height=0.046, resolution=6, split=1) 30 | T_right_finger = np.eye(4) 31 | T_right_finger[:3, 3] = [0.041, 0, 0.089] 32 | mesh_right_finger.transform(T_right_finger) 33 | 34 | mesh_gripper = mesh_base_1 + mesh_base_2 + mesh_left_finger + mesh_right_finger 35 | mesh_gripper.transform(T_grasp) 36 | 37 | scene += mesh_gripper 38 | 39 | scene.compute_vertex_normals() 40 | scene.paint_uniform_color([0.5, 0.5, 0.5]) 41 | 42 | scene_list += [scene] 43 | 44 | return scene_list 45 | 46 | 47 | def meshes_to_numpy(scenes): 48 | # Initialize 49 | vertices_np = [] 50 | triangles_np = [] 51 | colors_np = [] 52 | 53 | # Get maximum number of vertices and triangles 54 | max_num_vertices = max([len(scene.vertices) for scene in scenes]) 55 | max_num_triangles = max([len(scene.triangles) for scene in scenes]) 56 | 57 | # Match dimension between batches for Tensorboard 58 | for scene in scenes: 59 | diff_num_vertices = max_num_vertices - len(scene.vertices) 60 | diff_num_triangles = max_num_triangles - len(scene.triangles) 61 | 62 | vertices_np += [np.concatenate((np.asarray(scene.vertices), np.zeros((diff_num_vertices, 3))), axis=0)] 63 | triangles_np += [np.concatenate((np.asarray(scene.triangles), np.zeros((diff_num_triangles, 3))), axis=0)] 64 | colors_np += [np.concatenate((255 * np.asarray(scene.vertex_colors), np.zeros((diff_num_vertices, 3))), axis=0)] 65 | 66 | # Stack to single numpy array 67 | vertices_np = np.stack(vertices_np, axis=0) 68 | triangles_np = np.stack(triangles_np, axis=0) 69 | colors_np = np.stack(colors_np, axis=0) 70 | 71 | return vertices_np, triangles_np, colors_np 72 | -------------------------------------------------------------------------------- /utils/ode_solvers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from copy import deepcopy 3 | 4 | from utils.Lie import bracket_so3, exp_so3, Lie_bracket 5 | 6 | 7 | def get_ode_solver(cfg): 8 | name = cfg.pop('name') 9 | 10 | if name == 'SE3_Euler': 11 | solver = SE3_Euler(**cfg) 12 | elif name == 'SE3_RK_mk': 13 | solver = SE3_RK4_MK(**cfg) 14 | else: 15 | raise NotImplementedError(f"ODE solver {name} is not implemented.") 16 | 17 | return solver 18 | 19 | 20 | class SE3_Euler: 21 | def __init__(self, num_steps): 22 | self.t = torch.linspace(0, 1, num_steps + 1) 23 | 24 | @torch.no_grad() 25 | def __call__(self, z, x_0, func): 26 | # Initialize 27 | t = self.t.to(z.device) 28 | dt = t[1:] - t[:-1] 29 | traj = x_0.new_zeros(x_0.shape[0:1] + t.shape + x_0.shape[1:]) 30 | traj[:, 0] = x_0 31 | 32 | for n in range(len(t)-1): 33 | # Get n-th values 34 | x_n = traj[:, n].contiguous() 35 | t_n = t[n].repeat(len(x_0), 1) 36 | h = dt[n].repeat(len(x_0), 1) 37 | 38 | ##### Stage 1 ##### 39 | # Set function input 40 | x_hat = deepcopy(x_n) 41 | 42 | # Get vector (V_s) 43 | V_1 = func(z, t_n, x_hat) 44 | w_1 = V_1[:, :3] 45 | v_1 = V_1[:, 3:] 46 | 47 | # Change w_s to w_b and transform to matrix 48 | w_1 = torch.einsum('bji,bj->bi', x_hat[:, :3, :3], w_1) 49 | w_1 = bracket_so3(w_1) 50 | 51 | ##### Update ##### 52 | traj[:, n+1] = deepcopy(x_n) 53 | traj[:, n+1, :3, :3] @= exp_so3(h.unsqueeze(-1) * w_1) 54 | traj[:, n+1, :3, 3] += h * v_1 55 | 56 | return traj 57 | 58 | 59 | class SE3_RK4_MK: 60 | def __init__(self, num_steps): 61 | self.t = torch.linspace(0, 1, num_steps + 1) 62 | 63 | @torch.no_grad() 64 | def __call__(self, z, x_0, func): 65 | # Initialize 66 | t = self.t.to(z.device) 67 | dt = t[1:] - t[:-1] 68 | traj = x_0.new_zeros(x_0.shape[0:1] + t.shape + x_0.shape[1:]) 69 | traj[:, 0] = x_0 70 | 71 | for n in range(len(t)-1): 72 | # Get n-th values 73 | x_n = traj[:, n].contiguous() 74 | t_n = t[n].repeat(len(x_0), 1) 75 | h = dt[n].repeat(len(x_0), 1) 76 | 77 | ##### Stage 1 ##### 78 | # Set function input 79 | x_hat_1 = x_n 80 | 81 | # Get vector (V_s) 82 | V_1 = func(z, t_n, x_hat_1) 83 | w_1 = V_1[:, :3] 84 | v_1 = V_1[:, 3:] 85 | 86 | # Change w_s to w_b and transform to matrix 87 | w_1 = torch.einsum('bji,bj->bi', x_hat_1[:, :3, :3], w_1) 88 | w_1 = bracket_so3(w_1) 89 | 90 | # Set I_1 91 | I_1 = w_1 92 | 93 | ##### Stage 2 ##### 94 | u_2 = h.unsqueeze(-1) * (1 / 2) * w_1 95 | u_2 += (h.unsqueeze(-1) / 12) * Lie_bracket(I_1, u_2) 96 | 97 | # Set function input 98 | x_hat_2 = deepcopy(x_n) 99 | x_hat_2[:, :3, :3] @= exp_so3(u_2) 100 | x_hat_2[:, :3, 3] += h * (v_1 / 2) 101 | 102 | # Get vector (V_s) 103 | V_2 = func(z, t_n + (h / 2), x_hat_2) 104 | w_2 = V_2[:, :3] 105 | v_2 = V_2[:, 3:] 106 | 107 | # Change w_s to w_b and transform to matrix 108 | w_2 = torch.einsum('bji,bj->bi', x_hat_2[:, :3, :3], w_2) 109 | w_2 = bracket_so3(w_2) 110 | 111 | ##### Stage 3 ##### 112 | u_3 = h.unsqueeze(-1) * (1 / 2) * w_2 113 | u_3 += (h.unsqueeze(-1) / 12) * Lie_bracket(I_1, u_3) 114 | 115 | # Set function input 116 | x_hat_3 = deepcopy(x_n) 117 | x_hat_3[:, :3, :3] @= exp_so3(u_3) 118 | x_hat_3[:, :3, 3] += h * (v_2 / 2) 119 | 120 | # Get vector (V_s) 121 | V_3 = func(z, t_n + (h / 2), x_hat_3) 122 | w_3 = V_3[:, :3] 123 | v_3 = V_3[:, 3:] 124 | 125 | # Change w_s to w_b and transform to matrix 126 | w_3 = torch.einsum('bji,bj->bi', x_hat_3[:, :3, :3], w_3) 127 | w_3 = bracket_so3(w_3) 128 | 129 | ##### Stage 4 ##### 130 | u_4 = h.unsqueeze(-1) * w_3 131 | u_4 += (h.unsqueeze(-1) / 6) * Lie_bracket(I_1, u_4) 132 | 133 | # Set function input 134 | x_hat_4 = deepcopy(x_n) 135 | x_hat_4[:, :3, :3] @= exp_so3(u_4) 136 | x_hat_4[:, :3, 3] += h * v_3 137 | 138 | # Get vector (V_s) 139 | V_4 = func(z, t_n + h, x_hat_4) 140 | w_4 = V_4[:, :3] 141 | v_4 = V_4[:, 3:] 142 | 143 | # Change w_s to w_b and transform to matrix 144 | w_4 = torch.einsum('bji,bj->bi', x_hat_4[:, :3, :3], w_4) 145 | w_4 = bracket_so3(w_4) 146 | 147 | ##### Update ##### 148 | I_2 = (2 * (w_2 - I_1) + 2 * (w_3 - I_1) - (w_4 - I_1)) / h.unsqueeze(-1) 149 | u = h.unsqueeze(-1) * (1 / 6 * w_1 + 1 / 3 * w_2 + 1 / 3 * w_3 + 1 / 6 * w_4) 150 | u += (h.unsqueeze(-1) / 4) * Lie_bracket(I_1, u) + ((h ** 2).unsqueeze(-1) / 24) * Lie_bracket(I_2, u) 151 | 152 | traj[:, n+1] = deepcopy(x_n) 153 | traj[:, n+1, :3, :3] @= exp_so3(u) 154 | traj[:, n+1, :3, 3] += (h / 6) * (v_1 + 2 * v_2 + 2 * v_3 + v_4) 155 | 156 | return traj 157 | -------------------------------------------------------------------------------- /utils/optimizers.py: -------------------------------------------------------------------------------- 1 | from torch.optim import SGD, Adam, ASGD, Adamax, Adadelta, Adagrad, RMSprop 2 | 3 | 4 | def get_optimizer(cfg, model_params): 5 | name = cfg.pop('name') 6 | 7 | optimizer_class = get_optimizer_class(name) 8 | 9 | optimizer = optimizer_class(model_params, **cfg) 10 | 11 | return optimizer 12 | 13 | 14 | def get_optimizer_class(name): 15 | try: 16 | return { 17 | 'sgd': SGD, 18 | 'adam': Adam, 19 | 'asgd': ASGD, 20 | 'adamax': Adamax, 21 | 'adadelta': Adadelta, 22 | 'adagrad': Adagrad, 23 | 'rmsprop': RMSprop, 24 | }[name] 25 | except: 26 | raise NotImplementedError(f"Optimizer {name} is not available.") 27 | -------------------------------------------------------------------------------- /utils/partial_point_cloud.py: -------------------------------------------------------------------------------- 1 | import open3d as o3d 2 | import numpy as np 3 | from tqdm import tqdm 4 | 5 | 6 | def get_partial_point_clouds(mesh, view_vecs, num_points, visible_visualizer=False, use_tqdm=False, check_partial_pc=False): 7 | # Check open3d version 8 | assert o3d.__version__.split('.')[0] == '0' and o3d.__version__.split('.')[1] == '16', \ 9 | f"open3d version must be 0.16, 'ctr.convert_from_pinhole_camera_parameters(camera_params)' doesn't work well in later versions" 10 | 11 | # Set distance from object center to camera 12 | distance = 1.5 * np.linalg.norm(mesh.get_oriented_bounding_box().extent) 13 | 14 | # Set visualizer 15 | vis = o3d.visualization.Visualizer() 16 | vis.create_window(visible=visible_visualizer) 17 | 18 | ctr = vis.get_view_control() 19 | camera_params = ctr.convert_to_pinhole_camera_parameters() 20 | 21 | # Add mesh 22 | vis.add_geometry(mesh) 23 | 24 | # Set camera poses 25 | view_unit_vecs = view_vecs / np.linalg.norm(view_vecs, axis=1, keepdims=True) 26 | 27 | cam_z_s = - view_unit_vecs 28 | 29 | while True: 30 | cam_x_s = -1 + 2 * np.random.rand(len(cam_z_s), 3) 31 | cam_x_s = cam_x_s - np.sum(cam_x_s*cam_z_s, axis=1, keepdims=True) * cam_z_s 32 | 33 | if np.linalg.norm(cam_x_s, axis=1).any() == 0: 34 | continue 35 | else: 36 | cam_x_s /= np.linalg.norm(cam_x_s, axis=1, keepdims=True) 37 | 38 | break 39 | 40 | cam_y_s = np.cross(cam_z_s, cam_x_s) 41 | cam_y_s /= np.linalg.norm(cam_y_s, axis=1, keepdims=True) 42 | 43 | cam_Ts = np.tile(np.eye(4), (len(view_vecs), 1, 1)) 44 | cam_Ts[:, :3, :3] = np.stack([cam_x_s, cam_y_s, cam_z_s], axis=2) 45 | cam_Ts[:, :3, 3] = distance * view_unit_vecs 46 | 47 | # Get partial point clouds 48 | partial_pcds = [] 49 | 50 | if use_tqdm: 51 | pbar = tqdm(cam_Ts, desc="Iterating viewpoints ...", leave=False) 52 | else: 53 | pbar = cam_Ts 54 | 55 | for cam_T in pbar: 56 | # Set camera extrinsic parameters 57 | camera_params.extrinsic = np.linalg.inv(cam_T) 58 | 59 | ctr.convert_from_pinhole_camera_parameters(camera_params) 60 | 61 | # Update visualizer 62 | vis.poll_events() 63 | vis.update_renderer() 64 | 65 | # Get partial point cloud 66 | depth = vis.capture_depth_float_buffer() 67 | 68 | partial_pcd = o3d.geometry.PointCloud.create_from_depth_image(depth, camera_params.intrinsic, camera_params.extrinsic) 69 | 70 | # Raise Exception if the number of points in point cloud is less than 'num_points' 71 | if len(np.asarray(partial_pcd.points)) < num_points: 72 | raise Exception("Point cloud has an insufficient number of points. Increase visualizer window width and height.") 73 | 74 | # Downsample point cloud to match the number of points with 'num_points' 75 | else: 76 | voxel_size = 0.5 77 | voxel_size_min = 0 78 | voxel_size_max = 1 79 | 80 | while True: 81 | partial_pcd_tmp = partial_pcd.voxel_down_sample(voxel_size) 82 | 83 | num_points_tmp = len(np.asarray(partial_pcd_tmp.points)) 84 | 85 | if num_points_tmp - num_points >= 0 and num_points_tmp - num_points < 100: 86 | break 87 | else: 88 | if num_points_tmp > num_points: 89 | voxel_size_min = voxel_size 90 | elif num_points_tmp < num_points: 91 | voxel_size_max = voxel_size 92 | 93 | voxel_size = (voxel_size_min + voxel_size_max) / 2 94 | 95 | partial_pcd = partial_pcd_tmp.select_by_index(np.random.choice(num_points_tmp, num_points, replace=False)) 96 | 97 | partial_pcds += [partial_pcd] 98 | 99 | vis.destroy_window() 100 | 101 | # Check obtained partial point cloud with mesh 102 | if check_partial_pc: 103 | for partial_pcd in partial_pcds: 104 | o3d.visualization.draw_geometries([mesh, partial_pcd]) 105 | 106 | # Convert open3d PointCloud to numpy array 107 | partial_pcs = np.stack([np.asarray(partial_pcd.points) for partial_pcd in partial_pcds]) 108 | 109 | return partial_pcs 110 | 111 | 112 | class PartialPointCloudExtractor: 113 | def __init__(self): 114 | # set offscreen rendering 115 | width = 128 116 | height = 128 117 | 118 | self.renderer = o3d.visualization.rendering.OffscreenRenderer(width, height) 119 | 120 | # Set intrinsic parameters 121 | fx = fy = 110.85125168 122 | cx = (width - 1) / 2 123 | cy = (height - 1) / 2 124 | 125 | self.intrinsic = o3d.camera.PinholeCameraIntrinsic(width, height, fx, fy, cx, cy) 126 | 127 | def extract(self, mesh, view_vecs, num_points): 128 | # set distance from object center to camera 129 | distance = np.linalg.norm(mesh.get_oriented_bounding_box().extent) 130 | 131 | # add mesh 132 | self.renderer.scene.add_geometry('mesh', mesh, o3d.visualization.rendering.MaterialRecord()) 133 | 134 | # set camera poses 135 | view_unit_vecs = view_vecs / np.linalg.norm(view_vecs, axis=1, keepdims=True) 136 | 137 | cam_z_s = - view_unit_vecs 138 | 139 | while True: 140 | cam_x_s = -1 + 2 * np.random.rand(len(cam_z_s), 3) 141 | cam_x_s = cam_x_s - np.sum(cam_x_s*cam_z_s, axis=1, keepdims=True) * cam_z_s 142 | 143 | if np.linalg.norm(cam_x_s, axis=1).any() == 0: 144 | continue 145 | else: 146 | cam_x_s /= np.linalg.norm(cam_x_s, axis=1, keepdims=True) 147 | 148 | break 149 | 150 | cam_y_s = np.cross(cam_z_s, cam_x_s) 151 | cam_y_s /= np.linalg.norm(cam_y_s, axis=1, keepdims=True) 152 | 153 | cam_Ts = np.tile(np.eye(4), (len(view_vecs), 1, 1)) 154 | cam_Ts[:, :3, :3] = np.stack([cam_x_s, cam_y_s, cam_z_s], axis=2) 155 | cam_Ts[:, :3, 3] = distance * view_unit_vecs 156 | 157 | # Get partial point clouds 158 | partial_pcds = [] 159 | 160 | for cam_T in cam_Ts: 161 | # set extrinsic parameters 162 | extrinsic = np.linalg.inv(cam_T) 163 | 164 | # Set camera 165 | self.renderer.setup_camera(self.intrinsic, extrinsic) 166 | 167 | # Get depth image 168 | depth_image = self.renderer.render_to_depth_image(z_in_view_space=True) 169 | 170 | # get partial point cloud 171 | partial_pcd = o3d.geometry.PointCloud.create_from_depth_image(depth_image, self.intrinsic, extrinsic) 172 | 173 | pts = np.asarray(partial_pcd.points) 174 | pts = pts[~np.isnan(pts).any(1)] 175 | 176 | partial_pcd = o3d.geometry.PointCloud(points=o3d.utility.Vector3dVector(pts)) 177 | 178 | # raise Exception if the number of points in point cloud is less than 'num_points' 179 | if len(np.asarray(partial_pcd.points)) < num_points: 180 | raise Exception("Point cloud has an insufficient number of points. Increase visualizer window width and height.") 181 | 182 | # downsample point cloud to match the number of points with 'num_points' 183 | else: 184 | voxel_size = 0.5 185 | voxel_size_min = 0 186 | voxel_size_max = 1 187 | 188 | while True: 189 | partial_pcd_tmp = partial_pcd.voxel_down_sample(voxel_size) 190 | 191 | num_points_tmp = len(np.asarray(partial_pcd_tmp.points)) 192 | 193 | if num_points_tmp - num_points >= 0 and num_points_tmp - num_points < 100: 194 | break 195 | else: 196 | if num_points_tmp > num_points: 197 | voxel_size_min = voxel_size 198 | elif num_points_tmp < num_points: 199 | voxel_size_max = voxel_size 200 | 201 | voxel_size = (voxel_size_min + voxel_size_max) / 2 202 | 203 | partial_pcd = partial_pcd_tmp.select_by_index(np.random.choice(num_points_tmp, num_points, replace=False)) 204 | 205 | partial_pcds += [partial_pcd] 206 | 207 | # convert open3d PointCloud to numpy array 208 | partial_pcs = np.stack([np.asarray(partial_pcd.points) for partial_pcd in partial_pcds]) 209 | 210 | # Delete mesh 211 | self.renderer.scene.remove_geometry('mesh') 212 | 213 | return partial_pcs 214 | -------------------------------------------------------------------------------- /utils/visualization.py: -------------------------------------------------------------------------------- 1 | from plotly.subplots import make_subplots 2 | from plotly import graph_objects as go 3 | import numpy as np 4 | import torch 5 | 6 | 7 | class PlotlySubplotsVisualizer: 8 | def __init__(self, rows, cols): 9 | self.num_subplots = rows * cols 10 | 11 | self.reset(rows, cols) 12 | 13 | def reset(self, rows, cols): 14 | self.fig = make_subplots(rows=rows, cols=cols, specs=[[{'is_3d': True}]*cols]*rows) 15 | self.fig.update_layout(height=900) 16 | 17 | def add_vector(self, x, y, z, u, v, w, row, col, color='black', width=5, sizeref=0.2, showlegend=False): 18 | self.fig.add_trace( 19 | go.Scatter3d(x=[x, x+0.9*u], y=[y, y+0.9*v], z=[z, z+0.9*w], mode='lines', line=dict(color=color, width=width), showlegend=showlegend), 20 | row=row, col=col 21 | ) 22 | self.fig.add_trace( 23 | go.Cone(x=[x+u], y=[y+v], z=[z+w], u=[u], v=[v], w=[w], sizemode='absolute', sizeref=sizeref, anchor='tip', colorscale=[[0, color], [1, color]], showscale=False), 24 | row=row, col=col 25 | ) 26 | 27 | def add_mesh(self, mesh, row, col, color='aquamarine', idx='', showlegend=False): 28 | vertices = np.asarray(mesh.vertices) 29 | traingles = np.asarray(mesh.triangles) 30 | 31 | self.fig.add_trace( 32 | go.Mesh3d(x=vertices[:, 0], y=vertices[:, 1], z=vertices[:, 2], i=traingles[:, 0], j=traingles[:, 1], k=traingles[:, 2], color=color, showlegend=showlegend, name='mesh '+str(idx)), 33 | row=row, col=col 34 | ) 35 | 36 | def add_pc(self, pc, row, col, color='lightpink', size=5, idx='', showlegend=False): 37 | self.fig.add_trace( 38 | go.Scatter3d(x=pc[:,0], y=pc[:,1], z=pc[:,2], mode='markers', marker=dict(size=size, color=color), showlegend=showlegend, name='pc '+str(idx)), 39 | row=row, col=col, 40 | ) 41 | 42 | def add_gripper(self, T, row, col, color='violet', width=5, idx='', showlegend=False): 43 | gripper_scatter3d = get_gripper_scatter3d(T, color, width, idx, showlegend) 44 | 45 | self.fig.add_trace(gripper_scatter3d, row=row, col=col) 46 | 47 | def add_grippers(self, Ts, row, col, color='violet', width=5, idx='', showlegend=False): 48 | for T_gripper in Ts: 49 | gripper_scatter3d = get_gripper_scatter3d(T_gripper, color, width, idx, showlegend) 50 | 51 | self.fig.add_trace(gripper_scatter3d, row=row, col=col) 52 | 53 | def add_frame(self, T, row, col, size=1, width=5, sizeref=0.2): 54 | self.add_vector(T[0, 3], T[1, 3], T[2, 3], size * T[0, 0], size * T[1, 0], size * T[2, 0], row, col, color='red', width=width, sizeref=sizeref) 55 | self.add_vector(T[0, 3], T[1, 3], T[2, 3], size * T[0, 1], size * T[1, 1], size * T[2, 1], row, col, color='green', width=width, sizeref=sizeref) 56 | self.add_vector(T[0, 3], T[1, 3], T[2, 3], size * T[0, 2], size * T[1, 2], size * T[2, 2], row, col, color='blue', width=width, sizeref=sizeref) 57 | 58 | 59 | def get_gripper_scatter3d(T, color, width=5, idx='', showlegend=False): 60 | unit1 = 0.066 #* 8 # 0.56 61 | unit2 = 0.041 #* 8 # 0.32 62 | unit3 = 0.046 #* 8 # 0.4 63 | 64 | pbase = torch.Tensor([0, 0, 0, 1]).reshape(1, -1) 65 | pcenter = torch.Tensor([0, 0, unit1, 1]).reshape(1, -1) 66 | pleft = torch.Tensor([unit2, 0, unit1, 1]).reshape(1, -1) 67 | pright = torch.Tensor([-unit2, 0, unit1, 1]).reshape(1, -1) 68 | plefttip = torch.Tensor([unit2, 0, unit1+unit3, 1]).reshape(1, -1) 69 | prighttip = torch.Tensor([-unit2, 0, unit1+unit3, 1]).reshape(1, -1) 70 | 71 | hand = torch.cat([pbase, pcenter, pleft, pright, plefttip, prighttip], dim=0).to(T) 72 | hand = torch.einsum('ij, kj -> ik', T, hand).cpu() 73 | 74 | phandx = [hand[0, 4], hand[0, 2], hand[0, 1], hand[0, 0], hand[0, 1], hand[0, 3], hand[0, 5]] 75 | phandy = [hand[1, 4], hand[1, 2], hand[1, 1], hand[1, 0], hand[1, 1], hand[1, 3], hand[1, 5]] 76 | phandz = [hand[2, 4], hand[2, 2], hand[2, 1], hand[2, 0], hand[2, 1], hand[2, 3], hand[2, 5]] 77 | 78 | gripper_scatter3d = go.Scatter3d(x=phandx, y=phandy, z=phandz, mode='lines', line=dict(color=color, width=width), showlegend=showlegend, name='gripper '+str(idx)) 79 | 80 | return gripper_scatter3d 81 | --------------------------------------------------------------------------------